AI + Milvus:将时尚应用搭建进行到底

在上一篇文章中,我们学习了如何利用人工智能技术(例如开源 AI 向量数据库 Milvus 和 Hugging Face 模型)寻找与自己穿搭风格相似的明星。在这篇文章中,我们将进一步介绍如何通过对上篇文章中的项目代码稍作修改,获得更详细和准确的结果,文末附赠彩蛋。

注:试用此项目应用,需要点击下载并使用 notebook

01.回顾前文

在深入探讨前,先简要回顾一下前一篇教程文章。

导入所需的图像处理库和工具

首先导入所有必要的图像处理库,包括用于特征提取的 torchtransformers 中的 segformer 对象、matplotlibtorchvision 中的 Resizemasks_to_boxescrop 等。

import torch
from torch import nn, tensor
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from torchvision.transforms import Resize
import torchvision.transforms as T
from torchvision.ops import masks_to_boxes
from torchvision.transforms.functional import crop

预处理明星照片

在导入所有必要的图像处理库和工具后,就可以开始处理图像。以下三个函数 get_segmentationget_maskscrop_images 用于分割并裁剪图片中的时尚单品,以供后续使用。

import torch
def get_segmentation(extractor, model, image):inputs = extractor(images=image, return_tensors="pt")outputs = model(**inputs)logits = outputs.logits.cpu()upsampled_logits = nn.functional.interpolate(logits,size=image.size[::-1],mode="bilinear",align_corners=False,)pred_seg = upsampled_logits.argmax(dim=1)[0]return pred_seg# 返回两个 masks(tensor)列表和 obj_ids(int)
# 来自 Hugging Face 的 mattmdjaga/segformer_b2_clothes 模型
def get_masks(segmentation):obj_ids = torch.unique(segmentation)obj_ids = obj_ids[1:]masks = segmentation == obj_ids[:, None, None]return masks, obj_idsdef crop_images(masks, obj_ids, img):boxes = masks_to_boxes(masks)crop_boxes = []for box in boxes:crop_box = tensor([box[0], box[1], box[2]-box[0], box[3]-box[1]])crop_boxes.append(crop_box)preprocess = T.Compose([T.Resize(size=(256, 256)),T.ToTensor()])cropped_images = {}for i in range(len(crop_boxes)):crop_box = crop_boxes[i]cropped = crop(img, crop_box[1].item(), crop_box[0].item(), crop_box[3].item(), crop_box[2].item())cropped_images[obj_ids[i].item()] = preprocess(cropped)return cropped_images

将图像数据存储到向量数据库中

选择开源向量数据库 Milvus 来存储图像数据。开始前,需要先解压包含照片的 zip 文件,并在 notebook 相同的根目录中创建照片文件夹。完成后,可以运行以下代码来将图像数据存储在 Milvus 中。

