TensorBoard可视化+Confustion Matrix Drawing

for later~

代码阅读

1. 加载trainset

import argparse
import logging
import os
import numpy as npimport torch
from torch import distributed
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom backbones import get_model
from dataset import get_dataloader
from face_fc_ddp import FC_ddp
from utils.utils_callbacks import CallBackLogging, CallBackVerification
from utils.utils_config import get_config
from utils.utils_distributed_sampler import setup_seed
from utils.utils_logging import AverageMeter, init_loggingfrom utils.utils_invreg import env_loss_ce_ddp, assign_loss
from utils.utils_feature_saving import concat_feat, extract_feat_per_gpu
from utils.utils_partition import load_past_partitionassert torch.__version__ >= "1.9.0", "In order to enjoy the features of the new torch, \
we have upgraded the torch to 1.9.0. torch before than 1.9.0 may not work in the future."import datetimeos.environ["NCCL_BLOCKING_WAIT"] = "1"try:world_size = int(os.environ["WORLD_SIZE"])rank = int(os.environ["RANK"])distributed.init_process_group("nccl", timeout=datetime.timedelta(hours=3))
except KeyError:world_size = 1rank = 0distributed.init_process_group(backend="nccl",init_method="tcp://127.0.0.1:12584",rank=rank,world_size=world_size,)def main(args):cfg = get_config(args.config)setup_seed(seed=cfg.seed, cuda_deterministic=False)torch.cuda.set_device(args.local_rank)os.makedirs(cfg.output, exist_ok=True)init_logging(rank, cfg.output)summary_writer = (SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard"))if rank == 0else None)##################### Trainset definition ###################### only horizon-flip is used in transformstrain_loader = get_dataloader(cfg.rec,args.local_rank,cfg.batch_size,False,cfg.seed,cfg.num_workers,return_idx=True)

3. 定义backbone model,加载权重,并行化训练

    ##################### Model backbone definition #####################backbone = get_model(cfg.network, dropout=cfg.dropout, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()if cfg.resume:if rank == 0:dict_checkpoint = torch.load(os.path.join(cfg.pretrained, f"checkpoint_{cfg.pretrained_ep}.pt"))backbone.load_state_dict(dict_checkpoint["state_dict_backbone"])del dict_checkpointbackbone = torch.nn.parallel.DistributedDataParallel(module=backbone, broadcast_buffers=False, device_ids=[args.local_rank], bucket_cap_mb=16,find_unused_parameters=True)backbone.train()backbone._set_static_graph()

4. 分类函数+损失定义

    ##################### FC classification & loss definition ######################if cfg.invreg['irm_train'] == 'var':reduction = 'none'else:reduction = 'mean'module_fc = FC_ddp(cfg.embedding_size, cfg.num_classes, scale=cfg.scale,margin=cfg.cifp['m'], mode=cfg.cifp['mode'], use_cifp=cfg.cifp['use_cifp'],reduction=reduction).cuda()if cfg.resume:if rank == 0:dict_checkpoint = torch.load(os.path.join(cfg.pretrained, f"checkpoint_{cfg.pretrained_ep}.pt"))module_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])del dict_checkpointmodule_fc = torch.nn.parallel.DistributedDataParallel(module_fc, device_ids=[args.local_rank])module_fc.train().cuda()opt = torch.optim.SGD(params=[{"params": backbone.parameters()}, {"params": module_fc.parameters()}],lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay)##################### Train scheduler definition #####################cfg.total_batch_size = cfg.batch_size * world_sizecfg.num_image = len(train_loader.dataset)n_cls = cfg.num_classescfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epochcfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epochassert cfg.scheduler == 'step'from torch.optim.lr_scheduler import MultiStepLRlr_scheduler = MultiStepLR(optimizer=opt,milestones=cfg.step,gamma=0.1,last_epoch=-1)start_epoch = 0global_step = 0if cfg.resume:dict_checkpoint = torch.load(os.path.join(cfg.pretrained, f"checkpoint_{cfg.pretrained_ep}.pt"),map_location={'cuda:0': f'cuda:{rank}'})start_epoch = dict_checkpoint["epoch"]global_step = dict_checkpoint["global_step"]opt.load_state_dict(dict_checkpoint["state_optimizer"])del dict_checkpoint
  • dict_checkpoint是 检查点的信息,用字典存储

5. 评估定义

    ##################### Evaluation definition #####################callback_verification = CallBackVerification(val_targets=cfg.val_targets, rec_prefix=cfg.val_rec, summary_writer=summary_writer)callback_logging = CallBackLogging(frequent=cfg.frequent,total_step=cfg.total_step,batch_size=cfg.batch_size,start_step=global_step,writer=summary_writer)loss_am = AverageMeter()amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)updated_split_all = []for key, value in cfg.items():num_space = 25 - len(key)logging.info(": " + key + " " * num_space + str(value))loss_weight_irm_init = cfg.invreg['loss_weight_irm']

6. 训练迭代

    ##################### Training iterations #####################if cfg.resume:callback_verification(global_step, backbone)for epoch in range(start_epoch, cfg.num_epoch):if cfg.invreg['loss_weight_irm_anneal'] and cfg.invreg['loss_weight_irm'] > 0:cfg.invreg['loss_weight_irm'] = loss_weight_irm_init * (1 + 0.09) ** (epoch - 5)if epoch in cfg.invreg['stage'] and cfg.invreg['loss_weight_irm'] > 0:cfg.invreg['env_num'] = cfg.invreg['env_num_lst'][cfg.invreg['stage'].index(epoch)]save_dir = os.path.join(cfg.output, 'saved_feat', 'epoch_{}'.format(epoch))if os.path.exists(os.path.join(save_dir, 'final_partition.npy')):logging.info('Loading the past partition...')updated_split_all = load_past_partition(cfg, epoch)logging.info(f'Total {len(updated_split_all)} partition are loaded...')else:if os.path.exists(os.path.join(save_dir, 'feature.npy')):logging.info('Loading the pre-saved features...')else:# extract features for each gpuextract_feat_per_gpu(backbone, cfg, args, save_dir)if rank == 0:_, _ = concat_feat(cfg.num_image, world_size, save_dir)distributed.barrier()emb = np.load(os.path.join(save_dir, 'feature.npy'))lab = np.load(os.path.join(save_dir, 'label.npy'))# conduct partition learninglogging.info('Started partition learning...')from utils.utils_partition import update_partitionupdated_split = update_partition(cfg, save_dir, n_cls, emb, lab, summary_writer,backbone.device, rank, world_size)del emb, labdistributed.barrier()updated_split_all.append(updated_split)if isinstance(train_loader, DataLoader):train_loader.sampler.set_epoch(epoch)for _, (index, img, local_labels) in enumerate(train_loader):global_step += 1local_embeddings = backbone(img)# cross-entropy lossif cfg.invreg['irm_train'] == 'var':loss_ce_tensor, acc = module_fc(local_embeddings, local_labels, return_logits=False)loss_ce = loss_ce_tensor.mean()loss = loss_ceelif cfg.invreg['irm_train'] == 'grad':loss_ce, acc, logits = module_fc(local_embeddings, local_labels, return_logits=True)loss = loss_ce# IRM lossif len(updated_split_all) > 0:if cfg.invreg['irm_train'] == 'grad':loss_irm = env_loss_ce_ddp(logits, local_labels, world_size, cfg, updated_split_all, epoch)elif cfg.invreg['irm_train'] == 'var':import dist_all_gatherloss_total_lst = dist_all_gather.all_gather(loss_ce_tensor)label_total_lst = dist_all_gather.all_gather(local_labels)loss_total = torch.cat(loss_total_lst, dim=0)label_total = torch.cat(label_total_lst, dim=0)loss_irm_lst = []for updated_split in updated_split_all:n_env = updated_split.size(-1)loss_env_lst = []for env_idx in range(n_env):loss_env = assign_loss(loss_total, label_total, updated_split, env_idx)loss_env_lst.append(loss_env.mean())loss_irm_lst.append(torch.stack(loss_env_lst).var())loss_irm = sum(loss_irm_lst) / len(updated_split_all)else:print('Please check the IRM train mode')loss += loss_irm * cfg.invreg['loss_weight_irm']if rank == 0:callback_logging.writer.add_scalar(tag='Loss CE', scalar_value=loss_ce.item(),global_step=global_step)if len(updated_split_all) > 0:callback_logging.writer.add_scalar(tag='Loss IRM', scalar_value=loss_irm.item(),global_step=global_step)if cfg.fp16:amp.scale(loss).backward()amp.unscale_(opt)torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)amp.step(opt)amp.update()else:loss.backward()torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)opt.step()opt.zero_grad()if cfg.step[0] > cfg.num_epoch:# use global iteration as the stepslr_scheduler.step(global_step)else:lr_scheduler.step(epoch=epoch)with torch.no_grad():loss_am.update(loss.item(), 1)callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp, acc)if global_step % cfg.verbose == 0 and global_step > 0:callback_verification(global_step, backbone)if rank == 0:path_module = os.path.join(cfg.output, f"model_{epoch}.pt")torch.save(backbone.module.state_dict(), path_module)if cfg.save_all_states:checkpoint = {"epoch": epoch + 1,"global_step": global_step,"state_dict_backbone": backbone.module.state_dict(),"state_dict_softmax_fc": module_fc.module.state_dict(),"state_optimizer": opt.state_dict(),"state_lr_scheduler": lr_scheduler.state_dict()}torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_{epoch}.pt"))callback_verification(global_step, backbone)if rank == 0:path_module = os.path.join(cfg.output, f"model_{epoch}.pt")torch.save(backbone.module.state_dict(), path_module)# convert model and save itfrom torch2onnx import convert_onnxconvert_onnx(backbone.module.cpu().eval(), path_module, os.path.join(cfg.output, "model.onnx"))distributed.destroy_process_group()

Run it with “main” f

if __name__ == "__main__":torch.backends.cudnn.benchmark = Trueparser = argparse.ArgumentParser(description="Distributed Training of InvReg in Pytorch")parser.add_argument("config", type=str, help="py config file")parser.add_argument("--local_rank", type=int, default=0, help="local_rank")main(parser.parse_args())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/792989.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024年150道高频Java面试题(十七)

33. Stream API 在 Java 8 中是如何简化集合操作的 Java 8 引入的 Stream API 是一个强大的新特性,它极大地简化了集合操作的复杂性,提高了代码的可读性和效率。Stream API 允许开发人员以声明式的方式处理数据集合(比如集合、数组等&#x…

1.Spring Boot框架整合

Spring Boot项目创建&#xff08;约定大于配置&#xff09; 2.1.3.RELEASE版本示例 idea创建 从官网下载&#xff08;https://start.spring.io/&#xff09;单元测试默认依赖不对时&#xff0c;直接删除即可 Web支持&#xff08;SpringMVC&#xff09; <dependency>&…

Collection与数据结构 链表与LinkedList(三):链表精选OJ例题(下)

1. 分割链表 OJ链接 class Solution {public ListNode partition(ListNode head, int x) {if(head null){return null;//空链表的情况}ListNode cur head;ListNode formerhead null;ListNode formerend null;ListNode latterhead null;ListNode latterend null;//定义…

见证历史:Quantinuum与微软取得突破性进展,开启了可靠量子计算的新时代!

Quantinuum与微软的合作取得了重大突破&#xff0c;将可靠量子计算带入了新的时代。他们结合了Quantinuum的System Model H2量子计算机和微软创新的量子比特虚拟化系统&#xff0c;在逻辑量子比特领域取得了800倍于物理电路错误率的突破。这一创新不仅影响深远&#xff0c;加速…

Java中Stream流介绍

Java 8引入的Stream API是Java中处理集合的一种高效方式&#xff0c;它提供了一种高级的迭代方式&#xff0c;允许你以声明式方式处理数据。Stream API可以对数据执行复杂的查询操作&#xff0c;而不需要编写冗长且复杂的循环语句。下面是一些使用Stream API的常见场景和示例&a…

Python数据分析与可视化笔记 九 分类问题

分类 分类是找出数据库中一组数据对象的共同特点&#xff0c;并按照分类模式将其划分为不同的类&#xff0c;其目的是通过分类模型&#xff0c;将数据库中的数据项映射到某个给定的类别。 分类学习是一类监督学习的问题&#xff0c;训练数据会包含其分类结果&#xff0c;根据分…

设计模式面试题(二)

1.单例优缺点 单例模式是一种常用的设计模式&#xff0c;它确保一个类仅有一个实例&#xff0c;并提供一个全局访问点。单例模式的使用具有一定的优点&#xff0c;同时也伴随着一些潜在的缺点。 优点 资源控制&#xff1a;单例模式能够确保资源如数据库连接或文件系统的一致…

目标检测——监控下的汽车

一、重要性及意义 首先&#xff0c;车辆检测技术是保证视频监控系统正常运行的基础。通过监控摄像头实时获取的图像&#xff0c;可以自动检测出图像中的车辆&#xff0c;并进行车辆类型的分类和识别。这对于优化城市交通管理、实现智能交通系统具有重要意义。此外&#xff0c;…

【无标题】html中使用div标签的坏处

在HTML中使用<div>作为布局元素时&#xff0c;尽管它已经成为现代Web开发的标准做法之一&#xff0c;并且与CSS结合使用可以实现灵活、语义化的布局设计&#xff0c;但也存在以下潜在的坏处或挑战&#xff1a; 复杂度增加&#xff1a; - 学习曲线&#xff1a;对于初学者…

【SQL Server的详细使用教程】

&#x1f3a5;博主&#xff1a;程序员不想YY啊 &#x1f4ab;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f917;点赞&#x1f388;收藏⭐再看&#x1f4ab;养成习惯 ✨希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出…

【QT+QGIS跨平台编译】056:【pdal_lepcc+Qt跨平台编译】(一套代码、一套框架,跨平台编译)

点击查看专栏目录 文章目录 一、pdal_lepcc介绍二、pdal下载三、文件分析四、pro文件五、编译实践一、pdal_lepcc介绍 pdal_lepcc 是 PDAL(Point Data Abstraction Library)的一个插件,用于点云数据的压缩。它基于 EPCC(Entwine Point Cloud Compression)算法,提供了对点…

Go语言实现Redis分布式锁

基于go-redis的设计与实现 本文将基于go语言&#xff0c;使用了一个常用的go Redis客户端 go-redis库 , 一步一步探索与实现一个简单的Redis分布式锁。 代码&#xff1a;https://github.com/liwook/Redislock 连接Redis ​ func NewClient() *redis.Client {return redis.N…

51单片机入门之独立按键

目录 1.按键简介 2.独立按键控制LED亮灭 3.独立按键控制LED移位 1.按键简介 在生活中&#xff0c;我们常常会见到各种按键&#xff0c;我们的开发板上也有按键&#xff0c;就在左下角有四个按键&#xff0c;我们把它们叫做独立按键。 独立按键的原理比较简单&…

VUE实现下一页的功能

实现步骤&#xff1a;1、确定分页参数&#xff1a;确定当前页码和每页显示的数量&#xff1b;2、获取数据&#xff1a;使用vue的axios或其他http库向后端发送请求&#xff0c;传递当前页码和每页显示的数量作为参数&#xff1b;3、更新数据&#xff1a;在vue组件中&#xff0c;…

Qt与OpenCV实现图像模板匹配

在 Qt 中使用 OpenCV 实现模板匹配可以通过集成 OpenCV 库和使用其相关函数来完成。以下是一般的步骤&#xff1a; 安装 OpenCV&#xff1a;首先&#xff0c;确保你已经安装了 OpenCV 库&#xff0c;并将其配置到你的开发环境中。 创建 Qt 项目&#xff1a;使用 Qt creator 或…

VSCode 插件 Todo Tree 待办事项

官方介绍&#xff1a;这个扩展可以快速搜索工作区中的注释标签&#xff0c;并将它们显示在活动栏的树状图中 我们写代码的时候&#xff0c;难免会遇到一些情况需要标记或搁置&#xff0c;比如&#xff1a;前端开发者在编写页面的时候页面样式完成了&#xff0c;但是后端接口还…

【机器学习】《机器学习算法竞赛实战》第7章用户画像

文章目录 第7章 用户画像7.1 什么是用户画像7.2 标签系统7.2.1 标签分类方式7.2.2 多渠道获取标签7.2.3 标签体系框架 7.3 用户画像数据特征7.3.1 常见的数据形式7.3.2 文本挖掘算法7.3.3 神奇的嵌入表示7.3.4 相似度计算方法 7.4 用户画像的应用7.4.1 用户分析7.4.2 精准营销7…

RabbitMQ安装详细教程

&#xff08;一&#xff09;在Windows系统上安装Erlang的步骤如下&#xff1a; 打开Erlang的官方下载页面&#xff0c;选择适合你的Windows系统的版本进行下载。 下载完成后&#xff0c;双击运行下载的.exe文件&#xff0c;进入Erlang的安装向导。 在安装向导中&#xff0c;按…

vscode-keil一起用

安装插件 1、C/C Extension Pack 2、Keil Assistant 配置 重启生效&#xff01;&#xff01;&#xff01; 下载安装 Mingw 下载链接&#xff1a; 添加环境变量&#xff1a; 注意确认&#xff01;&#xff01;&#xff01; 报错 gccC:\迅雷下载\MinGW\MinGW\bin…

力扣爆刷第111天之CodeTop100五连刷41-45

力扣爆刷第111天之CodeTop100五连刷41-45 文章目录 力扣爆刷第111天之CodeTop100五连刷41-45一、232. 用栈实现队列二、4. 寻找两个正序数组的中位数三、31. 下一个排列四、69. x 的平方根五、8. 字符串转换整数 (atoi) 一、232. 用栈实现队列 题目链接&#xff1a;https://le…