深度学习 - PyTorch基本流程 (代码)

直接上代码

import torch 
import matplotlib.pyplot as plt 
from torch import nn# 创建data
print("**** Create Data ****")
weight = 0.3
bias = 0.9
X = torch.arange(0,1,0.01).unsqueeze(dim = 1)
y = weight * X + bias
print(f"Number of X samples: {len(X)}")
print(f"Number of y samples: {len(y)}")
print(f"First 10 X & y sample: \n X: {X[:10]}\n y: {y[:10]}")
print("\n")# 将data拆分成training 和 testing
print("**** Splitting data ****")
train_split = int(len(X) * 0.8)
X_train = X[:train_split]
y_train = y[:train_split]
X_test = X[train_split:]
y_test = y[train_split:]
print(f"The length of X train: {len(X_train)}")
print(f"The length of y train: {len(y_train)}")
print(f"The length of X test: {len(X_test)}")
print(f"The length of y test: {len(y_test)}\n")# 显示 training 和 testing 数据
def plot_predictions(train_data = X_train,train_labels = y_train,test_data = X_test,test_labels = y_test,predictions = None):plt.figure(figsize = (10,7))plt.scatter(train_data, train_labels, c = 'b', s = 4, label = "Training data")plt.scatter(test_data, test_labels, c = 'g', label="Test data")if predictions is not None:plt.scatter(test_data, predictions, c = 'r', s = 4, label = "Predictions")plt.legend(prop = {"size": 14})
plot_predictions()# 创建线性回归
print("**** Create PyTorch linear regression model by subclassing nn.Module ****")
class LinearRegressionModel(nn.Module):def __init__(self):super().__init__()self.weight = nn.Parameter(data = torch.randn(1,requires_grad = True,dtype = torch.float))self.bias = nn.Parameter(data = torch.randn(1,requires_grad = True,dtype = torch.float))def forward(self, x):return self.weight * x + self.biastorch.manual_seed(42)
model_1 = LinearRegressionModel()
print(model_1)
print(model_1.state_dict())
print("\n")# 初始化模型并放到目标机里
print("*** Instantiate the model ***")
print(list(model_1.parameters()))
print("\\n")# 创建一个loss函数并优化
print("*** Create and Loss function and optimizer ***")
loss_fn = nn.L1Loss()
optimizer = torch.optim.SGD(params = model_1.parameters(),lr = 0.01)
print(f"loss_fn: {loss_fn}")
print(f"optimizer: {optimizer}\n")# 训练
print("*** Training Loop ***")
torch.manual_seed(42)
epochs = 300
for epoch in range(epochs):# 将模型加载到训练模型里model_1.train()# 做 Forwardy_pred = model_1(X_train)# 计算 Lossloss = loss_fn(y_pred, y_train)# 零梯度optimizer.zero_grad()# 反向传播loss.backward()# 步骤优化optimizer.step()### 做测试if epoch % 20 == 0:# 将模型放到评估模型并设置上下文model_1.eval()with torch.inference_mode():# 做 Forwardy_preds = model_1(X_test)# 计算测试 losstest_loss = loss_fn(y_preds, y_test)# 输出测试结果print(f"Epoch: {epoch} | Train loss: {loss:.3f} | Test loss: {test_loss:.3f}")# 在测试集上对训练模型做预测
print("\n")
print("*** Make predictions with the trained model on the test data. ***")
model_1.eval()
with torch.inference_mode():y_preds = model_1(X_test)
print(f"y_preds:\n {y_preds}")
## 画图
plot_predictions(predictions = y_preds) # 保存训练好的模型
print("\n")
print("*** Save the trained model ***")
from pathlib import Path 
## 创建模型的文件夹
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents = True, exist_ok = True)
## 创建模型的位置
MODEL_NAME = "trained model"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME 
## 保存模型到刚创建好的文件夹
print(f"Saving model to {MODEL_SAVE_PATH}")
torch.save(obj = model_1.state_dict(), f = MODEL_SAVE_PATH)
## 创建模型的新类型
loaded_model = LinearRegressionModel()
loaded_model.load_state_dict(torch.load(f = MODEL_SAVE_PATH))
## 做预测,并跟之前的做预测
y_preds_new = loaded_model(X_test)
print(y_preds == y_preds_new)

结果如下

