【代码整理】Pytorch从0实现图像分类pipeline

文章目录

  • 引言
  • 1.数据集读取部分`dataloader.py`
    • 1.1.分类数据集的数据组织形式
    • 1.2自定义数据增强/数据预处理类
    • 1.3.重写`torch.utils.data.Dataset`数据集读取类
    • 1.4.模块测试样例
  • 2.网络部分`mynet.py`
    • 2.1.自定义分类网络`torch.nn.Module`
    • 2.2.模块测试样例
  • 3.训练/验证/测试模块`runner.py`
    • 3.1.`Runner`类初始化
      • 3.1.1.日志模块初始化
      • 3.1.2.自定义变量记录类
      • 3.1.3.定义优化器和学习率衰减策略
    • 3.2.训练pipeline `trainer`
      • 3.2.1.训练一个epoch的pipeline `fitEpoch`
        • 3.2.1.1.`recoardArgs`
        • 3.2.1.2.`recordTensorboardLog`
        • 3.2.1.3.`printLog`
      • 3.2.2.评估一个epoch的pipeline `evaler`
      • 3.2.3.保存网络权重 `saveCkpt`
    • 3.3.验证pipeline `evaler`
      • 3.3.1.推理得到网络预测结果`eval`
      • 3.3.2.可视化混淆矩阵`showComMatrix`
      • 3.3.3.绘制类别的PR曲线`drawPRCurve`
      • 3.3.4.计算每个类别的 AP, F1Score`clacAP`
      • 3.3.4.可视化训练过程中保存的参数`visArgsHistory`
    • 3.4.测试pipeline `tester`
    • 3.5.其他
      • 3.5.1 获取命令行参数
      • 3.5.2 根据给定路径动态import模块(config.py)
      • 3.5.3 config.py样例

引言

本篇博客可以看做是对:

pytorch实现手写英文字母识别 和 Pytorch搭建预训练VGG16实现10 Monkey Species Classification

这两篇博客中代码的重构和一些细节上的调整,并对功能类似的部分进行了模块化封装

● 还未实现的部分:混合精度训练、多GPU训练等(待补充)。

● 任何逻辑不完善的地方,欢迎指出或讨论。

● 后续随缘更新。

完整代码可在github获取👇https://github.com/Scienthusiasts/Classification_pytorch
若对你有帮助,不妨star支持一下

1.数据集读取部分dataloader.py

1.1.分类数据集的数据组织形式

images
├─valid
│ ├─apple_pie
│ ├─baby_back_ribs
│ … …
│ └─waffles
└─train
├─apple_pie
├─baby_back_ribs

​ … …

​ └─waffles

images为图像根目录,train为训练集图像, valid为验证集图像,对应的子目录以类别命名,用于存储不同类别的图像。

1.2自定义数据增强/数据预处理类

数据增强/预处理方法分为三类,分别是训练时增强,验证时增强和测试时增强。训练时增强包含最全的数据增强操作,并依概率随机对每张图像执行;验证时增强只保留最基础的数据预处理方法,不包含数据增强,测试时增强只针对最终的可视化,不包含对图像的归一化处理。

class Transforms():'''数据预处理/数据增强(基于albumentations库)'''def __init__(self, imgSize):# 训练时增强self.trainTF = A.Compose([# 随机旋转A.Rotate(limit=15, p=0.5),# 最长边限制为imgSizeA.LongestMaxSize(max_size=imgSize),# 随机镜像A.HorizontalFlip(p=0.5),# 参数:随机色调、饱和度、值变化A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=False, p=0.5),# 随机明亮对比度A.RandomBrightnessContrast(p=0.2),   # 高斯噪声A.GaussNoise(var_limit=(0.05, 0.09), p=0.4),     A.OneOf([# 使用随机大小的内核将运动模糊应用于输入图像A.MotionBlur(p=0.2),   # 中值滤波A.MedianBlur(blur_limit=3, p=0.1),    # 使用随机大小的内核模糊输入图像A.Blur(blur_limit=3, p=0.1),  ], p=0.2),# 较短的边做paddingA.PadIfNeeded(imgSize, imgSize, border_mode=cv2.BORDER_CONSTANT, value=[0,0,0]),A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),])# 验证时增强self.validTF = A.Compose([# 最长边限制为imgSizeA.LongestMaxSize(max_size=imgSize),# 较短的边做paddingA.PadIfNeeded(imgSize, imgSize, border_mode=0, mask_value=[0,0,0]),A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),])# 可视化增强(只reshape)self.visTF = A.Compose([# 最长边限制为imgSizeA.LongestMaxSize(max_size=imgSize),# 较短的边做paddingA.PadIfNeeded(imgSize, imgSize, border_mode=0, mask_value=[0,0,0]),])

1.3.重写torch.utils.data.Dataset数据集读取类

基本逻辑就是遍历数据集下每个类别对应的文件夹,并获取文件夹中的图像,图像的类别(标签)根据图像所在的文件夹划分。

class MyDataset(data.Dataset):      '''有监督分类任务对应的数据集读取方式'''def __init__(self, dir, mode, imgSize):    '''__init__() 为默认构造函数,传入数据集类别(训练或测试),以及数据集路径Args::param dir:     图像数据集的根目录:param mode:    模式(train/valid):param imgSize: 网络要求输入的图像尺寸Returns:precision, recall'''      self.tf = Transforms(imgSize = imgSize)# 记录数据集大小self.dataSize = 0      # 数据集类别数      self.labelsNum = len(os.listdir(os.path.join(dir, mode)))           # 训练/验证 self.mode = mode              # 数据预处理方法self.tf = Transforms(imgSize=imgSize)# 遍历所有类别self.imgPathList, self.labelList = [], []'''对类进行排序,很重要!!!,否则会造成分类时标签匹配不上导致评估的精度很低(默认按字符串,如果类是数字还需要更改)'''catDirs = sorted(os.listdir(os.path.join(dir, mode)))for idx, cat in enumerate(catDirs):catPath = os.path.join(dir, mode, cat)labelFiles = os.listdir(catPath)# 每个类别里图像数length = len(labelFiles)# 存放图片路径self.imgPathList += [os.path.join(catPath, labelFiles[i]) for i in range(length)]# 存放图片对应的标签(根据所在文件夹划分)self.labelList += [idx for _ in range(length)]self.dataSize += length        def __getitem__(self, item):  '''重载data.Dataset父类方法, 获取数据集中数据内容'''   # 读取图片img = Image.open(self.imgPathList[item]).convert('RGB')     img = np.array(img)# 获取image对应的labellabel = self.labelList[item]                 # 数据预处理/数据增强if self.mode=='train':transformed = self.tf.trainTF(image=img)if self.mode=='valid':transformed = self.tf.validTF(image=img)          img = transformed['image']    return img.transpose(2,1,0), torch.LongTensor([label])def __len__(self):'''重载data.Dataset父类方法, 返回数据集大小'''return self.dataSizedef get_cls_num(self):'''返回数据集类别数'''return self.labelsNum

