U-Net代码复现–train.py

本文记录自己的学习过程,内容包括:
代码解读:Pytorch-UNet
深度学习编程基础:Pytorch-深度学习(新手友好)
UNet论文解读:医学图像分割:U_Net 论文阅读
数据:https://hackernoon.com/hacking-gta-v-for-carvana-kaggle-challenge-6d0b7fb4c781

完整代码解读详见:U-Net代码复现–更新中

train.py

CarvanaDataset 读取并创建输入数据(具体实现详见:U-Net代码复现–utils data_loading.py)

	# 1. Create datasettry:dataset = CarvanaDataset(dir_img, dir_mask, img_scale)except (AssertionError, RuntimeError, IndexError):dataset = BasicDataset(dir_img, dir_mask, img_scale)

random_split()函数说明:

  • 这个函数的作用是划分数据集

参数说明:

  • dataset (Dataset): 划分的数据集
  • lengths (sequence): 被划分数据集的长度
	# 2. Split into train / validation partitions# 将数据集分为训练集和验证集n_val = int(len(dataset) * val_percent)n_train = len(dataset) - n_valtrain_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

加载和迭代数据集
关于DataLoader参考:Pytorch:torch.utils.data.DataLoader()

	# 3. Create data loadersloader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)train_loader = DataLoader(train_set, shuffle=True, **loader_args)val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

Wandb是Weights & Biases的缩写,是类似TensorBoard, visdom的一款可视化工具;是属于Python的,不是Pytorch的(大家感兴趣可以自己看看,这里就不多解释了)

	# (Initialize logging)experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp))

打印日志

	logging.info(f'''Starting training:Epochs:          {epochs}Batch size:      {batch_size}Learning rate:   {learning_rate}Training size:   {n_train}Validation size: {n_val}Checkpoints:     {save_checkpoint}Device:          {device.type}Images scaling:  {img_scale}Mixed Precision: {amp}''')
  • 关于 optim.RMSprop ,参考机器学习:优化器Optimizer的总结与比较
  • torch.optim.lr_scheduler 模块提供了一些根据epoch训练次数来调整学习率(learning rate)的方法。一般情况下我们会设置随着epoch的增大而逐渐减小学习率从而达到更好的训练效果。
  • torch.optim.lr_scheduler.ReduceLROnPlateau 则提供了基于训练中某些测量值使学习率动态下降的方法。
  • torch.cuda.amp.GradScaler 参考:PyTorch : torch.cuda.amp: 自动混合精度详解
  • nn.CrossEntropyLoss() 损失函数
	# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMPoptimizer = optim.RMSprop(model.parameters(),lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)  # goal: maximize Dice scoregrad_scaler = torch.cuda.amp.GradScaler(enabled=amp)criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()global_step = 0

训练函数,这部分内容比较多,将拆分为多个小部分分析:
========== part 1============
迭代数据集:

  • for batch in train_loader: 对应前文中:train_loader = DataLoader (train_set, shuffle=True, **loader_args)

  • images, true_masks = batch['image'], batch['mask'] 对应U-Net代码复现–utils data_loading.py中的

    'image': torch.as_tensor(img.copy()).float().contiguous(),
    'mask': torch.as_tensor(mask.copy()).long().contiguous().

  • images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last) .
    true_masks = true_masks.to(device=device, dtype=torch.long) .
    将Tensor或模型移动到指定的设备上;关于.to()的用法详见:pytorch:to()、device()、cuda()

========== part 2============

  • with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):

    with torch.autocast: 语句块内的代码会自动进行混合精度计算,也就是根据输入数据的类型自动选择合适的精度进行计算

  • masks_pred = model(images) .
    单次预测结果

  • if model.n_classes == 1: .
    n_classes:输出图的通道数,也就是最终得到几张特征图

  • loss = criterion(masks_pred.squeeze(1), true_masks.float()) .
    loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
    参考:U-Net代码复现–utils dice_score.py

========== part 3============

  • torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping).
    clip_grad_norm_的理解参考:Pytorch:torch.nn.utils.clip_grad_norm_梯度截断_解读

  • optimizer.zero_grad(set_to_none=True).
    grad_scaler.scale(loss).backward().
    grad_scaler.step(optimizer).
    grad_scaler.update().
    关于optimizer.zero_grad(), loss.backward(), optimizer.step()的理解参考:Pytorch:optimizer.zero_grad(), loss.backward(), optimizer.step()