**** Create Data ****
Number of X samples: 100
Number of y samples: 100
First 10 X & y sample: X: tensor([[0.0000],[0.0100],[0.0200],[0.0300],[0.0400],[0.0500],[0.0600],[0.0700],[0.0800],[0.0900]])y: tensor([[0.9000],[0.9030],[0.9060],[0.9090],[0.9120],[0.9150],[0.9180],[0.9210],[0.9240],[0.9270]])**** Splitting data ****
The length of X train: 80
The length of y train: 80
The length of X test: 20
The length of y test: 20**** Create PyTorch linear regression model by subclassing nn.Module ****
LinearRegressionModel()
OrderedDict([('weight', tensor([0.3367])), ('bias', tensor([0.1288]))])*** Instantiate the model ***
[Parameter containing:
tensor([0.3367], requires_grad=True), Parameter containing:
tensor([0.1288], requires_grad=True)]*** Create and Loss function and optimizer ***
loss_fn: L1Loss()
optimizer: SGD (
Parameter Group 0dampening: 0differentiable: Falseforeach: Nonelr: 0.01maximize: Falsemomentum: 0nesterov: Falseweight_decay: 0
)*** Training Loop ***
Epoch: 0 | Train loss: 0.757 | Test loss: 0.725
Epoch: 20 | Train loss: 0.525 | Test loss: 0.454
Epoch: 40 | Train loss: 0.294 | Test loss: 0.183
Epoch: 60 | Train loss: 0.077 | Test loss: 0.073
Epoch: 80 | Train loss: 0.053 | Test loss: 0.116
Epoch: 100 | Train loss: 0.046 | Test loss: 0.105
Epoch: 120 | Train loss: 0.039 | Test loss: 0.089
Epoch: 140 | Train loss: 0.032 | Test loss: 0.074
Epoch: 160 | Train loss: 0.025 | Test loss: 0.058
Epoch: 180 | Train loss: 0.018 | Test loss: 0.042
Epoch: 200 | Train loss: 0.011 | Test loss: 0.026
Epoch: 220 | Train loss: 0.004 | Test loss: 0.009
Epoch: 240 | Train loss: 0.004 | Test loss: 0.006
Epoch: 260 | Train loss: 0.004 | Test loss: 0.006
Epoch: 280 | Train loss: 0.004 | Test loss: 0.006*** Make predictions wit the trained model on the test data. ***
y_preds:tensor([[1.1464],[1.1495],[1.1525],[1.1556],[1.1587],[1.1617],[1.1648],[1.1679],[1.1709],[1.1740],[1.1771],[1.1801],[1.1832],[1.1863],[1.1893],[1.1924],[1.1955],[1.1985],[1.2016],[1.2047]])*** Save the trained model ***
Saving model to models/trained model
tensor([[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True]])

第一个结果图
第二个结果图

点个赞支持一下咯~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/778659.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯2016年第十三届省赛真题-生日蜡烛

一、题目 生日蜡烛. 某君从某年开始每年都举办一次生日party,并且每次都要吹熄与年龄相同根数的蜡烛。 现在算起来,他一共吹熄了236根蜡烛。 请问,他从多少岁开始过生日party的? 请填写他开始过生日party的年龄数。 注意&#xff…

python笔记(6)String(字符串)

目录 访问字符串中的值 Python字符串运算符 Python 字符串格式化 str.format() 数字格式化 多行注释 f-string Unicode 字符串 Python 的字符串内建函数 我们可以用单引号或者双引号"来创建字符串。 创建字符串很简单,给变量分配一个值即可例如 ahell…

基于资源的约束委派(下)

webclient http self relay Web 分布式创作和版本控制 (WebDAV) 是超文本传输协议 (HTTP) 的扩展,它定义了如何使用 HTTP ( docs.microsoft.com )执行复 制、移动、删除和创建等基本文件功能 需要启用 WebClient 服务才能使基于 WebDAV 的程序和功能正常工作。事实…

全国中学基础信息 API 数据接口

全国中学基础信息 API 数据接口 基础数据,高校高考,提供全国初级高级中学基础数据,定时更新,多维度筛选。 1. 产品功能 2024 年数据已更新;提供最新全国中学学校基本信息;包含全国初级中学与高等中学&…

Rust机器学习框架Candle

一、概述 Candle 是由知名开源组织 Hugging Face 开发的一个极简的机器学习框架。它专为 Rust 语言打造,致力于提供高性能和易用性的完美结合。Candle 的诞生为 Rust 生态在机器学习领域带来了新的选择,让 Rust 开发者能够更轻松地构建和部署机器学习应…

家庭琐事对工作效率的影响及应对策略

在快节奏的现代生活中,工作与家庭生活之间的界限日益模糊,人们往往难以将两者完全割裂开来。有时候,我们正在全身心投入工作时,却可能被突如其来的家庭琐事打扰,这不仅影响了心情,更会波及到工作效率和质量…

silk-v3-decoder将sil转为mp3

一、新建临时目录 新建临时目录,可自定义,本次新建目录为 /opt/packages mkdir /opt/packages二、下载、安装lame # cd /opt/packages# wget http://downloads.sourceforge.net/lame/lame-3.100.tar.gz# tar -zxvf lame-3.100.tar.gz# cd lame-3.100#…

git之目前的主流版本