1.4.模块测试样例

# for test only
if __name__ == '__main__':datasetDir = 'E:/datasets/Classification/food-101/images'mode = 'train'bs = 64seed = 22seed_everything(seed)train_data = MyDataset(datasetDir, mode, imgSize=224)print(f'数据集大小:{train_data.__len__()}')print(f'数据集类别数:{train_data.get_cls_num()}')train_data_loader = data.DataLoader(dataset = train_data, batch_size=bs, shuffle=True)# 获取label namecatNames = sorted(os.listdir(os.path.join(datasetDir, mode)))# 可视化一个batch里的图像from utils import visBatchvisBatch(train_data_loader, catNames)# 输出数据格式for step, batch in enumerate(train_data_loader):print(batch[0].shape, batch[1].shape)break

输出:

在这里插入图片描述

数据集大小:75750
数据集类别数:101
torch.Size([64, 3, 224, 224]) torch.Size([64, 1])

2.网络部分mynet.py

网络模块基于微调timm库里提供的模型,基本的逻辑就是将原来模型的分类头的分类数替换为当前数据集的分类数,Backbone部分保持不变,并使用ImageNet预训练权重初始化,训练时可以冻结Backbone只训练分类头,或者微调整个网络。

timm库里提供的模型名称和权重可以从huggingface中获取:https://huggingface.co/timm?sort_models=downloads#models

2.1.自定义分类网络torch.nn.Module

想要添加更多的Backbone,可以在modelList和分支语句中添加相应内容:

class Model(nn.Module):'''Backbone'''def __init__(self, catNums:int, modelType:str, loadckpt=False, pretrain=True, froze=True):'''网络初始化Args::param catNums:   数据集类别数:param modelType: 使用哪个模型(timm库里的模型):param loadckpt:  是否导入模型权重:param pretrain:  是否用预训练模型进行初始化(是则输入权重路径):param froze:     是否只训练分类头Returns:None'''super(Model, self).__init__()# 模型接到线性层的维度modelList = {'mobilenetv3_small_100.lamb_in1k':            1024, 'mobilenetv3_large_100.ra_in1k':              1280,'vit_base_patch16_224.augreg2_in21k_ft_in1k': 768, 'efficientnet_b5.sw_in12k_ft_in1k':           2048,'resnet50.a1_in1k':                           2048,'vgg16.tv_in1k':                              4096,}# 加载模型self.backbone = timm.create_model(modelType, pretrained=pretrain)# 删除原来的分类头并添加新的分类头(self.backbone就是去除了分类头的原始完整模型)baseModel = modelType.split('.')[0]if(baseModel in ['mobilenetv3_small_100', 'mobilenetv3_large_100', 'efficientnet_b5']):self.backbone.classifier = nn.Identity()self.head = nn.Linear(modelList[modelType], catNums)if(baseModel=='vit_base_patch16_224'):self.backbone.head = nn.Identity()self.head = nn.Linear(modelList[modelType], catNums)if(baseModel=='resnet50'):self.backbone.fc = nn.Identity()self.head = nn.Linear(modelList[modelType], catNums)if(baseModel=='vgg16'):self.backbone.head.fc = nn.Identity()self.head = nn.Linear(modelList[modelType], catNums)# 是否导入预训练权重if loadckpt: self.load_state_dict(torch.load(loadckpt))# 是否只训练线性层if froze:for param in self.backbone.parameters():param.requires_grad_(False)def forward(self, x):'''前向传播'''feat = self.backbone(x)out = self.head(feat)return out

2.2.模块测试样例

# for test only
if __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(device)model = Model(catNums=101, modelType='mobilenetv3_large_100.ra_in1k', pretrain=True, froze=True).to(device)'''验证 1'''# print(model)'''验证 2'''# summary(model, input_size=[(3, 224, 224)])  '''验证 3'''x = torch.rand((4, 3, 600, 600)).to(device)out = model(x)print(out.shape)

3.训练/验证/测试模块runner.py

所有的训练、验证(一个epoch结束)和测试(推理一张图像)的pipeline集成在自定义的Runner类中

3.1.Runner类初始化

