基于U2-Net如何训练一个一键抠图模型

1. 前言

抠图是图像编辑的基础功能之一,在抠图的基础上可以发展出很多有意思的玩法和特效。比如一键更换背景、一键任务卡通化、一键人物素描化等。正是因为这些有意思的玩法。

最近也是对此模型背后的U^2-Net网络很感兴趣,收集数据训练了人脸素描化模型,尽管受限于数据集,只能在人脸图片上转换成功,但自己仍然玩的不亦乐乎。不仅乐于玩模型的有意思的效果,更乐在训练模型过程中,以及遇到问题解决问题过程中,对模型理解的不断加深。

最近对一键扣图模型从头训练了一遍,并在训练过程中持续测试了不同阶段模型的表现,看着模型一点点的收敛,抠图效果慢慢变好。

此处记录下训练过程以及训练的效果。也可以对后来者有一个参考。

提前说一声,模型训练很耗时!

2. 代码 & 数据 & 环境准备

2.1 代码

代码是U-2-Net的开源代码,可以从Github下载:https://github.com/NathanUA/U-2-Net。这个模型本来是做显著性检测的,但是当成一键扣图模型也很好玩。

需要注意的地方是,如果是安装的最新的Pytorch,获取loss值的时候,需要将loss.data[0] 修改为loss.data.item()

笔者在训练过程中曾尝试修改Loss函数为其他的,比如改成BCESSIM的加权(参考U-2-Net作者的文章BASNet),未见明显提升。也曾修改输出通道训练其他模型,暂无好玩的结果,就当是积累经验了。

2.2 数据

数据集我们就用论文中提到的DUTS数据集,已经分好了训练集和测试集。网上搜一下直接下载即可。

当然,也可以用自己的数据集,按照DUTS的格式重新组织下数据集即可。

然后在训练代码里面把数据读取部分的路径更换为自己准备的数据的路径。

2.3 机器

然后基于Anaconda安装训练所需的Python环境,创建虚拟环境,安装pytorch, torchvision, skimage, opencv等等,直接pip install或者conda install即可。不多说。

另外多卡的话,代码还需要有一些细微的改动,在构建模型之后,将代码:

    if torch.cuda.is_available():net.cuda()

修改为

    if torch.cuda.is_available():net.cuda()net = nn.DataParallel(net)

3. 训练与测试

3.1 模型训练

以上代码、数据、机器和运行环境都已经准备好之后,就可以开始训练了。多卡训练的命令大概长下面这样:

CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python3 -u u2net_train.py > log_train.log &

然后tail命令查看日志文件log_train.log,如果看到下面这样的输出,说明跑起来了:

再用命令watch -n 1 nvidia-smi查看GPU的情况,可以看到四张卡都被充分利用起来了。

模型训练将近一周,达到了接近论文的效果。

另外,由于中间保存过多,为了节省空间,笔者删掉了太多前期模型,以下展示的前期效果是另外一次训练的前期模型的效果。

3.2 各阶段模型测试

笔者微调测试代码结构,把测试转移到了Jupyter里,这样画图看效果更加直观。

笔者测试模型的时候,每张图都会画出三个图:黑色背景的抠图结果、模型输出的Mask或称Alpha,原图。这样对比来看结果一目了然。这里每张图都展示了四个阶段模型的测试效果。显然,以下图片都不在训练集里面。

四个阶段对比着看,能更加直观地感受到模型的收敛过程。

从以下四个阶段的对比图可以看出,随着训练的进行

  • 前景逐渐变亮,背景逐渐变暗,即前景收敛于1,背景收敛于0。前两幅图之间的对比最为明显。

  • 前景的轮廓从模糊到清晰细锐,轮廓处的不确定区域,越来越少。

  • 注意指缝和发梢部分的Mask的变化,细节越来越清晰。

下面这幅图请注意这个卡通人物背后背的那个是蜗牛还是啥的东西的轮廓的细化过程。以及其嘴角的一撮小胡子。这个图美中不足的是两脚之间的背景没有被识别出来。

