深度学习 框架代码(草稿)

文章目录

  • train.py
  • dataload_five_flower.py
  • train_engin.py
  • lr_methods.py
  • __init__.py

  • train_sample.py 和 test.py 见文章:
    • 深度学习-AlexNet代码实现
  • 用 parser 方便服务器中的终端操作
  • 第三个代码将 mac的 mps 和 cuda 混合了,有点问题,看下代码整体思想就行,不用去跑
  • 因为我的电脑是 mac 的 mps,还没找到代码的替代方法
  • 可以直接用上面那篇文章中的 train_sample.py
  • 只要不是训练,cpu 一般都比 cuda快

train.py

############################################################################################################
# 相较于简单版本的训练脚本 train_sample 增添了以下功能:
# 1. 使用argparse类实现可以在训练的启动命令中指定超参数
# 2. 可以通过在启动命令中指定 --seed 来固定网络的初始化方式,以达到结果可复现的效果
# 3. 使用了更高级的学习策略 cosine warm up:在训练的第一轮使用一个较小的lr(warm_up),从第二个epoch开始,随训练轮数逐渐减小lr。 
# 4. 可以通过在启动命令中指定 --model 来选择使用的模型 
# 5. 使用amp包实现半精度训练,在保证准确率的同时尽可能的减小训练成本
# 6. 实现了数据加载类的自定义实现
# 7. 可以通过在启动命令中指定 --tensorboard 来进行tensorboard可视化, 默认不启用。
#    注意,使用tensorboad之前需要使用命令 "tensorboard --logdir= log_path"来启动,结果通过网页 http://localhost:6006/'查看可视化结果
############################################################################################################
# --model 可选的超参如下:
# alexnet   zfnet   vgg   vgg_tiny   vgg_small   vgg_big   googlenet   xception   resnet_small   resnet   resnet_big   resnext   resnext_big  
# densenet_tiny   densenet_small   densenet   densenet_big   mobilenet_v3   mobilenet_v3_large   shufflenet_small   shufflenet
# efficient_v2_small   efficient_v2   efficient_v2_large   convnext_tiny   convnext_small   convnext   convnext_big   convnext_huge
# vision_transformer_small   vision_transformer   vision_transformer_big   swin_transformer_tiny   swin_transformer_small   swin_transformer # 训练命令示例: # python train.py --model alexnet --num_classes 5
############################################################################################################
import os 
import argparse 
import math
import shutil
import random
import numpy as np
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler import classic_models 
from utils.lr_methods import warmup 
from dataload.dataload_five_flower import Five_Flowers_Load
from utils.train_engin import train_one_epoch, evaluate parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5, help='the number of classes')
parser.add_argument('--epochs', type=int, default=50, help='the number of training epoch')
parser.add_argument('--batch_size', type=int, default=64, help='batch_size for training')
parser.add_argument('--lr', type=float, default=0.0002, help='star learning rate')
parser.add_argument('--lrf', type=float, default=0.0001, help='end learning rate') 
parser.add_argument('--seed', default=False, action='store_true', help='fix the initialization of parameters')
parser.add_argument('--tensorboard', default=False, action='store_true', help=' use tensorboard for visualization') 
parser.add_argument('--use_amp', default=False, action='store_true', help=' training with mixed precision') 
# 数据路径需要改成自己的
parser.add_argument('--data_path', type=str, default="/Users/jiangxiyu/根目录/深度学习/flower")
parser.add_argument('--model', type=str, default="vgg", help=' select a model for training') 
parser.add_argument('--device', default='mps', help='device id (i.e. 0 or 0,1 or cpu)')# 把超参数实例化
opt = parser.parse_args()  if opt.seed:def seed_torch(seed=7):random.seed(seed) # Python random module.	os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现np.random.seed(seed) # Numpy module.torch.manual_seed(seed)  # 为CPU设置随机种子# mac m1 mps gpu可以不用# torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子# torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.# 设置cuDNN:cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:# torch.backends.cudnn.benchmark = False# torch.backends.cudnn.deterministic = True# 实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。print('random seed has been fixed')seed_torch() def main(args):# mac m1 gpudevice = torch.device(args.device if torch.backends.mps.is_available() else "cpu")print(args)if opt.tensorboard:# 这是存放你要使用tensorboard显示的数据的绝对路径log_path = os.path.join('./results/tensorboard' , args.model)print('Start Tensorboard with "tensorboard --logdir={}"'.format(log_path)) if os.path.exists(log_path) is False:os.makedirs(log_path)print("tensorboard log save in {}".format(log_path))else:shutil.rmtree(log_path) #当log文件存在时删除文件夹。记得在代码最开始import shutil # 实例化一个tensorboardtb_writer = SummaryWriter(log_path)# 数据集比较大的归一化ImageNet [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])} # 对标pytorch封装好的ImageFlolder,我们自己实现了一个数据加载类 Five_Flowers_Load,并使用指定的预处理操作来处理图像,结果会同时返回图像和对应的标签。  train_dataset = Five_Flowers_Load(os.path.join(args.data_path , 'train'), transform=data_transform["train"])val_dataset = Five_Flowers_Load(os.path.join(args.data_path , 'val'), transform=data_transform["val"]) if args.num_classes != train_dataset.num_class:raise ValueError("dataset have {} classes, but input {}".format(train_dataset.num_class, args.num_classes))nw = min([os.cpu_count(), args.batch_size if args.batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))# 使用 DataLoader 将加载的数据集处理成批量(batch)加载模式train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=nw, collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True,  num_workers=nw, collate_fn=val_dataset.collate_fn)# create modelmodel = classic_models.find_model_using_name(opt.model, num_classes=opt.num_classes).to(device) pg = [p for p in model.parameters() if p.requires_grad] optimizer = optim.Adam(pg, lr=args.lr)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)best_acc = 0.# save parameters pathsave_path = os.path.join(os.getcwd(), 'results/weights', args.model)if os.path.exists(save_path) is False:os.makedirs(save_path)for epoch in range(args.epochs):# trainmean_loss, train_acc = train_one_epoch(model=model, optimizer=optimizer, data_loader=train_loader, device=device, epoch=epoch, use_amp=args.use_amp, lr_method= warmup)scheduler.step()# validateval_acc = evaluate(model=model, data_loader=val_loader, device=device)print('[epoch %d] train_loss: %.3f  train_acc: %.3f  val_accuracy: %.3f' %  (epoch + 1, mean_loss, train_acc, val_acc))   with open(os.path.join(save_path, "AlexNet_log.txt"), 'a') as f: f.writelines('[epoch %d] train_loss: %.3f  train_acc: %.3f  val_accuracy: %.3f' %  (epoch + 1, mean_loss, train_acc, val_acc) + '\n')if opt.tensorboard:tags = ["train_loss", "train_acc", "val_accuracy", "learning_rate"]tb_writer.add_scalar(tags[0], mean_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_acc, epoch)tb_writer.add_scalar(tags[3], optimizer.param_groups[0]["lr"], epoch)# 判断当前验证集的准确率是否是最大的,如果是,则更新之前保存的权重if val_acc > best_acc:best_acc = val_acctorch.save(model.state_dict(), os.path.join(save_path, "AlexNet.pth")) if __name__ == '__main__':         main(opt)