========== part 4============

  • 参数更新

  • 进度条更新

	# 5. Begin training# ================================ part 1 =======================================for epoch in range(1, epochs + 1):model.train()epoch_loss = 0with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:for batch in train_loader:images, true_masks = batch['image'], batch['mask']assert images.shape[1] == model.n_channels, \f'Network has been defined with {model.n_channels} input channels, ' \f'but loaded images have {images.shape[1]} channels. Please check that ' \'the images are loaded correctly.'images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)true_masks = true_masks.to(device=device, dtype=torch.long)# ================================ part 2 =======================================with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):masks_pred = model(images)if model.n_classes == 1:loss = criterion(masks_pred.squeeze(1), true_masks.float())loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)else:loss = criterion(masks_pred, true_masks)loss += dice_loss(F.softmax(masks_pred, dim=1).float(),F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),multiclass=True)# ================================ part 3 =======================================optimizer.zero_grad(set_to_none=True)grad_scaler.scale(loss).backward()torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)grad_scaler.step(optimizer)grad_scaler.update()# ================================ part 4 =======================================pbar.update(images.shape[0])global_step += 1epoch_loss += loss.item()experiment.log({'train loss': loss.item(),'step': global_step,'epoch': epoch})pbar.set_postfix(**{'loss (batch)': loss.item()})# ================================ part 5 =======================================# Evaluation rounddivision_step = (n_train // (5 * batch_size))if division_step > 0:if global_step % division_step == 0:histograms = {}for tag, value in model.named_parameters():tag = tag.replace('/', '.')if not (torch.isinf(value) | torch.isnan(value)).any():histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())val_score = evaluate(model, val_loader, device, amp)scheduler.step(val_score)logging.info('Validation Dice score: {}'.format(val_score))try:experiment.log({'learning rate': optimizer.param_groups[0]['lr'],'validation Dice': val_score,'images': wandb.Image(images[0].cpu()),'masks': {'true': wandb.Image(true_masks[0].float().cpu()),'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),},'step': global_step,'epoch': epoch,**histograms})except:passif save_checkpoint:Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)state_dict = model.state_dict()state_dict['mask_values'] = dataset.mask_valuestorch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))logging.info(f'Checkpoint {epoch} saved!')
  • argparse.ArgumentParser :创建 ArgumentParser() 对象
  • parser.add_argument :调用 add_argument() 方法添加参数
  • parser.parse_args() : 使用 parse_args() 解析添加的参数

其中 parser.add_argument

name or flags - 一个命名或者一个选项字符串的列表,例如 foo 或 -f, --foo。
action -当参数在命令行中出现时使用的动作基本类型。
nargs - 命令行参数应当消耗的数目。
const - 被一些 action 和 nargs选择所需求的常数。
default - 当参数未在命令行中出现时使用的值。
choices - 可用的参数的容器。
required -此命令行选项是否可省略 (仅选项可用)。
help - 一个此选项作用的简单描述。
metavar - 在使用方法消息中使用的参数值示例。
dest - 被添加到 parse_args() 所返回对象上的属性名。

def get_args():parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,help='Learning rate', dest='lr')parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,help='Percent of the data that is used as validation (0-100)')parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')return parser.parse_args()

main():

Memory Format(Channel First 和 Channel Last )
Pytorch:torch.utils.checkpoint()
Pytorch:torch.cuda.empty_cache()
Pytorch:模型的加载和保存 torch.save,torch.load,torch.nn.Module.state_dict 和 torch.nn.Module.load_state_dict


