【代码分析】Unet-Pytorch

1:unet_parts.py

主要包含:

【1】double conv,双层卷积

【2】down,下采样

【3】up,上采样

【4】out conv,输出卷积

""" Parts of the U-Net model """import torch
import torch.nn as nn
import torch.nn.functional as Fclass DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels, mid_channels=None):super().__init__()if not mid_channels:mid_channels = out_channelsself.double_conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()# if bilinear, use the normal convolutions to reduce the number of channelsif bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)else:# // 是整除运算self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# input is CHWdiffY = x2.size()[2] - x1.size()[2]diffX = x2.size()[3] - x1.size()[3]x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])# if you have padding issues, see# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bdx = torch.cat([x2, x1], dim=1)return self.conv(x)class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)

【1】double conv

=》卷积。卷积核是3*3,填充是1

=》批归一化。

=》ReLU。激活函数

=》卷积。卷积核是3*3,填充是1

=》批归一化。

=》ReLU。激活函数

【2】down

=》最大池化。池化核是2*2

=》double conv。

【3】up

=》上采样。可选择upsample + double conv 和 transpose + double conv

=》计算尺寸差异。

=》填充x1。使得x1和x2对齐

=》拼接x2和x1。按照dim=1,也就是channel通道拼接

=》double conv。

【4】out conv

=》卷积。卷积核是1*1

2:unet_model.py

主要包含:UNet完整架构

