【深度学习】实验12 使用PyTorch训练模型

文章目录

  • 使用PyTorch训练模型
    • 1. 线性回归类
    • 2. 创建数据集
    • 3. 训练模型
    • 4. 测试模型
  • 附:系列文章

使用PyTorch训练模型

PyTorch是一个基于Python的科学计算库,它是一个开源的机器学习框架,由Facebook公司于2016年开源。它提供了构建动态计算图的功能,可以更自然地使用Python语言编写深度神经网络的程序,具有易于使用、灵活、高效等特点,被广泛应用于深度学习任务中。

PyTorch的核心是动态计算图(Dynamic Computational Graph),这意味着计算图是在运行时动态生成的,而不是预先编译好的。这个特点使得PyTorch具有高度的灵活性,可以更加轻松地进行实验和调试。同时,它也有一个静态计算图模块,可以用于生产环境中,提高计算效率。

另外,PyTorch的另一个特点是它的张量计算。张量是PyTorch中的核心数据结构,类似于NumPy中的数组。PyTorch支持GPU加速,可以使用GPU进行张量计算,大大提高了计算效率。同时,它也支持自动求导功能,可以自动计算张量的梯度,使得深度学习的模型训练更加便捷。

PyTorch还提供了丰富的模型库,包括经典的深度学习模型,如卷积神经网络(CNN)、循环神经网络(RNN)和生成对抗网络(GAN),以及各种领域的预训练模型,如自然语言处理(NLP)和计算机视觉(CV),可以快速搭建和训练模型。

PyTorch也具有良好的社区支持。它的文档详细且易于理解,社区提供了大量的示例和教程,可以帮助用户更好地学习和使用PyTorch。同时,PyTorch还有一个活跃的开发团队,定期发布新的版本,修复bug和增加新的特性,保证了PyTorch的稳定性和可用性。

总的来说,PyTorch是一个强大、灵活、易于使用的机器学习框架,具有良好的社区支持和广泛的应用领域,能够满足不同用户的需求。随着人工智能的不断发展,PyTorch的应用将会更加广泛。

1. 线性回归类

import torch
import numpy as np
import matplotlib.pyplot as plt
class LinearRegression(torch.nn.Module):def __init__(self):super().__init__()self.linear = torch.nn.Linear(1, 1)self.optimizer = torch.optim.SGD(self.parameters(), lr=0.01)self.loss_function = torch.nn.MSELoss()def forward(self, x):out = self.linear(x)return outdef train(self, data, model_save_path='model.path'):x = data["x"]y = data["y"]for epoch in range(10000):prediction = self.forward(x)loss = self.loss_function(prediction, y)self.optimizer.zero_grad()loss.backward()self.optimizer.step()if epoch % 100 == 0:print("epoch:{}, loss is:{}".format(epoch, loss.item()))torch.save(self.state_dict(), "linear.pth")def test(self, x, model_path="linear.pth"):x = data["x"]y = data["y"]self.load_state_dict(torch.load(model_path))prediction = self.forward(x)plt.scatter(x.numpy(), y.numpy(), c=x.numpy())plt.plot(x.numpy(), prediction.detach().numpy(), color="r")plt.show()

该Python代码实现了一个简单的线性回归模型,并进行了训练和测试。

首先,导入了PyTorch、NumPy和Matplotlib.pyplot库。

接下来,定义了一个名为LinearRegression的类,它是一个继承自torch.nn.Module的类,因此可以利用PyTorch的自动求导和优化功能。在该类的初始化方法中,定义了一个torch.nn.Linear对象,它表示一个全连接层,输入大小为1,输出大小为1;并定义了一个torch.optim.SGD对象,它表示随机梯度下降法的优化器,学习率为0.01;以及一个torch.nn.MSELoss对象,它表示均方误差损失函数。

接下来,定义了一个名为forward的方法,它表示前向传递过程,即对输入进行线性变换,得到输出。

