【代码分析】Unet-Pytorch

1:unet_parts.py

主要包含:

【1】double conv,双层卷积

【2】down,下采样

【3】up,上采样

【4】out conv,输出卷积

""" Parts of the U-Net model """import torch
import torch.nn as nn
import torch.nn.functional as Fclass DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels, mid_channels=None):super().__init__()if not mid_channels:mid_channels = out_channelsself.double_conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()# if bilinear, use the normal convolutions to reduce the number of channelsif bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)else:# // 是整除运算self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# input is CHWdiffY = x2.size()[2] - x1.size()[2]diffX = x2.size()[3] - x1.size()[3]x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])# if you have padding issues, see# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bdx = torch.cat([x2, x1], dim=1)return self.conv(x)class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)

【1】double conv

=》卷积。卷积核是3*3,填充是1

=》批归一化。

=》ReLU。激活函数

=》卷积。卷积核是3*3,填充是1

=》批归一化。

=》ReLU。激活函数

【2】down

=》最大池化。池化核是2*2

=》double conv。

【3】up

=》上采样。可选择upsample + double conv 和 transpose + double conv

=》计算尺寸差异。

=》填充x1。使得x1和x2对齐

=》拼接x2和x1。按照dim=1,也就是channel通道拼接

=》double conv。

【4】out conv

=》卷积。卷积核是1*1

2:unet_model.py

主要包含:UNet完整架构