下面这张图值得关注的应该就是其发梢的抠图细化过程、腰部的亮度变化过程。还有就是其手中的衣服了,对于要不要把一副也给抠出来,模型看起来也很纠结啊。

这个图最引人瞩目的莫过于这位美女在风中凌乱的发丝,这不是难为模型吗?说实话,如果不是看到Mask里胸前多出的东西,我都没注意到这个东西,衣服的胸结还是啥。

这大概就是训练了五天左右的效果,模型仍然在缓慢的收敛中,故事仍然在继续......

直到我实在是受不了越来越慢的收敛速度,等不及训练其他魔改的模型,终止了训练任务......

本着报喜不报忧的原则,下面再放几张测试效果还不错的图片,效果不怎么样的就不拿出来献丑了

上面的抠图效果还是有待提高,比如头发等边缘处,还是可见部分背景未分离。前几天刚转发了动物抠图的新论文,边缘和毛发的抠图效果很赞。其单开一条支路专门做轮廓边缘处的抠图的思路值得参考。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/179656.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

五、cookie、session、token、localstroage、sessionStroage区别

一、localStorage 跟 sessionStorage有什么不同???? localStorage 1、生命周期:localStorage的生命周期是永久的,关闭页面或浏览器之后localStorage中的数据也不会消失。localStorage除非主动删除数据&am…

Spark---资源、任务调度

一、Spark资源调度源码 1、Spark资源调度源码过程 Spark资源调度源码是在Driver启动之后注册Application完成后开始的。Spark资源调度主要就是Spark集群如何给当前提交的Spark application在Worker资源节点上划分资源。Spark资源调度源码在Master.scala类中的schedule()中进行…

界面控件DevExpress WPF流程图组件,完美复制Visio UI!(二)

DevExpress WPF Diagram(流程图)控件帮助用户完美复制Microsoft Visio UI,并将信息丰富且组织良好的图表、流程图和组织图轻松合并到您的下一个WPF项目中。 在上文中(点击这里回顾>>),我们为大家介绍…

AUTOSAR汽车电子嵌入式编程精讲300篇-基于智能网联车的CAN总线攻击与入侵检测(续)

目录 车辆总线攻击的远程实现 3.1 车辆总线攻击的实现方法 3.2 车身控制模块攻击 3.3 仪表盘攻击

git commit 撤销的三种方法

一般在提交代码的时候,顺序是这样的 git status // 查看修改文件状态(已添加至暂存区还是未添加至暂存区)git add . // 添加所有已修改文件 git add xxx/xxx // 添加目录为xxx/xxx的文件至暂存区git commit -m xx功能全部完成 // 提交暂存区…

Linux_Linux终端常用快捷键

Linux命令行核心常用快捷键是一些在终端中使用的快捷键组合,用于提高命令行操作的效率。下面是这些快捷键的原理详细解释、使用场景解释 Ctrl A :将光标移动到命令行的开头。这个快捷键的原理是发送一个控制序列到终端,告诉终端将光标移动到…

Java后端开发——MVC商品管理程序

Java后端开发——MVC商品管理程序 本篇文章内容主要有下面几个部分: MVC架构介绍项目环境搭建商品管理模块Servlet代码重构BaseServlet文件上传 MVC 是模型-视图-控制器(Model-View-Controller),它是一种设计模式,也…

## spring-@Autowired实现

spring-Autowired实现 我们知道 spring 中有很多的后置处理器 BeanPostProcessor, 而 Autowired 就是通过 AutowiredAnnotationBeanPostProcessor 来实现的 与之相似的还有 CommonAnnotationBeanPostProcessor 处理 Resource 注解 AutowiredAnnotationBeanPostProcessor 构…

java设计模式学习之【原型模式】

文章目录 引言原型模式简介定义与用途实现方式UML 使用场景优势与劣势原型模式在spring中的应用员工记录示例代码地址 引言 原型模式是一种创建型设计模式,它允许对象能够复制自身,以此来创建一个新的对象。这种模式在需要重复地创建相似对象时非常有用…

