深入学习pytorch笔记

两个重要的函数

  • dir(): 一个内置函数,用于列出对象的所有属性和方法
    在这里插入图片描述

  • help():一个内置函数,用于获取关于Python对象、模块、函数、类等的详细信息
    在这里插入图片描述

Dateset类

  • Dataset:pytorch中的一个类,开发者在训练和测试时,用一个子类去继承Dataset类,继承和重写Dataset类中方法和属性,以加载数据集。
class Dataset(object):"""An abstract class representing a Dataset.All other datasets should subclass it. All subclasses should override``__len__``, that provides the size of the dataset, and ``__getitem__``,supporting integer indexing in range from 0 to len(self) exclusive."""def __getitem__(self, index):raise NotImplementedErrordef __len__(self):raise NotImplementedErrordef __add__(self, other):return ConcatDataset([self, other])
  • def getitem(self, index):必须重写,用于以加载数据集。
  • def len(self):可不重写,用于计算数据集中样本个数。
    在这里插入图片描述

TensorBoard

  • TensorBoard 是pytorch中一组用于数据可视化的工具,包含在TensorFlow库。
  • SummaryWriter类:用于在给定目录中创建事件文件,在训练时,将数据添加到文件中,用于显示。使用SummaryWriter类创建对象时,若没有给出事件文件名,则默认的事件文件名为run。

损失函数

  • torch.nn.loss():PyTorch 中的一个类,用于计算L1 损失函数,即计算了预测值与实际值之间的L1范数(即绝对差值)。
  • 在创建torch.nn.L1Loss(reduction)对象时,可以传入一个可选的参数reduction,它决定了如何从每个样本的损失中聚合得到最终的损失。
    1. reduction=‘mean’:计算所有样本损失的平均值作为最终损失。默认情况下,reduction参数的值为’mean’,即计算所有样本损失的平均值作为最终损失。
    2. reduction=‘none’:不进行任何聚合操作,直接返回每个样本的损失。
    3. reduction=‘sum’:计算所有样本损失的总和作为最终损失。
    4. reduction= ‘mean_none’: 计算所有样本损失的平均值,但是不除以样本数,即不进行归一化。
    5. reduction=‘sum_none’:计算所有样本损失的总和,但是不乘以样本数,即不进行归一化。
  • 在调用torch.nn.L1Loss()对象时,要传入预测值和实际值。
    在这里插入图片描述
  • torch.nn.MSELoss():PyTorch库中的一个类,用于计算均方误差。MSE损失函数的计算方式是:对于每个样本,计算预测值与真实值之间的平方差,然后取这些平方差的平均值。具体公式为:loss = 1/n Σ (y_pred - y_true)^2,其中n是样本数量。
    在这里插入图片描述
  • torch.nn.CrossEntropyLoss:是PyTorch库中的一个类,用于计算交叉熵损失。
  • 在创建对象时,torch.nn.CrossEntropyLoss()参数:
    1. weight: 类别权重。这是一个一维的tensor,用于为每个类别指定不同的权重。默认值是None,这时所有的类别权重都相等。如果指定了类别权重,那么在计算损失时,每个类别的损失将会根据其对应的权重进行加权平均。
    2. reduction: 损失的归约方式。这个参数决定了如何将交叉熵损失的值从样本级别降低到批次级别。可能的值有:‘none’(不进行归约,返回每个样本的交叉熵损失),‘mean’(对所有样本的交叉熵损失取平均),‘sum’(将所有样本的交叉熵损失相加)。默认值是’mean’。
    3. ignore_index: 被忽略的类别索引。如果设置了该参数,那么在计算交叉熵损失时,该类别对应的损失将被忽略。这个参数主要用于处理数据集中的无效类别或不需要分类的类别。默认值是-100。
  • 在调用torch.nn.CrossEntropyLoss的对象时,需要传入两个参数:
    1. input:这是一个一维或二维张量,表示模型的输出。对于每个输入样本,输出应该是一个长度为类别数量的向量,每个元素表示该类别与输入样本的相似度。
    2. target:这是一个一维张量,表示每个输入样本的正确类别标签。
      在这里插入图片描述