import os
image_paths = []
for celeb in os.listdir("./photos"):for image in os.listdir(f"./photos/{celeb}/"):image_paths.append(f"./photos/{celeb}/{image}")from milvus import default_server
from pymilvus import utility, connections
default_server.start()
connections.connect(host="127.0.0.1", port=default_server.listen_port)
DIMENSION = 2048
BATCH_SIZE = 128
COLLECTION_NAME = "fashion"
TOP_K = 3
from pymilvus import FieldSchema, CollectionSchema, Collection, DataTypefields = [FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),FieldSchema(name="name", dtype=DataType.VARCHAR, max_length=200),FieldSchema(name="seg_id", dtype=DataType.INT64),FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)
index_params = {"index_type": "IVF_FLAT","metric_type": "L2","params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index_params)
collection.load()

接着,运行以下代码,使用来自 Hugging Face 的 Nvidia ResNet 50 模型生成 embedding 向量。

# 如遇 SSL 证书 URL 错误,请在导入 resnet50 模型前运行此步骤
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 并删除最后一层模型输出
embeddings_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
embeddings_model = torch.nn.Sequential(*(list(embeddings_model.children())[:-1]))
embeddings_model.eval()

以下函数定义了如何将图像转换为向量并插入到 Milvus 向量数据库中。代码会循环遍历所有图像。(注意:如果需要开启 Milvus 全新特性动态 Schema,需要修改代码。)

def embed_insert(data, collection, model):with torch.no_grad():output = model(torch.stack(data[0])).squeeze()collection.insert([data[1], data[2], data[3], output.tolist()])
from PIL import Image
data_batch = [[], [], [], []]for path in image_paths:image = Image.open(path)path_split = path.split("/")name = " ".join(path_split[2].split("_"))segmentation = get_segmentation(extractor, model, image)masks, ids = get_masks(segmentation)cropped_images = crop_images(masks, ids, image)for key, image in cropped_images.items():data_batch[0].append(image)data_batch[1].append(path)data_batch[2].append(name)data_batch[3].append(key)if len(data_batch[0]) % BATCH_SIZE == 0:embed_insert(data_batch, collection, embeddings_model)data_batch = [[], [], [], []]if len(data_batch[0]) != 0:embed_insert(data_batch, collection, embeddings_model)collection.flush()

查询向量数据库

以下代码演示了如何使用输入图像查询 Milvus 向量数据库,以检索和上传衣服图像最相似的的前三个结果。

def embed_search_images(data, model):with torch.no_grad():output = model(torch.stack(data))if len(output) > 1:return output.squeeze().tolist()else:return torch.flatten(output, start_dim=1).tolist()
# data_batch[0]是 tensor 列表
# data_batch[1]是图像文件的文件路径(字符串)
# data_batch[2]是图像中人物的名称列表(字符串)
# data_batch[3]是分割键值列表(int)
data_batch = [[], [], [], []]search_paths = ["./photos/Taylor_Swift/Taylor_Swift_3.jpg", "./photos/Taylor_Swift/Taylor_Swift_8.jpg"]for path in search_paths:image = Image.open(path)path_split = path.split("/")name = " ".join(path_split[2].split("_"))segmentation = get_segmentation(extractor, model, image)masks, ids = get_masks(segmentation)cropped_images = crop_images(masks, ids, image)for key, image in cropped_images.items():data_batch[0].append(image)data_batch[1].append(path)data_batch[2].append(name)data_batch[3].append(key)embeds = embed_search_images(data_batch[0], embeddings_model)
import time
start = time.time()
res = collection.search(embeds,anns_field='embedding',param={"metric_type": "L2","params": {"nprobe": 10}},limit=TOP_K,output_fields=['filepath'])
finish = time.time()
print(finish - start)
for index, result in enumerate(res):print(index)print(result)

02.匹配更多风格:标示每张图像中的时尚单品

除了直接使用上述代码,查找与你着装风格最相似的 3 位明星以外,我们还可以稍微修改一下代码,拓展项目的应用场景。可以修改代码获取如下所示,不包含边界框的图像。

alt

接下来,将为大家介绍如何修改上述代码寻找更多匹配的穿衣风格。

导入所需的图像处理库和工具

同样,需要先导入所有必要的图像处理库。如果已经完成导入,请跳过此步骤。

import torch
from torch import nn, tensor
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from torchvision.transforms import Resize
import torchvision.transforms as T
from torchvision.ops import masks_to_boxes
from torchvision.transforms.functional import crop

预处理图像

这个步骤涉及三个函数:get_segmentationget_maskscrop_images

无需修改 get_segmentation 函数部分的代码。

对于 get_masks 函数,只需要获取与 wanted 列表中的分割 ID 相对应的分割图像即可。

crop_image 函数做出更改。在前一篇文的教程中,此函数返回裁剪图像的列表。这里,我们进行一些调整,使函返回三个对象:裁剪图像对应的 embedding 向量、边界框在原始图像上的坐标列表,以及分割 ID 列表。这一改动将转化 embedding 向量的步骤提前了。

wanted = [1, 3, 4, 5, 6, 7, 8, 9, 10, 16, 17]
def get_segmentation(image):inputs = extractor(images=image, return_tensors="pt")outputs = segmentation_model(**inputs)logits = outputs.logits.cpu()upsampled_logits = nn.functional.interpolate(logits,size=image.size[::-1],mode="bilinear",align_corners=False,)pred_seg = upsampled_logits.argmax(dim=1)[0]return pred_seg# returns two lists masks (tensor) and obj_ids (int)
# "mattmdjaga/segformer_b2_clothes" from hugging face
def get_masks(segmentation):obj_ids = torch.unique(segmentation)obj_ids = obj_ids[1:]wanted_ids = [x.item() for x in obj_ids if x in wanted]wanted_ids = torch.Tensor(wanted_ids)masks = segmentation == wanted_ids[:, None, None]return masks, obj_idsdef crop_images(masks, obj_ids, img):boxes = masks_to_boxes(masks)crop_boxes = []for box in boxes:crop_box = tensor([box[0], box[1], box[2]-box[0], box[3]-box[1]])crop_boxes.append(crop_box)preprocess = T.Compose([T.Resize(size=(256, 256)),T.ToTensor()])cropped_images = []seg_ids = []for i in range(len(crop_boxes)):crop_box = crop_boxes[i]cropped = crop(img, crop_box[1].item(), crop_box[0].item(), crop_box[3].item(), crop_box[2].item())cropped_images.append(preprocess(cropped))seg_ids.append(obj_ids[i].item())with torch.no_grad():embeddings = embeddings_model(torch.stack(cropped_images)).squeeze().tolist()return embeddings, boxes.tolist(), seg_ids

有了图像数据之后,就可以加载数据了。这一步骤需要使用到批量插入功能,上篇文章的教程中也有涉及,但不同点在于,本文的教程中将数据作为 dictionary 列表一次性插入。这种插入方式更简洁,同时还允许我们在插入数据时动态新增 Schema 字段。

for path in image_paths:image = Image.open(path)path_split = path.split("/")name = " ".join(path_split[2].split("_"))segmentation = get_segmentation(image)masks, ids = get_masks(segmentation)embeddings, crop_corners, seg_ids = crop_images(masks, ids, image)inserts = [{"embedding": embeddings[x], "seg_id": seg_ids[x], "name": name, "filepath": path, "crop_corner": crop_corners[x]} for x in range(len(embeddings))]collection.insert(inserts)collection.flush()

查询向量数据库

现在可以开始在向量数据库 Milvus 中查询数据了。本文与上篇文章的教程有以下几点区别:

  • 将一张图像中匹配的时尚单品数量限制到 5 件。

  • 指定查询返回最相似的 3 张图像。

  • 添加函数获取图片的色彩图。

随后,在 matplotlib 中设置 figures 和 axes ,代码会循环遍历所有图像,将上文的 3 个函数应用到所有图像上,以获取分割结果和边界框。

查询数据时,可以根据每张图像中匹配的时尚单品数量来获得最相似的 3 张图像。

最终返回的结果图像中会带有标示出匹配单品的边界框。

from pprint import pprint
from PIL import ImageDraw
from collections import Counter
import matplotlib.patches as patchesLIMIT = 5 # 每张图像中匹配的时尚单品件数
CLOSEST = 3 # 返回的最相似图像数量。CLOSEST <= Limitsearch_paths = ["./photos/Taylor_Swift/Taylor_Swift_2.jpg", "./photos/Jenna_Ortega/Jenna_Ortega_6.jpg"] # Images to search fordef get_cmap(n, name='hsv'):'''Returns a function that maps each index in 0, 1, ..., n-1 to a distinctRGB color; the keyword argument name must be a standard mpl colormap name.Sourced from <https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib>'''return plt.cm.get_cmap(name, n)# 创建结果 subplot
f, axarr = plt.subplots(max(len(search_paths), 2), CLOSEST + 1)for search_i, path in enumerate(search_paths):# Generate crops and embeddings for all items foundimage = Image.open(path)segmentation = get_segmentation(image)masks, ids = get_masks(segmentation)embeddings, crop_corners, _ = crop_images(masks, ids, image)# 生成色彩图cmap = get_cmap(len(crop_corners))# Display the first box with image being searched foraxarr[search_i][0].imshow(image)axarr[search_i][0].set_title('Search Image')axarr[search_i][0].axis('off')for i, (x0, y0, x1, y1) in enumerate(crop_corners):rect = patches.Rectangle((x0, y0), x1-x0, y1-y0, linewidth=1, edgecolor=cmap(i), facecolor='none')axarr[search_i][0].add_patch(rect)# 查询向量数据库start = time.time()res = collection.search(embeddings,anns_field='embedding',param={"metric_type": "L2","params": {"nprobe": 10}, "offset": 0},limit=LIMIT,output_fields=['filepath', 'crop_corner'])finish = time.time()print("Total Search Time: ", finish - start)# 根据位置给查询结果增加不同的权重filepaths = []for hits in res:seen = set()for i, hit in enumerate(hits):if hit.entity.get("filepath") not in seen:seen.add(hit.entity.get("filepath"))filepaths.extend([hit.entity.get("filepath") for _ in range(len(hits) - i)])# 查找排名最高的图像counts = Counter(filepaths)most_common = [path for path, _ in counts.most_common(CLOSEST)]# 提取每张图像中与查询图像相关的时尚单品matches = {}for i, hits in enumerate(res):matches[i] = {}tracker = set(most_common)for hit in hits:if hit.entity.get("filepath") in tracker:matches[i][hit.entity.get("filepath")] = hit.entity.get("crop_corner")tracker.remove( hit.entity.get("filepath"))# 返回最相似图像:# 返回与查询图像临近的图像image = Image.open(res_path)axarr[search_i][res_i+1].imshow(image)axarr[search_i][res_i+1].set_title(" ".join(res_path.split("/")[2].split("_")))axarr[search_i][res_i+1].axis('off')
# 为匹配单品添加边界框if res_path in value:x0, y0, x1, y1 = value[res_path]rect = patches.Rectangle((x0, y0), x1-x0, y1-y0, linewidth=1, edgecolor=cmap(key), facecolor='none')axarr[search_i][res_i+1].add_patch(rect)

运行上述步骤后,结果如下所示:

alt

03.项目后续:探索更多应用场景

欢迎大家基于本项目拓展更多、更丰富的应用场景,例如:

  • 进一步延伸对比功能,例如将不同的单品归类到一起。同样,也可以上传更多图像到数据库中,丰富查询结果。

  • 将本项目转变为时尚探测仪或者时尚推荐系统。例如,将明星图像替换成可购买的衣服图像。这样一来,用户上传照片后,可以查询与他的衣服风格相似的其他衣服。

  • 还可以基于本项目搭建一个穿搭生成系统,很多方法都可以实现这个应用,但这个应用的搭建相对而言更有难度!本文提供了一种思路,系统可以根据用户上传的多张照片相应推荐穿搭。这里需要用到生成式图像模型,从而提供穿搭建议。

总之,不要限制你的想象力,搭建更丰富的应用。Milvus 之类的向量数据库为相似性搜索应用提供了无限可能。

04.总结

本文教程中,我们进一步拓展了时尚 AI 项目的应用场景。

本次教程使用了 Milvus 全新的 动态 Schema 功能,筛选了分割 ID,在返回图像中保留了边界框。同时,我们在查询中指定 Milvus 根据每张图像中匹配的时尚单品件数返回最相似的 3 张图像。Milvus 全新的动态 Schema 功能支持在上传数据时添加新的字段,改变了我们批量上传数据的方式。使用这个功能后,在上传数据时,无需改动 Schema 即可添加裁剪。在图像预处理步骤中,剔除了一些识别到的非着装类元素。同时,本教程保留了边界框,将转化向量的步骤提前至了裁剪图片的步骤。

当然,通过进一步调整代码,我们还可以搭建更多相关应用,例如:时尚推荐系统、帮助用户搭配着装的系统,甚至是生成式的时尚 AI 应用!

🌟「寻找 AIGC 时代的 CVP 实践之星」 专题活动即将启动!

Zilliz 将联合国内头部大模型厂商一同甄选应用场景, 由双方提供向量数据库与大模型顶级技术专家为用户赋能,一同打磨应用,提升落地效果,赋能业务本身。

如果你的应用也适合 CVP 框架,且正为应用落地和实际效果发愁,可直接申请参与活动,获得最专业的帮助和指导!联系邮箱为 business@zilliz.com。

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/55753.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[管理与领导-53]:IT基层管理者 - 8项核心技能 - 8 - 持续改进

前言&#xff1a; 管理者存在的价值就是制定目标&#xff0c;即目标管理、通过团队&#xff08;他人&#xff09;拿到结果。 要想通过他人拿到结果&#xff1a; &#xff08;1&#xff09;目标&#xff1a;制定符合SMART原则的符合业务需求的目标&#xff0c;团队跳一跳就可以…

Microsoft正在将Python引入Excel

Excel和Python这两个世界正在碰撞&#xff0c;这要归功于Microsoft的新集成&#xff0c;以促进数据分析和可视化 Microsoft正在将流行的编程语言Python引入Excel。该功能的公共预览版现已推出&#xff0c;允许Excel用户操作和分析来自Python的数据。 “您可以使用 Python 绘图…

Git向远程仓库与推送以及拉取远程仓库

理解分布式版本控制系统 1.中央服务器 我们⽬前所说的所有内容&#xff08;⼯作区&#xff0c;暂存区&#xff0c;版本库等等&#xff09;&#xff0c;都是在本地也就是在你的笔记本或者计算机上。⽽我们的 Git 其实是分布式版本控制系统&#xff01;什么意思呢? 那我们多人…

VUE3基础

一、vue-router v4.x 介绍 | Vue Router 1、安装 yarn add vue-routernext next代表最新的版本 2、路由配置 在src目录下&#xff0c;新建router/index.ts&#xff0c;具体配置如下 import {RouteRecordRaw,createRouter,createWebHashHistory} from vue-router const r…

NO.04 MyBatis的各种查询功能

目录 1、查询一个实体类对象 2、查询一个List集合 3、查询单个数据 5、查询多条数据并存储在Map集合中 5.1 方法一&#xff1a;将数据存储在map集合中&#xff0c;再将map集合存储在List集合中 5.2 方法二&#xff1a;将数据存储在map集合中 6、MyBatis中为Java中常用的…

JavaWeb 速通JQuery

目录 一、JQuery快速入门 1.基本介绍 : 2.入门案例 : 二、JQuery对象 1.基本介绍 : 2.DOM对象 --> JQuery对象 : 3.JQuery对象 --> DOM对象 : 三、JQuery选择器 1.简介 : 2.基本选择器 : 3.层次选择器 : 4.过滤选择器 : 4.1 基础过滤选择器 4.2 内容过滤选择…

精益求精:如何在 React 项目中巧妙运用 Redux 并优化项目结构

在前端开发中&#xff0c;React 和 Redux 已经成为构建可扩展、高效的应用程序的首选工具。本文将向您介绍如何在 React 项目中使用 Redux 并精心划分项目结构&#xff0c;以确保代码的可维护性和可扩展性。 第一步&#xff1a;安装 Redux 首先&#xff0c;确保您的项目已经集…

Vue中ElementUI结合transform使用时,发现弹框定位不准确问题

在近期开发中&#xff0c;需要将1920*1080放到更大像素大屏上演示&#xff0c;所以需要使用到transform来对页面进行缩放&#xff0c;但是此时发现弹框定位出错问题&#xff0c;无法准备定位到实际位置。 查看element-ui官方文档无果后&#xff0c;打算更换新的框架进行开发&am…

Java升级JDK17(更高版本同理),修改maven

记住三个网址就行&#xff1a;下面这个是oracle的 Java Platform, Standard Edition 17 ReferenceImplementations https://www.oracle.com/java/technologies/downloads/#jdk17-windows 另外一个 redhat旗下的&#xff1a;这个是开源的&#xff08;推荐这个&#xff01;&am…

国产系统下开发QT程序总结

国产系统下开发QT程序总结 1. 国产系统简介 开发国产系统客户端的过程中&#xff0c;会出现兼容性问题。以下介绍Kylin和UOS环境下开发QT程序&#xff0c; 首先麒麟和统信这两个系统基于Ubuntu开发的。所以在Ubuntu开发理论上在国产系统上也能运行。芯片架构又分为amd,arm,mi…

pythonapi接口怎么对接?

Python API接口对接是使用Python语言开发应用程序时&#xff0c;与外部API接口进行交互的一种方式。API&#xff08;应用程序接口&#xff09;是一种定义了程序或系统如何与另一个程序或系统进行交互的协议。通过使用Python API接口&#xff0c;可以轻松地访问和操作外部API提供…

飞凌嵌入式受邀参加「RISC-V芯片应用交流会」并发表主题演讲

8月23日下午&#xff0c;在第三届RISC-V中国峰会现场&#xff0c;由赛昉科技主办的「RISC-V芯片应用交流会」吸引了诸多行业伙伴和专家到场参与。此次会议旨在分享赛昉科技高性能RISC-V芯片的软件生态、应用产品、解决方案等全面进展&#xff0c;共同探讨RISC-V芯片的未来发展和…

Git 删除已经合并的本地分支

在使用 Git 的开发流程中&#xff0c;经常会创建很多的 Git 分支&#xff0c;包括功能分支&#xff08;features/*&#xff09;、发布分支&#xff08;release/*&#xff09;和 hotfix 分支&#xff08;hotfix/*&#xff09;。在开发了一段时间之后&#xff0c;本地就会有出现很…

javaee spring 自动注入,如果满足条件的类有多个如何区别

如图IDrinkDao有两个实现类 方法一 方法二 Resource(name“对象名”) Resource(name"oracleDrinkDao") private IDrinkDao drinkDao;

Java --- 异常处理

目录 一、什么是异常 二、异常抛出机制 三、如何对待异常 四、 Java异常体系 4.1、Throwable 4.2、Error 4.2、Exception 4.2.1、编译时异常 4.2.2、运行时期异常 五、异常处理 5.1、捕获异常&#xff08;try-catch&#xff09; 5.1.2、catch中异常处理方式 …

你对SPA单页面的理解,它的优缺点分别是什么?如何实现SPA应用呢?

一、什么是SPA SPA&#xff08;single-page application&#xff09;&#xff0c;翻译过来就是单页应用SPA是一种网络应用程序或网站的模型&#xff0c;它通过动态重写当前页面来与用户交互&#xff0c;这种方法避免了页面之间切换打断用户体验在单页应用中&#xff0c;所有必…

TCP的三次握手 四次挥手以及TCP的11种状态

三次握手流程&#xff1a; 客户端给服务端发送数据时&#xff0c;数据包中带有一个头&#xff0c;这个头就是前几十个字节&#xff0c;就是下面这张图。从源端口号&#xff0c;目的端口号&#xff0c;一直到序列号&#xff0c;直到Options。第一个包会将这前十几个字节中的SYN置…

12. 完整模型训练套路

12.1 CIFAR 10 model 网络模型 ① 下面用 CIFAR 10 model网络来完成分类问题&#xff0c;网络模型如下图所示。 12.2 DataLoader加载数据集 import torchvision from torch import nn from torch.utils.data import DataLoader# 准备数据集 train_data torchvision.dataset…

C#,《小白学程序》第二课:数组与排序

1 文本格式 /// <summary> /// 《小白学程序》第二课&#xff1a;数组与排序 /// </summary> /// <param name"sender"></param> /// <param name"e"></param> private void button2_Click(object sender, EventArgs …

RT-Thread内核学习

内核框架 内核是操作系统最基础也是最重要的部分&#xff0c;内核处于硬件层之上&#xff0c;内核部分包括内核库、实时内核实现。 内核库是为了保证内核能够独立运行的一套小型的类似C库的函数实现子集。这部分根据编译器不同自带C库的情况也会不同。 当使用GNU GCC编译器时&…