if __name__ == '__main__':args = get_args()logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'Using device {device}')# Change here to adapt to your data# n_channels=3 for RGB images# n_classes is the number of probabilities you want to get per pixelmodel = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)model = model.to(memory_format=torch.channels_last)logging.info(f'Network:\n'f'\t{model.n_channels} input channels\n'f'\t{model.n_classes} output channels (classes)\n'f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')if args.load:state_dict = torch.load(args.load, map_location=device)del state_dict['mask_values']model.load_state_dict(state_dict)logging.info(f'Model loaded from {args.load}')model.to(device=device)try:train_model(model=model,epochs=args.epochs,batch_size=args.batch_size,learning_rate=args.lr,device=device,img_scale=args.scale,val_percent=args.val / 100,amp=args.amp)except torch.cuda.OutOfMemoryError:logging.error('Detected OutOfMemoryError! ''Enabling checkpointing to reduce memory usage, but this slows down training. ''Consider enabling AMP (--amp) for fast and memory efficient training')torch.cuda.empty_cache()model.use_checkpointing()train_model(model=model,epochs=args.epochs,batch_size=args.batch_size,learning_rate=args.lr,device=device,img_scale=args.scale,val_percent=args.val / 100,amp=args.amp)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/757540.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中霖教育:二级建造师证书好考吗?

在建筑行业,二级建造师资格认证相较于一级建造师资格,难度会低一些。考试科目共有三科,考生需要在连续两个年度内通过所有科目的考试才为通过。 对于具备建筑相关基础和实践经验的考生来说,二级建造师的考试难度会低一些。根据往…

云扩展要求(云租户)

层面控制点四级三级二级安全 区域 边界访问控制应在虚拟化网络边界部署访问控制机制,并设置访问控制规则;应在虚拟化网络边界部署访问控制机制,并设置访问控制规则;应在虚拟化网络边界部署访问控制机制,并设置访问控制…

30天拿下Rust之错误处理

概述 在软件开发领域,对错误的妥善处理是保证程序稳定性和健壮性的重要环节。Rust作为一种系统级编程语言,以其对内存安全和所有权的独特设计而著称,其错误处理机制同样体现了Rust的严谨与实用。在Rust中,错误处理通常分为两大类&…

KUKA机器人自动回原点程序

一、创建全局变量点 创建两个全局变量分别用于储存机器人的笛卡尔姿态与关节角姿态。 打开System文件夹中的config文件创建全局变量的点位。 在USER GROBALS用户自定义变量Userdefined variables下创建一个E6POS类型的点位,一个E6AXIS类型的点位。 二、创建回原点…

webRtc麦克风摄像头检测

最近在做webRtc相关音视频项目&#xff0c;碰到了很多用户不知道自己设备是否被支持发起webRtc&#xff0c;所以特意总结相关实用方法&#xff1b; HTML /*id方便一会把媒体流赋值过去, autoPlay: 自动播放 */ <audio id"devDetectionMicroRef" autoPlay><…

基于SpringBoot+Vue交通管理在线服务系统的开发(源码+部署说明+演示视频+源码介绍)

您好&#xff0c;我是码农飞哥&#xff08;wei158556&#xff09;&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。&#x1f4aa;&#x1f3fb; 1. Python基础专栏&#xff0c;基础知识一网打尽&#xff0c;9.9元买不了吃亏&#xff0c;买不了上当。 Python从入门到精通…

经典面试智力题总结

常见面试智力题总结 本部分主要是笔者在练习常见面试智力题所做的笔记&#xff0c;如果出现错误&#xff0c;希望大家指出&#xff01; 常见智力题 时针与分针夹角度数问题&#xff1f; 分析&#xff1a; 当时间为 m 点 n 分时&#xff0c;其时针与分针夹角的度数为多少&…

React状态管理Mobx

1 https://zh.mobx.js.org/README.html 2 https://juejin.cn/post/7046710251382374413 3 https://cn.mobx.js.org/refguide/observable.html ​​mobx入门基础教程-慕课网​​ ​​Mobx学习 - 掘金​​ 十分钟入门 MobX & React ​​十分钟入门 MobX & React​​…

警惕!On Hold被踢,2本1区,5本Springer旗下,共8本SCI/SSCI被剔除!

毕业推荐 SSCI&#xff08;ABS一星&#xff09; • 社科类&#xff0c;3.0-4.0&#xff0c;JCR2区&#xff0c;中科院3区 • 13天录用&#xff0c;28天见刊&#xff0c;13天检索 SCIE&#xff1a; • 计算机类&#xff0c;6.5-7.0&#xff0c;JCR1区&#xff0c;中科院2区…

农业气象站在农业生产中的应用—气象科普

