通过Faiss和DINOv2进行场景识别

目标:通过Faiss和DINOv2进行场景识别,确保输入的照片和注册的图片,保持内容一致。

MetaAI 通过开源 DINOv2,在计算机视觉领域取得了一个显着的里程碑,这是一个在包含1.42 亿张图像的令人印象深刻的数据集上训练的模型。产生适用于图像级视觉任务(图像分类、实例检索、视频理解)以及像素级视觉任务(深度估计、语义分割)的通用特征。

Faiss是一个用于高效相似性搜索和密集向量聚类的库。它包含的算法可以搜索任意大小的向量集,甚至可能无法容纳在 RAM 中的向量集。

#!usr/bin/env python
# -*- coding:utf-8 -*-# pip install transformers faiss-gpu torch Pillow
import torch
import os
import concurrent.futures
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import faiss
from tqdm import tqdm
import numpy as np
from utils_img import *os.environ["CUDA_VISIBLE_DEVICES"] = "1"class SceneRecognition:def __init__(self, dimension=384, threshold=0.8, batch_size=128):"""初始化 SceneRecognition 类Parameters:dimension (int): 向量的维度,默认为 384"""# 加载模型和处理器self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.processor = AutoImageProcessor.from_pretrained('models/dinov2-small')self.model = AutoModel.from_pretrained('models/dinov2-small').to(self.device)# 保存特征向量的维度self.dimension = dimension# 搭配 faiss.normalize_L2 创建 Faiss 的余弦相似度索引self.index = faiss.IndexFlatIP(self.dimension)# 保存图片向量对应的 idself.ids = []# 保存特征向量数据库self.db_path = "vector.index"# 图片路径self.images_path = []# 相似度阈值self.threshold = thresholdself.batch_size = batch_size# 初始化self.init()def read_image_open(self, image_path):"""打开图像并返回图像和图像路径Parameters:image_path (str): 图像的路径Returns:Image: 打开的图像str: 图像的路径"""# 默认以 RGB 模式打开image = Image.open(image_path)return image, image_pathdef read_images_from_folder(self, file_path):"""从文件夹中读取图像Parameters:file_path (str or list): 文件夹的路径或文件路径列表Returns:list: 图像列表list: 图像路径列表"""image_list = []image_path_list = []try:if type(file_path) is not list:task_list = get_files(file_path)else:task_list = file_path# 使用线程池执行器with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:res = executor.map(self.read_image_open, task_list)image_list, image_path_list = list(zip(*res))return image_list, image_path_listexcept Exception as e:return None, Nonedef download_model(self):"""从 huggingface_hub 下载模型"""from huggingface_hub import snapshot_downloadsnapshot_download(repo_id="facebook/dinov2-small",  # 模型 IDlocal_dir="./models/dinov2-small")  # 指定本地地址保存模型def init(self):"""初始化模型"""batch_images = np.zeros((self.batch_size, 1920, 1080, 3),dtype=np.int8)batch_inputs = self.processor(images=batch_images,return_tensors="pt").to(self.device)batch_outputs = self.model(**batch_inputs)def add_vector_to_index(self, embedding, index):"""将向量添加到索引中Parameters:embedding (torch.Tensor): 特征向量index (faiss.Index): Faiss 索引"""vector = embedding.detach().cpu().numpy()vector = np.float32(vector)faiss.normalize_L2(vector)index.add(vector)def create_database_from_images(self, file_path):"""从图像创建数据库Parameters:file_path (str or list): 图像文件夹的路径或图像路径列表"""image_list, self.images_path = self.read_images_from_folder(file_path)self.extract_features_form_images(image_list)def create_database_from_batch_images(self, file_path):"""从批量图像创建数据库Parameters:file_path (str or list): 图像文件夹的路径或图像路径列表"""image_list, self.images_path = self.read_images_from_folder(file_path)self.extract_features_form_batch_images(image_list)def extract_features_form_images(self, images):"""从图像列表提取特征Parameters:images (list): 图像列表"""for image_id, img in enumerate(images):with torch.no_grad():inputs = self.processor(images=img,return_tensors="pt").to(self.device)outputs = self.model(**inputs)features = outputs.last_hidden_state.mean(dim=1)self.add_vector_to_index(features, self.index)# 记录特征向量的 idself.ids.append(image_id)# faiss.write_index(self.index, self.db_path)def extract_features_form_batch_images(self, image_list):"""从批量图像中提取特征Parameters:image_list (list): 图像列表"""img_num = len(image_list)for beg_img_no in tqdm(range(0, img_num, self.batch_size),desc="Extracting features"):end_img_no = min(img_num, beg_img_no + self.batch_size)batch_image = image_list[beg_img_no:end_img_no]with torch.no_grad():batch_inputs = self.processor(images=batch_image,return_tensors="pt").to(self.device)batch_outputs = self.model(**batch_inputs)batch_features = batch_outputs.last_hidden_state.mean(dim=1)for image_id, features in enumerate(batch_features):self.add_vector_to_index(features.reshape(1, -1),self.index)# 记录特征向量的 idimage_id += beg_img_noself.ids.append(image_id)# faiss.write_index(self.index, self.db_path)def search_similar_batch_images(self, user_image_path, k=1):"""批量搜索相似图像Parameters:user_image_path (str or list): 用户图像的路径或路径列表k (int): 搜索的近邻数量Returns:list: 相似图像的路径列表"""image_list, image_path_list = self.read_images_from_folder(user_image_path)img_num = len(image_list)similar_base_images = []similar_query_images = []for beg_img_no in tqdm(range(0, img_num, self.batch_size),desc="Searching similar images"):end_img_no = min(img_num, beg_img_no + self.batch_size)batch_image = image_list[beg_img_no:end_img_no]batch_image_path = image_path_list[beg_img_no:end_img_no]with torch.no_grad():batch_inputs = self.processor(images=batch_image,return_tensors="pt").to(self.device)batch_outputs = self.model(**batch_inputs)query_features = batch_outputs.last_hidden_state.mean(dim=1)batch_query_vector = query_features.detach().cpu().numpy()batch_query_vector = np.float32(batch_query_vector)faiss.normalize_L2(batch_query_vector)batch_distances, batch_indices = self.index.search(batch_query_vector, k)if len(batch_distances) > 0:for image_path, dis, ind in zip(batch_image_path,batch_distances,batch_indices):# 保存超过阈值的最相似特征向量if dis[0] > self.threshold:similar_base_images.append(self.images_path[ind[0]])similar_query_images.append(image_path)dissimilar_images = list(set(image_path_list).difference(set(similar_query_images)))return similar_query_imagesdef search_similar_images(self, query_image_path, k=1):"""搜索相似图像Parameters:query_image_path (str): 查询图像的路径k (int): 搜索的近邻数量Returns:list: 相似图像的路径列表"""img = Image.open(query_image_path).convert('RGB')with torch.no_grad():inputs = self.processor(images=img,return_tensors="pt").to(self.device)outputs = self.model(**inputs)query_features = outputs.last_hidden_state.mean(dim=1)query_vector = query_features.detach().cpu().numpy()query_vector = np.float32(query_vector)faiss.normalize_L2(query_vector)distances, indices = self.index.search(query_vector, k)similar_base_images = []similar_query_images = []if len(distances) > 0:for dis, ind in zip(distances, indices):# 保存超过阈值的最相似特征向量if dis > self.threshold:similar_base_images.append(self.images_path[ind[0]])similar_query_images.append(query_image_path)return similar_query_imagesdef remove_image_by_id(self, image_id):"""根据 ID 删除图像Parameters:image_id (int): 图像的 ID"""# 先删除高索引的元素,再删除低索引的元素,避免索引错位的问题。index_to_remove = [i for i, stored_id in enumerate(self.ids) if stored_id == image_id]  # 找到需要删除的特征向量的索引for i in sorted(index_to_remove, reverse=True):# 从 Faiss 索引中删除对应的特征向量self.index.remove_ids(np.array([i]))# 从 id 列表中删除对应的 iddel self.ids[i]def get_num_images(self):"""获取保存的图像数量Returns:int: 保存的图像数量"""# 返回保存的图片向量数量return len(self.ids)def clear_data_base(self):"""清空数据库"""# 重置 Faiss 索引self.index.reset()# 清空 id 列表self.ids.clear()if __name__ == '__main__':# 创建一个 SceneRecognition 实例scene_rec = SceneRecognition()scene_rec.create_database_from_batch_images('imgs')scene_rec.search_similar_batch_images('images')

