【Pytorch】模块化

文章目录

  • 1. 获取数据
  • 2. 创建Dataset和DataLoader
  • 3. 定义模型
  • 4. 创建训练模型引擎函数
  • 5. 创建保存模型的函数
  • 6. 训练、评估并保存模型


模块化涉及将jupyter notebook代码转换为一系列提供类似功能的不同 Python 脚本。

可以将笔记本代码从一系列单元格转换为以下 Python 文件:

  1. data_setup.py - 如果需要,用于准备和下载数据的文件。
  2. engine.py - 包含各种训练函数的文件。
  3. model_builder.pymodel.py - 用于创建 PyTorch 模型的文件。
  4. train.py - 用于利用所有其他文件并训练目标 PyTorch 模型的文件。
  5. utils.py - 专用于有用实用功能的文件。

上述文件的命名和布局将取决于您的用例和代码要求。 Python 脚本与单个notebook单元一样通用,这意味着您可以为几乎任何类型的功能创建脚本。

notebook非常适合快速迭代探索和运行实验,但是,对于较大规模的项目,您可能会发现 Python 脚本更具可重复性且更易于运行。

在你去download别人开源的项目时,可能会指示您在终端/命令行中运行如下代码来训练模型:

python train.py --model MODEL_NAME --batch_size BATCH_SIZE --lr LEARNING_RATE --num_epochs NUM_EPOCHS

train.py 是目标 Python 脚本,它可能包含训练 PyTorch 模型的函数,--model--batch_size--lr--num_epochs 被称为参数标志。

可以将它们设置为您喜欢的任何值,如果它们与 train.py 兼容,它们就会工作,如果不兼容,它们就会出错。

例如,训练 TinyVGG 模型 10 个时期,批量大小为 32,学习率为 0.001:

python train.py --model tinyvgg --batch_size 32 --lr 0.001 --num_epochs 10

Python脚本的目录结构:

going_modular/
├── going_modular/
│   ├── data_setup.py
│   ├── engine.py
│   ├── model_builder.py
│   ├── train.py
│   └── utils.py
├── models/
│   ├── 05_going_modular_cell_mode_tinyvgg_model.pth
│   └── 05_going_modular_script_mode_tinyvgg_model.pth
└── data/└── pizza_steak_sushi/├── train/│   ├── pizza/│   │   ├── image01.jpeg│   │   └── ...│   ├── steak/│   └── sushi/└── test/├── pizza/├── steak/└── sushi/

1. 获取数据


2. 创建Dataset和DataLoader

( data_setup.py )

"""
Contains functionality for creating PyTorch DataLoaders for 
image classification data.
"""
import osfrom torchvision import datasets, transforms
from torch.utils.data import DataLoaderNUM_WORKERS = os.cpu_count()def create_dataloaders(train_dir: str, test_dir: str, transform: transforms.Compose, batch_size: int, num_workers: int=NUM_WORKERS
):"""Creates training and testing DataLoaders.Takes in a training directory and testing directory path and turnsthem into PyTorch Datasets and then into PyTorch DataLoaders.Args:train_dir: Path to training directory.test_dir: Path to testing directory.transform: torchvision transforms to perform on training and testing data.batch_size: Number of samples per batch in each of the DataLoaders.num_workers: An integer for number of workers per DataLoader.Returns:A tuple of (train_dataloader, test_dataloader, class_names).Where class_names is a list of the target classes.Example usage:train_dataloader, test_dataloader, class_names = \= create_dataloaders(train_dir=path/to/train_dir,test_dir=path/to/test_dir,transform=some_transform,batch_size=32,num_workers=4)"""# Use ImageFolder to create dataset(s)train_data = datasets.ImageFolder(train_dir, transform=transform)test_data = datasets.ImageFolder(test_dir, transform=transform)# Get class namesclass_names = train_data.classes# Turn images into data loaderstrain_dataloader = DataLoader(train_data,batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=True,)test_dataloader = DataLoader(test_data,batch_size=batch_size,shuffle=False,num_workers=num_workers,pin_memory=True,)return train_dataloader, test_dataloader, class_names

