图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)

概述

DIS(Dichotomous Image Segmentation)是一种新的图像分割任务,旨在从自然图像中分割出高精度的物体。与传统的图像分割任务相比,DIS更侧重于具有单个或几个目标的图像,因此可以提供更丰富准确的细节。

为了研究DIS任务,研究人员创建了一个名为DIS5K的大规模、可扩展的数据集。DIS5K数据集包含了5,470张高分辨率图像,每张图像都配有高精度的二值分割掩码。这个数据集的建立有助于推动多个应用方向的发展,如图像去背景、艺术设计、模拟视图运动、基于图像的增强现实(AR)应用、基于视频的AR应用、3D视频制作等。

通过研究DIS任务和使用DIS5K数据集,研究人员可以探索新的图像分割方法,并为各种应用领域提供更精确、更可靠的图像分割技术,从而推动分割技术在更广泛的领域中的应用。

官网:https://xuebinqin.github.io/dis/index.html
Github:https://github.com/xuebinqin/DIS

数据集

图像二类分割是将图像分割成两个主要区域:前景和背景。在这种情况下,前景代表图像中的某个类别的物体,而背景则是除了该物体之外的所有内容。
官方公布了算所使用的数据集DIS5K, DIS5K数据集中的每张图像都经过了像素级别的手工标注,标注的真值掩码非常精确,每张图像的标记时间相当长。这种高精度的标注使得数据集中的每个像素都与其相应的类别关联起来,从而为模型提供了可靠的训练数据。这种高精度的标注是实现图像二类分割的关键,因为模型需要能够准确地识别和分割出前景物体。

在DIS5K数据集中,标注对象的类型多样,包括透明和半透明的物体,标注使用单个像素的二值掩码进行。这种精确的标注确保了模型训练的有效性和准确性,并且使得模型能够预测出高精度的物体分割结果。

DIS5K数据集网盘地址:https://pan.baidu.com/s/1umNk2AeBG5aB5kXlHTHdIg
提取码:7qfs

模型训练

模型训练可参考git上的官方的文档

模型推理

模型C++使用onnxruntime进行推理