参考网址:

  1. https://blog.csdn.net/level_code/article/details/137772620
  2. https://blog.csdn.net/weixin_38739735/article/details/136979083
  3. https://blog.csdn.net/u010970956/article/details/134945210
  4. https://blog.csdn.net/hh1357102/article/details/135066581
  5. https://blog.csdn.net/sinat_34770838/article/details/137021023
  6. https://zhuanlan.zhihu.com/p/704250322
  7. https://zhuanlan.zhihu.com/p/668148439
  8. https://blog.51cto.com/u_14273/10165547
  9. https://blog.csdn.net/ResumeProject/article/details/135350945
  10. https://www.zhihu.com/question/637818872/answer/3380169469
  11. https://zhuanlan.zhihu.com/p/644077057

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/50049.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于 HTML+ECharts 实现智慧运维数据可视化大屏(含源码)

智慧运维数据可视化大屏:基于 HTML 和 ECharts 的实现 在现代企业中,运维管理是确保系统稳定运行的关键环节。随着数据量的激增,如何高效地监控和分析运维数据成为了一个重要课题。本文将介绍如何利用 HTML 和 ECharts 实现一个智慧运维数据可…

深入理解 Java NIO:ByteBuffer和MappedByteBuffer的特性与使用

目录 前言 ByteBuffer是什么 重要特点 分配缓冲区 读写模式切换 操作文本数据 操作基本数据类型 案例解析-循环输出数据 MappedByteBuffer是什么 MappedByteBuffer 的工作机制 刷盘时机 总结 前言 在深入学习 RocketMQ 这款高性能消息队列框架的源码时&#xff0c…

