文章目录
- 背景
- 1.Grounding DINO安装
- 2.裁剪指定目标的脚本
背景
- 在处理公开数据集ImageNet-21k的时候发现里面有很多的数据有问题,比如,数据目标有很多背景,且部分类别有其他种类的图片。
- 针对数据目标有很多背景,公开数据集ImageNet-21k的21k种类别进行裁剪。
- 文本提示检测图像任意目标(Grounding DINO),这更模型可以很好的应用在这个场景。
1.Grounding DINO安装
github地址
- 从 GitHub 克隆 GroundingDINO 存储库。
git clone https://github.com/IDEA-Research/GroundingDINO.git
- 将当前目录更改为 GroundingDINO 文件夹。
cd GroundingDINO/
- 在当前目录中安装所需的依赖项。
pip install -e .
- 下载预训练模型权重。
mkdir weights
cd weights
wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
cd ..
- 下载bert-base-uncased到text_encoder_type(自己创建一个文件夹)
需要下载下面的三个文件,放进text_encoder_type里面就好。
- 修改地址
修改/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py
文件中text_encoder_type
的路径。
-
如果您有 CUDA 环境,请确保设置了环境变量 CUDA_HOME 。如果没有可用的 CUDA,它将在仅 CPU 模式下编译。
-
可能遇到的bug
Segmentation fault (core dumped)
是因为timm版本和cuda,pytorch等版本不匹配重新安装可以解决这个bug。
pip uninstall timm
pip install timm
2.裁剪指定目标的脚本
- 如下是测试的demo
import cv2print("456")
from groundingdino.util.inference import load_model, load_image, predict, annotateprint("123")
model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", "weight/groundingdino_swint_ogc.pth", "cpu")
IMAGE_PATH = r"images/th.jpg"
TEXT_PROMPT = "dolphins"
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
print("456")
image_source, image = load_image(IMAGE_PATH)print("789")
boxes, logits, phrases = predict(model=model,image=image,caption=TEXT_PROMPT,box_threshold=BOX_TRESHOLD,text_threshold=TEXT_TRESHOLD
)print("10")
print(boxes)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
cv2.imwrite("annotated_image.jpg", annotated_frame)
- 裁剪指定目标的脚本
该脚本指定目录后,会对该目录下子文件夹的不同目标类别,进行裁剪并将裁剪结果放在与原路径对应的相对路径种。
脚本全部代码:
import os
import time
from groundingdino.util.inference import load_model, load_image, predict
import cv2
import torch
from torchvision.ops import box_convertdef save_cropped_images(image, boxes, image_name, output_folder):os.makedirs(output_folder, exist_ok=True)h, w, _ = image.shapeboxes = boxes * torch.tensor([w, h, w, h])xyxy_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()for i, box in enumerate(xyxy_boxes):x_min, y_min, x_max, y_max = map(int, box)cropped_image = image[y_min:y_max, x_min:x_max]# Ensure the color channels are in BGR order for OpenCVcropped_image_bgr = cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR)cv2.imwrite(f"{output_folder}/{image_name}_cropped_{i}.jpg", cropped_image_bgr)def process_image(image_path, model, output_folder, box_threshold=0.35, text_threshold=0.25):image_source, image = load_image(image_path)try:boxes, logits, phrases = predict(model=model,image=image,caption=TEXT_PROMPT,box_threshold=box_threshold,text_threshold=text_threshold)except RuntimeError as e:print(f"RuntimeError: {e}")# Get the image name without extensionimage_name = os.path.splitext(os.path.basename(image_path))[0]# Save cropped images with image name includedsave_cropped_images(image_source, boxes, image_name, output_folder)def process_images_in_folder(folder_path, model, box_threshold=0.35, text_threshold=0.25):folder_name = os.path.basename(folder_path.rstrip('/'))output_folder = os.path.join("/animals_classify/Cropped_Dataset/QuanKe", folder_name)print(f"{folder_name}, cropping.")# Start timer for processing this folderstart_time = time.time()for filename in os.listdir(folder_path):if filename.endswith(".jpg") or filename.endswith(".png") or filename.endswith(".JPEG"):image_path = os.path.join(folder_path, filename)process_image(image_path, model, output_folder, box_threshold, text_threshold)# End timer for processing this folderfolder_processing_time = time.time() - start_timeprocess_images_in_folder.total_time += folder_processing_timeprint(f"{folder_name}, cropped. Time taken: {folder_processing_time:.2f} seconds")print(f"Total time taken so far: {process_images_in_folder.total_time:.2f} seconds")# Initialize the total time taken to 0
process_images_in_folder.total_time = 0.0# Configuration and model loading
model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", "weight/groundingdino_swint_ogc.pth")
TEXT_PROMPT = "canine"
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25FOLDERS_PATH = "/animals_classify/Raw_Dataset/QuanKe"
for FOLDER_Name in os.listdir(FOLDERS_PATH):FOLDER_PATH = os.path.join(FOLDERS_PATH, FOLDER_Name)# Process all images in the folderprocess_images_in_folder(FOLDER_PATH, model, BOX_THRESHOLD, TEXT_THRESHOLD)
裁剪示例:
原图:
结果: