深度学习-Pytorch如何构建和训练模型

深度学习-Pytorch如何构建和训练模型

用pytorch如何构建模型,如何训练模型,如何测试模型?

pytorch 目前在深度学习具有重要的地位,比起早先的caffe,tensorflow,keras越来越受到欢迎,其他的深度学习框架越来越显得小众。

数据分析

数据分析-Pandas如何转换产生新列

数据分析-Pandas如何统计数据概况

数据分析-Pandas如何轻松处理时间序列数据

数据分析-Pandas如何选择数据子集

数据分析-Pandas如何重塑数据表-CSDN博客

经典算法

经典算法-遗传算法的python实现

经典算法-模拟退火算法的python实现

经典算法-粒子群算法的python实现-CSDN博客

LLM应用

大模型查询工具助手之股票免费查询接口

Python技巧-终端屏幕打印光标和文字控制

如何构建深度学习模型

在pytorch里构建模型,继承现有的模块比较方便——这就是nn.Module 模块。通过模型的初始化 _init_ 函数,定义网络层,指定数据是如何通过网络的前进流向。

同时,最重要的是提升训练和预测的速度,加速训练就免不了把模型放在GPU或者CPU里面,这就需要模型配置硬件部署位置。

# 模型配置到cpu还是GPU,或者多处理器.
device = ("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available()else "cpu"
)
print(f"Using {device} device")# 定义模型
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork().to(device)
print(model)

此处用一个简单的神经网络模型解释深度学习网络的构建,麻雀虽小五脏俱全。

模型网络有6层:Flatten层,顺序网络里,构建1个28 X 28像素到512的线性连接层,接着是ReLU激活层 ,随后又是512 X 512的线性连接层,随后激活层ReLU层,最后是一个512 X 10线性连接层,输出10个结果值。

Using cuda device
NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
)

如何训练并优化模型

构建好模型后,如何优化模型,提升识别精确度呢?

其实也很简单,pytorch已经封装好了很多处理,只需要调用损失函数和优化函数即可。

此处损失函数采用交叉熵函数CrossEntropyLoss,优化函数采用SGD函数,你也可以用别的函数。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

当然,上篇说道,训练数据集较多,GPU显存又有限,只好分批次进行加载。如此,便需要结合上篇文章分批次加载模型,不再赘述,详见:

[深度学习-Pytorch数据集构造和分批加载-CSDN博客]

在每个训练循环中,模型都在训练数据集计算,并计算损失函数,然后再反向传播,叠加优化器对模型进行优化处理。不断循环叠加,使得模型预测能力逐渐提升。

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

如何测试模型效果

当要测试模型性能效果时,是不能再使用训练数据来测的。就像是考试不能提前知道考卷题目,否则漏题就不能获得真正的能力。模型测试也一样,因此在模型加载前先把数据集按比例分了,或三七分,或二八分。把一部分测试集进行测试,观察效果如何。

依旧加载测试的数据集,告知模型是做测试,不再训练 model.eval(),因此也不再把残差反向传播,torch.no_grad()。当然,依旧把数据和模型加载到加速设备上,如果有GPU的话。

监控数据主要有两个:平均损失值,准确度。

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

有了监测数据后,每个轮次都想查看训练效果是否朝着好的方向前进,如果好的就继续不断迭代,直到满足模型设计的预测精度等指标要求,如果不是就提前终止了。

epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

很简单吧,这里面把训练5个轮次,就结束。也就是示意一下深度学习的流程。

以上代码只是一个简单示例,示例代码中的表达式可以根据实际问题进行修改。

觉得有用 收藏 收藏 收藏

点个赞 点个赞 点个赞

End


DeepLearning文章:

深度学习-Pytorch数据集构造和分批加载-CSDN博客

GPT专栏文章:

GPT实战系列-ChatGLM3本地部署CUDA11+1080Ti+显卡24G实战方案

GPT实战系列-LangChain + ChatGLM3构建天气查询助手

大模型查询工具助手之股票免费查询接口

GPT实战系列-简单聊聊LangChain

GPT实战系列-大模型为我所用之借用ChatGLM3构建查询助手

GPT实战系列-P-Tuning本地化训练ChatGLM2等LLM模型,到底做了什么?(二)