#include <opencv2/opencv.hpp>
#include <onnxruntime_cxx_api.h>class DIS
{
public:DIS(std::string model_path);void inference(cv::Mat& cv_src, cv::Mat& cv_mask);
private:std::vector<float> input_image_;int inpWidth;int inpHeight;int outWidth;int outHeight;const float score_th = 0;Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "DIS");Ort::Session* ort_session = nullptr;Ort::SessionOptions sessionOptions = Ort::SessionOptions();std::vector<char*> input_names;std::vector<char*> output_names;std::vector<std::vector<int64_t>> input_node_dims; // >=1 outputsstd::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
};DIS::DIS(std::string model_path)
{std::wstring widestr = std::wstring(model_path.begin(), model_path.end());//OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0);sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);ort_session = new Ort::Session(env, widestr.c_str(), sessionOptions);size_t numInputNodes = ort_session->GetInputCount();size_t numOutputNodes = ort_session->GetOutputCount();Ort::AllocatorWithDefaultOptions allocator;for (int i = 0; i < numInputNodes; i++){input_names.push_back(ort_session->GetInputName(i, allocator));Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i);auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();auto input_dims = input_tensor_info.GetShape();input_node_dims.push_back(input_dims);}for (int i = 0; i < numOutputNodes; i++){output_names.push_back(ort_session->GetOutputName(i, allocator));Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();auto output_dims = output_tensor_info.GetShape();output_node_dims.push_back(output_dims);}this->inpHeight = input_node_dims[0][2];this->inpWidth = input_node_dims[0][3];this->outHeight = output_node_dims[0][2];this->outWidth = output_node_dims[0][3];
}void DIS::inference(cv::Mat& cv_src, cv::Mat& cv_mask)
{cv::Mat cv_dst;cv::resize(cv_src, cv_dst, cv::Size(this->inpWidth, this->inpHeight));this->input_image_.resize(this->inpWidth * this->inpHeight * cv_dst.channels());for (int c = 0; c < 3; c++){for (int i = 0; i < this->inpHeight; i++){for (int j = 0; j < this->inpWidth; j++){float pix = cv_dst.ptr<uchar>(i)[j * 3 + 2 - c];this->input_image_[c * this->inpHeight * this->inpWidth + i * this->inpWidth + j] = pix / 255.0 - 0.5;}}}std::array<int64_t, 4> input_shape_{ 1, 3, this->inpHeight, this->inpWidth };auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);Ort::Value input_tensor_ = Ort::Value::CreateTensor<float>(allocator_info,input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());std::vector<Ort::Value> ort_outputs = ort_session->Run(Ort::RunOptions{ nullptr }, &input_names[0],&input_tensor_, 1, output_names.data(), output_names.size());   // 开始推理float* pred = ort_outputs[0].GetTensorMutableData<float>();cv::Mat mask(outHeight, outWidth, CV_32FC1, pred);double min_value, max_value;minMaxLoc(mask, &min_value, &max_value, 0, 0);mask = (mask - min_value) / (max_value - min_value);cv::resize(mask, cv_mask, cv::Size(cv_src.cols, cv_src.rows));
}void show_img(std::string name, const cv::Mat& img)
{cv::namedWindow(name, 0);int max_rows = 500;int max_cols = 600;if (img.rows >= img.cols && img.rows > max_rows) {cv::resizeWindow(name, cv::Size(img.cols * max_rows / img.rows, max_rows));}else if (img.cols >= img.rows && img.cols > max_cols) {cv::resizeWindow(name, cv::Size(max_cols, img.rows * max_cols / img.cols));}cv::imshow(name, img);
}cv::Mat replaceBG(const cv::Mat cv_src, cv::Mat& alpha, std::vector<int>& bg_color)
{int width = cv_src.cols;int height = cv_src.rows;cv::Mat cv_matting = cv::Mat::zeros(cv::Size(width, height), CV_8UC3);float* alpha_data = (float*)alpha.data;for (int i = 0; i < height; i++){for (int j = 0; j < width; j++){float alpha_ = alpha_data[i * width + j];cv_matting.at < cv::Vec3b>(i, j)[0] = cv_src.at < cv::Vec3b>(i, j)[0] * alpha_ + (1 - alpha_) * bg_color[0];cv_matting.at < cv::Vec3b>(i, j)[1] = cv_src.at < cv::Vec3b>(i, j)[1] * alpha_ + (1 - alpha_) * bg_color[1];cv_matting.at < cv::Vec3b>(i, j)[2] = cv_src.at < cv::Vec3b>(i, j)[2] * alpha_ + (1 - alpha_) * bg_color[2];}}return cv_matting;
}int main()
{DIS dis_net("isnet_general_use_720x1280.onnx");std::string path = "images";std::vector<std::string> filenames;cv::glob(path, filenames, false);for (auto file_name : filenames){cv::Mat cv_src = cv::imread(file_name);//std::vector<cv::Mat> cv_dsts;cv::Mat cv_dst, cv_mask;dis_net.inference(cv_src, cv_mask);std::vector<int> color{255, 0, 0};cv_dst=replaceBG(cv_src, cv_mask, color);show_img("src", cv_src);show_img("mask", cv_mask);show_img("dst", cv_dst);cv::waitKey(0);}
}

python推理代码也依赖onnxruntime

import argparse
import cv2
import numpy as np
import onnxruntime
### onnxruntime load ['isnet_general_use_HxW.onnx', 'isnet_HxW.onnx', 'isnet_Nx3xHxW.onnx']  inference failed
class DIS():def __init__(self, modelpath, score_th=None):so = onnxruntime.SessionOptions()so.log_severity_level = 3self.net = onnxruntime.InferenceSession(modelpath, so)self.input_height = self.net.get_inputs()[0].shape[2]self.input_width = self.net.get_inputs()[0].shape[3]self.input_name = self.net.get_inputs()[0].nameself.output_name = self.net.get_outputs()[0].nameself.score_th = score_thdef detect(self, srcimg):img = cv2.resize(srcimg, dsize=(self.input_width, self.input_height))img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = img.astype(np.float32) / 255.0 - 0.5blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0).astype(np.float32)outs = self.net.run([self.output_name], {self.input_name: blob})mask = np.array(outs[0]).squeeze()min_value = np.min(mask)max_value = np.max(mask)mask = (mask - min_value) / (max_value - min_value)if self.score_th is not None:mask = np.where(mask < self.score_th, 0, 1)mask *= 255mask = mask.astype('uint8')mask = cv2.resize(mask, dsize=(srcimg.shape[1], srcimg.shape[0]), interpolation=cv2.INTER_LINEAR)return maskdef generate_overlay_image(srcimg, mask):overlay_image = np.zeros(srcimg.shape, dtype=np.uint8)overlay_image[:] = (255, 255, 255)mask = np.stack((mask,) * 3, axis=-1).astype('uint8') mask_image = np.where(mask, srcimg, overlay_image)return mask, mask_imageif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument("--imgpath", type=str, default='images/cam_image47.jpg')parser.add_argument("--modelpath", type=str, default='weights/isnet_general_use_480x640.onnx')args = parser.parse_args()mynet = DIS(args.modelpath)srcimg = cv2.imread(args.imgpath)mask = mynet.detect(srcimg)mask, overlay_image = generate_overlay_image(srcimg, mask)winName = 'Deep learning object detection in onnxruntime'cv2.namedWindow(winName, cv2.WINDOW_NORMAL)cv2.imshow(winName, np.hstack((srcimg, mask)))cv2.waitKey(0)cv2.destroyAllWindows()