dataload_five_flower.py

  • 不同的数据集,torch封装的dataload不一定适配,所以要学会自己封装dataload
from PIL import Image
from matplotlib.cbook import ls_mapper
import torch
from torch.utils.data import Dataset
import random
import osclass Five_Flowers_Load(Dataset):def __init__(self, data_path: str, transform=None):self.data_path = data_path self.transform = transformrandom.seed(0)  # 保证随机结果可复现assert os.path.exists(data_path), "dataset root: {} does not exist.".format(data_path)# 遍历文件夹,一个文件夹对应一个类别,['daisy', 'dandelion', 'roses', 'sunflower', 'tulips']flower_class = [cla for cla in os.listdir(os.path.join(data_path))] # 得到一个列表self.num_class = len(flower_class)# 排序,保证顺序一致flower_class.sort()# 生成类别名称以及对应的数字索引  {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}class_indices = dict((cla, idx) for idx, cla in enumerate(flower_class)) self.images_path = []  # 存储训练集的所有图片路径self.images_label = []  # 存储训练集图片对应索引信息 self.images_num = []  # 存储每个类别的样本总数supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型# 遍历每个文件夹下的文件for cla in flower_class:cla_path = os.path.join(data_path, cla)# 遍历获取supported支持的所有文件路径images = [os.path.join(data_path, cla, i) for i in os.listdir(cla_path) if os.path.splitext(i)[-1] in supported]# 获取该类别对应的索引image_class = class_indices[cla]# 记录该类别的样本数量self.images_num.append(len(images)) # 写入列表for img_path in images: self.images_path.append(img_path)self.images_label.append(image_class)print("{} images were found in the dataset.".format(sum(self.images_num))) def __len__(self):return sum(self.images_num)def __getitem__(self, idx):img = Image.open(self.images_path[idx])label = self.images_label[idx]if img.mode != 'RGB':raise ValueError("image: {} isn't RGB mode.".format(self.images_path[idx]))if self.transform is not None:img = self.transform(img)else:raise ValueError('Image is not preprocessed')return img, label# 非必须实现,torch里有默认实现;该函数的作用是: 决定一个batch的数据以什么形式来返回数据和标签# 官方实现的default_collate可以参考# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py@staticmethoddef collate_fn(batch):images, labels = tuple(zip(*batch))images = torch.stack(images, dim=0) labels = torch.as_tensor(labels)  return images, labels