然后,定义了一个名为train的方法,它接受一个数据字典和一个模型保存路径作为输入。该方法首先从数据字典中获取输入数据x和输出数据y,然后进行10000次迭代训练。在每次迭代中,先将输入数据x送入模型中得到预测输出prediction,然后计算预测输出和真实输出之间的均方误差损失loss,并进行反向传播和参数优化。每100次迭代打印一次损失值。最后将模型参数保存到指定的文件路径中。

最后,定义了一个名为test的方法,它接受一个输入数据x和一个模型保存路径作为输入。该方法首先从文件中加载训练好的模型参数,然后将输入数据x送入模型中得到预测输出prediction,并将预测输出和真实输出以及输入数据可视化展示出来。

总之,这段代码实现了一个简单的线性回归模型,并可以通过train方法进行训练,通过test方法进行测试和可视化展示。

2. 创建数据集

def create_linear_data(nums_data, if_plot=False):x = torch.linspace(0, 1, nums_data)x = torch.unsqueeze(x, dim = 1)k = 2y = k * x + torch.rand(x.size())if if_plot:plt.scatter(x.numpy(), y.numpy(), c=x.numpy())plt.show()data = {"x":x, "y":y}return data
data = create_linear_data(300, if_plot=True)

1

3. 训练模型

model = LinearRegression()
model.train(data)
   epoch:0, loss is:3.8653182983398438epoch:100, loss is:0.31251025199890137epoch:200, loss is:0.2438090741634369epoch:300, loss is:0.20671892166137695epoch:400, loss is:0.17835141718387604epoch:500, loss is:0.15658551454544067epoch:600, loss is:0.13988454639911652epoch:700, loss is:0.12706983089447021epoch:800, loss is:0.11723710596561432epoch:900, loss is:0.10969242453575134epoch:1000, loss is:0.10390334576368332epoch:1100, loss is:0.09946136921644211epoch:1200, loss is:0.09605306386947632epoch:1300, loss is:0.09343785047531128epoch:1400, loss is:0.09143117070198059epoch:1500, loss is:0.0898914709687233epoch:1600, loss is:0.08871004730463028epoch:1700, loss is:0.08780352771282196epoch:1800, loss is:0.08710794895887375epoch:1900, loss is:0.08657423406839371epoch:2000, loss is:0.08616471290588379epoch:2100, loss is:0.08585048466920853epoch:2200, loss is:0.08560937643051147epoch:2300, loss is:0.08542437106370926epoch:2400, loss is:0.08528240770101547epoch:2500, loss is:0.08517350256443024epoch:2600, loss is:0.08508992940187454epoch:2700, loss is:0.08502580225467682epoch:2800, loss is:0.08497659116983414epoch:2900, loss is:0.08493883907794952epoch:3000, loss is:0.08490986377000809epoch:3100, loss is:0.08488764613866806epoch:3200, loss is:0.08487057685852051epoch:3300, loss is:0.08485749363899231epoch:3400, loss is:0.08484745025634766epoch:3500, loss is:0.08483975380659103epoch:3600, loss is:0.08483383059501648epoch:3700, loss is:0.08482930809259415epoch:3800, loss is:0.08482582122087479epoch:3900, loss is:0.08482315391302109epoch:4000, loss is:0.08482109755277634epoch:4100, loss is:0.08481952548027039epoch:4200, loss is:0.08481831848621368epoch:4300, loss is:0.08481740206480026epoch:4400, loss is:0.08481667935848236epoch:4500, loss is:0.08481614291667938epoch:4600, loss is:0.08481571823358536epoch:4700, loss is:0.08481539785861969epoch:4800, loss is:0.08481515198945999epoch:4900, loss is:0.08481497317552567epoch:5000, loss is:0.08481481671333313epoch:5100, loss is:0.08481471240520477epoch:5200, loss is:0.08481462299823761epoch:5300, loss is:0.08481455594301224epoch:5400, loss is:0.08481451123952866epoch:5500, loss is:0.08481448143720627epoch:5600, loss is:0.08481443673372269epoch:5700, loss is:0.08481442183256149epoch:5800, loss is:0.0848143994808197epoch:5900, loss is:0.0848143920302391epoch:6000, loss is:0.08481437712907791epoch:6100, loss is:0.08481436222791672epoch:6200, loss is:0.08481435477733612epoch:6300, loss is:0.08481435477733612epoch:6400, loss is:0.08481435477733612epoch:6500, loss is:0.08481435477733612epoch:6600, loss is:0.08481435477733612epoch:6700, loss is:0.08481435477733612epoch:6800, loss is:0.08481434732675552epoch:6900, loss is:0.08481435477733612epoch:7000, loss is:0.08481433987617493epoch:7100, loss is:0.08481435477733612epoch:7200, loss is:0.08481433987617493epoch:7300, loss is:0.08481433987617493epoch:7400, loss is:0.08481434732675552epoch:7500, loss is:0.08481434732675552epoch:7600, loss is:0.08481434732675552epoch:7700, loss is:0.08481434732675552epoch:7800, loss is:0.08481434732675552epoch:7900, loss is:0.08481434732675552epoch:8000, loss is:0.08481434732675552epoch:8100, loss is:0.08481434732675552epoch:8200, loss is:0.08481434732675552epoch:8300, loss is:0.08481434732675552epoch:8400, loss is:0.08481434732675552epoch:8500, loss is:0.08481434732675552epoch:8600, loss is:0.08481434732675552epoch:8700, loss is:0.08481434732675552epoch:8800, loss is:0.08481434732675552epoch:8900, loss is:0.08481434732675552epoch:9000, loss is:0.08481434732675552epoch:9100, loss is:0.08481434732675552epoch:9200, loss is:0.08481434732675552epoch:9300, loss is:0.08481434732675552epoch:9400, loss is:0.08481434732675552epoch:9500, loss is:0.08481434732675552epoch:9600, loss is:0.08481434732675552epoch:9700, loss is:0.08481434732675552epoch:9800, loss is:0.08481434732675552epoch:9900, loss is:0.08481434732675552