醒醒,别睡了...讲《数据分析pandas库》了—/—<1>

一、了解pandas No.1 Pandas 是 Python 语言的一个扩展程序库,用于数据分析,是一个强大的分析结构化数据的工具集,基础是Numpy库,可以去参考前面所讲的课。(提供高性能的矩阵运算) No.2 应用 :P…

Kylin Cube的灵动更新:部分刷新策略全解析

Kylin Cube的灵动更新:部分刷新策略全解析 Apache Kylin是一个高性能的分布式分析引擎,它通过预计算和存储多维数据模型(Cube)来加速对大数据集的查询。在实际应用中,数据经常发生变化,这就引出了一个问题…

vue上传Excel文件并直接点击文件列表进行预览

本文主要内容:用elementui的Upload 组件上传Excel文件,上传后的列表采用xlsx插件实现点击预览表格内容效果。 在项目中可能会有这样的需求,有很多种方法实现。但是不想要跳转外部地址,所以用了xlsx插件来解析表格,并展…

【数据集处理】Polars库、Parquet 文件

一、Polars 库 Polars 库在数据处理和分析方面具有显著的优势,特别是在性能和效率上。 1. 高性能 Polars 设计的核心目标之一是性能优化,尤其是针对大数据集的处理: 多线程执行:Polars 利用 Rust 编写,并且默认使用…

Docker安装kkFileView实现在线文件预览

kkFileView为文件文档在线预览解决方案,该项目使用流行的spring boot搭建,易上手和部署,基本支持主流办公文档的在线预览,如doc,docx,xls,xlsx,ppt,pptx,pdf,txt,zip,rar,图片,视频,音频等等 官方文档地址:https://kkview.cn/zh-cn/docs/production.html 一、拉取镜像 do…

1 深度学习网络DNN