推理结果
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
资源和模型下载地址:https://download.csdn.net/download/matt45m/89024664

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/769645.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu20.04上,VTK9.3在QT5上的环境配置与开发测试

Ubuntu20.04上&#xff0c;VTK9.3在QT5上的环境配置与开发测试 1 背景介绍2 VTK9.3的编译安装2.1 安装ccmake 和 VTK 的依赖项&#xff1a;2.2 建立VTK编译文件夹并下载2.3 cmake配置VTK9.3的编译环境2.4 make编译安装VTK9.32.5 测试VTK安装是否成功 3 基于qmake的QT5的VTK9.3开…

异地共享文件如何设置?

在当今数字化时代&#xff0c;异地办公已成为常态&#xff0c;越来越多的企业和个人需要在不同地区间进行文件共享与访问。为了解决复杂网络环境下的远程连接问题&#xff0c;北京金万维科技有限公司推出了一款名为【天联】的异地组网内网穿透产品。 【天联】组网是一款由北京金…

python基础——语句

一、条件语句 就是 if else 语句 &#xff01; 代表不等于 代表等于if 关键字&#xff0c;判断语句&#xff0c;有“如果”的意思&#xff0c;后面跟上判断语句else 常和“if” 连用&#xff0c;有“否则”的意思&#xff0c;后面直接跟上冒号 …

qt学习第三天,qt设计师的第一个简单案例

3月25&#xff0c;应用qt设计师&#xff0c;手动设计界面形状 ​ 如何启动qt设计师&#xff0c;找到对应的安装地点&#xff0c;对应你自己安装的pyside6或其他qt的安装路径来找 ​ 应用qt设计师的优点是不用敲代码然后慢慢调节框框大小&#xff0c;位置等、可以直接修改…

TTS通用播放库技术设计

TTS音频播放库技术设计 目录介绍 01.整体介绍概述 1.1 项目背景介绍1.2 遇到问题1.3 基础概念介绍1.4 设计目标1.5 问题答疑和思考 02.技术调研说明 2.1 语音播放方案2.2 TTS技术分析2.3 语音合成技术2.4 方案选择说明2.5 方案设计思路2.6 文本生成音频 03.系统TTS使用实践 3…

JavaEE企业开发新技术4

2.16 模拟Spring IOC容器功能-1 2.17 模拟Spring IOC容器功能-2 什么是IOC&#xff1f; 控制反转&#xff0c;把对象创建和对象之间的调用过程交给Spring框架进行管理使用IOC的目的&#xff1a;为了耦合度降低 解释&#xff1a; 模仿 IOC容器的功能&#xff0c;我们利用 Map…

LLM - 大语言模型的指令微调(Instruction Tuning) 概述

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://blog.csdn.net/caroline_wendy/article/details/137009993 大语言模型的指令微调(Instruction Tuning)是一种优化技术&#xff0c;通过在特定的数据集上进一步训练大型语言模型(LLMs)&a…

【算法 高级数据结构】树状数组:一种高效的数据结构(二)

&#x1f680;个人主页&#xff1a;为梦而生~ 关注我一起学习吧&#xff01; &#x1f4a1;专栏&#xff1a;算法题、 基础算法、数据结构~赶紧来学算法吧 &#x1f4a1;往期推荐&#xff1a; 【算法基础 & 数学】快速幂求逆元&#xff08;逆元、扩展欧几里得定理、小费马定…

RTthread如何引入webclient和cjson来编写自己的模块代码||SecureCRT的安装与激活||安装VScode

目录 1.RTthread如何引入webclient和cjson来编写自己的模块代码 2.SecureCRT的安装与激活 3.static与const的区别 4.安装VScode 1.RTthread如何引入webclient和cjson来编写自己的模块代码 以我自己的工程为例&#xff1a; 首先将新引入的模块在applicatons下新建cpeinfo文件…