""" Full assembly of the parts to form the complete network """from .unet_parts import *class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=False):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = (DoubleConv(n_channels, 64))self.down1 = (Down(64, 128))self.down2 = (Down(128, 256))self.down3 = (Down(256, 512))factor = 2 if bilinear else 1self.down4 = (Down(512, 1024 // factor))self.up1 = (Up(1024, 512 // factor, bilinear))self.up2 = (Up(512, 256 // factor, bilinear))self.up3 = (Up(256, 128 // factor, bilinear))self.up4 = (Up(128, 64, bilinear))self.outc = (OutConv(64, n_classes))def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logitsdef use_checkpointing(self):self.inc = torch.utils.checkpoint(self.inc)self.down1 = torch.utils.checkpoint(self.down1)self.down2 = torch.utils.checkpoint(self.down2)self.down3 = torch.utils.checkpoint(self.down3)self.down4 = torch.utils.checkpoint(self.down4)self.up1 = torch.utils.checkpoint(self.up1)self.up2 = torch.utils.checkpoint(self.up2)self.up3 = torch.utils.checkpoint(self.up3)self.up4 = torch.utils.checkpoint(self.up4)self.outc = torch.utils.checkpoint(self.outc)

其中,use_checkpointing的作用是丢弃中间计算结果,加快训练速度。

上面的代码可以结合下图分析

前向传播过程:

        x1 = self.inc(x)

通过double conv双层卷积,输入通道为图像自身的,输出通道为64

        x2 = self.down1(x1)

通过down下采样,输入通道为64,输出通道为128

        x3 = self.down2(x2)

通过down下采样,输入通道为128,输出通道为256

        x4 = self.down3(x3)

通过down下采样,输入通道为256,输出通道为512

        x5 = self.down4(x4)

通过down下采样,输入通道为512,输出通道为1024(非bilinear,后续上采样也是如此)

        x = self.up1(x5, x4)

通过up上采样,输入通道为1024,输出通道为512

这个地方concat的对象是x4,也就是下采样输出通道为512的时候的特征

        x = self.up2(x, x3)

通过up上采样,输入通道为512,输出通道为256

这个地方concat的对象是x,也就是原图(后续也是原图)

其实这里和原作者的跳跃连接有点不太一样,代码库的作者直接省事用了原图进行拼接

        x = self.up3(x, x2)

通过up上采样,输入通道为256,输出通道为128

        x = self.up4(x, x1)

通过up上采样,输入通道为128,输出通道为64

        logits = self.outc(x)

通过out conv输出卷积,输入通道为64,输出通道为2,也就是分割为背景和物体2个类别的像素

3:完整代码

可以在github上通过git clone下载

milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation with high quality images (github.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/65480.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[Leetcode] 最大子数组和 [击败99%的解法]

解法1&#xff1a; 暴力解法 遍历每个元素&#xff0c;从它当前位置一直加到最后&#xff0c;然后用一个最大值来记录全局最大值。 代码如下&#xff1a; class Solution {public int maxSubArray(int[] nums) {long sum, max nums[len-1];for (int i0; i<nums.length;…

系统压力测试助手——stress-ng

1、背景 在系统性能测试和压力测试中&#xff0c;stress-ng 是一个非常强大的工具&#xff0c;广泛应用于对 Linux 系统进行各种硬件和软件方面的负载测试。它能够模拟多种极端负载情况&#xff0c;帮助开发人员和运维人员检查系统在高负载下的表现&#xff0c;以便发现潜在的…

计算机网络500题2024-2025学年度第一学期复习题库(选择、判断、填空)

一、单选题 1、&#xff08; &#xff09;是实现两个同种网络互连的设备 A. 网桥 B. 网关 C. 集线器 D. 路由器 2、10M以太网有三种接口标准&#xff0c;其中10BASE-T采用&#xff08; &#xff09; A. 双绞线 B. 粗同轴电缆 C. 细同轴电缆 D. 光纤 3、HDLC是哪…

在JavaScript文件中定义方法和数据(不是在对象里定以数据和方法,不要搞错了)

在对象里定以数据和方法看这一篇 对象字面量内定义属性和方法&#xff08;什么使用const等关键字&#xff0c;什么时候用键值对&#xff09;-CSDN博客https://blog.csdn.net/m0_62961212/article/details/144788665 下是在JavaScript文件中定义方法和数据的基本方式&#xff…

基于SpringBoot的垃圾分类系统设计与实现【源码+文档+部署讲解】

系统介绍 基于SpringBootVue实现的垃圾分类系统设计了三种角色、分别是管理员、垃圾分类管理员、用户&#xff0c;实现了个人中心、用户管理、垃圾分类管理员管理、垃圾分类管理、垃圾类型管理、垃圾图谱管理、系统管理等功能 技术选型 开发工具&#xff1a;idea2020.3Webst…

今日总结 2024-12-28

今天全身心投入到鸿蒙系统下 TCPSocket 的学习中。从最基础的 TCP 协议三次握手、四次挥手原理重新梳理&#xff0c;深刻理解其可靠连接建立与断开机制&#xff0c;这是后续运用 TCPSocket 无误通信的根基。在深入鸿蒙体系时&#xff0c;仔细研读了其为 TCPSocket 封装的 API&a…

springboot启动不了 因一个spring-boot-starter-web底下的tomcat-embed-core依赖丢失

这个包丢失了 启动不了 起因是pom中加入了 <tomcat.version></tomcat.version>版本指定&#xff0c;然后idea自动编译后&#xff0c;包丢了&#xff0c;删除这个配置后再也找不回来&#xff0c; 这个包正常在 <dependency><groupId>org.springframe…

前后端分离(对话框的使用)

1.首先先定义两个按钮(一个添加按钮&#xff0c;一个修改按钮) <el-button type"primary" click"openDialog(true)">添加员工</el-button> <el-button size"mini" click"openDialog(false, scope.row)">编辑</…

doris集群存储目录切换

doris集群存储目录切换 1. 背景 3节点集群&#xff0c;BE存储目录&#xff0c;因为运维原因。存储盘系统放在了一一起。 需要增加硬盘&#xff0c;并替换原有目录。 3节点集群&#xff0c;如果各个表都是3副本&#xff0c;可以实现轮流停机&#xff0c;方式处理。 但是业务…

【Maven_bugs】The project main artifact does not exist

背景&#xff1a;我想使用 maven-shade-plugin 打一个 fat jar 时报了标题中的错误&#xff0c;使用的命令是&#xff1a;org.apache.maven.plugins:maven-shade-plugin:shade -pl :shade-project。项目结构如下图&#xff0c;我想把子模块 shade-project 打成一个 fat jar&…

Qt 的信号槽机制详解:之信号槽引发的 Segmentation Fault 问题拆析(上)

Qt 的信号槽机制详解&#xff1a;之因信号槽误用引发的 Segmentation Fault 问题拆析&#xff08;上&#xff09; 前言一. 信号与槽的基本概念信号&#xff08;Signal&#xff09;槽&#xff08;Slot&#xff09;连接信号与槽 二. 信号槽机制的实现原理元对象系统&#xff08;M…

贪心算法(常见贪心模型)

常见贪心模型 简单排序模型 最小化战斗力差距 题目分析&#xff1a; #include <bits/stdc.h> using namespace std;const int N 1e5 10;int n; int a[N];int main() {// 请在此输入您的代码cin >> n;for (int i 1;i < n;i) cin >> a[i];sort(a1,a1n);…

Docker 安装与配置 Nginx

摘要 1、本文全面介绍了如何在 Docker 环境中安装和配置 Nginx 容器。 2、文中详细解释了如何设置 HTTPS 安全连接及配置 Nginx 以实现前后端分离的代理服务。 2、同时&#xff0c;探讨了通过 IP 和域名两种方式访问 Nginx 服务的具体配置方法 3、此外&#xff0c;文章还涵…

机器学习常用术语

目录 概要 机器学习常用术语 1、模型 2、数据集 3、样本与特征 4、向量 5、矩阵 6、假设函数与损失函数 7、拟合、过拟合与欠拟合 8、激活函数(Activation Function) 9、反向传播(Backpropagation) 10、基线(Baseline) 11、批量(Batch) 12、批量大小(Batch Size)…

微服务篇-深入了解 MinIO 文件服务器(你还在使用阿里云 0SS 对象存储图片服务?教你使用 MinIO 文件服务器:实现从部署到具体使用)

&#x1f525;博客主页&#xff1a; 【小扳_-CSDN博客】 ❤感谢大家点赞&#x1f44d;收藏⭐评论✍ 文章目录 1.0 MinIO 文件服务器概述 1.1 MinIO 使用 Docker 部署 1.2 MinIO 控制台的使用 2.0 使用 Java 操作 MinIO 3.0 使用 minioClient 对象的方法 3.1 判断桶是否存在 3.2…

第一个C++程序|cin和cout|命名空间

第一个C程序 基础程序 使用DevC5.4.0 写一个C程序 在屏幕上打印hello world #include <iostream> using namespace std;int main() {cout << "hello world" << endl;return 0; } 运行这个C程序 F9->编译 F10->运行 F11->编译运行 mai…

【大模型】wiki中文语料的word2vec模型构建

在自然语言处理&#xff08;NLP&#xff09;任务中&#xff0c;词向量&#xff08;Word Embedding&#xff09;是一个非常重要的概念。通过将词语映射到一个高维空间中&#xff0c;我们能够以向量的形式表达出词语之间的语义关系。Word2Vec作为一种流行的词向量学习方法&#x…

1.RPC基本原理

文章目录 RPC1.定义2.概念3.优缺点4.RPC结构5.RPC消息协议5.1 消息边界5.2 内容5.3 压缩 6.RPC的实现6.1 divide_protocol.py6.2 server.py6.3 client.py RPC 1.定义 远程过程调用(remote procedure call) 2.概念 广义:所有通过网络进行通讯,的调用统称为RPC调用 狭义:不采…

强化特种作业管理,筑牢安全生产防线

在各类生产经营活动中&#xff0c;特种作业由于其操作的特殊性和高风险性&#xff0c;一直是安全生产管理的重点领域。有效的特种作业管理体系涵盖多个关键方面&#xff0c;从作业人员的资质把控到安全设施的配备维护&#xff0c;再到特种设备的精细管理以及作业流程的严格规范…

iOS 苹果开发者账号: 查看和添加设备UUID 及设备数量

参考链接&#xff1a;苹果开发者账号下添加新设备UUID - 简书 如果要添加新设备到 Profiles 证书里&#xff1a; 1.登录开发者中心 Sign In - Apple 2.找到证书设置&#xff1a; Certificate&#xff0c;Identifiers&Profiles > Profiles > 选择对应证书 edit &g…