【代码】基于卷积神经网络(CNN)-支持向量机(SVM)的分类预测算法

程序名称:基于卷积神经网络(CNN)-支持向量机(SVM)的分类预测算法 实现平台:matlab 代码简介:CNN-SVM是一种常用的图像分类方法,结合了卷积神经网络(CNN)和支…

移动应用开发介绍及iOS方向学习路线(HUT移动组版)

移动应用开发介绍及iOS方向学习路线(HUT移动组版) 前言 ​ 作为一个HUT移动组待了一坤年(两年半)多的老人,在这里为还在考虑进哪个组的萌新们以及将来进组的新朋友提供一份关于移动应用开发介绍以及学习路线的白话文…

DC电源模块有哪些常见故障?怎么解决这些问题?

DC-DC电源模块的作用是将输入电压转换为所需的输出电压,广泛应用于电子产品、汽车电子、医疗设备、通信系统等领域。但是在使用过程中DC电源模块会出现一些故障和问题,影响电源模块和其它电路器件的性能。因此,纳米软件将为大家介绍常见的DC-…

大坝安全监测的内容及作用

大坝安全监测是指对大坝水雨情沉降、倾斜、渗压以及大坝形状特征有效地进行监测,及时发现潜在的安全隐患和异常情况,以便大坝管理人员能够做出科学决策,以确保大坝安全稳定运行。 大坝安全监测的主要内容 1.表面位移监测:监测大坝…

分子骨架跃迁工具-DiffHopp 评测

一、文章背景介绍 DiffHopp模型发表在ICML 2023 Workshop on Computational Biology(简称:2023 ICML-WCB)上的文章。第一作者是剑桥计算机系的Jos Torge。 DiffHopp是一个专门针对骨架跃迁任务而训练的E3等变条件扩散模型。此外,…

golang构建docker镜像的几种方式

目前docker支持以下几种方式指定上下文来构建镜像 本地项目路径本地压缩包路径docekrfile文本链接压缩包文件链接git仓库链接 在此记录下golang中使用git仓库链接构建方法 import ("context""github.com/docker/docker/api/types""github.com/dock…

LeetCode Hot100 84.柱状图中最大的矩形

题目: 给定 n 个非负整数,用来表示柱状图中各个柱子的高度。每个柱子彼此相邻,且宽度为 1 。 求在该柱状图中,能够勾勒出来的矩形的最大面积。 方法: 代码: class Solution {public int largestRectang…

Go操作MySQL

1、下载依赖 Go语言中的database/sql包提供了保证SQL或类SQL数据库的泛用接口,并不提供具体的数据库驱动。使用database/sql包时必须注入(至少)一个数据库驱动。 go get -u github.com/go-sql-driver/mysql2、相关函数 2.1 Open() func O…

MySOL常见四种连接查询

1、内联接 &#xff08;典型的联接运算&#xff0c;使用像 或 <> 之类的比较运算符&#xff09;。包括相等联接和自然联接。 内联接使用比较运算符根据每个表共有的列的值匹配两个表中的行。例如&#xff0c;检索 students和courses表中学生标识号相同的所有行。 2、…

CVPR 2023 精选论文学习笔记:Towards Scalable Neural Representation for Diverse Videos

基于 MECE 原则,我们给出以下四个分类标准: 分类标准 1:表示类型 隐式神经表示(INR) 隐式神经表示(INR)是一类神经网络架构,将场景或对象表示为从 3D 点映射到颜色和不透明度值的连续函数。该函数通常从一组训练图像或视频中学习,然后可以用于渲染场景或对象的新视…

机器学习之危险品车辆目标检测

危险品的运输涉及从离开仓库到由车辆运输到目的地的风险。监控事故、车辆运动动态以及车辆通过特定区域的频率对于监督车辆运输危险品的过程至关重要。 在线工具推荐&#xff1a; 三维数字孪生场景工具 - GLTF/GLB在线编辑器 - Three.js AI自动纹理化开发 - YOLO 虚幻合成数…