model.test(data)

4. 测试模型

2

附:系列文章

序号文章目录直达链接
1波士顿房价预测https://want595.blog.csdn.net/article/details/132181950
2鸢尾花数据集分析https://want595.blog.csdn.net/article/details/132182057
3特征处理https://want595.blog.csdn.net/article/details/132182165
4交叉验证https://want595.blog.csdn.net/article/details/132182238
5构造神经网络示例https://want595.blog.csdn.net/article/details/132182341
6使用TensorFlow完成线性回归https://want595.blog.csdn.net/article/details/132182417
7使用TensorFlow完成逻辑回归https://want595.blog.csdn.net/article/details/132182496
8TensorBoard案例https://want595.blog.csdn.net/article/details/132182584
9使用Keras完成线性回归https://want595.blog.csdn.net/article/details/132182723
10使用Keras完成逻辑回归https://want595.blog.csdn.net/article/details/132182795
11使用Keras预训练模型完成猫狗识别https://want595.blog.csdn.net/article/details/132243928
12使用PyTorch训练模型https://want595.blog.csdn.net/article/details/132243989
13使用Dropout抑制过拟合https://want595.blog.csdn.net/article/details/132244111
14使用CNN完成MNIST手写体识别(TensorFlow)https://want595.blog.csdn.net/article/details/132244499
15使用CNN完成MNIST手写体识别(Keras)https://want595.blog.csdn.net/article/details/132244552
16使用CNN完成MNIST手写体识别(PyTorch)https://want595.blog.csdn.net/article/details/132244641
17使用GAN生成手写数字样本https://want595.blog.csdn.net/article/details/132244764
18自然语言处理https://want595.blog.csdn.net/article/details/132276591

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/84846.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux 线程(thread)

