目标:通过Faiss和DINOv2进行场景识别,确保输入的照片和注册的图片,保持内容一致。
MetaAI 通过开源 DINOv2,在计算机视觉领域取得了一个显着的里程碑,这是一个在包含1.42 亿张图像的令人印象深刻的数据集上训练的模型。产生适用于图像级视觉任务(图像分类、实例检索、视频理解)以及像素级视觉任务(深度估计、语义分割)的通用特征。
Faiss是一个用于高效相似性搜索和密集向量聚类的库。它包含的算法可以搜索任意大小的向量集,甚至可能无法容纳在 RAM 中的向量集。
#!usr/bin/env python
# -*- coding:utf-8 -*-# pip install transformers faiss-gpu torch Pillow
import torch
import os
import concurrent.futures
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import faiss
from tqdm import tqdm
import numpy as np
from utils_img import *os.environ["CUDA_VISIBLE_DEVICES"] = "1"class SceneRecognition:def __init__(self, dimension=384, threshold=0.8, batch_size=128):"""初始化 SceneRecognition 类Parameters:dimension (int): 向量的维度,默认为 384"""# 加载模型和处理器self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.processor = AutoImageProcessor.from_pretrained('models/dinov2-small')self.model = AutoModel.from_pretrained('models/dinov2-small').to(self.device)# 保存特征向量的维度self.dimension = dimension# 搭配 faiss.normalize_L2 创建 Faiss 的余弦相似度索引self.index = faiss.IndexFlatIP(self.dimension)# 保存图片向量对应的 idself.ids = []# 保存特征向量数据库self.db_path = "vector.index"# 图片路径self.images_path = []# 相似度阈值self.threshold = thresholdself.batch_size = batch_size# 初始化self.init()def read_image_open(self, image_path):"""打开图像并返回图像和图像路径Parameters:image_path (str): 图像的路径Returns:Image: 打开的图像str: 图像的路径"""# 默认以 RGB 模式打开image = Image.open(image_path)return image, image_pathdef read_images_from_folder(self, file_path):"""从文件夹中读取图像Parameters:file_path (str or list): 文件夹的路径或文件路径列表Returns:list: 图像列表list: 图像路径列表"""image_list = []image_path_list = []try:if type(file_path) is not list:task_list = get_files(file_path)else:task_list = file_path# 使用线程池执行器with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:res = executor.map(self.read_image_open, task_list)image_list, image_path_list = list(zip(*res))return image_list, image_path_listexcept Exception as e:return None, Nonedef download_model(self):"""从 huggingface_hub 下载模型"""from huggingface_hub import snapshot_downloadsnapshot_download(repo_id="facebook/dinov2-small", # 模型 IDlocal_dir="./models/dinov2-small") # 指定本地地址保存模型def init(self):"""初始化模型"""batch_images = np.zeros((self.batch_size, 1920, 1080, 3),dtype=np.int8)batch_inputs = self.processor(images=batch_images,return_tensors="pt").to(self.device)batch_outputs = self.model(**batch_inputs)def add_vector_to_index(self, embedding, index):"""将向量添加到索引中Parameters:embedding (torch.Tensor): 特征向量index (faiss.Index): Faiss 索引"""vector = embedding.detach().cpu().numpy()vector = np.float32(vector)faiss.normalize_L2(vector)index.add(vector)def create_database_from_images(self, file_path):"""从图像创建数据库Parameters:file_path (str or list): 图像文件夹的路径或图像路径列表"""image_list, self.images_path = self.read_images_from_folder(file_path)self.extract_features_form_images(image_list)def create_database_from_batch_images(self, file_path):"""从批量图像创建数据库Parameters:file_path (str or list): 图像文件夹的路径或图像路径列表"""image_list, self.images_path = self.read_images_from_folder(file_path)self.extract_features_form_batch_images(image_list)def extract_features_form_images(self, images):"""从图像列表提取特征Parameters:images (list): 图像列表"""for image_id, img in enumerate(images):with torch.no_grad():inputs = self.processor(images=img,return_tensors="pt").to(self.device)outputs = self.model(**inputs)features = outputs.last_hidden_state.mean(dim=1)self.add_vector_to_index(features, self.index)# 记录特征向量的 idself.ids.append(image_id)# faiss.write_index(self.index, self.db_path)def extract_features_form_batch_images(self, image_list):"""从批量图像中提取特征Parameters:image_list (list): 图像列表"""img_num = len(image_list)for beg_img_no in tqdm(range(0, img_num, self.batch_size),desc="Extracting features"):end_img_no = min(img_num, beg_img_no + self.batch_size)batch_image = image_list[beg_img_no:end_img_no]with torch.no_grad():batch_inputs = self.processor(images=batch_image,return_tensors="pt").to(self.device)batch_outputs = self.model(**batch_inputs)batch_features = batch_outputs.last_hidden_state.mean(dim=1)for image_id, features in enumerate(batch_features):self.add_vector_to_index(features.reshape(1, -1),self.index)# 记录特征向量的 idimage_id += beg_img_noself.ids.append(image_id)# faiss.write_index(self.index, self.db_path)def search_similar_batch_images(self, user_image_path, k=1):"""批量搜索相似图像Parameters:user_image_path (str or list): 用户图像的路径或路径列表k (int): 搜索的近邻数量Returns:list: 相似图像的路径列表"""image_list, image_path_list = self.read_images_from_folder(user_image_path)img_num = len(image_list)similar_base_images = []similar_query_images = []for beg_img_no in tqdm(range(0, img_num, self.batch_size),desc="Searching similar images"):end_img_no = min(img_num, beg_img_no + self.batch_size)batch_image = image_list[beg_img_no:end_img_no]batch_image_path = image_path_list[beg_img_no:end_img_no]with torch.no_grad():batch_inputs = self.processor(images=batch_image,return_tensors="pt").to(self.device)batch_outputs = self.model(**batch_inputs)query_features = batch_outputs.last_hidden_state.mean(dim=1)batch_query_vector = query_features.detach().cpu().numpy()batch_query_vector = np.float32(batch_query_vector)faiss.normalize_L2(batch_query_vector)batch_distances, batch_indices = self.index.search(batch_query_vector, k)if len(batch_distances) > 0:for image_path, dis, ind in zip(batch_image_path,batch_distances,batch_indices):# 保存超过阈值的最相似特征向量if dis[0] > self.threshold:similar_base_images.append(self.images_path[ind[0]])similar_query_images.append(image_path)dissimilar_images = list(set(image_path_list).difference(set(similar_query_images)))return similar_query_imagesdef search_similar_images(self, query_image_path, k=1):"""搜索相似图像Parameters:query_image_path (str): 查询图像的路径k (int): 搜索的近邻数量Returns:list: 相似图像的路径列表"""img = Image.open(query_image_path).convert('RGB')with torch.no_grad():inputs = self.processor(images=img,return_tensors="pt").to(self.device)outputs = self.model(**inputs)query_features = outputs.last_hidden_state.mean(dim=1)query_vector = query_features.detach().cpu().numpy()query_vector = np.float32(query_vector)faiss.normalize_L2(query_vector)distances, indices = self.index.search(query_vector, k)similar_base_images = []similar_query_images = []if len(distances) > 0:for dis, ind in zip(distances, indices):# 保存超过阈值的最相似特征向量if dis > self.threshold:similar_base_images.append(self.images_path[ind[0]])similar_query_images.append(query_image_path)return similar_query_imagesdef remove_image_by_id(self, image_id):"""根据 ID 删除图像Parameters:image_id (int): 图像的 ID"""# 先删除高索引的元素,再删除低索引的元素,避免索引错位的问题。index_to_remove = [i for i, stored_id in enumerate(self.ids) if stored_id == image_id] # 找到需要删除的特征向量的索引for i in sorted(index_to_remove, reverse=True):# 从 Faiss 索引中删除对应的特征向量self.index.remove_ids(np.array([i]))# 从 id 列表中删除对应的 iddel self.ids[i]def get_num_images(self):"""获取保存的图像数量Returns:int: 保存的图像数量"""# 返回保存的图片向量数量return len(self.ids)def clear_data_base(self):"""清空数据库"""# 重置 Faiss 索引self.index.reset()# 清空 id 列表self.ids.clear()if __name__ == '__main__':# 创建一个 SceneRecognition 实例scene_rec = SceneRecognition()scene_rec.create_database_from_batch_images('imgs')scene_rec.search_similar_batch_images('images')
参考网址:
- https://blog.csdn.net/level_code/article/details/137772620
- https://blog.csdn.net/weixin_38739735/article/details/136979083
- https://blog.csdn.net/u010970956/article/details/134945210
- https://blog.csdn.net/hh1357102/article/details/135066581
- https://blog.csdn.net/sinat_34770838/article/details/137021023
- https://zhuanlan.zhihu.com/p/704250322
- https://zhuanlan.zhihu.com/p/668148439
- https://blog.51cto.com/u_14273/10165547
- https://blog.csdn.net/ResumeProject/article/details/135350945
- https://www.zhihu.com/question/637818872/answer/3380169469
- https://zhuanlan.zhihu.com/p/644077057