GPT实战系列-P-Tuning本地化训练ChatGLM2等LLM模型,到底做了什么?(一)

GPT实战系列-ChatGLM2模型的微调训练参数解读

GPT实战系列-如何用自己数据微调ChatGLM2模型训练

GPT实战系列-ChatGLM2部署Ubuntu+Cuda11+显存24G实战方案

GPT实战系列-Baichuan2本地化部署实战方案

GPT实战系列-Baichuan2等大模型的计算精度与量化

GPT实战系列-GPT训练的Pretraining,SFT,Reward Modeling,RLHF

GPT实战系列-探究GPT等大模型的文本生成-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/647566.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

10.多柱状图(MuliBarChart)

愿你出走半生,归来仍是少年&#xff01; 环境&#xff1a;.NET 7、MAUI 话接上回&#xff08;9.单柱状图&#xff08;SingleBarChart&#xff09;&#xff09;&#xff0c;从单柱拓展到多柱状图。 1.数据设置 private void InitValue(List<BasicSerieDto> dtos){Serie…

英文阅读-LinkedIn‘s Tips for Highly Effective Code Review

LinkedIn的CR技巧 LinkedIn团队CodeReview经验与方法&#xff0c;原文来自https://thenewstack.io/linkedin-code-review/ 总结 Do I Understand the “Why”? 在提交pr的同时需要描述本次修改的“动机”&#xff0c;有助于提高代码文档质量。 Am I Giving Positive Feedbac…

openssl3.2/test/certs - 055 - all DNS-like CNs allowed by CA1, no DNS SANs