【MySQL】一条 SQL 查询语句在数据库中的执行流程 | SQL语句中各个关键字的执行顺序

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; 给大家跳段街舞感谢支持&#xff01;ጿ ኈ ቼ ዽ ጿ ኈ ቼ ዽ ጿ ኈ ቼ …

知识图表示学习中的负抽样研究综述

摘要 知识图表示学习(KGRL)或知识图嵌入(KGE)在知识构建和信息探索的人工智能应用中起着至关重要的作用。这些模型旨在将知识图中的实体和关系编码到低维向量空间中。在KGE模型的训练过程中&#xff0c;使用正样本和负样本是区分的必要条件。然而&#xff0c;直接从现有的知识…

鸿蒙NXET实战:高德地图定位SDK【获取Key+获取定位数据】(二)

如何申请key 1、创建新应用 进入[控制台]&#xff0c;创建一个新应用。如果您之前已经创建过应用&#xff0c;可直接跳过这个步骤。 2、添加新Key 在创建的应用上点击"添加新Key"按钮&#xff0c;在弹出的对话框中&#xff0c;依次&#xff1a;输入应用名名称&…

Muduo类详解之EventLoop

最核⼼的部分就是 EventLoop 、 Channel 以及 Poller 三个类&#xff0c;其中 EventLoop 可以看作是对业务线程的封装&#xff0c;⽽ Channel 可以看作是对每个已经建⽴连接的封装&#xff08;即 accept(3) 返回的⽂件描述符&#xff09; EventLoop class EventLoop { p…

解决SLF4J: Class path contains multiple SLF4J bindings.

JDK版本&#xff1a;jdk17 IDEA版本&#xff1a;IntelliJ IDEA 2022.1.3 SpringBoot 版本&#xff1a;v2.5.7 maven版本&#xff1a;3.6.3 文章目录 问题描述&#xff1a;原因分析&#xff1a;解决方案&#xff1a;参考资料&#xff1a; 问题描述&#xff1a; 当SpringBoot项目…

并发VS并行

参考文章 面试必考的&#xff1a;并发和并行有什么区别&#xff1f; 并发&#xff1a;一个人同时做多件事&#xff08;射击游戏队友抢装备&#xff09; 并行&#xff1a;多人同时处理同一件事&#xff08;射击游戏敌人同时射击对方&#xff09;

学习数据结构:算法的时间复杂度和空间复杂度

一、算法的复杂度 衡量一个算法的好坏&#xff0c;一般是从时间和空间两个维度来衡量的&#xff0c;即时间复杂度和空间复杂度。 时间复杂度主要衡量一个算法的运行快慢&#xff0c;而空间复杂度主要衡量一个算法运行所需要的额外空间。 算法的时间复杂度 算法中的基本操作的…

SAP BAS中Fiori开发的高阶功能(storyboard, navigation, guided development, variant)

1. 前言 在之前的几篇文章中&#xff0c;我介绍了SAP BAS的一些基本功能&#xff0c;包括账户申请&#xff0c;创建工作区&#xff0c;git的使用以及如何step-by-step去创建出你的第一个Fiori项目等等。在本篇中&#xff0c;我将进一步介绍一些在开发Fiori应用程序时会用到的高…

JAVA学习笔记19(面向对象编程)

1.面向对象编程 1.1 类与对象 1.类与对象的概念 ​ *对象[属性]/[行为] ​ *语法 class cat {String name;int age; }main() {//cat1就是一个对象//创建一只猫Cat cat1 new Cat();//给猫的属性赋值cat1.name "123";cat1.age 10; }​ *类是抽象的&#xff0c;…

前端使用正则表达式进行校验

一、定义 设计思想是用一种描述性的语言定义一个规则&#xff0c;凡是符合规则的字符串&#xff0c;我们就认为它“匹配”了&#xff0c;否则&#xff0c;该字符串就是不合法的。 在 JavaScript中&#xff0c;正则表达式也是对象&#xff0c;构建正则表达式有两种方式&#x…

【可用Claude Opus模型】Claude3国内镜像站,亲测完全超越GPT-4(可用Claude Opus,官网价值20刀)

#今天在知乎看到一个问题&#xff1a;“平民不参与内测的话没有账号还有机会使用Claude 3吗&#xff1f;” 从去年GPT大火到现在&#xff0c;关于GPT的消息铺天盖地&#xff0c;真要有心想要去用&#xff0c;途径很多&#xff0c;别的不说&#xff0c;国内GPT的镜像站到处都是…