在Runner类的初始化阶段,一些与训练有关的模块会被定义,比如日志模块(用于训练时实时打印训练情况)、tensorboard模块、数据集,模型、损失函数、是否恢复断点等等。基于传入参数mode的不同,初始化的模块也会有所不同。

    def __init__(self, timm_model_name, img_size, ckpt_load_path, dataset_dir, epoch, bs, lr, log_dir, log_interval, pretrain, froze, optim_type, mode, resume=None, seed=0):'''Runner初始化Args::param timm_model_name: 模型名称(timm):param img_size:        统一图像尺寸的大小:param ckpt_load_path:  预加载的权重路径:param dataset_dir:     数据集根目录:param eopch:           训练批次:param bs:              训练batch size:param lr:              学习率:param log_dir:         日志文件保存目录:param log_interval:    训练或验证时隔多少bs打印一次日志:param pretrain:        backbone是否用ImageNet预训练权重初始化:param froze:           是否冻结Backbone只训练分类头:param optim_type:      优化器类型:param mode:            训练模式:train/eval/test:param resume:          是否从断点恢复训练:param seed:            固定全局种子Returns:None'''# 设置全局种子seed_everything(seed)self.timm_model_name = timm_model_nameself.img_size = img_sizeself.ckpt_load_path = ckpt_load_pathself.dataset_dir = dataset_dirself.epoch = epochself.bs = bsself.lr = lrself.log_dir = log_dirself.log_interval = log_intervalself.pretrain = pretrainself.froze = frozeself.mode = modeself.optim_type = optim_typeself.cats = os.listdir(os.path.join(self.dataset_dir, 'valid'))self.cls_num = len(self.cats)'''GPU/CPU'''self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')'''日志模块'''if mode == 'train' or mode == 'eval':self.logger, self.log_save_path = self.myLogger()'''训练/验证时参数记录模块'''json_save_dir, _ = os.path.split(self.log_save_path)self.argsHistory = ArgsHistory(json_save_dir)'''实例化tensorboard summaryWriter对象'''if mode == 'train':self.tb_writer = SummaryWriter(log_dir=os.path.join(self.log_dir, self.log_save_path.split('.')[0]))'''导入数据集'''if mode == 'train':# 导入训练集self.train_data = MyDataset(dataset_dir, 'train', imgSize=img_size)self.train_data_loader = DataLoader(dataset = self.train_data, batch_size=bs, shuffle=True, num_workers=2)if mode == 'train' or mode == 'eval':# 导入验证集self.val_data = MyDataset(dataset_dir, 'valid', imgSize=img_size)self.val_data_loader = DataLoader(dataset = self.val_data, batch_size=1, shuffle=False, num_workers=2)'''导入模型'''self.model = Model(catNums=self.cls_num, modelType=timm_model_name, loadckpt=ckpt_load_path, pretrain=pretrain, froze=froze).to(self.device)'''定义损失函数(多分类交叉熵损失)'''if mode == 'train' or mode == 'eval':self.loss_func = nn.CrossEntropyLoss()'''定义优化器(自适应学习率的带动量梯度下降方法)'''if mode == 'train':self.optimizer, self.scheduler = self.defOptimSheduler()'''当恢复断点训练'''self.start_epoch = 0if resume != None:checkpoint = torch.load(resume)self.start_epoch = checkpoint['epoch'] + 1 # +1是因为从当前epoch的下一个epoch开始训练self.model.load_state_dict(checkpoint['model_state_dict'])self.optimizer.load_state_dict(checkpoint['optim_state_dict'])self.scheduler.load_state_dict(checkpoint['sched_state_dict'])# 导入上一次中断训练时的argsjson_dir, _ = os.path.split(resume)self.argsHistory.loadRecord(json_dir)# 打印日志if mode == 'train':self.logger.info('训练集大小:   %d' % self.train_data.__len__())if mode == 'train' or mode == 'eval':self.logger.info('验证集大小:   %d' % self.val_data.__len__())self.logger.info('数据集类别数: %d' % self.cls_num)if mode == 'train':self.logger.info(f'损失函数: {self.loss_func}')self.logger.info(f'优化器: {self.optimizer}')if mode == 'train' or mode == 'eval':self.logger.info(f'全局种子: {seed}')self.logger.info('='*100)

由于其中有些初始化方法过于冗长,因此封装成为类中的方法:

3.1.1.日志模块初始化

日志模块基于logging库,会初始化以下内容:

  1. 定义文件日志(这部分日志会写入日志文件)
  2. 定义终端日志(这部分日志会打印在终端上)
  3. 定义文件日志保存路径(根据self.mode的不同而不同)
def myLogger(self):'''生成日志对象'''logger = logging.getLogger('runer')logger.setLevel(level=logging.DEBUG)# 日志格式formatter = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s')if self.mode == 'train':# 写入文件的日志self.log_dir = os.path.join(self.log_dir, f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}_train")# 日志文件保存路径log_save_path = os.path.join(self.log_dir, f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}_train.log")if self.mode == 'eval':# 写入文件的日志self.log_dir = os.path.join(self.log_dir, f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}_val")# 日志文件保存路径log_save_path = os.path.join(self.log_dir, f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}_val.log")if not os.path.isdir(self.log_dir):os.makedirs(self.log_dir)file_handler = logging.FileHandler(log_save_path, encoding="utf-8", mode="a")file_handler.setLevel(level=logging.INFO)file_handler.setFormatter(formatter)logger.addHandler(file_handler)# 终端输出的日志stream_handler = logging.StreamHandler()stream_handler.setLevel(logging.INFO)stream_handler.setFormatter(formatter)logger.addHandler(stream_handler)return logger, log_save_path

3.1.2.自定义变量记录类

ArgsHistory以iter为最小记录单位,记录train或val过程中的一些变量(比如 loss, acc, lr等),并将记录内容在每个epoch结束以json文件保存。可以方便在训练结束后对这些变量进行可视化。

ArgsHistory.recoard方法通过传参自动添加新变量,或在已有变量列表的末尾进行更新,无需提前定义变量名。

class ArgsHistory():'''记录train或val过程中的一些变量(比如 loss, acc, lr等)'''def __init__(self, json_save_dir):self.json_save_dir = json_save_dirself.args_history_dict = {}def record(self, key, value):'''记录argsArgs::param key:   要记录的当前变量的名字:param value: 要记录的当前变量的数值Returns:None'''# 可能存在json格式不支持的类型, 因此统一转成float类型value = float(value)# 如果日志中还没有这个变量,则新建if key not in self.args_history_dict.keys():self.args_history_dict[key] = []# 更新self.args_history_dict[key].append(value)def saveRecord(self):'''以json格式保存args'''if not os.path.isdir(self.json_save_dir):os.makedirs(self.json_save_dir) json_save_path = os.path.join(self.json_save_dir, 'args_history.json')# 保存with open(json_save_path, 'w') as json_file:json.dump(self.args_history_dict, json_file)def loadRecord(self, json_load_dir):'''导入上一次训练时的args(一般用于resume)'''json_path = os.path.join(json_load_dir, 'args_history.json')with open(json_path, "r", encoding="utf-8") as json_file:self.args_history_dict = json.load(json_file)

3.1.3.定义优化器和学习率衰减策略

优化器支持pytorch官方提供的sgd, adam, adamw优化器,优化策略基于timm.scheduler.CosineLRScheduler,采用 warmup+余弦退火。