train_engin.py

import sysimport torch
from tqdm import tqdmfrom utils.distrubute_utils import  is_main_process, reduce_value
from utils.lr_methods import warmupdef train_one_epoch(model, optimizer, data_loader, device, epoch, use_amp=False, lr_method=None):model.train()loss_function = torch.nn.CrossEntropyLoss()train_loss = torch.zeros(1).to(device)acc_num = torch.zeros(1).to(device)optimizer.zero_grad()lr_scheduler = Noneif epoch == 0  and lr_method == warmup : warmup_factor = 1.0/1000warmup_iters = min(1000, len(data_loader) -1)lr_scheduler = warmup(optimizer, warmup_iters, warmup_factor)if is_main_process():data_loader = tqdm(data_loader, file=sys.stdout)# 创建一个梯度缩放标量,以最大程度避免使用fp16进行运算时的梯度下溢 enable_amp = use_amp and "mps" in device.typescaler = torch.cuda.amp.GradScaler(enabled=enable_amp)sample_num = 0for step, data in enumerate(data_loader):images, labels = datasample_num += images.shape[0]with torch.cuda.amp.autocast(enabled=enable_amp):pred = model(images.to(device))loss = loss_function(pred, labels.to(device))pred_class = torch.max(pred, dim=1)[1]acc_num += torch.eq(pred_class, labels.to(device)).sum()scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()optimizer.zero_grad()train_loss += reduce_value(loss, average=True).detach()# 在进程中打印平均lossif is_main_process():info = '[epoch{}]: learning_rate:{:.5f}'.format(epoch + 1, optimizer.param_groups[0]["lr"])data_loader.desc = info # tqdm 成员 descif not torch.isfinite(loss):print('WARNING: non-finite loss, ending training ', loss)sys.exit(1)if lr_scheduler is not None:  # 如果使用warmup训练,逐渐调整学习率lr_scheduler.step()# 等待所有进程计算完毕if device != torch.device('cpu'):torch.cuda.synchronize(device)return train_loss.item() / (step + 1), acc_num.item() / sample_num@torch.no_grad()
def evaluate(model, data_loader, device):model.eval()# 验证集样本个数num_samples = len(data_loader.dataset) # 用于存储预测正确的样本个数sum_num = torch.zeros(1).to(device)for step, data in enumerate(data_loader):images, labels = datapred = model(images.to(device))pred_class = torch.max(pred, dim=1)[1]sum_num += torch.eq(pred_class, labels.to(device)).sum()# 等待所有进程计算完毕if device != torch.device('cpu'):torch.cuda.synchronize(device)sum_num = reduce_value(sum_num, average=False)val_acc = sum_num.item() / num_samplesreturn val_acc

lr_methods.py

import torch def warmup(optimizer, warm_up_iters, warm_up_factor):def f(x):"""根据step数返回一个学习率倍率因子, x代表step"""if x >= warm_up_iters:return 1alpha = float(x) / warm_up_iters# 迭代过程中倍率因子从warmup_factor -> 1return warm_up_factor * (1 - alpha) + alphareturn torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)

init.py