进程线程区别 创建线程 #include <pthread.h> int pthread_create(pthread_t *thread, const pthread_attr_t *attr, void *(*start_routine) (void *), void *arg); -功能&#xff1a;创建一个子线程&#xff0c;一般情况下main函数所在的线程称为主线程&#xff0c;…

21天学会C++:Day14----模板

CSDN的uu们&#xff0c;大家好。这里是C入门的第十四讲。 座右铭&#xff1a;前路坎坷&#xff0c;披荆斩棘&#xff0c;扶摇直上。 博客主页&#xff1a; 姬如祎 收录专栏&#xff1a;C专题 目录 1. 知识引入 2. 模板的使用 2.1 函数模板 2.2 类模板 3. 模板声明和定义…

kubernetes(k8s)PVC

概念 PVC 的全称是&#xff1a;PersistentVolumeClaim&#xff08;持久化卷声明&#xff09;&#xff0c;PVC 是用户存储的一种声明&#xff0c;PVC 和 Pod 比较类似&#xff0c;Pod 消耗的是节点&#xff0c;PVC 消耗的是 PV 资源&#xff0c;Pod 可以请求 CPU 和内存&#x…

基于Kubernetes的Serverless PaaS稳定性建设万字总结

作者&#xff1a;许成铭&#xff08;竞霄&#xff09; 数字经济的今天&#xff0c;云计算俨然已经作为基础设施融入到人们的日常生活中&#xff0c;稳定性作为云产品的基本要求&#xff0c;研发人员的技术底线&#xff0c;其不仅仅是文档里承诺的几个九的 SLA 数字&#xff0c…

AI助手引领游戏创作革命

近期&#xff0c;Roblox 在其开发者大会&#xff08;RDC&#xff09;上宣布了一款新的对话式 AI 助手&#xff1a;RobloxAssistant。这款助手的本质是简化游戏制作难度&#xff0c;用自然语言代替编程。通过输入文字提示&#xff0c;创作者可以生成游戏场景、3D 模型等操作。该…

自动驾驶中的决策规划

参考: 【干货篇】轻舟智航&#xff1a;自动驾驶中的决策规划技术&#xff08;附视频回放 PPT 下载&#xff09; - AIQ 如图所示, 各模块介绍 定位模块主要负责解答的问题是“车现在在哪里”&#xff0c;是在道路上还是在路口&#xff0c;是在高架桥上还是在停车场里。 感知…

python随手小练

题目&#xff1a; 使用python做一个简单的英雄联盟商城登录界面 具体操作&#xff1a; print("英雄联盟商城登录界面") print("~ * "*15 "~") #找其规律 a "1、用户登录" b "2、新用户注册" c "3、退出系统&quo…

jq弹窗拖动改变宽高

预览效果 <div classtishiMask><div class"tishiEm"><div id"coor"></div><div class"topNew ismove"><span class"ismove">提示</span><p onclick"closeTishi()"></p&…

计算机组成原理——基础入门总结(二)

上一期的路径&#xff1a;基础入门总结&#xff08;一&#xff09; 目录 一.输入输出系统和IO控制方式 二.存储系统的基本概念 三.cache的基本概念和原理 四.CPU的功能和基本结构 五.总线概述 一.输入输出系统和IO控制方式 IO设备又可以被统一称为外部设备~ IO接口&…

Python 根据身高体重计算体质(BMI)指数