优化器(参数更新)

  • torch.optim.SGD:PyTorch 中的一个类,它实现了随机梯度下降(Stochastic Gradient Descent)算法。
  • 创建类对象时,torch.optim.SGD(params,lr,momentum,dampening,weight_decay,nesterov)的参数:
    1. params:要优化的参数,通常是模型中的参数。
    2. lr:学习率。控制参数更新的步长。默认值是0.01。
    3. momentum:动量。这个参数会考虑之前梯度的方向,使得优化器具有一定的"惯性",有助于加速训练。默认值是0。
    4. dampening:阻尼。这个参数可以防止动量过大导致震荡。默认值是0。
    5. weight_decay:权重衰减。可以防止过拟合,通过对参数本身进行惩罚来控制模型的复杂度。默认值是0,表示不进行权重衰减。
    6. nesterov:是否使用 Nesterov 动量。如果为 True,会使用 Nesterov 动量,否则使用标准 momentum。默认值是False
  • 创建优化器后,我们可以通过调用 optimizer.zero_grad() 清除之前的梯度,然后通过反向传播计算新的梯度,最后使用 optimizer.step() 更新模型的参数。

import torch
from torch import nn
from torch.nn import Sequential,Conv2d,MaxPool2d,Flatten
from torch.nn import Linear
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.transforms
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset, batch_size=64)class MY_Dodule(nn.Module):def __init__(self):super(MY_Dodule,self).__init__()self.model = Sequential(Conv2d(3, 32, kernel_size=5, padding=2),MaxPool2d(2),Conv2d(32, 32, kernel_size=5, padding=2),MaxPool2d(2),Conv2d(32, 64, kernel_size=5, padding=2),MaxPool2d(2),Flatten(),Linear(1024,64),Linear(64,10))def forward(self,input):output = self.model(input)return outputmy_module = MY_Dodule()
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(my_module.parameters(),lr=0.1)
for epoch in range(20):running_loss = 0.0for data in dataloader:images,targets = datainput = imagesoutput = my_module(input)  # 前向转播result_loss = loss(output,targets)  # 计算损失optim.zero_grad()  # 清除之前的梯度result_loss.backward() # 反向转播optim.step() #梯度更新running_loss += result_losspassprint(running_loss)pass

网络模型的使用和修改

  • torchvision.models.vgg16(pretrained,progress):PyTorch 中的一个类,是用来加载预训练的 VGG-16 模型的函数。

    1. pretrained:布尔型,决定是否从 PyTorch 的预训练模型库中加载训练好的权重。如果设为 True,则返回的模型会包含在大规模图像分类任务上训练得到的权重。如果设为 False,则模型不包含预训练的权重,你需要自己训练模型。默认为False。
    2. progress:布尔型,决定是否显示下载预训练模型过程的进度条。如果设为 True,则在下载预训练模型时会显示进度条。默认为True。
  • 在 VGG-16 模型中添加层:model是torchvision.models.vgg16()示例化对象,model.classifier.add_module(str,nn.Module)这个函数接受两个参数。

    1. 模块名称(str):这是你想要添加的模块的名称。你可以自己定义一个有意义的名称,以便在后续的代码中引用这个模块。
    2. 模块对象(nn.Module):这是你想要添加的模块本身。这个模块可以是任何PyTorch定义的神经网络层或者你自己定义的层。
  • 在 VGG-16 模型中修改层:model是torchvision.models.vgg16()示例化对象,model.classifier[n] = nn.Module

    1. n:VGG-16 模型中修改层的层号
    2. nn.Module:修改后的模块本身。这个模块可以是任何PyTorch定义的神经网络层或者你自己定义的层。
      在这里插入图片描述

网络模型的保存与读取

  • torch.save(model, ‘model.pth’):PyTorch 中的一个函数,模型model的权重和参数,保存在指定文件model.pth中。
  • model = torch.load(‘model.pth’):PyTorch 中的一个函数,根据model.pth文件,加载保存的模型并返回给变量 model
  • torch.save(model.state_dict(), ‘model.pth’): 将模型model参数(权重和偏置等,不包括模型的结构),以字典的形式保存到指定的文件 ‘model.pth’ 中。
  • model.load_state_dict(torch.load(‘model.pth’)):torch.load()函数读取文件中模型的参数信息,加载到model模型中。请注意,这种方式要求你在加载模型时已经知道模型model的结构。

模型训练流程(以CIFAR10为例)

  • 第一步:准备数据集,包括训练集和测试集
import torchvision# 准备训练集
train_data = torchvision.datasets.CIFAR10("dataset",train=True,transform=torchvision.transforms.ToTensor(),download=True)# 准备测试集
test_data = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
  • 第二步:计算数据长度
# 计算数据集长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))
  • 第三步:用dataloader()加载数据集,将数据集划分为批量子集
# dataloader()加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
  • 第四步:搭建神经网络,一般用一个单独python文件保存
import torch
from torch import nnclass My_Module(nn.Module):def __init__(self):super(My_Module,self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32 ,32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(32,64,5,1,2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4,64),nn.Linear(64,10),)def forward(self,input):output = self.model(input)return outputif __name__ == '__main__':my_module = My_Module()input = torch.ones((64, 3, 32, 32))output = my_module(input)print(output.shape)
  • 第五步:创建网络模型
