本文记录自己的学习过程,内容包括:
代码解读:Pytorch-UNet
深度学习编程基础:Pytorch-深度学习(新手友好)
UNet论文解读:医学图像分割:U_Net 论文阅读
数据:https://hackernoon.com/hacking-gta-v-for-carvana-kaggle-challenge-6d0b7fb4c781
完整代码解读详见:U-Net代码复现–更新中
train.py
CarvanaDataset
: 读取并创建输入数据(具体实现详见:U-Net代码复现–utils data_loading.py)
# 1. Create datasettry:dataset = CarvanaDataset(dir_img, dir_mask, img_scale)except (AssertionError, RuntimeError, IndexError):dataset = BasicDataset(dir_img, dir_mask, img_scale)
random_split()
函数说明:
- 这个函数的作用是划分数据集
参数说明:
- dataset (Dataset): 划分的数据集
- lengths (sequence): 被划分数据集的长度
# 2. Split into train / validation partitions# 将数据集分为训练集和验证集n_val = int(len(dataset) * val_percent)n_train = len(dataset) - n_valtrain_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
加载和迭代数据集
关于DataLoader
参考:Pytorch:torch.utils.data.DataLoader()
# 3. Create data loadersloader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)train_loader = DataLoader(train_set, shuffle=True, **loader_args)val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
Wandb是Weights & Biases的缩写,是类似TensorBoard, visdom的一款可视化工具;是属于Python的,不是Pytorch的(大家感兴趣可以自己看看,这里就不多解释了)
# (Initialize logging)experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp))
打印日志
logging.info(f'''Starting training:Epochs: {epochs}Batch size: {batch_size}Learning rate: {learning_rate}Training size: {n_train}Validation size: {n_val}Checkpoints: {save_checkpoint}Device: {device.type}Images scaling: {img_scale}Mixed Precision: {amp}''')
- 关于
optim.RMSprop
,参考机器学习:优化器Optimizer的总结与比较 torch.optim.lr_scheduler
模块提供了一些根据epoch训练次数来调整学习率(learning rate)的方法。一般情况下我们会设置随着epoch的增大而逐渐减小学习率从而达到更好的训练效果。torch.optim.lr_scheduler.ReduceLROnPlateau
则提供了基于训练中某些测量值使学习率动态下降的方法。torch.cuda.amp.GradScaler
参考:PyTorch : torch.cuda.amp: 自动混合精度详解nn.CrossEntropyLoss()
损失函数
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMPoptimizer = optim.RMSprop(model.parameters(),lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5) # goal: maximize Dice scoregrad_scaler = torch.cuda.amp.GradScaler(enabled=amp)criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()global_step = 0
训练函数,这部分内容比较多,将拆分为多个小部分分析:
========== part 1============
迭代数据集:
-
for batch in train_loader:
对应前文中:train_loader = DataLoader (train_set, shuffle=True, **loader_args)
-
images, true_masks = batch['image'], batch['mask']
对应U-Net代码复现–utils data_loading.py中的'image': torch.as_tensor(img.copy()).float().contiguous()
,
'mask': torch.as_tensor(mask.copy()).long().contiguous()
. -
images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
.
true_masks = true_masks.to(device=device, dtype=torch.long)
.
将Tensor或模型移动到指定的设备上;关于.to()
的用法详见:pytorch:to()、device()、cuda()
========== part 2============
-
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp)
:with torch.autocast: 语句块内的代码会自动进行混合精度计算,也就是根据输入数据的类型自动选择合适的精度进行计算
-
masks_pred = model(images)
.
单次预测结果 -
if model.n_classes == 1:
.
n_classes:输出图的通道数,也就是最终得到几张特征图 -
loss = criterion(masks_pred.squeeze(1), true_masks.float())
.
loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
参考:U-Net代码复现–utils dice_score.py
========== part 3============
-
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
.
clip_grad_norm_的理解参考:Pytorch:torch.nn.utils.clip_grad_norm_梯度截断_解读 -
optimizer.zero_grad(set_to_none=True)
.
grad_scaler.scale(loss).backward()
.
grad_scaler.step(optimizer)
.
grad_scaler.update()
.
关于optimizer.zero_grad(), loss.backward(), optimizer.step()的理解参考:Pytorch:optimizer.zero_grad(), loss.backward(), optimizer.step()
========== part 4============
-
参数更新
-
进度条更新
# 5. Begin training# ================================ part 1 =======================================for epoch in range(1, epochs + 1):model.train()epoch_loss = 0with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:for batch in train_loader:images, true_masks = batch['image'], batch['mask']assert images.shape[1] == model.n_channels, \f'Network has been defined with {model.n_channels} input channels, ' \f'but loaded images have {images.shape[1]} channels. Please check that ' \'the images are loaded correctly.'images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)true_masks = true_masks.to(device=device, dtype=torch.long)# ================================ part 2 =======================================with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):masks_pred = model(images)if model.n_classes == 1:loss = criterion(masks_pred.squeeze(1), true_masks.float())loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)else:loss = criterion(masks_pred, true_masks)loss += dice_loss(F.softmax(masks_pred, dim=1).float(),F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),multiclass=True)# ================================ part 3 =======================================optimizer.zero_grad(set_to_none=True)grad_scaler.scale(loss).backward()torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)grad_scaler.step(optimizer)grad_scaler.update()# ================================ part 4 =======================================pbar.update(images.shape[0])global_step += 1epoch_loss += loss.item()experiment.log({'train loss': loss.item(),'step': global_step,'epoch': epoch})pbar.set_postfix(**{'loss (batch)': loss.item()})# ================================ part 5 =======================================# Evaluation rounddivision_step = (n_train // (5 * batch_size))if division_step > 0:if global_step % division_step == 0:histograms = {}for tag, value in model.named_parameters():tag = tag.replace('/', '.')if not (torch.isinf(value) | torch.isnan(value)).any():histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())val_score = evaluate(model, val_loader, device, amp)scheduler.step(val_score)logging.info('Validation Dice score: {}'.format(val_score))try:experiment.log({'learning rate': optimizer.param_groups[0]['lr'],'validation Dice': val_score,'images': wandb.Image(images[0].cpu()),'masks': {'true': wandb.Image(true_masks[0].float().cpu()),'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),},'step': global_step,'epoch': epoch,**histograms})except:passif save_checkpoint:Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)state_dict = model.state_dict()state_dict['mask_values'] = dataset.mask_valuestorch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))logging.info(f'Checkpoint {epoch} saved!')
argparse.ArgumentParser
:创建 ArgumentParser() 对象parser.add_argument
:调用 add_argument() 方法添加参数parser.parse_args()
: 使用 parse_args() 解析添加的参数
其中
parser.add_argument
:
name or flags
- 一个命名或者一个选项字符串的列表,例如 foo 或 -f, --foo。
action
-当参数在命令行中出现时使用的动作基本类型。
nargs
- 命令行参数应当消耗的数目。
const
- 被一些 action 和 nargs选择所需求的常数。
default
- 当参数未在命令行中出现时使用的值。
choices
- 可用的参数的容器。
required
-此命令行选项是否可省略 (仅选项可用)。
help
- 一个此选项作用的简单描述。
metavar
- 在使用方法消息中使用的参数值示例。
dest
- 被添加到 parse_args() 所返回对象上的属性名。
def get_args():parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,help='Learning rate', dest='lr')parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,help='Percent of the data that is used as validation (0-100)')parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')return parser.parse_args()
main():
Memory Format(Channel First 和 Channel Last )
Pytorch:torch.utils.checkpoint()
Pytorch:torch.cuda.empty_cache()
Pytorch:模型的加载和保存 torch.save,torch.load,torch.nn.Module.state_dict 和 torch.nn.Module.load_state_dict
if __name__ == '__main__':args = get_args()logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'Using device {device}')# Change here to adapt to your data# n_channels=3 for RGB images# n_classes is the number of probabilities you want to get per pixelmodel = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)model = model.to(memory_format=torch.channels_last)logging.info(f'Network:\n'f'\t{model.n_channels} input channels\n'f'\t{model.n_classes} output channels (classes)\n'f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')if args.load:state_dict = torch.load(args.load, map_location=device)del state_dict['mask_values']model.load_state_dict(state_dict)logging.info(f'Model loaded from {args.load}')model.to(device=device)try:train_model(model=model,epochs=args.epochs,batch_size=args.batch_size,learning_rate=args.lr,device=device,img_scale=args.scale,val_percent=args.val / 100,amp=args.amp)except torch.cuda.OutOfMemoryError:logging.error('Detected OutOfMemoryError! ''Enabling checkpointing to reduce memory usage, but this slows down training. ''Consider enabling AMP (--amp) for fast and memory efficient training')torch.cuda.empty_cache()model.use_checkpointing()train_model(model=model,epochs=args.epochs,batch_size=args.batch_size,learning_rate=args.lr,device=device,img_scale=args.scale,val_percent=args.val / 100,amp=args.amp)