PyTorch-线性回归

已经进入大模微调的时代,但是学习pytorch,对后续学习rasa框架有一定帮助吧。

<!--  给出一系列的点作为线性回归的数据,使用numpy来存储这些点。 -->
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],[9.779], [6.182], [7.59], [2.167], [7.042],[10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],[3.366], [2.596], [2.53], [1.221], [2.827],[3.465], [1.65], [2.904], [1.3]], dtype=np.float32)<!--  转化tensor格式。 -->
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)<!--  这里的nn.Linear表示的是 y=w*x b,里面的两个参数都是1,表示的是x是1维,y也是1维。当然这里是可以根据你想要的输入输出维度来更改的。 -->
class linearRegression(nn.Module):def __init__(self):super(linearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # input and output is 1 dimensiondef forward(self, x):out = self.linear(x)return out
model = linearRegression()<!-- 定义loss和优化函数,这里使用的是最小二乘loss,之后我们做分类问题更多的使用的是cross entropy loss,交叉熵。优化函数使用的是随机梯度下降,注意需要将model的参数model.parameters()传进去让这个函数知道他要优化的参数是那些。 -->
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)<!-- 开始训练 -->
num_epochs = 1000
for epoch in range(num_epochs):inputs = Variable(x_train)target = Variable(y_train)# forwardout = model(inputs) # 前向传播loss = criterion(out, target) # 计算loss# backwardoptimizer.zero_grad() # 梯度归零loss.backward() # 反向传播optimizer.step() # 更新参数if (epoch 1) % 20 == 0:print(f'Epoch[{epoch+1}/{num_epochs}], loss: {loss.item():.6f}')<!--训练完成之后我们就可以开始测试模型了-->
model.eval()
predict = model(Variable(x_train))
predict = predict.data.numpy()<!-- 显示图例 -->
fig = plt.figure(figsize=(10, 5))
plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')
plt.plot(x_train.numpy(), predict, label='Fitting Line')plt.legend() 
plt.show()<!-- 保存模型 -->
torch.save(model.state_dict(), './linear.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/686447.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【VTKExamples::PolyData】第二十九期 LoopBooleanPolyDataFilter

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 前言 本文分享VTK样例LoopBooleanPolyDataFilter,并解析接口vtkLoopBooleanPolyDataFilter,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^…

win32汇编获取系统信息

.data fmt db "页尺寸&#xff1a;%d",0 db "" lpsystem SYSTEM_INFO <?> szbuf db 200 dup(0) .const szCaption db 系统信息,0 .code start: invoke GetSystemInfo,addr lpsystem …

Java编程在工资信息管理中的最佳实践

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…

困于环中的机器人

1041. 困于环中的机器人 在无限的平面上&#xff0c;机器人最初位于 (0, 0) 处&#xff0c;面朝北方。注意: 北方向 是y轴的正方向。南方向 是y轴的负方向。东方向 是x轴的正方向。西方向 是x轴的负方向。 机器人可以接受下列三条指令之一&#xff1a; "G"&#…

用Java实现简单的图书管理系统

目录 1.总体框架 2.book包 Books类 booklist类 3.operation包 IO接口&#xff1a; addbooks类&#xff1a; borrowbooks类&#xff1a; delbooks类&#xff1a; returnbooks类&#xff1a; exit类&#xff1a; 4.user包 user类 Adminuser类&#xff08;难点&#…

Pytorch的安装教程,解决jupyter不能使用pytorch的问题

一.Pytorch的安装教程&#xff1a;PyTorch深度学习快速入门教程&#xff08;绝对通俗易懂&#xff01;&#xff09;【小土堆】_哔哩哔哩_bilibili 在anaconda prompt 提示符输入以下语句&#xff1a; 激活pytorch环境&#xff1a;conda activate pytorch查看pytorch环境下安装了…

嵌入式linux驱动开发篇之设备树

什么是设备树&#xff1f; 设备树&#xff08;Device Tree&#xff09;是一种用于描述嵌入式系统硬件组件及其连接关系的数据结构。它被广泛用于嵌入式 Linux 系统&#xff0c;尤其是针对使用多种不同架构和平台的嵌入式系统。它是一种与硬件描述相关的中间表示形式&#xff0c…

嵌入式培训机构四个月实训课程笔记(完整版)-Linux ARM驱动编程第五天-ARM Linux编程之设备节点 (物联技术666)

链接&#xff1a;https://pan.baidu.com/s/1hOBKyRom-4EZMBpFn1H9kQ?pwd1688 提取码&#xff1a;1688 Linux设备节点 设备管理是linux中比较基础的东西&#xff0c;但是由于Linux智能程度的越来越高&#xff0c;Udev的使用越来越广泛&#xff0c;使得越来越多的Linux新用户对…

vivado Convergent Rounding (LSB CorrectionTechnique)

DSP块基元利用模式检测电路来计算收敛舍入&#xff08;要么为偶数&#xff0c;要么为奇数&#xff09;。以下是收敛舍入推理的示例&#xff0c;它在块满时进行推理并且还推断出2输入and门&#xff08;1 LUT&#xff09;以实现LSB校正。 Rounding to Even (Verilog) Filename: …

如何生成狗血短剧

如何生成狗血短剧 狗血短剧剧本将上述剧本转成对话 狗血短剧剧本 标题&#xff1a;《爱的轮回》 类型&#xff1a;现代都市爱情短剧 角色&#xff1a; 1. 林晓雪 - 女&#xff0c;25岁&#xff0c;职场小白&#xff0c;善良单纯 2. 陆子轩 - 男&#xff0c;28岁&#xff0c;公…

WINCC如何新增下单菜单,切换显示页面

杭州工控赖工 首先我们先看一下&#xff0c;显示的效果&#xff0c;通过下拉菜单&#xff0c;切换主显示页面。如图一&#xff1a; 图1 显示效果 第一步&#xff1a; 通过元件新增一个组合框&#xff0c;见图2&#xff1b; 组合框的设置&#xff0c;设置下拉框的长宽及组合数…

GPT4的平替llama2本地部署教程,打造自己的专属大模型

llama2 是Meta公司发布的大预言模型&#xff0c;而且是一款开源免费的AI模型。光开源这个格局就吊打了GPT。从性能上来说更是号称是GPT4的平替。 今天这篇文章会从以下几个方面介绍下llama2&#xff1a; 1 基本介绍 2 本地mac环境部署llama2 llama2官方网址 https://llama…

Rust 数据结构与算法:1算法分析之乱序字符串检查

Rust 数据结构与算法 一、算法分析 算法是通用的旨在解决某种问题的指令列表。 算法分析是基于算法使用的资源量来进行比较的。之所以说一个算法比另一个算法好,原因就在于前者在使用资源方面更有效率,或者说前者使用了更少的资源。 ●算法使用的空间指的是内存消耗。算法…

JAVA语言程序设计 第12版 5.14题

求两个数之间的最大公约数 public class exer14 {public static void main(String[] args) {//方法 1Scanner inputnew Scanner(System.in);System.out.println("Enter first integer : ");int n1input.nextInt();System.out.println("Enter second integer: &…

JDK 17 新特性 (一)

既然 Springboot 3.0 强制使用 JDK 17 那就看看 JDK17 有哪些新特性吧 参考链接 介绍一下 新特性的历史渊源 JDK 17是Java Development Kit&#xff08;JDK&#xff09;的一个版本&#xff0c;它是Java编程语言的一种实现。JDK 17于2021年9月14日发布&#xff0c;并作为Java …

基于springboot智慧外贸平台源码和论文

网络的广泛应用给生活带来了十分的便利。所以把智慧外贸管理与现在网络相结合&#xff0c;利用java技术建设智慧外贸平台&#xff0c;实现智慧外贸的信息化。则对于进一步提高智慧外贸管理发展&#xff0c;丰富智慧外贸管理经验能起到不少的促进作用。 智慧外贸平台能够通过互…

js脚本的 defer 和 async 的区别

defer 和 async 都是用于控制 HTML 中 <script> 标签加载和执行 JavaScript 的属性&#xff0c;它们的作用有所不同&#xff1a; defer&#xff1a; 当浏览器遇到带有 defer 属性的 <script> 标签时&#xff0c;它会继续解析 HTML 页面&#xff0c;同时并行下载 de…

ddp是什么意思

DDP通常代表"Distributed Data Parallelism"&#xff0c;即分布式数据并行。它是一种用于训练深度学习模型的并行计算策略。在深度学习中&#xff0c;模型训练通常需要处理大量的数据和复杂的计算任务。DDP的目标是通过将数据和计算任务分布到多个计算设备&#xff0…

神经网络算法原理

目录 得分函数 数学表示 计算方法 损失函数 ​编辑 前向传播 反向传播 ​编辑 整体架构 正则化的作用 数据预处理 ​过拟合解决方法 得分函数 得分函数是在机器学习和自然语言处理中常用的一种函数&#xff0c;用于评估模型对输入数据的预测结果的准确性或匹配程度。…

【Python---六大数据结构】

&#x1f680; 作者 &#xff1a;“码上有前” &#x1f680; 文章简介 &#xff1a;Python &#x1f680; 欢迎小伙伴们 点赞&#x1f44d;、收藏⭐、留言&#x1f4ac; Python---六大数据结构 往期内容前言概述一下可变与不可变 Number四种不同的数值类型Number类型的创建i…