def defOptimSheduler(self):'''定义优化器和学习率衰减策略'''optimizer = {# adam会导致weight_decay错误,使用adam时建议设置为 0'adamw' : torch.optim.AdamW(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=0),'adam' : torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999), weight_decay=0),'sgd'  : torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9, nesterov=True, weight_decay=0)}[self.optim_type]# 使用warmup+余弦退火学习率scheduler = CosineLRScheduler(optimizer=optimizer,t_initial=self.epoch*len(self.train_data_loader),          # 总迭代数lr_min=self.lr*0.01,                                       # 余弦退火最低的学习率warmup_t=round(self.epoch/12)*len(self.train_data_loader), # 学习率预热阶段的epoch数量warmup_lr_init=self.lr*0.01,                               # 学习率预热阶段的lr起始值)return optimizer, scheduler

3.2.训练pipeline trainer

训练pipeline的基本的流程如下:

训练一个epoch(训练集)→在验证集上评估→每个epoch结束保存checkpoint→每个epoch结束打印日志(评估结果)

当所有epoch结束时:重新在验证集上使用最佳权重评估→打印或可视化各种评估指标

def trainer(self):'''把pytorch训练代码独自分装成一个函数'''for epoch in range(self.start_epoch, self.epoch):'''一个epoch的训练'''self.fitEpoch(epoch)'''一个epoch的验证'''self.evaler(epoch, self.log_dir)'''保存网络权重'''self.saveCkpt(epoch)'''打印日志(一个epoch结束)'''self.printLog('epoch', 0, epoch, len(self.val_data_loader))'''结果评估'''self.model.load_state_dict(torch.load(os.path.join(self.log_dir, 'best.pt')))# 评估各种指标self.evaler(self.log_dir)

3.2.1.训练一个epoch的pipeline fitEpoch

主要包括前向反向,更新梯度和学习率, 记录训练时变量,打印日志等步骤。

def fitEpoch(self, epoch):'''对一个epoch进行训练的流程'''self.model.train()# 一个Epoch包含几轮Batchtrain_batch_num = len(self.train_data_loader)for step, batch in enumerate(self.train_data_loader):# [bs, channel, w, h] -> [bs, w*h, channel]with torch.no_grad():x = batch[0].to(self.device)y = batch[1].to(self.device).reshape(-1)   # 标签[batch_size, 1]# 前向传播output = self.model(x) # [batchsize, cls_num]# 计算lossloss = self.loss_func(output, y)# 预测结果对应置信最大的那个下标pre_lab = torch.argmax(output, 1)# 计算一个batchsize的准确率train_acc = torch.sum(pre_lab == y.data) / x.shape[0]# 记录args(lr, loss, acc)self.recoardArgs(mode='train', loss=loss.item(), acc=train_acc)# 记录tensorboardself.recordTensorboardLog('train', epoch, train_batch_num, step)# 打印日志self.printLog('train', step, epoch, train_batch_num) # 将上一次迭代计算的梯度清零self.optimizer.zero_grad()# 反向传播计算梯度loss.backward()       # 更新参数self.optimizer.step() # 更新学习率self.scheduler.step(epoch * train_batch_num + step) 
3.2.1.1.recoardArgs
def recoardArgs(self, mode, loss=None, acc=None, mAP=None, mF1Score=None):'''训练/验证过程中记录变量(每个iter都会记录, 不间断)Args::param mode: 模式(train, epoch):param loss: 损失:param acc:  准确率Returns:None'''        if mode == 'train':current_lr = self.optimizer.param_groups[0]['lr']self.argsHistory.record('lr', current_lr)self.argsHistory.record('train_loss', loss)self.argsHistory.record('train_acc', acc)# 一个epoch结束后val评估结果的平均值if mode == 'epoch':self.argsHistory.record('mean_val_acc', acc)self.argsHistory.record('val_mAP', mAP)self.argsHistory.record('val_mF1Score', mF1Score)self.argsHistory.saveRecord()
3.2.1.2.recordTensorboardLog
def recordTensorboardLog(self, mode, epoch, batch_num=None, step=None):'''训练过程中记录tensorBoard日志Args::param mode:       模式(train, val, epoch):param step:       当前迭代到第几个batch:param batch_num:  当前batch的大小Returns:None'''    if mode == 'train':step = epoch * batch_num + steploss = self.argsHistory.args_history_dict['train_loss'][-1]acc = self.argsHistory.args_history_dict['train_acc'][-1]self.tb_writer.add_scalar('train_loss', loss, step)self.tb_writer.add_scalar('train_acc', acc, step)if mode == 'epoch':acc = self.argsHistory.args_history_dict['mean_val_acc'][-1]mAP = self.argsHistory.args_history_dict['val_mAP'][-1]mF1Score = self.argsHistory.args_history_dict['val_mF1Score'][-1]self.tb_writer.add_scalar('mean_valid_acc', acc, epoch)self.tb_writer.add_scalar('valid_mAP', mAP, epoch)self.tb_writer.add_scalar('valid_mF1Score', mF1Score, epoch)

可视化tensorboard:

在这里插入图片描述

3.2.1.3.printLog
def printLog(self, mode, step, epoch, batch_num):'''训练/验证过程中打印日志Args::param mode:       模式(train, val, epoch):param step:       当前迭代到第几个batch:param epoch:      当前迭代到第几个epoch:param batch_num:  当前batch的大小:param loss:       当前batch的loss:param acc:        当前batch的准确率:param best_epoch: 当前最佳模型所在的epochReturns:None'''        lr = self.optimizer.param_groups[0]['lr']if mode == 'train':# 每间隔self.log_interval个iter才打印一次if step % self.log_interval == 0:loss = self.argsHistory.args_history_dict['train_loss'][-1]acc = self.argsHistory.args_history_dict['train_acc'][-1]log = ("Epoch(train)  [%d][%d/%d]  lr: %8f  train_loss: %5f  train_acc.: %5f") % (epoch+1, step, batch_num, lr, loss, acc)self.logger.info(log)elif mode == 'epoch':acc_list = self.argsHistory.args_history_dict['mean_val_acc']mAP_list = self.argsHistory.args_history_dict['val_mAP']mF1Score_list = self.argsHistory.args_history_dict['val_mF1Score']# 找到最高准确率对应的epochbest_epoch = acc_list.index(max(acc_list)) + 1self.logger.info('=' * 100)log = ("Epoch  [%d]  mean_val_acc.: %.5f  mAP: %.5f  mF1Score: %.5f  best_val_epoch: %d" % (epoch+1, acc_list[-1], mAP_list[-1], mF1Score_list[-1], best_epoch))self.logger.info(log)self.logger.info('=' * 100)

值得注意的是,3.2.1.1,3.2.1.2 和3.2.1.3的基本逻辑是,先使用recoardArgs将变量记录到字典argsHistory中去,后续recordTensorboardLogprintLog需要打印或记录哪些变量直接从字典中获取即可,省去了再将变量参数作为函数传参。

3.2.2.评估一个epoch的pipeline evaler

这部分的流程直接调用验证pipeline,固将在相应章节介绍。

3.2.3.保存网络权重 saveCkpt

saveCkpt在一个epoch结束后保存断点信息(权重、优化器断点,学习率等),并根据验证集acc判断当前epoch是否是最佳权重,是则进行保存。

def saveCkpt(self, epoch):'''保存权重和训练断点Args::param epoch:        当前epoch:param max_acc:      当前最佳模型在验证集上的准确率:param mean_val_acc: 当前epoch准确率:param best_epoch:   当前最佳模型对应的训练epochReturns:None'''  # checkpoint_dict能够恢复断点训练checkpoint_dict = {'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optim_state_dict': self.optimizer.state_dict(),'sched_state_dict': self.scheduler.state_dict()}torch.save(checkpoint_dict, os.path.join(self.log_dir, f"epoch_{epoch}.pt"))# 如果本次Epoch的acc最大,则保存参数(网络权重)acc_list = self.argsHistory.args_history_dict['mean_val_acc']if epoch == acc_list.index(max(acc_list)):torch.save(self.model.state_dict(), os.path.join(self.log_dir, 'best.pt'))self.logger.info('best checkpoint has saved !')

3.3.验证pipeline evaler

验证集上完整推理一遍(batch size=1), 并评估各种指标,可视化等。这部分同样作为训练一个epoch结束的评估流程。

def evaler(self, epoch, log_dir):'''把pytorch训练代码独自分装成一个函数Args::param modelType:    模型名称(timm):param DatasetDir:   数据集根目录(到images那一层, 子目录是train/valid):param BatchSize:    BatchSize:param imgSize:      网络接受的图像输入尺寸:param ckptPath:     权重路径:param logSaveDir:   训练日志保存目录Returns:None'''# 得到网络预测结果# shape = [val_size,] [val_size,] [val_size, cls_num]predList, trueList, softList = self.eval()'''自定义的实现'''# 准确率acc = sum(predList==trueList) / predList.shape[0]self.logger.info(f'acc: {acc}')# # 可视化混淆矩阵showComMatrix(trueList, predList, self.cats, self.log_dir)# 绘制PR曲线PRs = drawPRCurve(self.cats, trueList, softList, self.log_dir)# 计算每个类别的 AP, F1ScoremAP, mF1Score, form = clacAP(PRs, self.cats)self.logger.info(f'\n{form}')# 记录args(epoch)self.recoardArgs(mode='epoch', acc=acc, mAP=mAP, mF1Score=mF1Score)# 绘制损失,学习率,准确率曲线visArgsHistory(log_dir, self.log_dir)# 记录tensorboard的logif self.mode == 'train':self.recordTensorboardLog('epoch', epoch)

3.3.1.推理得到网络预测结果eval

eval方法用于得到真实标签true_list, 预测标签pred_list, 置信度soft_list,为后续计算各种评估指标做准备。

def eval(self):'''得到网络在验证集的真实标签true_list, 预测标签pred_list, 置信度soft_list, 为后续评估做准备'''# 记录真实标签和预测标签pred_list, true_list, soft_list = [], [], []# 验证模式self.model.eval()# 验证时无需计算梯度with torch.no_grad():print('evaluating val dataset...')for batch in tqdm(self.val_data_loader):x = batch[0].to(self.device)   # [batch_size, 3, 64, 64]y = batch[1].to(self.device).reshape(-1)  # [batch_size, 1]# 前向传播output = self.model(x)# 预测结果对应置信最大的那个下标pre_lab = torch.argmax(output, dim=1)# 记录(真实标签true_list, 预测标签pred_list, 置信度soft_list)true_list += list(y.cpu().detach())pred_list += list(pre_lab.cpu().detach())soft_list += list(np.array(output.softmax(dim=-1).cpu().detach()))return np.array(pred_list), np.array(true_list), np.array(soft_list)

3.3.2.可视化混淆矩阵showComMatrix

def showComMatrix(trueList, predList, cat, evalDir):'''可视化混淆矩阵Args::param trueList:  验证集的真实标签:param predList:  网络预测的标签:param cat:       所有类别的字典Returns:None'''if len(cat)>=50:# 100类正合适的大小  plt.figure(figsize=(40, 33))plt.subplots_adjust(left=0.05, right=1, bottom=0.05, top=0.99) else:# 10类正合适的大小plt.figure(figsize=(12, 9))plt.subplots_adjust(left=0.1, right=1, bottom=0.1, top=0.99) conf_mat = confusion_matrix(trueList, predList)df_cm = pd.DataFrame(conf_mat, index=cat, columns=cat)heatmap = sns.heatmap(df_cm, annot=True, fmt='d', cmap=plt.cm.Blues)heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation = 0, ha = 'right')heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation = 50, ha = 'right')plt.ylabel('true label')plt.xlabel('pred label')if not os.path.isdir(evalDir):os.makedirs(evalDir)# 保存图像plt.savefig(os.path.join(evalDir, '混淆矩阵.png'), dpi=200)plt.clf() 

例(food-101验证集):

在这里插入图片描述

3.3.3.绘制类别的PR曲线drawPRCurve

def drawPRCurve(cat, trueList, softList, evalDir):'''绘制类别的PR曲线 Args::param cat:  类别索引列表:param trueList:  验证集的真实标签:param softList:  网络预测的置信度:param evalDir:   PR曲线图保存路径Returns:None'''def calcPRThreshold(trueList, softList, clsNum, T):'''给定一个类别, 单个阈值下的PR值Args::param trueList:  验证集的真实标签:param predList:  网络预测的标签:param clsNum:    类别索引Returns:precision, recall'''label = (trueList==clsNum)prob = softList[:,clsNum]>TTP = sum(label*prob)   # 正样本预测为正样本FN = sum(label*~prob)  # 正样本预测为负样本FP = sum(~label*prob)  # 负样本预测为正样本precision = TP / (TP + FP) if (TP + FP)!=0 else 1recall = TP / (TP + FN) return precision, recall, Tdef clacPRCurve(trueList, softList, clsNum, interval=100):'''所有类别下的PR曲线值Args::param trueList:  验证集的真实标签:param predList:  网络预测的标签:param clsNum:    类别索引列表:param interval:  阈值变化划分的区间,如interval=100, 则间隔=0.01Returns::param PRs:       不同阈值下的PR值[2, interval, cat_num]'''PRs = []print('calculating PR per classes...')for cls in trange(clsNum):PR_value = [calcPRThreshold(trueList, softList, cls, i/interval) for i in range(interval+1)]PRs.append(np.array(PR_value))return np.array(PRs)plt.figure(figsize=(12, 9))# 计算所有类别下的PR曲线值PRs = clacPRCurve(trueList, softList, len(cat))# 绘制每个类别的PR曲线for i in range(len(cat)):PR = PRs[i]plt.plot(PR[:,1], PR[:,0], linewidth=1)plt.legend(labels=cat)plt.xlabel('recall')plt.ylabel('precision')plt.xlim(0,1)plt.ylim(0,1)# 保存图像 plt.savefig(os.path.join(evalDir, '类别PR曲线.png'), dpi=200)plt.clf()  return PRs

例(food-101验证集):

在这里插入图片描述

3.3.4.计算每个类别的 AP, F1ScoreclacAP

def clacAP(PRs, cat):'''计算每个类别的 AP, F1ScoreArgs::param PRs:  不同阈值下的PR值[2, interval, cat_num]:param cat:  类别索引列表Returns:None'''form = [['catagory', 'AP', 'F1_Score']]# 所有类别的平均AP与平均F1ScoremAP, mF1Score = 0, 0for i in range(len(cat)):PR = PRs[i]AP = 0for j in range(PR.shape[0]-1):# 每小条梯形的矩形部分+三角形部分面积h = PR[j, 0] - PR[j+1, 0]w = PR[j, 1] - PR[j+1, 1]AP += (PR[j+1, 0] * w) + (w * h / 2)if(PR[j, 2]==0.5):F1Score0_5 = 2 * PR[j, 0] * PR[j, 1] / (PR[j, 0] + PR[j, 1])form.append([cat[i], AP, F1Score0_5])  mAP += APmF1Score += F1Score0_5form.append(['average', mAP / len(cat), mF1Score / len(cat)]) return mAP, mF1Score, tabulate(form, headers='firstrow') # tablefmt='fancy_grid'

例,输出的逐类别评估指标(food-101验证集):

catagory                       AP    F1_Score
-----------------------  --------  ----------
apple_pie                0.634729    0.596413
baby_back_ribs           0.871711    0.815574
baklava                  0.924247    0.870445
beef_carpaccio           0.920737    0.877551
beef_tartare             0.853179    0.806794
beet_salad               0.791348    0.746835
beignets                 0.929849    0.858824
bibimbap                 0.970363    0.931452
bread_pudding            0.628122    0.598778
breakfast_burrito        0.816585    0.759494
bruschetta               0.811589    0.74645
caesar_salad             0.92251     0.862205
cannoli                  0.912558    0.859504
caprese_salad            0.883566    0.81409
carrot_cake              0.853388    0.7833
ceviche                  0.742306    0.693446
cheesecake               0.91843     0.875
cheese_plate             0.775621    0.711579
chicken_curry            0.865944    0.805785
chicken_quesadilla       0.877484    0.830957
chicken_wings            0.935781    0.883534
chocolate_cake           0.713896    0.676596
chocolate_mousse         0.644484    0.616302
churros                  0.953728    0.917505
clam_chowder             0.928124    0.879032
club_sandwich            0.919055    0.866935
crab_cakes               0.81353     0.78
creme_brulee             0.9357      0.901354
croque_madame            0.923631    0.878049
cup_cakes                0.933192    0.886719
deviled_eggs             0.951984    0.927126
donuts                   0.894698    0.843177
dumplings                0.935255    0.903491
edamame                  0.997601    0.993964
eggs_benedict            0.926735    0.88755
escargots                0.938622    0.904
falafel                  0.855487    0.797495
filet_mignon             0.716486    0.666667
fish_and_chips           0.917366    0.866667
foie_gras                0.706522    0.655804
french_fries             0.957856    0.912621
french_onion_soup        0.902586    0.853175
french_toast             0.830571    0.789062
fried_calamari           0.914943    0.878543
fried_rice               0.909526    0.846602
frozen_yogurt            0.961339    0.91945
garlic_bread             0.853865    0.803245
gnocchi                  0.811815    0.745174
greek_salad              0.907218    0.847737
grilled_cheese_sandwich  0.802601    0.746507
grilled_salmon           0.823634    0.762887
guacamole                0.932105    0.893443
gyoza                    0.914981    0.882353
hamburger                0.872562    0.80167
hot_and_sour_soup        0.965024    0.930693
hot_dog                  0.901362    0.856
huevos_rancheros         0.784487    0.716484
hummus                   0.895116    0.847107
ice_cream                0.828925    0.763948
lasagna                  0.820834    0.774059
lobster_bisque           0.90727     0.858871
lobster_roll_sandwich    0.94466     0.907975
macaroni_and_cheese      0.891181    0.829569
macarons                 0.979871    0.95122
miso_soup                0.970756    0.918489
mussels                  0.956927    0.919918
nachos                   0.880572    0.830266
omelette                 0.783796    0.713656
onion_rings              0.947257    0.913725
oysters                  0.958218    0.928287
pad_thai                 0.957988    0.894027
paella                   0.913414    0.854839
pancakes                 0.91128     0.866935
panna_cotta              0.798689    0.743434
peking_duck              0.925841    0.864097
pho                      0.967159    0.938614
pizza                    0.932811    0.870406
pork_chop                0.651566    0.606557
poutine                  0.960185    0.918367
prime_rib                0.891828    0.850394
pulled_pork_sandwich     0.876265    0.812245
ramen                    0.936075    0.873016
ravioli                  0.752236    0.698545
red_velvet_cake          0.898298    0.860041
risotto                  0.797334    0.739394
samosa                   0.866941    0.8375
sashimi                  0.938466    0.899384
scallops                 0.783607    0.732919
seaweed_salad            0.957724    0.917172
shrimp_and_grits         0.833112    0.771084
spaghetti_bolognese      0.956515    0.916
spaghetti_carbonara      0.960952    0.900398
spring_rolls             0.885243    0.852459
steak                    0.570503    0.524313
strawberry_shortcake     0.859055    0.805726
sushi                    0.891105    0.846626
tacos                    0.82507     0.771037
takoyaki                 0.958959    0.92623
tiramisu                 0.845881    0.792079
tuna_tartare             0.75262     0.699411
waffles                  0.906239    0.854291
average                  0.873475    0.825314

3.3.4.可视化训练过程中保存的参数visArgsHistory

def visArgsHistory(json_dir, save_dir):'''可视化训练过程中保存的参数Args::param json_dir: 参数的json文件路径:param logDir:   可视化json文件保存路径Returns:None'''json_path = os.path.join(json_dir, 'args_history.json')with open(json_path) as json_file:args = json.load(json_file)for args_key in args.keys():arg = args[args_key]plt.plot(arg, linewidth=1)plt.xlabel('Epoch')plt.ylabel(args_key)plt.savefig(os.path.join(save_dir, f'{args_key}.png'), dpi=200)plt.clf()

例(food-101,mobilenet-v3-large):

在这里插入图片描述

3.4.测试pipeline tester

可视化图像CAM热图和分类结果(top10)

def tester(self, img_path, save_res_dir):'''把pytorch测试代码独自分装成一个函数Args::param img_path:     测试图像路径:param save_res_dir: 推理结果保存目录Returns:None'''from dataloader import Transforms# 加载一张图片并进行预处理image = Image.open(img_path)image = np.array(image)tf = Transforms(imgSize = self.img_size)visImg = tf.visTF(image=image)['image']img = torch.tensor(tf.validTF(image=image)['image']).permute(2,1,0).unsqueeze(0).to(self.device)# 加载网络self.model.eval()# 预测logits = self.model(img).softmax(dim=-1).cpu().detach().numpy()[0]sorted_id = sorted(range(len(logits)), key=lambda k: logits[k], reverse=True)# 超过10类则只显示top10的类别logits_top_10 = logits[sorted_id[:10]]cats_top_10 = [self.cats[i] for i in sorted_id[:10]]'''CAM'''# CAM需要网络能反传梯度, 否则会报错# 要可视化网络哪一层的CAM(以mobilenetv3_large_100.ra_in1k为例, 不同的网络这部分还需更改)target_layers = [self.model.backbone.blocks[-1]]cam = GradCAM(model=self.model, target_layers=target_layers)# 要关注的区域对应的类别targets = [ClassifierOutputTarget(sorted_id[0])]grayscale_cam = cam(input_tensor=img, targets=targets)[0].transpose(1,0)visualization = show_cam_on_image(visImg / 255., grayscale_cam, use_rgb=True)'''可视化预测结果'''fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4))# 在第一个子图中绘制图像ax1.set_title('image')ax1.axis('off')# ax1.imshow(image)ax1.imshow(visualization)# 在第二个子图中绘制置信度(横向)ax2.barh(cats_top_10, logits_top_10.reshape(-1))ax2.set_title('classification')ax2.set_xlabel('confidence')# 将数值最大的条块设置为不同颜色bar2 = ax2.patches[0]bar2.set_color('orange')# y轴上下反转,不然概率最大的在最下面plt.gca().invert_yaxis()plt.subplots_adjust(left=0.05, right=0.99, bottom=0.1, top=0.90)if not os.path.isdir(save_res_dir):os.makedirs(save_res_dir)plt.savefig(os.path.join(save_res_dir, 'res.jpg'), dpi=200)plt.clf() 

例:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

3.5.其他

主函数部分,通过命令行参数获取config文件路径,读取config文件里的参数(字典形式)作为训练的超参。

if __name__ == '__main__':args = getArgs()# 使用动态导入的模块config_path = args.configconfig_file = import_module_by_path(config_path)# 调用动态导入的模块的函数config = config_file.configrunner = Runner(config['timm_model_name'], config['img_size'], config['ckpt_load_path'],config['dataset_dir'], config['epoch'], config['bs'], config['lr'], config['log_dir'], config['log_interval'], config['pretrain'], config['froze'], config['optim_type'], config['mode'], config['resume'], config['seed'])# 训练if config['mode'] == 'train':runner.trainer()# 评估elif config['mode'] == 'eval':runner.evaler(epoch=0, log_dir=config['eval_log_dir'])elif config['mode'] == 'test':runner.tester(config['img_path'], config['save_res_dir'])else:print("mode not valid. it must be 'train', 'eval' or 'test'.")

3.5.1 获取命令行参数

def getArgs():parser = argparse.ArgumentParser()parser.add_argument('--config', type=str, help='config file')args = parser.parse_args()return args

3.5.2 根据给定路径动态import模块(config.py)

def import_module_by_path(module_path):"""根据给定的完整路径动态导入模块(config.py)"""spec = importlib.util.spec_from_file_location("module_name", module_path)module = importlib.util.module_from_spec(spec)spec.loader.exec_module(module)return module

3.5.3 config.py样例

config = dict(# trainmode = 'test',timm_model_name = 'mobilenetv3_large_100.ra_in1k',img_size = 224,ckpt_load_path =  'log/2024-02-14-03-04-03_train/best.pt',dataset_dir = 'E:/datasets/Classification/food-101/images',epoch = 36,bs = 64,lr = 1e-3,log_dir = './log',log_interval = 50,pretrain = True,froze = False,optim_type = 'adamw',resume = None,  # 'log/2024-02-05-21-28-59_train/epoch_9.pt',seed=22,# evaleval_log_dir = 'log/2024-02-14-03-04-03_train',# test# french_fries/3171053.jpg 3897130.jpg 3393816.jpg club_sandwich/3143042.jpgimg_path = 'E:/datasets/Classification/food-101/images/valid/french_fries/3393816.jpg',save_res_dir = './result'
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/689279.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

嵌入式第十七天!(文件IO)

文件IO: 标准IO和文件IO的区别: 1. 标准IO是库函数,是对系统调用的封装 2. 文件IO是系统调用,是Linux内核中的函数接口 3. 标准IO是有缓存的 4. 文件IO是没有缓存的 1. 操作步骤: 打开 -> 读/写 -> 关闭 2. 打开…

基于Java SSM框架实现精准扶贫管理系统项目【项目源码】

基于java的SSM框架实现精准扶贫管理系统演示 JSP技术介绍 JSP技术本身是一种脚本语言,但它的功能是十分强大的,因为它可以使用所有的JAVA类。当它与JavaBeans 类进行结合时,它可以使显示逻辑和内容分开,这就极大的方便了用户的需…

⭐北邮复试刷题LCR 012. 寻找数组的中心下标__前缀和思想 (力扣119经典题变种挑战)

LCR 012. 寻找数组的中心下标 给你一个整数数组 nums ,请计算数组的 中心下标 。 数组 中心下标 是数组的一个下标,其左侧所有元素相加的和等于右侧所有元素相加的和。 如果中心下标位于数组最左端,那么左侧数之和视为 0 ,因为…

数据管理关键技术顶层设计

数据管理关键技术顶层设计 明仔 数据思考笔记 2023-12-27 07:36 广东 数据思考笔记 专注于数据架构,数据中台,数据治理的相关分享,寻求数据与业务的结合点。 13篇原创内容 公众号 数据治理65 数据治理 目录 上一篇企业数据资产管理解…

基于Java SSM框架实现生鲜食品o2o商城系统项目【项目源码+论文说明】

基于java的SSM框架实现生鲜食品o2o商城系统演示 摘要 随着社会的发展,社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 生鲜食品o2o商城系统,主要的模块包括查看管理员;首页、个人中心、用户…

Tomcat版本号泄露

1.问题描述 Tomcat报错页面泄漏Apache Tomcat/7.0.92相关版本号信息,是攻击者攻击的途径之一。因此实际当中建议去掉版本号信息。 2.测试过程 随便访问一个tomcat不存在的界面 http://127.0.0.1:8080/examples/mytest.jsp 3.解决办法 1.进入到tomcat/lib目录下&a…

预检请求:为跨域请求保驾护航(下)

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

svg图片构造QGraphicsSvgItem对象耗时很长的问题解决

目录 1. 问题的提出 2. 问题解决 1. 问题的提出 今天通过一张像素为141 * 214,大小为426KB的svg格式的图片构造QGraphicsSvgItem对象,再通过Qt的Graphics View Framework框架,将QGraphicsSvgItem对象显示到场景视图上,代码如下&…

【深度优先搜索】【树】【有向图】【推荐】685. 冗余连接 II

LeetCode685. 冗余连接 II 在本问题中,有根树指满足以下条件的 有向 图。该树只有一个根节点,所有其他节点都是该根节点的后继。该树除了根节点之外的每一个节点都有且只有一个父节点,而根节点没有父节点。 输入一个有向图,该图由…

离线数仓(二)【用户行为日志采集平台搭建】

用户行为日志采集平台搭建 1、用户行为日志概述 用户行为日志的内容,主要包括用户的各项行为信息以及行为所处的环境信息。收集这些信息的主要目的是优化产品和为各项分析统计指标提供数据支撑。收集这些信息的手段通常为埋点。 目前主流的埋点方式,有代…

CleanMyMacX需要付费吗?多少钱?有哪些新功能

CleanMyMac X是一个付费应用程序**,需要许可证或订阅来解锁所有功能。不过,CleanMyMac X提供免费试用版供您访问其有限的功能。在试用模式下,用户可以使用部分功能进行体验,但这并非完全免费,因为某些功能会受到限制。…

怎样才能使网页跳转更加顺畅和自然?

网页跳转是用户在使用网页时经常遇到的操作,一个好的跳转设计可以提升用户体验,提高用户满意度。然而,有些网页的跳转设计却常常给用户带来不好的体验,比如页面加载缓慢、跳转速度慢、页面卡顿等问题。那么,怎样才能使…

在SpringBoot中@PathVariable与@RequestParam的区别

PathVariable GetMapping("/{userId}")public R<User> getUserById(PathVariable Long userId) {return userService.getUserById(userId);} // 根据id获取一条数据 function getStudentDataByIdAndDisplayInput(id) {// 发送 AJAX 请求$.ajax({url: /dorm/st…

Elasticsearch:将 IT 智能和业务 KPI 与 AI 连接起来 - 房间里的大象

作者&#xff1a;Fermi Fang 大象寓言的智慧 在信息技术和商业领导力的交叉点&#xff0c;蒙眼人和大象的古老寓言提供了一个富有洞察力的类比。 这个故事起源于印度次大陆&#xff0c;讲述了六个蒙住眼睛的人第一次遇到大象的故事。 每个人触摸大象的不同部位 —— 侧面、象牙…

Vue3快速上手(八) toRefs和toRef的用法

顾名思义&#xff0c;toRef 就是将其转换为ref的一种实现。详细请看&#xff1a; 一、toRef 1.1 示例 <script langts setup name"toRefsAndtoRef"> // 引入reactive,toRef import { reactive, toRef } from vue // reactive包裹的数据即为响应式对象 let p…

第三篇【传奇开心果系列】Python的文本和语音相互转换库技术点案例示例:pyttsx3实现语音助手经典案例

传奇开心果短博文系列 系列短博文目录Python的文本和语音相互转换库技术点案例示例系列 短博文目录一、项目背景和目标二、雏形示例代码三、扩展思路介绍四、与其他库和API集成示例代码五、自定义语音示例代码六、多语言支持示例代码七、语音控制应用程序示例代码八、文本转语音…

机器人初识 —— 定制AI

一、机器人设计难点 波士顿动力设计的机器人&#xff0c;尤其是其人形机器人Atlas和四足机器人Spot等产品&#xff0c;在技术上面临多重难点&#xff1a; 1. **动态平衡与稳定性**&#xff1a;双足或四足机器人在运动时需要维持极高的动态平衡&#xff0c;特别是在不平坦地面…

LiveGBS流媒体平台GB/T28181功能-自定义收流端口区间30000至30249UDP端口TCP端区间配置及相关端口复用问题说明

LiveGBS自定义收流端口区间30000至30249UDP端口TCP端区间配置及相关端口复用问题说明 1、收流端口配置1.1、INI配置1.2、页面配置 2、相关问题3、最少可以开放多少端口3.1、端口复用3.2、配置最少端口如下 4、搭建GB28181视频直播平台 1、收流端口配置 1.1、INI配置 可在lives…

基于Java SSM框架实现电影售票系统项目【项目源码】

基于java的SSM框架实现电影售票系统演示 SSM框架 当今流行的“SSM组合框架”是Spring SpringMVC MyBatis的缩写&#xff0c;受到很多的追捧&#xff0c;“组合SSM框架”是强强联手、各司其职、协调互补的团队精神。web项目的框架&#xff0c;通常更简单的数据源。Spring属于…

VNCTF2024misc方向部分wp

文章目录 sqlsharkLearnOpenGLez_msbOnlyLocalSql sqlshark tshark -r sqlshark.pcap -Y "http" -T fields -e frame.len -e http.file_data > data.txt不太像常规的盲注&#xff0c;一次性发送两条很类似的payload&#xff0c;比常规的多了一个least在判断passw…