""" 根据身高体重计算体质(BMI)指数知识点&#xff1a;1、计算公式&#xff1a;体质指数(BMI) 体重(KG) / (身高(M) * 身高(M))2、变量类型转换3、运算符幂**&#xff0c;例如&#xff1a;3 ** 2 9 <> 3 * 34、更多的运用请参考&#xff1a;https://blo…

【2023全网最全最火】Selenium WebDriver教程(建议收藏)

在本教程中&#xff0c;我将向您介绍 Selenium Webdriver&#xff0c;它是当今市场上使用最广泛的自动化测试框架。它是开源的&#xff0c;可与所有著名的编程语言&#xff08;如Java、Python、C&#xff03;、Ruby、Perl等&#xff09;一起使用&#xff0c;以实现浏览器活动的…

【Hierarchical Coverage Path Planning in Complex 3D Environments】

Hierarchical Coverage Path Planning in Complex 3D Environments 复杂三维环境下的分层覆盖路径规划 视点采样全局TSP 算法分两层&#xff0c;一层高级一层低级&#xff1a; 高层算法将环境分离多个子空间&#xff0c;如果给定体积中有大量的结构&#xff0c;则空间会进一步细…

为什么要选择Spring cloud Sentinel

为什么要选择Spring cloud Sentinel &#x1f34e;对比Hystrix&#x1f342;雪崩问题及解决方案&#x1f342;雪崩问题&#x1f342;.超时处理&#x1f342;仓壁模式&#x1f342;断路器&#x1f342;限流&#x1f342;总结 &#x1f34e;对比Hystrix 在SpringCloud当中支持多…

美创科技参编《数字政府建设与发展研究报告(2023)》 正式发布

9月14日&#xff0c;中国信息通信研究院云计算与大数据研究所牵头编制的《数字政府建设与发展研究报告&#xff08;2023&#xff09;》正式发布。 美创科技结合在政务数据安全领域的丰富实践经验&#xff0c;参与报告编写。 《数字政府建设与发展研究报告》 以“技术、业务、数…

ARM 汇编指令作业(求公约数、for循环实现1-100之间和、从SVC模式切换到user模式简单写法)

1、求两个数最大公约数 .text .globl _start_start:mov r0, #9mov r1, #15 Loop: 循环cmp r0,r1 比较r0和r1的大小beq stop 当r0和r1相等时&#xff0c;跳到stop标签subhi r0,r0,r1 r0-r1>0 时&#xff0c;证明r0>r1,将r0-r1的值赋给r0&…

近年来国内室内定位领域硕士论文选题的现状与趋势

目录 一、前言 二、选题的目的和意义 三、选题现状分析 四、选题趋势分析 一、前言 本博文采用了图表统计法分析了近5年来100余篇高被引室内定位领域硕士论文选题的现状&#xff0c;并从选题现状中得出了该领域选题的大致趋势。本文还通过分析该领域硕士毕业论文选题的现…

数字孪生技术如何提升工厂生产效率?

数字孪生技术是一项引领工业界数字化转型的创新力量。随着工业4.0时代的到来&#xff0c;制造业正经历着巨大的变革&#xff0c;数字孪生技术在这个变革中发挥了关键作用。它不仅仅是一种技术&#xff0c;更是一种理念&#xff0c;将现实世界与数字世界相结合&#xff0c;为工厂…

C++真的是 C加加

&#x1f4dd;个人主页&#xff1a;夏目浅石. &#x1f4cc;博客专栏&#xff1a;C的故事 &#x1f3e0;学习社区&#xff1a;夏目友人帐. 文章目录 前言Ⅰ. 函数重载0x00 重载规则0x01 函数重载的原理名字修饰 Ⅱ. 引用0x00 引用的概念0x01 引用和指针区分0x03 引用的本质0x04…

U盘有病毒插上电脑会感染吗?了解下U盘的病毒传播机制

U盘作为一种常见的移动存储设备&#xff0c;我们会经常使用它来传输和存储重要的文件。然而&#xff0c;有时可能会遇到文件被当作病毒误删除的情况&#xff0c;这给我们带来了不便和焦虑。好在&#xff0c;这里将向您介绍一些简单而有效的方法&#xff0c;帮助您恢复被误删除的…

vite自定义打包路径

修改vite.config.js 增加: build: { outDir:‘…/out’ }, base: ‘./’, 例子: // https://vitejs.dev/config/ export default defineConfig({plugins: [vue(),WindiCSS()],build: {outDir:../out},base: ./,server: {host:0.0.0.0,// https:{// cert: fs.readFi…