农业气象站在农业生产中发挥着至关重要的作用。它能够有效监测和记录农田环境中的各类气象要素&#xff0c;为农民提供科学、准确的气象数据&#xff0c;帮助他们更好地掌握天气变化规律&#xff0c;从而合理安排农业生产活动。 首先&#xff0c;农业气象站能够实时提供温度、…

使用 Clojure 进行 OpenCV 开发简介

返回&#xff1a;OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇&#xff1a;如何将OpenCV Java 与Eclipse结合使用 下一篇&#xff1a; OpenCV4.9.0在Android 开发简介 ​警告 本教程可以包含过时的信息。 从 OpenCV 2.4.4 开始&#xff0c;OpenCV 支持…

Hibernate相关问题

文章目录 Hibernate是如何简化JDBC操作的&#xff1f;解释Hibernate的ORM概念Hibernate中的Session和Transaction有什么区别&#xff1f;Session&#xff1a;Transaction&#xff1a; Hibernate有哪些缓存类型&#xff1f;它们是如何工作的&#xff1f;一级缓存&#xff08;Fir…

挑战设计极限!电路仿真软件成功案例大揭秘,助您圆梦创新之路

在电子设计领域&#xff0c;电路仿真软件扮演着至关重要的角色。它们不仅能够帮助工程师们模拟和分析电路的性能&#xff0c;还能够加速设计过程&#xff0c;降低成本&#xff0c;提高产品的质量和可靠性。今天&#xff0c;让我们一起挑战设计极限&#xff0c;揭秘电路仿真软件…

服务器版本ros镜像,包含了CAN通讯以及VNC界面操作

以下镜像包含了ros的moveit、novnc、CAN通讯&#xff0c;并且可以web操作界面: 19900617/ros-moveit-rviz-gazebo:noetic docker-compose.yml配置文件如下: version: 3services:ros:container_name: rosimage: 19900617/ros-moveit-rviz-gazebo:noeticentrypoint: ["b…

Clickhouse MergeTree异常数据处理

作者&#xff1a;俊达 说明 clickhouse mergetree的数据文件如果遇到数据损坏&#xff0c;可能会导致clickhouse无法启动。 本文章说明如何处理这类问题。 测试 我们先人为模拟破坏mergetree数据文件&#xff1a; detach table&#xff1a; ck01 :) detach table metric…

探索.NET中的定时器:选择最适合你的应用场景

概述&#xff1a;.NET提供多种定时器&#xff0c;如 System.Windows.Forms.Timer适用于UI&#xff0c;System.Web.UI.Timer用于Web&#xff0c;System.Diagnostics.Timer用于性能监控&#xff0c;System.Threading.Timer和System.Timers.Timer用于一般定时任务。在.NET 6及以上…

Java基础---反射

什么是反射&#xff1f; 反射允许对成员变量&#xff0c;成员方法和构造方法的信息进行编程访问。 这么说可能比较抽象&#xff0c;可以简单理解为&#xff1a;反射就是一个人&#xff0c;可以把类里面的成员变量&#xff0c;成员方法&#xff0c;构造方法都获取出来。 并且可…

Springcloud智慧工地APP云综合平台源码 SaaS服务

目录 智慧工地功能介绍 一、项目人员 二、视频监控 三、危大工程 四、绿色施工 五、安全隐患 具体功能介绍&#xff1a; 1.劳务管理&#xff1a; 2.施工安全管理&#xff1a; 3.视频监控管理&#xff1a; 4.机械安全管理&#xff1a; 5.危大工程监管&#xff1a; …

项目技术问题记录-内网环境下搭建LVS实现四层负载

原创作者&#xff1a;田超凡&#xff08;程序员田宝宝&#xff09; 版权所有&#xff0c;引用请注明原作者&#xff0c;严禁复制转载 lvs实现四层负载DR模式什么是lvs LVS是Linux Virtual Server的简写&#xff0c;意即Linux虚拟服务器&#xff0c;是一个虚拟的服务器集群系统…

ctf_show笔记篇(web入门---反序列化)

目录 反序列化 254&#xff1a;无用&#xff0c;是让熟悉序列化这个东西的 255&#xff1a;直接使$isViptrue 256&#xff1a;还是使用变量覆盖 257&#xff1a;开始使用魔法函数 258&#xff1a;将序列化最前面的过滤了&#xff0c;使用绕过 259: 这一题需要看writeup才…