""" Full assembly of the parts to form the complete network """from .unet_parts import *class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=False):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = (DoubleConv(n_channels, 64))self.down1 = (Down(64, 128))self.down2 = (Down(128, 256))self.down3 = (Down(256, 512))factor = 2 if bilinear else 1self.down4 = (Down(512, 1024 // factor))self.up1 = (Up(1024, 512 // factor, bilinear))self.up2 = (Up(512, 256 // factor, bilinear))self.up3 = (Up(256, 128 // factor, bilinear))self.up4 = (Up(128, 64, bilinear))self.outc = (OutConv(64, n_classes))def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logitsdef use_checkpointing(self):self.inc = torch.utils.checkpoint(self.inc)self.down1 = torch.utils.checkpoint(self.down1)self.down2 = torch.utils.checkpoint(self.down2)self.down3 = torch.utils.checkpoint(self.down3)self.down4 = torch.utils.checkpoint(self.down4)self.up1 = torch.utils.checkpoint(self.up1)self.up2 = torch.utils.checkpoint(self.up2)self.up3 = torch.utils.checkpoint(self.up3)self.up4 = torch.utils.checkpoint(self.up4)self.outc = torch.utils.checkpoint(self.outc)

其中,use_checkpointing的作用是丢弃中间计算结果,加快训练速度。

上面的代码可以结合下图分析

前向传播过程:

        x1 = self.inc(x)

通过double conv双层卷积,输入通道为图像自身的,输出通道为64

        x2 = self.down1(x1)

通过down下采样,输入通道为64,输出通道为128

        x3 = self.down2(x2)

通过down下采样,输入通道为128,输出通道为256

        x4 = self.down3(x3)

通过down下采样,输入通道为256,输出通道为512

        x5 = self.down4(x4)

通过down下采样,输入通道为512,输出通道为1024(非bilinear,后续上采样也是如此)

        x = self.up1(x5, x4)

通过up上采样,输入通道为1024,输出通道为512

这个地方concat的对象是x4,也就是下采样输出通道为512的时候的特征

        x = self.up2(x, x3)

通过up上采样,输入通道为512,输出通道为256

这个地方concat的对象是x,也就是原图(后续也是原图)

其实这里和原作者的跳跃连接有点不太一样,代码库的作者直接省事用了原图进行拼接

        x = self.up3(x, x2)

通过up上采样,输入通道为256,输出通道为128

        x = self.up4(x, x1)

通过up上采样,输入通道为128,输出通道为64

        logits = self.outc(x)

通过out conv输出卷积,输入通道为64,输出通道为2,也就是分割为背景和物体2个类别的像素

3:完整代码

可以在github上通过git clone下载

milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation with high quality images (github.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/65480.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[Leetcode] 最大子数组和 [击败99%的解法]

解法1&#xff1a; 暴力解法 遍历每个元素&#xff0c;从它当前位置一直加到最后&#xff0c;然后用一个最大值来记录全局最大值。 代码如下&#xff1a; class Solution {public int maxSubArray(int[] nums) {long sum, max nums[len-1];for (int i0; i<nums.length;…

在JavaScript文件中定义方法和数据(不是在对象里定以数据和方法,不要搞错了)

在对象里定以数据和方法看这一篇 对象字面量内定义属性和方法&#xff08;什么使用const等关键字&#xff0c;什么时候用键值对&#xff09;-CSDN博客https://blog.csdn.net/m0_62961212/article/details/144788665 下是在JavaScript文件中定义方法和数据的基本方式&#xff…

基于SpringBoot的垃圾分类系统设计与实现【源码+文档+部署讲解】

系统介绍 基于SpringBootVue实现的垃圾分类系统设计了三种角色、分别是管理员、垃圾分类管理员、用户&#xff0c;实现了个人中心、用户管理、垃圾分类管理员管理、垃圾分类管理、垃圾类型管理、垃圾图谱管理、系统管理等功能 技术选型 开发工具&#xff1a;idea2020.3Webst…

springboot启动不了 因一个spring-boot-starter-web底下的tomcat-embed-core依赖丢失

这个包丢失了 启动不了 起因是pom中加入了 <tomcat.version></tomcat.version>版本指定&#xff0c;然后idea自动编译后&#xff0c;包丢了&#xff0c;删除这个配置后再也找不回来&#xff0c; 这个包正常在 <dependency><groupId>org.springframe…

前后端分离(对话框的使用)

1.首先先定义两个按钮(一个添加按钮&#xff0c;一个修改按钮) <el-button type"primary" click"openDialog(true)">添加员工</el-button> <el-button size"mini" click"openDialog(false, scope.row)">编辑</…

【Maven_bugs】The project main artifact does not exist

背景&#xff1a;我想使用 maven-shade-plugin 打一个 fat jar 时报了标题中的错误&#xff0c;使用的命令是&#xff1a;org.apache.maven.plugins:maven-shade-plugin:shade -pl :shade-project。项目结构如下图&#xff0c;我想把子模块 shade-project 打成一个 fat jar&…

贪心算法(常见贪心模型)

常见贪心模型 简单排序模型 最小化战斗力差距 题目分析&#xff1a; #include <bits/stdc.h> using namespace std;const int N 1e5 10;int n; int a[N];int main() {// 请在此输入您的代码cin >> n;for (int i 1;i < n;i) cin >> a[i];sort(a1,a1n);…

Docker 安装与配置 Nginx

摘要 1、本文全面介绍了如何在 Docker 环境中安装和配置 Nginx 容器。 2、文中详细解释了如何设置 HTTPS 安全连接及配置 Nginx 以实现前后端分离的代理服务。 2、同时&#xff0c;探讨了通过 IP 和域名两种方式访问 Nginx 服务的具体配置方法 3、此外&#xff0c;文章还涵…

机器学习常用术语

目录 概要 机器学习常用术语 1、模型 2、数据集 3、样本与特征 4、向量 5、矩阵 6、假设函数与损失函数 7、拟合、过拟合与欠拟合 8、激活函数(Activation Function) 9、反向传播(Backpropagation) 10、基线(Baseline) 11、批量(Batch) 12、批量大小(Batch Size)…

微服务篇-深入了解 MinIO 文件服务器(你还在使用阿里云 0SS 对象存储图片服务?教你使用 MinIO 文件服务器:实现从部署到具体使用)

&#x1f525;博客主页&#xff1a; 【小扳_-CSDN博客】 ❤感谢大家点赞&#x1f44d;收藏⭐评论✍ 文章目录 1.0 MinIO 文件服务器概述 1.1 MinIO 使用 Docker 部署 1.2 MinIO 控制台的使用 2.0 使用 Java 操作 MinIO 3.0 使用 minioClient 对象的方法 3.1 判断桶是否存在 3.2…

第一个C++程序|cin和cout|命名空间

第一个C程序 基础程序 使用DevC5.4.0 写一个C程序 在屏幕上打印hello world #include <iostream> using namespace std;int main() {cout << "hello world" << endl;return 0; } 运行这个C程序 F9->编译 F10->运行 F11->编译运行 mai…

强化特种作业管理,筑牢安全生产防线

在各类生产经营活动中&#xff0c;特种作业由于其操作的特殊性和高风险性&#xff0c;一直是安全生产管理的重点领域。有效的特种作业管理体系涵盖多个关键方面&#xff0c;从作业人员的资质把控到安全设施的配备维护&#xff0c;再到特种设备的精细管理以及作业流程的严格规范…

iOS 苹果开发者账号: 查看和添加设备UUID 及设备数量

参考链接&#xff1a;苹果开发者账号下添加新设备UUID - 简书 如果要添加新设备到 Profiles 证书里&#xff1a; 1.登录开发者中心 Sign In - Apple 2.找到证书设置&#xff1a; Certificate&#xff0c;Identifiers&Profiles > Profiles > 选择对应证书 edit &g…

如何让Tplink路由器自身的IP网段 与交换机和电脑的IP网段 保持一致?

问题分析&#xff1a; 正常情况下&#xff0c;我的需求是&#xff1a;电脑又能上网&#xff0c;又需要与路由器处于同一局域网下&#xff08;串流Pico4 VR眼镜&#xff09;&#xff0c;所以&#xff0c;我是这么连接 交换机、路由器、电脑 的&#xff1a; 此时&#xff0c;登录…

(南京观海微电子)——GH7009开机黑屏案例分析

一、 现象描述&#xff1a; 不良现象: LVDS模组&#xff0c;开机大概2秒后就黑屏。 二、问题分析 等主机进入Kernel 后做以下测试&#xff1a; 1、手动reset LCM 后 可以显示正常&#xff1b; 总结&#xff1a; 1&#xff09;uboot 部分HS 太窄&#xff0c;仅有4个clk宽度&am…

第10章 初等数论

2024年12月27日一稿&#xff08;P341&#xff09; 2024年12月28日二稿 2024年12月29日三稿 当命运这扇大门向你打开的时候&#xff0c;不要犹豫和害怕&#xff0c;一直往前跑就是了&#xff01; 10.1 素数 这里写错了&#xff0c;不能整除应该表示为 10.2 最大公约数与最小公…

XXE漏洞 黑盒测试 白盒测试 有无回显问题

前言 什么是XXE&#xff08;xml外部实体注入漏洞&#xff09;&#xff1f; 就是网站以xml传输数据 的时候我们截取他的传输流进行修改&#xff08;网站没有对我们的输入进行过滤&#xff09; 添加恶意代码 导致数据传输到后台 后台解析xml形式 导致恶意代码被执行 几种常见的…

yolov5 yolov6 yolov7 yolov8 yolov9目标检测、目标分类 目标切割 性能对比

文章目录 YOLOv1-YOLOv8之间的对比如下表所示&#xff1a;一、YOLO算法的核心思想1. YOLO系列算法的步骤2. Backbone、Neck和Head 二、YOLO系列的算法1.1 模型介绍1.2 网络结构1.3 实现细节1.4 性能表现 2. YOLOv2&#xff08;2016&#xff09;2.1 改进部分2.2 网络结构 3. YOL…

jdk版本介绍

1.JDK版本编号 • 主版本号&#xff1a;表示JDK的主要版本&#xff0c;如JDK 8、JDK 11中的8和11。主版本号的提升通常意味着引入了重大的新特性或变更。 • 次版本号&#xff1a;在主版本号之后&#xff0c;有时会跟随一个或多个次版本号&#xff08;如JDK 11.0.2中的0.2&…

低代码开源项目Joget的研究——基本概念和Joget7社区版应用

大纲 1. 基本概念1.1 Form1.1.1 Form1.1.1.1 概述1.1.1.2 主要特点和用途1.1.1.3 创建和使用 Form1.1.1.4 示例 1.1.2 Section1.1.2.1 概述1.1.2.2 主要特点和用途1.1.2.3 示例 1.1.3 Column1.1.4 Field1.1.5 示例 1.2 Datalist1.2.1 Datalist1.2.1.1 主要特点和用途1.2.1.2 创…