# 创建网络模型
my_module = My_Module()
  • 第六步:定义损失函数
loss_f = nn.CrossEntropyLoss()
  • 第七步:定义优化器,进行梯度下降
# 定义优化器,进行梯度下降
learning_rate = 0.01  # 学习效率
optimizer = torch.optim.SGD(my_module, lr=learning_rate)
  • 第八步:设置训练网络模型的一些参数
# 设置训练网络模型的一些参数
total_train_step = 0  # 记录训练次数
total_test_step = 0  # 记录测试次数
epoch = 10 # 训练的轮次
writer = SummaryWriter("P27")  # 添加tensorboard
  • 第九步:训练网络模型
# 训练网络模型
for i in range(epoch):print("------第{}轮训练开始------".format(i + 1))for data in train_dataloader:images ,targets = datainput = imagesoutput = my_module(input)  # 前向传播loss = loss_f(output, targets)  # 计算损失loss.backward()  # 反向转播optimizer.zero_grad()  #optimizer.step() # 梯度下降total_train_step = total_train_step + 1print("训练次数:{},loss:{}".format(total_train_step, loss.item()))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/166476.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

抖音电商品牌力不足咋办?如何升级或强开旗舰店、官方旗舰店?我们有妙招!

随着抖音电商的发展,越来越多的商家蜂拥而至,入驻经营抖音小店... 然而我们在开店的时候,选择开通官方旗舰店、旗舰店、专营店或专卖店,却被系统提示为你的商标品牌力不足,无法开通官方旗舰店、旗舰店、专营店、专卖店…

在 vscode 中的json文件写注释,不报错的解决办法

打开 vscode 的「设置」,搜索:files: associations,然后添加 *.json jsonc最后

Nginx 配置错误导致的漏洞

目录 1. CRLF注入漏洞 Bottle HTTP头注入漏洞 2.目录穿越漏洞 3. http add_header被覆盖 本篇要复现的漏洞实验有一个网站直接为我们提供了Docker的环境,我们只需要下载下来就可以使用: Docker环境的安装可以参考:Docker安装 漏洞环境的…

展现天津援疆工作成果 “团结村里看振兴”媒体采风团走进和田

央广网天津11月19日消息(记者周思杨)11月18日,由媒体记者、书法和摄影家、旅行社企业代表等40余人组成的“团结村里看振兴”媒体采风团走进新疆和田。在接下来的一周时间里,采风团将走访天津援疆和田地区策勒县、于田县、民丰县乡村振兴示范村&#xff0…

HTML CSS登录网页设计

一、效果图: 二、HTML代码: <!DOCTYPE html> <!-- 定义HTML5文档 --> <html lang="en"> …

在全球碳市场中崭露头角的中碳CCNG

在全球气候治理的大背景下&#xff0c;中国碳中和发展集团有限公司&#xff08;简称中国碳中和&#xff09;正在成为全球碳交易市场的一个重要参与者。随着国际社会对碳排放的日益关注&#xff0c;中国碳中和凭借其在碳资产开发、咨询与管理等领域的深厚积累&#xff0c;正成为…

视频剪辑新招:批量随机分割,分享精彩瞬间

随着社交媒体的普及&#xff0c;短视频已经成为分享生活、交流信息的重要方式。为制作出吸引的短视频&#xff0c;许多创作者都投入了大量的时间和精力进行剪辑。然而&#xff0c;对于一些没有剪辑经验的新手来说&#xff0c;这个过程可能会非常繁琐。现在一起来看云炫AI智剪批…

杨传辉:从一体化架构,到一体化产品,为关键业务负载打造一体化数据库

在刚刚结束的年度发布会上&#xff0c;OceanBase正式推出一体化数据库的首个长期支持版本 4.2.1 LTS&#xff0c;这是面向 OLTP 核心场景的全功能里程碑版本&#xff0c;相比上一个 3.2.4 LTS 版本&#xff0c;新版本能力全面提升&#xff0c;适应场景更加丰富&#xff0c;有更…

web前端之若依框架图标对照表、node获取文件夹中的文件名,并通过数组返回文件名、在html文件中引入.svg文件、require、icon

MENU 前言效果图htmlJavaScripstylenode获取文件夹中的文件名 前言 需要把若依原有的icon的svg文件拿到哦&#xff01; 注意看生成svg的路径。 效果图 html <div id"idSvg" class"svg_box"></div>JavaScrip let listSvg [404, bug, build, …

CentOS7安装Docker运行环境

