LLM学习笔记-5

目录

  • 1.多层神经网络的实现
  • 2. 训练轮次示例
  • 3. 保存并加载模型
  • 4. 使用GPU加速训练
  • 5. 使用上面所教,进行一次训练

摘要:今天想整理一下Pytorch常用操作,以便以后进行预习(不是)
在这里插入图片描述

1.多层神经网络的实现

这是常用的操作,要会

class NeuralNetwork(torch.nn.Module):def __init__(self, num_inputs, num_outputs):super().__init__()self.layers = torch.nn.Sequential(# 第一个隐藏层torch.nn.Linear(num_inputs, 30),torch.nn.ReLU(),# 第二个隐藏层torch.nn.Linear(30, 20),torch.nn.ReLU(),# 输出层torch.nn.Linear(20, num_outputs),)def forward(self, x):logits = self.layers(x)return logitsmodel = NeuralNetwork(50, 3)
print(model)

NeuralNetwork(
(layers): Sequential(
(0): Linear(in_features=50, out_features=30, bias=True)
(1): ReLU()
(2): Linear(in_features=30, out_features=20, bias=True)
(3): ReLU()
(4): Linear(in_features=20, out_features=3, bias=True)
)
)

2. 训练轮次示例

import torch.nn.functional as Ftorch.manual_seed(123)
model = NeuralNetwork(num_inputs=2, num_outputs=2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)num_epochs = 3for epoch in range(num_epochs):model.train()for batch_idx, (features, labels) in enumerate(train_loader):logits = model(features)loss = F.cross_entropy(logits, labels) # 损失函数optimizer.zero_grad()loss.backward()optimizer.step()### 日志print(f"Epoch: {epoch+1:03d}/{num_epochs:03d}"f" | Batch {batch_idx:03d}/{len(train_loader):03d}"f" | Train/Val Loss: {loss:.2f}")model.eval()# 可选的模型评估指标

Epoch: 001/003 | Batch 000/002 | Train/Val Loss: 0.75
Epoch: 001/003 | Batch 001/002 | Train/Val Loss: 0.65
Epoch: 002/003 | Batch 000/002 | Train/Val Loss: 0.44
Epoch: 002/003 | Batch 001/002 | Train/Val Loss: 0.13
Epoch: 003/003 | Batch 000/002 | Train/Val Loss: 0.03
Epoch: 003/003 | Batch 001/002 | Train/Val Loss: 0.00

3. 保存并加载模型

就一句话

torch.save(model.state_dict(), "model.pth")

4. 使用GPU加速训练

我们常常说的CUDA就是在GPU上训练

import torch
# 显示PyTorch是否支持GPU
print(torch.cuda.is_available())

如果显示True,则代表可以用GPU,否则则要用CPU

# 根据设备可用情况选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

5. 使用上面所教,进行一次训练

创建了一个简单的神经网络模型来对二分类问题进行训练,并且使用了 PyTorch 提供的 Dataset 和 DataLoader 类来加载数据集并进行批处理。此外,你还定义了一个函数来计算模型的准确率。

import torch
X_train = torch.tensor([[-1.2, 3.1],[-0.9, 2.9],[-0.5, 2.6],[2.3, -1.1],[2.7, -1.5]
])
y_train = torch.tensor([0, 0, 0, 1, 1])
X_test = torch.tensor([[-0.8, 2.8],[2.6, -1.6],
])
y_test = torch.tensor([0, 1])from torch.utils.data import Dataset
class ToyDataset(Dataset):def __init__(self, X, y):self.features = Xself.labels = ydef __getitem__(self, index):one_x = self.features[index]one_y = self.labels[index]return one_x, one_ydef __len__(self):return self.labels.shape[0]
train_ds = ToyDataset(X_train, y_train)
test_ds = ToyDataset(X_test, y_test)from torch.utils.data import DataLoader
torch.manual_seed(123)
train_loader = DataLoader(dataset=train_ds,batch_size=2,shuffle=True,num_workers=1,drop_last=True
)
test_loader = DataLoader(dataset=test_ds,batch_size=2,shuffle=False,num_workers=1
)class NeuralNetwork(torch.nn.Module):def __init__(self, num_inputs, num_outputs):super().__init__()self.layers = torch.nn.Sequential(# 第一个隐藏层torch.nn.Linear(num_inputs, 30),torch.nn.ReLU(),# 第二个隐藏层torch.nn.Linear(30, 20),torch.nn.ReLU(),# 输出层torch.nn.Linear(20, num_outputs),)def forward(self, x):logits = self.layers(x)return logits# 使用accuracy(准确率)作为指标
def compute_accuracy(model, dataloader, device):model = model.eval()correct = 0.0total_examples = 0for idx, (features, labels) in enumerate(dataloader):# 将数据移动到指定的设备上features, labels = features.to(device), labels.to(device) # Newwith torch.no_grad():logits = model(features)# 获取预测结果并计算准确数量predictions = torch.argmax(logits, dim=1)compare = labels == predictionscorrect += torch.sum(compare)total_examples += len(compare)# 计算并返回准确率return (correct / total_examples).item()import torch.nn.functional as F
# 设置随机数种子,以确保可复现性
torch.manual_seed(123)
# 创建神经网络模型
model = NeuralNetwork(num_inputs=2, num_outputs=2)
# 根据设备可用情况选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 将模型移动到所选设备上
model = model.to(device)
# 定义优化器,使用随机梯度下降 (SGD)
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
# 定义训练循环的 epoch 数量
num_epochs = 3
for epoch in range(num_epochs):model.train()for batch_idx, (features, labels) in enumerate(train_loader):features, labels = features.to(device), labels.to(device) logits = model(features)loss = F.cross_entropy(logits, labels) # 损失函数optimizer.zero_grad()loss.backward()optimizer.step()### 训练日志print(f"Epoch: {epoch+1:03d}/{num_epochs:03d}"f" | Batch {batch_idx:03d}/{len(train_loader):03d}"f" | Train/Val Loss: {loss:.2f}")model.eval()print('accuracy',str(compute_accuracy(model, train_loader, device=device)))

Epoch: 001/003 | Batch 000/002 | Train/Val Loss: 0.75
Epoch: 001/003 | Batch 001/002 | Train/Val Loss: 0.65
Epoch: 002/003 | Batch 000/002 | Train/Val Loss: 0.44
Epoch: 002/003 | Batch 001/002 | Train/Val Loss: 0.13
Epoch: 003/003 | Batch 000/002 | Train/Val Loss: 0.03
Epoch: 003/003 | Batch 001/002 | Train/Val Loss: 0.00
accuracy:1.0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/3864.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Elcomsoft iOS Forensics Toolkit: iPhone/iPad/iPod 设备取证工具包

天津鸿萌科贸发展有限公司是 ElcomSoft 系列取证软件的授权代理商。 Elcomsoft iOS Forensics Toolkit 软件工具包适用于取证工作,对 iPhone、iPad 和 iPod Touch 设备执行完整文件系统和逻辑数据采集。对设备文件系统制作镜像,提取设备机密&#xff08…

阿斯达年代记三强争霸服务器没反应 安装中发生错误的解决方法

阿斯达年代记三强争霸服务器没反应 安装中发生错误的解决方法 最近刚上线的由影视剧改编的游戏《阿斯达年代记三强争霸》可谓是在游戏圈内引起了轩然大波,这是一款由网石集团与龙工作室联合开发的MMORPG游戏,游戏背景设定在一个名为阿斯大陆的区域&…

vue 实现项目进度甘特图

项目需求: 实现以1天、7天、30天为周期(周期根据筛选条件选择),展示每个项目不同里程碑任务进度。 项目在Vue-Gantt-chart: 使用Vue做数据控制的Gantt图表基础上进行了改造。 有需要的小伙伴也可以直接引入插件,自己…

用Scrapy编写第一个入门项目(基础四件套:spider,pipeline,setting,items)

简介:scrapy是一个用于爬取网页并提取数据的应用框架,也可用于提取API数据 写在前面:只想看scrapy的童鞋子请跳过5-7直接step8) step5,6是xpath和css入门,用于提取数据; step7是文件储存方式&…

国产麒麟系统下打包electron+vue项目(AppImage、deb)

需要用到的一些依赖包、安装包以及更详细的打包方法word以及麒麟官网给出的文档都已放网盘,链接在文章最后!!!!!!!!!!!!&a…

备考数通HCIE证书4点经验分享!

大家好,我是来自安阳工学院20级网络工程的刁同学,在2023年12月20日成功通过了华为Datacom HCIE认证,并且取得了笔试900多分,实验B的成绩。在此,我想把我的一些考证心得分享给正在备考的小伙伴们。 关于为什么考证 我…

使用自定义注解处理器,自动收集类信息

背景 在开发过程有些时候我们会需要收集一些类信息。比如要知道某个子类下的所有实现类。可以通过反射的方式实现。但是这种方法有性能问题,因为在运行时,所有类都会包含在dex文件中。这个文件中的类可能有几十万个。而且在实际开发中会发现&#xff0c…

ArcGIS专题图制作—3D峡谷地形

6分钟教你在ArcGIS Pro中优雅完成炫酷的美国大峡谷3D地图 6分钟教你在ArcGIS Pro中优雅完成炫酷的美国大峡谷3D地图。 这一期的制图教程将带我们走入美国大峡谷,让我们一起绘制这张美妙的地图吧!视频也上传到了B站,小伙伴可以去! …

数据结构与算法解题-20240426

这里写目录标题 面试题 08.04. 幂集367. 有效的完全平方数192. 统计词频747. 至少是其他数字两倍的最大数718. 最长重复子数组 面试题 08.04. 幂集 中等 幂集。编写一种方法,返回某集合的所有子集。集合中不包含重复的元素。 说明:解集不能包含重复的子…

【网络原理】TCP协议的连接管理机制(三次握手和四次挥手)

系列文章目录 【网络通信基础】网络中的常见基本概念 【网络编程】网络编程中的基本概念及Java实现UDP、TCP客户端服务器程序(万字博文) 【网络原理】UDP协议的报文结构 及 校验和字段的错误检测机制(CRC算法、MD5算法) 【网络…

Swift - 流程控制

文章目录 Swift - 流程控制if-else2. while3. for3.1 闭区间运算符3.2 半开区间运算符3.3 for - 区间运算符用在数组上3.3.1 单侧区间 3.4 区间类型3.5 带间隔的区间值 4. switch4.1 fallthrough4.2 switch注意点 5. 复合条件6. 区间匹配、元组匹配7. 值绑定8. where9. 标签语句…

DRF JWT认证进阶

JWT认证进阶 【0】准备工作 (1)模型准备 模型准备(继承django的auth_user表) from django.db import models from django.contrib.auth.models import AbstractUserclass UserInfo(AbstractUser):mobile models.CharField(ma…

C语言——内存函数的实现与模拟

1. memcpy 函数 与strcpy 函数类似 1.头文件 <string.h> 2.基本格式 • 函数memcpy从source的位置开始向后复制num个 字节 的数据到destination指向的内存位置。 • 这个函数在遇到 \0 的时候并不会停下来。 • 如果source和destination有任何的重叠&#xff0…

2024年钉钉直播回放怎么下载

又到了2024年,最近钉钉迎来了一波更新,经过我的研究,总算研究出来了一个方法,并且做成了工具 首先&#xff0c;让我们了解一下钉钉直播回放的下载方法。 钉钉直播回放工具链接&#xff1a;https://pan.baidu.com/s/1oPWJOp8L2SBDlklt_t5WQQ?pwd1234 提取码&#xff1a;1234 -…

【快速上手ESP32(基于ESP-IDFVSCode)】10-事件循环WiFi

事件循环 本来这篇文章是只写WiFi的&#xff0c;但是写的时候才发现离不开事件循环&#xff0c;因此再多添一点内容在WiFi前面。 事件循环简单来说就是一个&#xff08;循&#xff09;环&#xff0c;我们可以在这个环上绑上一些事件&#xff0c;我们也可以监听这个环&#xf…

JavaScript进阶(十五):JS 垃圾回收机制_vue gc

内存&#xff1a;由可读写单元组成&#xff0c;表示一片可操作空间&#xff1b;管理&#xff1a;人为的去操作一片空间的申请、使用和释放&#xff1b;内存管理&#xff1a;开发者主动申请空间、使用空间、释放空间&#xff1b;管理流程&#xff1a;申请-使用-释放&#xff1b;…

oracle sql monitor简单使用说明

一 sql monitor介绍 二 用命令行方式生成sql monitor报告 set long 1000000 set longchunksize 100000 set linesize 1000 set pagesize 0 set trim on set trimspool on set echo off set feedback off spool report_sql_monitor.html select dbms_sqltune.report_s…

线性代数-行列式-p1 矩阵的秩

目录 1.定义 2. 计算矩阵的秩 3. 矩阵的秩性质 1.定义 2. 计算矩阵的秩 3. 矩阵的秩性质

美国言语听力学会(ASHA)关于非处方 (OTC) 助听器的媒体声明(翻译稿)

美国国会于 2021 年 4 月 13 日批准美国听力学会积极提供建议&#xff0c;并一直积极参与制定FDA关于非处方助听器销售的拟议法规。根据2017年通过的立法授权。学院积极参与帮助塑造授权立法&#xff0c;并就即将出台的条例分享了建议。 根据美国卫生与公众服务部NIH / NIDCD的…

用Python绘制了几张有趣的可视化图表

流程图存在于我们生活的方方面面&#xff0c;对于我们追踪项目的进展&#xff0c;做出各种事情的决策都有着巨大的帮助&#xff0c;而对于的Python而言呢&#xff0c;绘制流程图也是十分轻松的&#xff0c;今天小编就来为大家介绍两个用于绘制流程图的模块&#xff0c;我们先来…