from .alexnet import alexnet
from .vggnet import vgg11, vgg13, vgg16, vgg19
from .zfnet import zfnet 
from .googlenet_v1 import googlenet
from .xception import xception
from .resnet import  resnet34, resnet50, resnet101, resnext50_32x4d, resnext101_32x8d
from .densenet import densenet121, densenet161, densenet169, densenet201
from .dla import dla34
from .mobilenet_v3 import mobilenet_v3_small, mobilenet_v3_large
from .shufflenet_v2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from .efficientnet_v2 import efficientnetv2_l, efficientnetv2_m, efficientnetv2_s
from .convnext import convnext_tiny, convnext_small, convnext_base, convnext_large, convnext_xlargefrom .vision_transformer import vit_base_patch16_224, vit_base_patch32_224, vit_large_patch16_224
from .swin_transformer import swin_tiny_patch4_window7_224, swin_small_patch4_window7_224, swin_base_patch4_window7_224
cfgs = {'alexnet': alexnet,'zfnet': zfnet,'vgg': vgg16,'vgg_tiny': vgg11,'vgg_small': vgg13,'vgg_big': vgg19,'googlenet': googlenet,'xception': xception,    'resnet_small': resnet34,'resnet': resnet50,'resnet_big': resnet101,'resnext': resnext50_32x4d,'resnext_big': resnext101_32x8d,'densenet_tiny': densenet121,'densenet_small': densenet161,'densenet': densenet169,'densenet_big': densenet121,'dla': dla34, 'mobilenet_v3': mobilenet_v3_small,'mobilenet_v3_large': mobilenet_v3_large,'shufflenet_small':shufflenet_v2_x0_5,'shufflenet': shufflenet_v2_x1_0,'efficient_v2_small': efficientnetv2_s,'efficient_v2': efficientnetv2_m,'efficient_v2_large': efficientnetv2_l,'convnext_tiny': convnext_tiny,'convnext_small': convnext_small,'convnext': convnext_base,'convnext_big': convnext_large,'convnext_huge': convnext_xlarge,'vision_transformer_small': vit_base_patch32_224,    'vision_transformer': vit_base_patch16_224,'vision_transformer_big': vit_large_patch16_224,'swin_transformer_tiny': swin_tiny_patch4_window7_224,'swin_transformer_small': swin_small_patch4_window7_224,'swin_transformer': swin_base_patch4_window7_224
}def find_model_using_name(model_name, num_classes):   return cfgs[model_name](num_classes)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/114724.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Node学习笔记之HTTP 模块

回顾:什么是客户端、什么是服务器? 在网络节点中,负责消费资源的电脑,叫做客户端;负责对外提供网络资源的电脑,叫做服务器。 http 模块是 Node.js 官方提供的、用来创建 web 服务器的模块。通过 http 模块…

用Python获取网络数据

用Python获取网络数据 网络数据采集是 Python 语言非常擅长的领域,上节课我们讲到,实现网络数据采集的程序通常称之为网络爬虫或蜘蛛程序。即便是在大数据时代,数据对于中小企业来说仍然是硬伤和短板,有些数据需要通过开放或付费…

相似度loss汇总,pytorch code

用于约束图像生成,作为loss。 可梯度优化 pytorch structural similarity (SSIM) loss https://github.com/Po-Hsun-Su/pytorch-ssimhttps://github.com/harveyslash/Facial-Similarity-with-Siamese-Networks-in-Pytorch/blob/master/Siamese-networks-medium.ip…

为什么嵌入通常优于TF-IDF:探索NLP的力量

塔曼纳 一、说明 自然语言处理(NLP)是计算机科学的一个领域,涉及人类语言的处理和分析。它用于各种应用程序,例如聊天机器人、情绪分析、语音识别等。NLP 中的重要任务之一是文本分类,我们根据文本的内容将文本分类为不…

UE4逆向篇-2_各类数据的查找方式

写在前面 1.通过前面的文章,相信各位已经能够自己找到GNames并使用DUMP工具导出GNames了。 2.本篇文章将介绍各种所需数据的查找方法。 一、准备工作 1.CheatEngine,本篇以及后续篇幅的重要工具。 2.一个记事本,保证你能记录下关键信息。…

ubuntu启动模式介绍以及如何进入单用户模式和恢复模式

Ubuntu操作系统提供了多种启动模式,每种模式都有不同的用途和功能。下面将深入介绍Ubuntu的几种启动模式: 正常启动模式(Normal boot):这是默认的启动模式,也是大多数用户使用的模式。在正常启动模式下&am…

在Mac上使用安卓桌面模式

在安装Homeblew的基础上 替换国内源 export HOMEBREW_API_DOMAIN"https://mirrors.tuna.tsinghua.edu.cn/homebrew-bottles/api" export HOMEBREW_BREW_GIT_REMOTE"https://mirrors.tuna.tsinghua.edu.cn/git/homebrew/brew.git" brew update 安装Scrcpy …