文章目录 openssl3.2/test/certs - 055 - all DNS-like CNs allowed by CA1, no DNS SANs概述笔记END openssl3.2/test/certs - 055 - all DNS-like CNs allowed by CA1, no DNS SANs 概述 openssl3.2 - 官方demo学习 - test - certs 笔记 /*! * \file D:\my_dev\my_local_…

Java转成m3u8,hls格式

Java转成m3u8,hls格式 需求分析 大致思路 循环文件夹下面所有文件判断当前文件是否是视频文件&#xff0c;如果是视频文件先转为ts文件 因为听别人说先转成ts之后再切片会快很多 转成ts文件&#xff0c;并为这些文件单独生成一个目录&#xff0c;如果目录不存在则新建一个目…

14.5 Flash查询和添加数据库数据

14.5 Flash查询和添加数据库数据 在Flash与数据库通讯的实际应用中&#xff0c;如何实现用户的登录与注册是经常遇到的一个问题。登录实际上就是ASP根据Flash提供的数据查询数据库的过程&#xff0c;而注册则是ASP将Flash提供的数据写入数据库的过程。 1.启动Access2003&…

C++Linux网络编程Day1

最简单server程序 #include <iostream>// sys&#xff08;系统&#xff09;,socket&#xff08;套接字&#xff09;&#xff0c;这个还是挺好理解的 #include <sys/socket.h>#include <arpa/inet.h>#include <stdio.h> #include <string.h>int …

项目管理的细节:屁股决定脑袋

在我刚入行的时候&#xff0c;由于我勤奋好学&#xff0c;技术上钻研得比较深入&#xff0c;我曾一度自视甚高&#xff0c;甚至有些傲慢。那时&#xff0c;我年轻气盛&#xff0c;未能充分认识到自己的不足。我曾认为领导对我的评价不公&#xff0c;未能充分认识到我的能力和价…

SpringBoot和Vue接口调用传参方式

简单总结一下常用的传参方式&#xff0c;一些前后端分离项目接口调试时经常出现传参格式错误问题。 前后端进行交互时方法一般就分为get和post&#xff0c;至于后面的delete和put都是基于post进行封装而出的。 Http请求中不同的请求方式会设置不同的Content-Type,参数的传递方…

C#,获取与设置Windows背景图片的源代码

为了满足孩子们个性化桌面的需求。 这里发布获取与设置Windows背景图片的源代码。 1 文本格式 using System; using System.IO; using System.Data; using System.Linq; using System.Text; using System.Drawing; using System.Collections; using System.Collections.Gene…

ES数据处理方法

由于日志数据存在ES项目里&#xff0c;需要从ES中获取日志进行分析&#xff0c;使用SQL数据进行处理&#xff0c;如下&#xff1a; select traceid-- STRING COMMENT 流程id, ,appnum -- BIGINT COMMENT 迭代号, ,appversion --STRING COMMENT APP版本, ,appc…

JeecgBoot集成TiDB,打造高效可靠的数据存储解决方案

TiDB简介 TiDB是PingCAP公司自主设计、研发的开源分布式关系型数据库&#xff0c;同时支持在线事务处理与在线分析处理 (Hybrid Transactional and Analytical Processing, HTAP) 的融合型分布式数据库产品&#xff0c;具备水平扩容或者缩容、金融级高可用、实时 HTAP、云原生…

【学习笔记】CF1349F2 Slime and Sequences (Hard Version)

多项式工业警告&#xff01;&#xff01;&#xff01; 点击看题意 思路来自 这位大佬 。 为什么这么好的题解没人评论。 Part 1 前置知识&#xff1a;拉格朗日反演&#xff08;多项式复合&#xff09;&#xff0c;分式域&#xff08;引入负整数次项&#xff09;。 条件&a…

基数排序算法

1. 排序算法分类 十种常见排序算法可以分为两大类&#xff1a; 比较类排序&#xff1a; 通过比较来决定元素间的相对次序&#xff0c;由于其时间复杂度不能突破O(nlogn)&#xff0c;因此也称为非线性时间比较类排序。比较类排序算法包括&#xff1a;插入排序、希尔排序、选择…

Netty接收超长TCP数据时 使用按行分隔Decoder无法正确解码的问题解决

使用Netty实现的tcp服务端&#xff0c;由于tcp是流式传输的&#xff0c;故需要选用一个解码器对流式消息进行解码和包分隔&#xff0c;以防收到不正确的包。例如LineBasedFrameDecoder&#xff0c;LengthFieldBasedFrameDecoder&#xff0c;DelimiterBasedFrameDecoder等常用解…

第139期 做大还是做小-Oracle名称哪些事(20240125)

数据库管理139期 2024-01-25 第139期 做大还是做小-Oracle名称哪些事&#xff08;20240125&#xff09;1 问题2 排查3 扩展总结 第139期 做大还是做小-Oracle名称哪些事&#xff08;20240125&#xff09; 作者&#xff1a;胖头鱼的鱼缸&#xff08;尹海文&#xff09; Oracle A…

SQL - 事务控制

SQL - 事务控制 文章目录 SQL - 事务控制TCL - 事务事务的边界事务的特性事务的应用 事务隔离等级MySQL支持四种隔离级别 TCL - 事务 **模拟场景&#xff1a;**生活当中转账是转账方账户扣钱&#xff0c;收账方账户加钱。用数据库操作来模拟现实转账。 数据库模拟&#xff1a…

CI/CD

介绍一下CI/CD CI/CD的出现改变了开发人员和测试人员发布软件的方式,从最初的瀑布模型,到最后的敏捷开发(Agile Development),再到今天的DevOps,这是现代开发人员构建出色产品的技术路线 随着DevOps的兴起,出现了持续集成,持续交付和持续部署的新方法,传统的软件开发和交付方…

软件设计师——软件工程(五)

&#x1f4d1;前言 本文主要是【软件工程】——软件设计师——软件工程的文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是听风与他&#x1f947; ☁️博客首页&#xff1a;CSDN主页听风与他 &#x1f304…

安全防御综合组网实验

题目 要求 生产区在工作时间可以访问服务器区&#xff0c;仅可以访问http服务器。办公区全天可以访问服务器区&#xff0c;其中10.0.2.20 可以访问FTP服务器和http服务器。10.0.2.10仅可以ping通10.0.3.10。办公区在访问服务器区时采用匿名认证的方式进行上网行为管理。办公区…

Vue内嵌套层级过深,el-input改变值视图无响应

出现场景 1.v-for内绑定复杂类型数据内部值&#xff0c;通过input更改其值时。 2.element表单或表格等组件内部&#xff0c;el-input绑定复杂类型数据内部值&#xff0c;通过input更改其值时。 解决思路 1.el-input加入 input“change($event)” … import { getCurrentInst…