代码来自B站up爆肝杰哥 测试版本 import torch import torchvisiondef print_hi(name):print(fHi, {name}) if __name__ __main__:print_hi(陀思妥耶夫斯基)print("HELLO pytorch {}".format(torch.__version__))print("torchvision.version:", torchvi…

有多个第三方sdk 里的manifest里都配置了provider,如何优化

当多个第三方 SDK 的 AndroidManifest.xml 文件中都配置了 ContentProvider,并且导致应用启动变慢时,可以通过以下优化策略来改善启动性能: 1. 推迟 ContentProvider 的初始化 将一些 ContentProvider 的初始化推迟到应用实际需要使用时再进行,而不是在应用启动时进行。可…

用在ROS2系统中保持差速轮方向不变的PID程序

在ROS 2中,为了保持差速轮机器人的方向不变,通常需要使用PID(Proportional Integral Derivative)控制器来控制机器人的角速度。PID控制器可以帮助调整机器人的角速度,以维持其朝向不变。 下面是一个简单的ROS 2节点示…

使用el-table的案例小结——包含跨页多选、双击行、分页器、编辑\删除行、动态根据分页生成序号

首先看一下业务需求 需要实现跨页多选,双击行的时候弹出编辑对话框,对每行可以进行编辑和删除,实现分页器。 如果还没在项目中配置element-plus的可以参考文章 从零开始创建vue3项目——包含项目初始化、element-plus、eslint、axios、router…

vue import from

vue import from 导入文件,从XXXX路径;引入文件 import xxxx from “./minins/resize” import xxxx from “./minins/resize.js” vue.config.js 定义 : resolve(src);就是指src 目录 import xxxx from “/utils/auth” im…

014集——RSA非对称加密——vba源代码

今天介绍一种安全的加密方法,RSA非对称加密。 RSA算法基于一个十分简单的数论事实:将两个大质数相乘十分容易,但是想要对其乘积进行因式分解却极其困难,因此可以将乘积公开作为加密密钥。 部分源代码如下: qq4434402042024年3月…

【C++初阶】string类

【C初阶】string类 🥕个人主页:开敲🍉 🔥所属专栏:C🥭 🌼文章目录🌼 1. 为什么学习string类? 1.1 C语言中的字符串 1.2 实际中 2. 标准库中的string类 2.1 string类 2.…

Web响应式设计———1、Grid布局

1、网格布局 Grid布局 流动网格布局是响应式设计的基础。它通过使用百分比而不是固定像素来定义网格和元素的宽度。这样&#xff0c;页面上的元素可以根据屏幕宽度自动调整大小&#xff0c;适应不同设备和分辨率。 <!DOCTYPE html> <html lang"en"> &l…

并发线程学习(Java)

消费者生产者模型 package thread;import java.util.LinkedList; import java.util.Queue;public class ProducerConsumer {private static final int MAX_SIZE 5;private final Queue<Integer> buffer new LinkedList<>();public synchronized void producer(i…

element表单disabled功能失效问题

element表单disabled功能失效问题 场景:当需要根据商品状态来判断是否开启disabled来禁用表单时, disabled绑定了对应的值, 但无论商品是哪种状态, 表单都能操作, disabled失效 <el-form-item label"商品分类"><el-selectv-model"form.packagesTypeI…

二叉树---二叉搜索树的最近公共祖先

题目&#xff1a; 给定一个二叉搜索树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为&#xff1a;“对于有根树 T 的两个结点 p、q&#xff0c;最近公共祖先表示为一个结点 x&#xff0c;满足 x 是 p、q 的祖先且 x 的深度尽可能大&#xff08;一…

Unable to connect to Redis] with root cause

Servlet.service() for servlet [dispatcherServlet] in context with path [] threw exception [Request processing failed: org.springframework.data.redis.RedisConnectionFailureException: Unable to connect to Redis] with root cause springboot运行不了&#xff0c…

Object.entries()解析出来的数组顺序乱了,健是string类型

现象: 从后端哪里拿到了一长串数据 const obj {"2023-07-01":10,"2023-09-18":2,"2023-10-10":3,"2024-01-10":1,"2024-01-12":1,"2024-02-20":4,"2024-07-01":4,... }; 比如上面的数据有一年的 并…