官方文档 简介 我们都知道,在开发过程中,版本控制是至关重要的。Git作为目前最为流行的版本控制系统,已经成为了开发者们的标配。出于好奇,本人对git目前主流几大版本(GitLab、GitHub、Gitee 和 GitCode)…

虚拟现实(VR)项目的开发工具

虚拟现实(VR)项目的开发涉及到多种工具,这些工具可以帮助开发者从建模、编程到最终内容的发布。以下是一些被广泛认可的VR开发工具,它们覆盖了从3D建模到交互设计等多个方面。北京木奇移动技术有限公司,专业的软件外包…

取消svn关联脚本

写在前面,该脚本由朋友提供,来源与网络,侵删。 取消svn关联脚本 创建一个文件,后缀名为reg,将下面的脚本复制到文件里。 Windows Registry Editor Version 5.00 [HKEY_LOCAL_MACHINE\SOFTWARE\Classes\Folder\shel…

spring boot中使用spring cache

原因 项目原来越慢&#xff0c;为了提升效率加入spring cache 初步想法把数据库的压力减轻一点。 引入 pom 中加入&#xff1a; <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-cache</artifactId&g…

机器学习_集成学习_梯度提升_回归_决策树_XGBoost相关概念

目录 1. 机器学习 使用监督吗&#xff1f;什么又是监督学习&#xff1f; 2. 与XGBoost 类似的机器学习方法有哪些&#xff1f; 3. 随机森林方法 和 梯度提升方法 有什么区别&#xff1f; 分别应用于什么场景&#xff1f; 4. 决策树回归方法 和 Gradient Boosting类回归方法…

为什么我的微信小程序 窗口背景色backgroundColor设置参数 无效的问题处理记录!

当我们在微信小程序 json 中设置 backgroundColor 时&#xff0c;实际在电脑的模拟器中根本看不到效果。 这是因为 backgroundColor 指的窗体背景颜色&#xff0c;而不是页面的背景颜色&#xff0c;即窗体下拉刷新或上拉加载时露出的背景。在电脑的模拟器中是看不到这个动作的…

发挥ChatGPT潜力:高效撰写学术论文技巧

ChatGPT无限次数:点击直达 发挥ChatGPT潜力&#xff1a;高效撰写学术论文技巧 在当今信息爆炸的时代&#xff0c;如何高效撰写学术论文成为许多研究者关注的焦点。而随着人工智能技术的不断发展&#xff0c;如何利用ChatGPT这一先进的技术工具来提升论文写作效率&#xff0c;成…

Elasticsearch 面试题及参考答案:深入解析与实战应用

在大数据时代,Elasticsearch 以其强大的搜索能力和高效的数据处理性能,成为了数据架构师和开发者必备的技能之一。本文将为您提供一系列精选的 Elasticsearch 面试题及参考答案,帮助您在面试中脱颖而出,同时也为您的大数据架构设计提供实战参考。 目录 1. 为什么要使用 E…

Acwing_795前缀和 【一维前缀和】+【模板】二维前缀和

Acwing_795前缀和 【一维前缀和】 题目&#xff1a; 代码&#xff1a; #include <bits/stdc.h> #define int long long #define INF 0X3f3f3f3f #define endl \n using namespace std; const int N 100010; int arr[N];int n,m; int l,r; signed main(){std::ios::s…

Flink基于Hudi维表Join缺陷解析及解决方案

Hudi&#xff0c;这个近年来备受瞩目的数据存储解决方案&#xff0c;无疑是大数据领域的一颗耀眼新星。其凭借出色的性能和稳定性&#xff0c;以及对于数据湖场景的深度适配&#xff0c;赢得了众多企业和开发者的青睐。然而&#xff0c;正如任何一项新兴技术&#xff0c;Hudi在…

服务器不能DELETE和PUT

问题描述&#xff1a;前端VUE、后端JAVA&#xff0c;代码放在本地可以完美运行&#xff0c;放在服务器外网不能运行delete和put&#xff0c;get和post不能运行 经过摸索总结&#xff0c;在不改变原有RESTful的情况下&#xff0c;亲身实验&#xff0c;得到两种解决办法&#xff…

力扣爆刷第107天之CodeTop100五连刷21-25

力扣爆刷第107天之CodeTop100五连刷21-25 文章目录 力扣爆刷第107天之CodeTop100五连刷21-25一、103. 二叉树的锯齿形层序遍历二、92. 反转链表 II三、54. 螺旋矩阵四、160. 相交链表五、23. 合并 K 个升序链表 一、103. 二叉树的锯齿形层序遍历 题目链接&#xff1a;https://…

详解IOS的Automatically Sign在设备上打包

大家好我是咕噜美乐蒂&#xff0c;很高兴又和大家见面了&#xff01; "Automatically Sign" 是 Xcode 提供的一个功能&#xff0c;用于简化在设备上打包和签名应用的流程。通过使用 "Automatically Sign"&#xff0c;开发者可以在 Xcode 中轻松地进行应用…