Invarient facial recongnition

for later~

代码阅读

1. 加载trainset

import argparse
import logging
import os
import numpy as npimport torch
from torch import distributed
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom backbones import get_model
from dataset import get_dataloader
from face_fc_ddp import FC_ddp
from utils.utils_callbacks import CallBackLogging, CallBackVerification
from utils.utils_config import get_config
from utils.utils_distributed_sampler import setup_seed
from utils.utils_logging import AverageMeter, init_loggingfrom utils.utils_invreg import env_loss_ce_ddp, assign_loss
from utils.utils_feature_saving import concat_feat, extract_feat_per_gpu
from utils.utils_partition import load_past_partitionassert torch.__version__ >= "1.9.0", "In order to enjoy the features of the new torch, \
we have upgraded the torch to 1.9.0. torch before than 1.9.0 may not work in the future."import datetimeos.environ["NCCL_BLOCKING_WAIT"] = "1"try:world_size = int(os.environ["WORLD_SIZE"])rank = int(os.environ["RANK"])distributed.init_process_group("nccl", timeout=datetime.timedelta(hours=3))
except KeyError:world_size = 1rank = 0distributed.init_process_group(backend="nccl",init_method="tcp://127.0.0.1:12584",rank=rank,world_size=world_size,)def main(args):cfg = get_config(args.config)setup_seed(seed=cfg.seed, cuda_deterministic=False)torch.cuda.set_device(args.local_rank)os.makedirs(cfg.output, exist_ok=True)init_logging(rank, cfg.output)summary_writer = (SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard"))if rank == 0else None)##################### Trainset definition ###################### only horizon-flip is used in transformstrain_loader = get_dataloader(cfg.rec,args.local_rank,cfg.batch_size,False,cfg.seed,cfg.num_workers,return_idx=True)

3. 定义backbone model,加载权重,并行化训练

    ##################### Model backbone definition #####################backbone = get_model(cfg.network, dropout=cfg.dropout, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()if cfg.resume:if rank == 0:dict_checkpoint = torch.load(os.path.join(cfg.pretrained, f"checkpoint_{cfg.pretrained_ep}.pt"))backbone.load_state_dict(dict_checkpoint["state_dict_backbone"])del dict_checkpointbackbone = torch.nn.parallel.DistributedDataParallel(module=backbone, broadcast_buffers=False, device_ids=[args.local_rank], bucket_cap_mb=16,find_unused_parameters=True)backbone.train()backbone._set_static_graph()

4. 分类函数+损失定义

    ##################### FC classification & loss definition ######################if cfg.invreg['irm_train'] == 'var':reduction = 'none'else:reduction = 'mean'module_fc = FC_ddp(cfg.embedding_size, cfg.num_classes, scale=cfg.scale,margin=cfg.cifp['m'], mode=cfg.cifp['mode'], use_cifp=cfg.cifp['use_cifp'],reduction=reduction).cuda()if cfg.resume:if rank == 0:dict_checkpoint = torch.load(os.path.join(cfg.pretrained, f"checkpoint_{cfg.pretrained_ep}.pt"))module_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])del dict_checkpointmodule_fc = torch.nn.parallel.DistributedDataParallel(module_fc, device_ids=[args.local_rank])module_fc.train().cuda()opt = torch.optim.SGD(params=[{"params": backbone.parameters()}, {"params": module_fc.parameters()}],lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay)##################### Train scheduler definition #####################cfg.total_batch_size = cfg.batch_size * world_sizecfg.num_image = len(train_loader.dataset)n_cls = cfg.num_classescfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epochcfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epochassert cfg.scheduler == 'step'from torch.optim.lr_scheduler import MultiStepLRlr_scheduler = MultiStepLR(optimizer=opt,milestones=cfg.step,gamma=0.1,last_epoch=-1)start_epoch = 0global_step = 0if cfg.resume:dict_checkpoint = torch.load(os.path.join(cfg.pretrained, f"checkpoint_{cfg.pretrained_ep}.pt"),map_location={'cuda:0': f'cuda:{rank}'})start_epoch = dict_checkpoint["epoch"]global_step = dict_checkpoint["global_step"]opt.load_state_dict(dict_checkpoint["state_optimizer"])del dict_checkpoint
  • dict_checkpoint是 检查点的信息,用字典存储

5. 评估定义

    ##################### Evaluation definition #####################callback_verification = CallBackVerification(val_targets=cfg.val_targets, rec_prefix=cfg.val_rec, summary_writer=summary_writer)callback_logging = CallBackLogging(frequent=cfg.frequent,total_step=cfg.total_step,batch_size=cfg.batch_size,start_step=global_step,writer=summary_writer)loss_am = AverageMeter()amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)updated_split_all = []for key, value in cfg.items():num_space = 25 - len(key)logging.info(": " + key + " " * num_space + str(value))loss_weight_irm_init = cfg.invreg['loss_weight_irm']

6. 训练迭代

    ##################### Training iterations #####################if cfg.resume:callback_verification(global_step, backbone)for epoch in range(start_epoch, cfg.num_epoch):if cfg.invreg['loss_weight_irm_anneal'] and cfg.invreg['loss_weight_irm'] > 0:cfg.invreg['loss_weight_irm'] = loss_weight_irm_init * (1 + 0.09) ** (epoch - 5)if epoch in cfg.invreg['stage'] and cfg.invreg['loss_weight_irm'] > 0:cfg.invreg['env_num'] = cfg.invreg['env_num_lst'][cfg.invreg['stage'].index(epoch)]save_dir = os.path.join(cfg.output, 'saved_feat', 'epoch_{}'.format(epoch))if os.path.exists(os.path.join(save_dir, 'final_partition.npy')):logging.info('Loading the past partition...')updated_split_all = load_past_partition(cfg, epoch)logging.info(f'Total {len(updated_split_all)} partition are loaded...')else:if os.path.exists(os.path.join(save_dir, 'feature.npy')):logging.info('Loading the pre-saved features...')else:# extract features for each gpuextract_feat_per_gpu(backbone, cfg, args, save_dir)if rank == 0:_, _ = concat_feat(cfg.num_image, world_size, save_dir)distributed.barrier()emb = np.load(os.path.join(save_dir, 'feature.npy'))lab = np.load(os.path.join(save_dir, 'label.npy'))# conduct partition learninglogging.info('Started partition learning...')from utils.utils_partition import update_partitionupdated_split = update_partition(cfg, save_dir, n_cls, emb, lab, summary_writer,backbone.device, rank, world_size)del emb, labdistributed.barrier()updated_split_all.append(updated_split)if isinstance(train_loader, DataLoader):train_loader.sampler.set_epoch(epoch)for _, (index, img, local_labels) in enumerate(train_loader):global_step += 1local_embeddings = backbone(img)# cross-entropy lossif cfg.invreg['irm_train'] == 'var':loss_ce_tensor, acc = module_fc(local_embeddings, local_labels, return_logits=False)loss_ce = loss_ce_tensor.mean()loss = loss_ceelif cfg.invreg['irm_train'] == 'grad':loss_ce, acc, logits = module_fc(local_embeddings, local_labels, return_logits=True)loss = loss_ce# IRM lossif len(updated_split_all) > 0:if cfg.invreg['irm_train'] == 'grad':loss_irm = env_loss_ce_ddp(logits, local_labels, world_size, cfg, updated_split_all, epoch)elif cfg.invreg['irm_train'] == 'var':import dist_all_gatherloss_total_lst = dist_all_gather.all_gather(loss_ce_tensor)label_total_lst = dist_all_gather.all_gather(local_labels)loss_total = torch.cat(loss_total_lst, dim=0)label_total = torch.cat(label_total_lst, dim=0)loss_irm_lst = []for updated_split in updated_split_all:n_env = updated_split.size(-1)loss_env_lst = []for env_idx in range(n_env):loss_env = assign_loss(loss_total, label_total, updated_split, env_idx)loss_env_lst.append(loss_env.mean())loss_irm_lst.append(torch.stack(loss_env_lst).var())loss_irm = sum(loss_irm_lst) / len(updated_split_all)else:print('Please check the IRM train mode')loss += loss_irm * cfg.invreg['loss_weight_irm']if rank == 0:callback_logging.writer.add_scalar(tag='Loss CE', scalar_value=loss_ce.item(),global_step=global_step)if len(updated_split_all) > 0:callback_logging.writer.add_scalar(tag='Loss IRM', scalar_value=loss_irm.item(),global_step=global_step)if cfg.fp16:amp.scale(loss).backward()amp.unscale_(opt)torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)amp.step(opt)amp.update()else:loss.backward()torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)opt.step()opt.zero_grad()if cfg.step[0] > cfg.num_epoch:# use global iteration as the stepslr_scheduler.step(global_step)else:lr_scheduler.step(epoch=epoch)with torch.no_grad():loss_am.update(loss.item(), 1)callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp, acc)if global_step % cfg.verbose == 0 and global_step > 0:callback_verification(global_step, backbone)if rank == 0:path_module = os.path.join(cfg.output, f"model_{epoch}.pt")torch.save(backbone.module.state_dict(), path_module)if cfg.save_all_states:checkpoint = {"epoch": epoch + 1,"global_step": global_step,"state_dict_backbone": backbone.module.state_dict(),"state_dict_softmax_fc": module_fc.module.state_dict(),"state_optimizer": opt.state_dict(),"state_lr_scheduler": lr_scheduler.state_dict()}torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_{epoch}.pt"))callback_verification(global_step, backbone)if rank == 0:path_module = os.path.join(cfg.output, f"model_{epoch}.pt")torch.save(backbone.module.state_dict(), path_module)# convert model and save itfrom torch2onnx import convert_onnxconvert_onnx(backbone.module.cpu().eval(), path_module, os.path.join(cfg.output, "model.onnx"))distributed.destroy_process_group()

Run it with “main” f

if __name__ == "__main__":torch.backends.cudnn.benchmark = Trueparser = argparse.ArgumentParser(description="Distributed Training of InvReg in Pytorch")parser.add_argument("config", type=str, help="py config file")parser.add_argument("--local_rank", type=int, default=0, help="local_rank")main(parser.parse_args())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/794259.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows 2008虚拟机安装、安装VM Tools、快照和链接克隆、添加硬盘修改格式为GPT

一、安装vmware workstation软件 VMware workstation的安装介质,获取路径: 链接:https://pan.baidu.com/s/1AUAw_--yjZAUPbsR7StOJQ 提取码:umz1 所在目录:\vmware\VMware workstation 15.1.0 1.找到百度网盘中vmwa…

Mysql的基本命令

1 服务相关命令 命令描述systemctl status mysql查看MySQL服务的状态systemctl stop mysql停止MySQL服务systemctl start mysql启动MySQL服务systemctl restart mysql重启MySQL服务ps -ef | grep mysql查看mysql的进程mysql -uroot -hlocalhost -p123456登录MySQLhelp显示MySQ…

C++初阶:vector类的模拟实现(含模板)

目录 1.Vector.h 2.Test.cpp 1.Vector.h #include <iostream> #include <cassert> #include <algorithm> using namespace std; template <class T> class Vector { public:typedef T *iterator;typedef const T *const_iterator;iterator begin() …

8.list容器的使用

文章目录 list容器1.构造函数代码工程运行结果 2.赋值和交换代码工程运行结果 3.大小操作代码工程运行结果 4.插入和删除代码工程运行结果 5.数据存取工程代码运行结果 6.反转和排序代码工程运行结果 list容器 1.构造函数 /*1.默认构造-无参构造*/ /*2.通过区间的方式进行构造…

java-网络编程socket-聊天室-03

完整版代码 java -聊天室的代码: 用于存放聊天室的项目的代码和思路导图https://gitee.com/to-uphold-justice-for-others/java---code-for-chat-rooms.git 多线程并发问题 多线程的并发问题主要出现在当一个程序涉及多个线程同时运行时&#xff0c;这些线程可能会同时访问共…

conda删除环境指令

要删除Conda环境&#xff0c;你可以使用下面的指令&#xff1a; conda remove --name your_env_name --all这里的 your_env_name 是你想要删除的环境名称。这条指令会删除指定的环境及其下的所有包。 如果你正在使用的是Anaconda Prompt或者是命令行界面&#xff0c;确保你有…

Java:接口应用(Clonable 接口和深拷贝)

目录 1.引例2.Object中clone方法的实现3.Cloneable接口讲解4.深拷贝和浅拷贝4.1浅拷贝4.2深拷贝 1.引例 Java 中内置了一些很有用的接口, Clonable 就是其中之一. Object 类中存在一个 clone 方法, 调用这个方法可以创建一个对象的 “拷贝”. 但是要想合法调用 clone 方法。必…

精密电阻阻值表和电容容值表

前面2张是电阻阻值表&#xff08;E-96/0603/1%&#xff09; 常见贴片电容的容值表

解决windows下Qt Creator显示界面过大的问题

&#x1f40c;博主主页&#xff1a;&#x1f40c;​倔强的大蜗牛&#x1f40c;​ &#x1f4da;专栏分类&#xff1a;QT❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 目录 问题描述 解决方法 1、右击此电脑--->属性 2、点击高级系统设置--->点击环境变量 3、 找到系…

【美团笔试题汇总】2023-08-26-美团春秋招笔试题-三语言题解(CPP/Python/Java)

&#x1f36d; 大家好这里是KK爱Coding &#xff0c;一枚热爱算法的程序员 ✨ 本系列打算持续跟新小米近期的春秋招笔试题汇总&#xff5e; &#x1f4bb; ACM银牌&#x1f948;| 多次AK大厂笔试 &#xff5c; 编程一对一辅导 &#x1f44f; 感谢大家的订阅➕ 和 喜欢&#x1f…

想要安装ssh?

SSH&#xff08;Secure Shell&#xff09;是一种加密的网络协议&#xff0c;用于在不安全的网络上安全地进行远程登录和执行命令。它通过加密通信和身份验证机制&#xff0c;确保用户和系统之间的通信是安全的。 SSH协议的主要功能包括&#xff1a; 加密通信&#xff1a;SSH使…

虚良SEO-蜘蛛池的作用与工作原理

蜘蛛池是一种SEO优化工具&#xff0c;其主要作用是吸引搜索引擎蜘蛛到特定网站进行爬行和索引&#xff0c;从而提高网站的可见性和排名。下面分别介绍蜘蛛池的作用和工作原理。 蜘蛛池的作用&#xff1a; 提高网站收录&#xff1a; 当一个网站新发布时&#xff0c;或者长时间…

哈佛大学商业评论 --- 第三篇:真实世界中的增强现实

AR将全面融入公司发展战略&#xff01; AR将成为人类和机器之间的新接口&#xff01; AR将成为人类的关键技术之一&#xff01; 请将此文转发给您的老板&#xff01; --- 本文作者&#xff1a;Michael E.Porter和James E.Heppelmann 虽然物理世界是三维的&#xff0c;但大…

如何查询大数据信用报告?查询时需要注意的事项有哪些?

在数字化时代&#xff0c;大数据信用评分对于需要资金周转的个人或企业来说至关重要。许多机构在贷款审批过程中使用大数据信用评分作为风险控制的重要手段。因此&#xff0c;了解自己的大数据信用状况成为了常规操作。本文将详细介绍如何查询大数据信用报告以及查询时需要注意…

ES6: class类

类 class 面相对象class关键字创建类关于类的继承 面相对象 一切皆对象。 举例&#xff1a; 操作浏览器要使用window对象&#xff1b;操作网页要使用document对象&#xff1b;操作控制台要使用console对象&#xff1b; ES6中增加了类的概念&#xff0c;其实ES5中已经可以实现类…

什么是 内网穿透

内网穿透是一种技术手段&#xff0c;用于在内部网络&#xff08;如家庭网络或公司网络&#xff09;中的设备能够被外部网络访问和控制。它允许将位于私有网络中的设备暴露在公共网络&#xff08;如互联网&#xff09;上&#xff0c;从而实现远程访问和管理。 内网穿透通常通过…

【云计算】云网络是未来的网络基础设施

云网络是未来的网络基础设施 1.什么是云网络2.云网络具备云的特征2.1 资源共享2.2 弹性伸缩2.3 自助服务2.4 可计量2.5 连接无处不在2.6 兼容性 3.云网络改变商业模式4.DevOps 变革产业生态 1.什么是云网络 到底什么是云网络&#xff1f;它和传统的网络有什么不同&#xff1f;…

震惊!子元素的padding/margin-top依据的是父元素的宽度!

子元素的margin-top和padding-top都是以父元素的宽度为基准 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" />…

7.java openCV4.x 入门-Mat之转换、重塑与计算

专栏简介 &#x1f492;个人主页 &#x1f4f0;专栏目录 点击上方查看更多内容 &#x1f4d6;心灵鸡汤&#x1f4d6;我们唯一拥有的就是今天&#xff0c;唯一能把握的也是今天建议把本文当作笔记来看&#xff0c;据说专栏目录里面有相应视频&#x1f92b; &#x1f9ed;文…

Mybatis 传参方式

1、mybatis简介 Mybatis 是⼀个半 ORM&#xff08;对象关系映射&#xff09;框架&#xff0c;它内部封装了 JDBC&#xff0c;开发时只需要关注 SQL 语句本身&#xff0c;不 需要花费精⼒去处理加载驱动、创建连接、创建statement 等繁杂的过程。程序员直接编写原⽣态 sql&#…