1 引言 Docker 是一个用于开发&#xff0c;交付和运行应用程序的开放平台。Docker 使您能够将应用程序与基础架构分开&#xff0c;从而可以快速交付软件。借助 Docker&#xff0c;您可以与管理应用程序相同的方式来管理基础架构。通过利用 Docker 的方法来快速交付&#xff0c;…

在新疆乌鲁木齐的汽车托运

在新疆乌鲁木齐要托运的宝! 看过来了 找汽车托运公司了 连夜吐血给你们整理了攻略!! ⬇️以下&#xff1a; 1 网上搜索 可以在搜索引擎或专业的货运平台上搜索相关的汽车托运公司信息。在网站上可以了解到公司的服务范围、托运价格、运输时效等信息&#xff0c;也可以参考其他车…

2024年的云趋势:云计算的前景如何?

本文讨论了2024年云计算的发展趋势。 适应复杂的生态系统、提供实时功能、优先考虑安全性和确保可持续性的需求正在引领云计算之船。多样化的工作负载允许探索通用的公共云基础设施范例之外的选项。由于需要降低成本、提高灵活性和降低风险&#xff0c;混合云和多云系统越来越受…

RabbitMQ 消息队列编程

安装与配置 安装 RabbitMQ 读者可以在 RabbitMQ 官方文档中找到完整的安装教程&#xff1a;Downloading and Installing RabbitMQ — RabbitMQ 本文使用 Docker 的方式部署。 RabbitMQ 社区镜像列表&#xff1a;https://hub.docker.com/_/rabbitmq 创建目录用于映射存储卷…

YOLOv5 分类模型 预处理 OpenCV实现

YOLOv5 分类模型 预处理 OpenCV实现 flyfish YOLOv5 分类模型 预处理 PIL 实现 YOLOv5 分类模型 OpenCV和PIL两者实现预处理的差异 YOLOv5 分类模型 数据集加载 1 样本处理 YOLOv5 分类模型 数据集加载 2 切片处理 YOLOv5 分类模型 数据集加载 3 自定义类别 YOLOv5 分类模型…

关于python 语音转字幕,字幕转语音大杂烩

文字转语音 Python语音合成之第三方库gTTs/pyttsx3/speech横评(内附使用方法)_python_脚本之家 代码示例 from gtts import gTTStts gTTS(你好你在哪儿&#xff01;,langzh-CN)tts.save(hello.mp3)import pyttsx3engine pyttsx3.init() #创建对象"""语速"…

目前比较好用的护眼台灯,小学生适合的护眼台灯推荐

随着技术的发展&#xff0c;灯光早已成为每家每户都需要的东西。但是灯光不好可能会对眼睛造成伤害是很多人没有注意到的。现在随着护眼灯产品越来越多&#xff0c;市场上台灯的选择越来越多样化&#xff0c;如何选择一个对眼睛无伤害、无辐射的台灯成为许多家长首先要考虑的问…

【C++初阶】四、类和对象(构造函数、析构函数、拷贝构造函数、赋值运算符重载函数)

相关代码gitee自取&#xff1a; C语言学习日记: 加油努力 (gitee.com) 接上期&#xff1a; 【C初阶】三、类和对象 &#xff08;面向过程、class类、类的访问限定符和封装、类的实例化、类对象模型、this指针&#xff09; -CSDN博客 引入&#xff1a;类的六个默认成员函数…

如何使用springboot服务端接口公网远程调试——实现HTTP服务监听

&#x1f308;个人主页&#xff1a;聆风吟 &#x1f525;系列专栏&#xff1a;网络奇遇记、Cpolar杂谈 &#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 文章目录 &#x1f4cb;前言一. 本地环境搭建1.1 环境参数1.2 搭建springboot服务项目 二. 内网穿透2.1 安装…

ATA-2042高压放大器在细胞的剪切应力传感器研究中的应用

微流控技术是一种通过微小的通道和微型装置对流体进行精确操控和分析的技术。它是现代医学技术发展过程中的一种重要的生物医学工程技术&#xff0c;具有广泛的应用前景和重要性。它在高通量分析、个性化医疗、细胞筛选等方面有着巨大的潜力&#xff0c;Aigtek安泰电子今天就将…

HR8833 双通道H桥电机驱动芯片

HR8833为玩具、打印机和其它电机一T化应用提供一种双通道电机驱动方案。HR8833提供两种封装&#xff0c;一种是带有L露焊盘的TSSOP-16封装&#xff0c;能改进散热性能&#xff0c;且是无铅产品&#xff0c;引脚框采用100&#xff05;无锡电镀。另一种封装为SOP16&#xff0c;不…