如果我们想要创建 DataLoader ,我们现在可以在 data_setup.py 中使用该函数,如下所示:

# Import data_setup.py
from going_modular import data_setup# Create train/test dataloader and get class names as a list
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(...)

3. 定义模型

(model_builder.py)

"""
Contains PyTorch model code to instantiate a TinyVGG model.
"""
import torch
from torch import nn class TinyVGG(nn.Module):"""Creates the TinyVGG architecture.Replicates the TinyVGG architecture from the CNN explainer website in PyTorch.See the original architecture here: https://poloclub.github.io/cnn-explainer/Args:input_shape: An integer indicating number of input channels.hidden_units: An integer indicating number of hidden units between layers.output_shape: An integer indicating number of output units."""def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:super().__init__()self.conv_block_1 = nn.Sequential(nn.Conv2d(in_channels=input_shape, out_channels=hidden_units, kernel_size=3, stride=1, padding=0),  nn.ReLU(),nn.Conv2d(in_channels=hidden_units, out_channels=hidden_units,kernel_size=3,stride=1,padding=0),nn.ReLU(),nn.MaxPool2d(kernel_size=2,stride=2))self.conv_block_2 = nn.Sequential(nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),nn.ReLU(),nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),nn.ReLU(),nn.MaxPool2d(2))self.classifier = nn.Sequential(nn.Flatten(),# Where did this in_features shape come from? # It's because each layer of our network compresses and changes the shape of our inputs data.nn.Linear(in_features=hidden_units*13*13,out_features=output_shape))def forward(self, x: torch.Tensor):x = self.conv_block_1(x)x = self.conv_block_2(x)x = self.classifier(x)return x# return self.classifier(self.block_2(self.block_1(x))) # <- leverage the benefits of operator fusion

4. 创建训练模型引擎函数

  1. train_step() - 接受模型、 DataLoader 、损失函数和优化器,并在 DataLoader 上训练模型。
  2. test_step() - 接受模型、 DataLoader 和损失函数,并在 DataLoader 上评估模型。
  3. train() - 对给定数量的 epoch 一起执行 1. 和 2. 并返回结果字典。

由于这些将成为我们模型训练的引擎,因此我们可以将它们全部放入名为 engine.py 的 Python 脚本中:

"""
Contains functions for training and testing a PyTorch model.
"""
import torchfrom tqdm.auto import tqdm
from typing import Dict, List, Tupledef train_step(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module, optimizer: torch.optim.Optimizer,device: torch.device) -> Tuple[float, float]:"""Trains a PyTorch model for a single epoch.Turns a target PyTorch model to training mode and thenruns through all of the required training steps (forwardpass, loss calculation, optimizer step).Args:model: A PyTorch model to be trained.dataloader: A DataLoader instance for the model to be trained on.loss_fn: A PyTorch loss function to minimize.optimizer: A PyTorch optimizer to help minimize the loss function.device: A target device to compute on (e.g. "cuda" or "cpu").Returns:A tuple of training loss and training accuracy metrics.In the form (train_loss, train_accuracy). For example:(0.1112, 0.8743)"""# Put model in train modemodel.train()# Setup train loss and train accuracy valuestrain_loss, train_acc = 0, 0# Loop through data loader data batchesfor batch, (X, y) in enumerate(dataloader):# Send data to target deviceX, y = X.to(device), y.to(device)# 1. Forward passy_pred = model(X)# 2. Calculate  and accumulate lossloss = loss_fn(y_pred, y)train_loss += loss.item() # 3. Optimizer zero gradoptimizer.zero_grad()# 4. Loss backwardloss.backward()# 5. Optimizer stepoptimizer.step()# Calculate and accumulate accuracy metric across all batchesy_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)train_acc += (y_pred_class == y).sum().item()/len(y_pred)# Adjust metrics to get average loss and accuracy per batch train_loss = train_loss / len(dataloader)train_acc = train_acc / len(dataloader)return train_loss, train_accdef test_step(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module,device: torch.device) -> Tuple[float, float]:"""Tests a PyTorch model for a single epoch.Turns a target PyTorch model to "eval" mode and then performsa forward pass on a testing dataset.Args:model: A PyTorch model to be tested.dataloader: A DataLoader instance for the model to be tested on.loss_fn: A PyTorch loss function to calculate loss on the test data.device: A target device to compute on (e.g. "cuda" or "cpu").Returns:A tuple of testing loss and testing accuracy metrics.In the form (test_loss, test_accuracy). For example:(0.0223, 0.8985)"""# Put model in eval modemodel.eval() # Setup test loss and test accuracy valuestest_loss, test_acc = 0, 0# Turn on inference context managerwith torch.inference_mode():# Loop through DataLoader batchesfor batch, (X, y) in enumerate(dataloader):# Send data to target deviceX, y = X.to(device), y.to(device)# 1. Forward passtest_pred_logits = model(X)# 2. Calculate and accumulate lossloss = loss_fn(test_pred_logits, y)test_loss += loss.item()# Calculate and accumulate accuracytest_pred_labels = test_pred_logits.argmax(dim=1)test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))# Adjust metrics to get average loss and accuracy per batch test_loss = test_loss / len(dataloader)test_acc = test_acc / len(dataloader)return test_loss, test_accdef train(model: torch.nn.Module, train_dataloader: torch.utils.data.DataLoader, test_dataloader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer,loss_fn: torch.nn.Module,epochs: int,device: torch.device) -> Dict[str, List]:"""Trains and tests a PyTorch model.Passes a target PyTorch models through train_step() and test_step()functions for a number of epochs, training and testing the modelin the same epoch loop.Calculates, prints and stores evaluation metrics throughout.Args:model: A PyTorch model to be trained and tested.train_dataloader: A DataLoader instance for the model to be trained on.test_dataloader: A DataLoader instance for the model to be tested on.optimizer: A PyTorch optimizer to help minimize the loss function.loss_fn: A PyTorch loss function to calculate loss on both datasets.epochs: An integer indicating how many epochs to train for.device: A target device to compute on (e.g. "cuda" or "cpu").Returns:A dictionary of training and testing loss as well as training andtesting accuracy metrics. Each metric has a value in a list for each epoch.In the form: {train_loss: [...],train_acc: [...],test_loss: [...],test_acc: [...]} For example if training for epochs=2: {train_loss: [2.0616, 1.0537],train_acc: [0.3945, 0.3945],test_loss: [1.2641, 1.5706],test_acc: [0.3400, 0.2973]} """# Create empty results dictionaryresults = {"train_loss": [],"train_acc": [],"test_loss": [],"test_acc": []}# Loop through training and testing steps for a number of epochsfor epoch in tqdm(range(epochs)):train_loss, train_acc = train_step(model=model,dataloader=train_dataloader,loss_fn=loss_fn,optimizer=optimizer,device=device)test_loss, test_acc = test_step(model=model,dataloader=test_dataloader,loss_fn=loss_fn,device=device)# Print out what's happeningprint(f"Epoch: {epoch+1} | "f"train_loss: {train_loss:.4f} | "f"train_acc: {train_acc:.4f} | "f"test_loss: {test_loss:.4f} | "f"test_acc: {test_acc:.4f}")# Update results dictionaryresults["train_loss"].append(train_loss)results["train_acc"].append(train_acc)results["test_loss"].append(test_loss)results["test_acc"].append(test_acc)# Return the filled results at the end of the epochsreturn results

现在我们已经有了 engine.py 脚本,我们可以通过以下方式从中导入函数:

# Import engine.py
from going_modular import engine# Use train() by calling it from engine.py
engine.train(...)

5. 创建保存模型的函数

( utils.py )

将 save_model() 函数保存到名为 utils.py 的文件中:

"""
Contains various utility functions for PyTorch model training and saving.
"""
import torch
from pathlib import Pathdef save_model(model: torch.nn.Module,target_dir: str,model_name: str):"""Saves a PyTorch model to a target directory.Args:model: A target PyTorch model to save.target_dir: A directory for saving the model to.model_name: A filename for the saved model. Should includeeither ".pth" or ".pt" as the file extension.Example usage:save_model(model=model_0,target_dir="models",model_name="05_going_modular_tingvgg_model.pth")"""# Create target directorytarget_dir_path = Path(target_dir)target_dir_path.mkdir(parents=True,exist_ok=True)# Create model save pathassert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"model_save_path = target_dir_path / model_name# Save the model state_dict()print(f"[INFO] Saving model to: {model_save_path}")torch.save(obj=model.state_dict(),f=model_save_path)

可以导入它并通过以下方式使用它,而不是重新编写它:

# Import utils.py
from going_modular import utils# Save a model to file
save_model(model=...target_dir=...,model_name=...)

6. 训练、评估并保存模型

( train.py )
可以在命令行上使用一行代码来训练 PyTorch 模型:

python train.py

要创建 train.py ,我们将执行以下步骤:

  1. 导入各种依赖项,即 torch 、 os 、 torchvision.transforms 以及 going_modular 目录 data_setup 、 model_builder 、 utils 。
  2. 注意:由于 train.py 将位于 going_modular 目录中,因此我们可以通过 import … 而不是 from going_modular import … 导入其他模块。
  3. 设置各种超参数,例如批量大小、时期数、学习率和隐藏单元数(将来可以通过 Python 的 argparse 设置)。
  4. 设置训练和测试目录。
  5. 设置与设备无关的代码。
  6. 创建必要的数据转换。
  7. 使用 data_setup.py 创建 DataLoaders。
  8. 使用 model_builder.py 创建模型。
  9. 设置损失函数和优化器。
  10. 使用 engine.py 训练模型。
  11. 使用 utils.py 保存模型。
"""
Trains a PyTorch image classification model using device-agnostic code.
"""import os
import torch
import data_setup, engine, model_builder, utilsfrom torchvision import transforms# Setup hyperparameters
NUM_EPOCHS = 5
BATCH_SIZE = 32
HIDDEN_UNITS = 10
LEARNING_RATE = 0.001# Setup directories
train_dir = "data/pizza_steak_sushi/train"
test_dir = "data/pizza_steak_sushi/test"# Setup target device
device = "cuda" if torch.cuda.is_available() else "cpu"# Create transforms
data_transform = transforms.Compose([transforms.Resize((64, 64)),transforms.ToTensor()
])# Create DataLoaders with help from data_setup.py
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(train_dir=train_dir,test_dir=test_dir,transform=data_transform,batch_size=BATCH_SIZE
)# Create model with help from model_builder.py
model = model_builder.TinyVGG(input_shape=3,hidden_units=HIDDEN_UNITS,output_shape=len(class_names)
).to(device)# Set loss and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=LEARNING_RATE)# Start training with help from engine.py
engine.train(model=model,train_dataloader=train_dataloader,test_dataloader=test_dataloader,loss_fn=loss_fn,optimizer=optimizer,epochs=NUM_EPOCHS,device=device)# Save the model with help from utils.py
utils.save_model(model=model,target_dir="models",model_name="05_going_modular_script_mode_tinyvgg_model.pth")

可以调整 train.py 文件以使用 Python 的 argparse 模块的参数标志输入,这将允许我们提供不同的超参数设置,如前面讨论的:

python train.py --model MODEL_NAME --batch_size BATCH_SIZE --lr LEARNING_RATE --num_epochs NUM_EPOCHS

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/705448.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JetBrains系列工具,配置PlantUML绘图

PlantUML是一个很强大的绘图工具&#xff0c;各种图都可以绘制&#xff0c;具体的可以去官网看看&#xff0c;或者百度。 PlantUML简述 https://plantuml.com/zh/ PlantUML语言参考指引 https://plantuml.com/zh/guide PlantUML语言是依赖Graphviz进行解析的。Graphviz是开源…

[设计模式Java实现附plantuml源码~行为型] 撤销功能的实现——备忘录模式

前言&#xff1a; 为什么之前写过Golang 版的设计模式&#xff0c;还在重新写Java 版&#xff1f; 答&#xff1a;因为对于我而言&#xff0c;当然也希望对正在学习的大伙有帮助。Java作为一门纯面向对象的语言&#xff0c;更适合用于学习设计模式。 为什么类图要附上uml 因为很…

2024程序员容器化上云之旅-第6集-Ubuntu-WSL2-Windows11版:艰难复活

故事梗概 Java程序员马意浓在互联网公司维护老旧电商后台系统。 渴望学习新技术的他在工作中无缘Docker和K8s。 他开始自学Vue3并使用SpringBoot3完成了一个前后端分离的Web应用系统&#xff0c;并打算将其用Docker容器化后用K8s上云。 8 复活重生 周末终于有点属于自己的…

【书籍分享 • 第三期】虚拟化与容器技术

文章目录 一、本书内容二、读者对象三、编辑推荐四、前言4.1 云计算技术的发展4.2 KVM、Docker4.3 本书内容简介4.4 作者简介 五、粉丝福利 一、本书内容 《虚拟化与容器技术》通过深入浅出的方式介绍KVM虚拟化技术与Docker容器技术的概念、原理及实现方法&#xff0c;内容包括…

Linux之安装Nginx、前后端分离项目部署

目录 一、安装Nginx 1.1先一键安装4个依赖 1.2下载并解压安装包 1.3安装nginx&#xff0c;一般我们在nginx都是要安装ssl证书的 1.4 启动nginx服务 1.5开放80端口 1.6配置nginx自启动 1.7修改/etc/rc.d/rc/local的权限 二、多个tomcat负载加后端部署 2.1创建多个tomca…

Windows已经安装了QT 6.3.0,如何再安装一个QT 5.12

要在Windows上安装Qt 5.12&#xff0c;您可以按照以下步骤操作&#xff1a; 下载Qt 5.12&#xff1a;访问Qt官方网站或其他可信赖的来源&#xff0c;下载Qt 5.12的安装包。 下载安装地址 下载安装详细教程 安装问题点 qt安装时“Error during installation process(qt.tools…

react useRef用法

1&#xff0c;保存变量永远不丢失 import React, { useState,useRef } from react export default function App() { const [count,setcount] useState(0) var mycount useRef(0)//保存变量永远不丢失--useRef用的是闭包原理 return( <div> <button onClick{()>…

跨境电商营销进化史:从传统广告到智能化策略的全面探析

随着全球化的不断推进和互联网技术的飞速发展&#xff0c;跨境电商在过去几年里取得了显著的发展。在这个竞争激烈的市场中&#xff0c;企业们纷纷调整营销策略以应对不断变化的消费者需求和市场趋势。本文Nox聚星将和大家探讨跨境电商营销策略的演变过程&#xff0c;从传统的推…

MySQL基础(二)

文章目录 MySQL基础&#xff08;二&#xff09;1. 数据库操作-DQL1.1 介绍1.2 语法1.3 基本查询1.4 条件查询1.5 聚合函数1.6 分组查询1.7 排序查询1.8 分页查询1.9 案例1.9.1 案例一1.9.2 案例二 2. 多表设计2.1 一对多2.1.1 表设计2.1.2 外键约束 2.2 一对一2.3 多对多2.4 案…

【Spring Boot 3】【JPA】@ManyToOne 实现一对多单向关联

【Spring Boot 3】【JPA】@ManyToOne 实现一对多单向关联 背景介绍开发环境开发步骤及源码工程目录结构总结背景 软件开发是一门实践性科学,对大多数人来说,学习一种新技术不是一开始就去深究其原理,而是先从做出一个可工作的DEMO入手。但在我个人学习和工作经历中,每次学…

Python爬虫中的单线程、多线程问题(文末送书)

前言 在使用爬虫爬取数据的时候&#xff0c;当需要爬取的数据量比较大&#xff0c;且急需很快获取到数据的时候&#xff0c;可以考虑将单线程的爬虫写成多线程的爬虫。下面来学习一些它的基础知识和代码编写方法。 一、进程和线程 进程可以理解为是正在运行的程序的实例。进…

【Flink精讲】Flink反压调优

Flink 网络流控及反压的介绍&#xff1a; Apache Flink学习网 反压的理解 简单来说&#xff0c; Flink 拓扑中每个节点&#xff08;Task&#xff09;间的数据都以阻塞队列的方式传输&#xff0c;下游来不及消费导致队列被占满后&#xff0c;上游的生产也会被阻塞&#xff0c;…

dpvs 笔记

20、 基于ECMP的多活负载均衡策略 当使用ospf/ECMP来实现高可用&#xff0c;所以keepalived不需要配置vrrp功能。keepalived只使用后端服务健康检查功能。 Equal-Cost Multi-Path Routing (ECMP) ECMP根据SIP-DIP对来选择路由 keepalived 健康检查机制说明 keepalived TCP chec…

深入探索计算机组成原理:构建信息时代的基石

### 深入探索计算机组成原理&#xff1a;构建信息时代的基石 在当代社会&#xff0c;计算机已经渗透到我们生活的方方面面&#xff0c;从家庭到工作场所&#xff0c;从基础科学研究到工业生产&#xff0c;无一不受到其深远影响。这一切的基础都建立在对计算机组成原理的深入理…

GaussDB SQL调优:选择合适的分布列

一、背景 GaussDB是华为公司倾力打造的自研企业级分布式关系型数据库&#xff0c;该产品具备企业级复杂事务混合负载能力&#xff0c;同时支持优异的分布式事务&#xff0c;同城跨AZ部署&#xff0c;数据0丢失&#xff0c;支持1000扩展能力&#xff0c;PB级海量存储等企业级数…

Netty NIO 非阻塞模式

1.概要 1.1 说明 使用非阻塞的模式&#xff0c;就可以用一个现场&#xff0c;处理多个客户端的请求了 1.2 要点 ssc.configureBlocking(false);if(sc!null){ sc.configureBlocking(false); channels.add(sc); }if(len>0){ byteBuffer.flip(); 2.代码 2.1 服务端代码 …

Springboot 多级缓存设计与实现

&#x1f3f7;️个人主页&#xff1a;牵着猫散步的鼠鼠 &#x1f3f7;️系列专栏&#xff1a;Java全栈-专栏 &#x1f3f7;️个人学习笔记&#xff0c;若有缺误&#xff0c;欢迎评论区指正 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&…

StarRocks之扩容缩容

扩缩容 FE 集群 StarRocks FE 节点分为 Follower 节点和 Observer 节点。Follower 节点参与选举投票和写入&#xff0c;Observer 节点只用来同步日志&#xff0c;扩展读性能。 注意&#xff1a; 所有 FE 节点的 http_port 必须相同。 Follower FE 节点&#xff08;包括 Leader…

超真诚婚礼邀请函小程序

结婚了&#xff0c;自己写个婚礼邀请函小程序&#xff0c;含泪省下&#xffe5;49.9&#xff1b;程序员的浪漫&#xff01; 1、定位直达 2、背景音乐 3、倒计时 4、CSDN图床 页面代码如下&#xff1a; <cu-custom bgColor"bg-yellow-light" isBack"{{fal…

基于HT32的智能家居demo(蓝牙上位机)

参加合泰杯作品的部分展示&#xff0c;基于HT32的智能家居&#xff0c;这里展示灯光的相关控制&#xff0c;是用蓝牙进行的数据透传&#xff0c;参考了一些资料&#xff0c;美化封装了一下之前的上位机界面。 成果展示 点击主界面的蓝牙设置&#xff0c;进行连接&#xff0c;下…