屏幕录像推荐:Apeaksoft Screen Recorder 中文 for mac

Apeaksoft Screen Recorder 是一款功能强大的屏幕录制软件,它允许用户在 Windows 和 Mac 系统上捕捉和录制屏幕活动。无论是记录游戏过程、创建教学视频、制作演示文稿还是捕捉在线流媒体内容,该软件都提供了丰富的功能和工具。 以下是 Apeaksoft Scree…

计算机视觉(CV)技术

是一种将数字图像或视频进行处理和分析的技术,旨在使计算机能够模拟人类视觉系统。该领域涉及到图像处理、模式识别、机器学习等多个领域,主要涵盖以下几方面: 图像处理:对图像进行去噪、增强、分割、特征提取等处理。图像分类&a…

vite+vue3+elementPlus+less+router+pinia+axios

1.创建项目2.按需引入elementplus3.引入less安装vue-router安装 axios安装 piniapinia的持久化配置(用于把数据放在localStorage中)---另外增加的配置 1.创建项目 npm init vitelatest2.按需引入elementplus npm install element-plus --save//按需引入 npm install -D unpl…

HTTP框架 - HttpMaster 核心基类上传

场景 在电子商务应用中,可能需要与多个供应商和物流服务提供商进行通信。这些服务提供商可能具有不同的 API 和身份验证要求。通过封装 HTTP 工具,可以统一管理与这些服务提供商的通信,处理价格查询、订单跟踪、库存查询等任务。如果供应商或…

【MATLAB源码-第52期】基于matlab的4用户DS-CDMA误码率仿真,对比不同信道以及不同扩频码。

操作环境: MATLAB 2022a 1、算法描述 1. DS-CDMA系统 DS-CDMA (Direct Sequence Code Division Multiple Access) 是一种多址接入技术,其基本思想是使用伪随机码序列来调制发送信号。DS-CDMA的特点是所有用户在同一频率上同时发送和接收信息&#xf…

《动手学深度学习 Pytorch版》 9.4 双向循环神经网络

之前的序列学习中假设的目标是在给定观测的情况下对下一个输出进行建模,然而也存在需要后文预测前文的情况。 9.4.1 隐马尔可夫模型中的动态规划 数学推导太复杂了,略。 9.4.2 双向模型 双向循环神经网络(bidirectional RNNs)…

Ubuntu 17.10的超震撼声音权限

从GNOME GUADEC 2017开发者大会归来之后,Canonical的Didier Roche就开始了一个日更博客系列,主要讲述即将带来的Ubuntu 17.10(Artful Aardvark)发行版将如何从Unity到GNOME Shell的转变。有趣的是,Ubuntu Unity桌面环境…

gin框架39--重构 BasicAuth 中间件

gin框架39--重构 BasicAuth 中间件 介绍gin BasicAuth 解析自定义newAuth实现基础认证注意事项说明 介绍 每当我们打开一个网址的时候,会自动弹出一个认证界面,要求我们输入用户名和密码,这种BasicAuth是最基础、最常见的认证方式&#xff0…

SIEMENS S7-1200 汽车转弯灯程序 编程与分析

公告 项目地址:https://github.com/MartinxMax/SIEMENS-1200-car_turn_signal 分析 题目: 画IO分配表 输入输出m3.0左转弯开关q0.0左闪灯m3.1右转弯开关q0.1右闪灯m3.2停止开关 博图V16配置 设置PLC的IP地址 允许远程通信访问 将HMI设备拖入 注意,我们这边选择的是HMI连接…

数据结构----算法--五大基本算法(这里没有写分支限界法)和银行家算法

数据结构----算法–五大基本算法(这里没有写分支限界法)和银行家算法 一.贪心算法 1.什么是贪心算法 在有多个选择的时候不考虑长远的情况,只考虑眼前的这一步,在眼前这一步选择当前的最好的方案 二.分治法 1.分治的概念 分…

某讯D-Link AC集中管理平台未授权访问漏洞复现 CNVD-2023-19479

目录 1.漏洞概述 2.影响版本 3.漏洞等级 4.漏洞复现 5.Nuclei自动化验证POC 6.修复建议

【JavaEE】Callable 接口

Callable 是一个 interface . 相当于把线程封装了一个 “返回值”. 方便程序猿借助多线程的方式计算结果. 实现Callable也是创建线程的一种方法!!!! Callable的用法非常接近于Runnable,Runnable描述了一个任务&#…