PyTorch-线性回归

已经进入大模微调的时代,但是学习pytorch,对后续学习rasa框架有一定帮助吧。

<!--  给出一系列的点作为线性回归的数据,使用numpy来存储这些点。 -->
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],[9.779], [6.182], [7.59], [2.167], [7.042],[10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],[3.366], [2.596], [2.53], [1.221], [2.827],[3.465], [1.65], [2.904], [1.3]], dtype=np.float32)<!--  转化tensor格式。 -->
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)<!--  这里的nn.Linear表示的是 y=w*x b,里面的两个参数都是1,表示的是x是1维,y也是1维。当然这里是可以根据你想要的输入输出维度来更改的。 -->
class linearRegression(nn.Module):def __init__(self):super(linearRegression, self).__init__()self.linear = nn.Linear(1, 1)  # input and output is 1 dimensiondef forward(self, x):out = self.linear(x)return out
model = linearRegression()<!-- 定义loss和优化函数,这里使用的是最小二乘loss,之后我们做分类问题更多的使用的是cross entropy loss,交叉熵。优化函数使用的是随机梯度下降,注意需要将model的参数model.parameters()传进去让这个函数知道他要优化的参数是那些。 -->
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)<!-- 开始训练 -->
num_epochs = 1000
for epoch in range(num_epochs):inputs = Variable(x_train)target = Variable(y_train)# forwardout = model(inputs) # 前向传播loss = criterion(out, target) # 计算loss# backwardoptimizer.zero_grad() # 梯度归零loss.backward() # 反向传播optimizer.step() # 更新参数if (epoch 1) % 20 == 0:print(f'Epoch[{epoch+1}/{num_epochs}], loss: {loss.item():.6f}')<!--训练完成之后我们就可以开始测试模型了-->
model.eval()
predict = model(Variable(x_train))
predict = predict.data.numpy()<!-- 显示图例 -->
fig = plt.figure(figsize=(10, 5))
plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')
plt.plot(x_train.numpy(), predict, label='Fitting Line')plt.legend() 
plt.show()<!-- 保存模型 -->
torch.save(model.state_dict(), './linear.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/686447.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

win32汇编获取系统信息

.data fmt db "页尺寸&#xff1a;%d",0 db "" lpsystem SYSTEM_INFO <?> szbuf db 200 dup(0) .const szCaption db 系统信息,0 .code start: invoke GetSystemInfo,addr lpsystem …

Java编程在工资信息管理中的最佳实践

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…

用Java实现简单的图书管理系统

目录 1.总体框架 2.book包 Books类 booklist类 3.operation包 IO接口&#xff1a; addbooks类&#xff1a; borrowbooks类&#xff1a; delbooks类&#xff1a; returnbooks类&#xff1a; exit类&#xff1a; 4.user包 user类 Adminuser类&#xff08;难点&#…

嵌入式linux驱动开发篇之设备树

什么是设备树&#xff1f; 设备树&#xff08;Device Tree&#xff09;是一种用于描述嵌入式系统硬件组件及其连接关系的数据结构。它被广泛用于嵌入式 Linux 系统&#xff0c;尤其是针对使用多种不同架构和平台的嵌入式系统。它是一种与硬件描述相关的中间表示形式&#xff0c…

如何生成狗血短剧

如何生成狗血短剧 狗血短剧剧本将上述剧本转成对话 狗血短剧剧本 标题&#xff1a;《爱的轮回》 类型&#xff1a;现代都市爱情短剧 角色&#xff1a; 1. 林晓雪 - 女&#xff0c;25岁&#xff0c;职场小白&#xff0c;善良单纯 2. 陆子轩 - 男&#xff0c;28岁&#xff0c;公…

WINCC如何新增下单菜单,切换显示页面

杭州工控赖工 首先我们先看一下&#xff0c;显示的效果&#xff0c;通过下拉菜单&#xff0c;切换主显示页面。如图一&#xff1a; 图1 显示效果 第一步&#xff1a; 通过元件新增一个组合框&#xff0c;见图2&#xff1b; 组合框的设置&#xff0c;设置下拉框的长宽及组合数…

Rust 数据结构与算法:1算法分析之乱序字符串检查

Rust 数据结构与算法 一、算法分析 算法是通用的旨在解决某种问题的指令列表。 算法分析是基于算法使用的资源量来进行比较的。之所以说一个算法比另一个算法好,原因就在于前者在使用资源方面更有效率,或者说前者使用了更少的资源。 ●算法使用的空间指的是内存消耗。算法…

基于springboot智慧外贸平台源码和论文

网络的广泛应用给生活带来了十分的便利。所以把智慧外贸管理与现在网络相结合&#xff0c;利用java技术建设智慧外贸平台&#xff0c;实现智慧外贸的信息化。则对于进一步提高智慧外贸管理发展&#xff0c;丰富智慧外贸管理经验能起到不少的促进作用。 智慧外贸平台能够通过互…

神经网络算法原理

目录 得分函数 数学表示 计算方法 损失函数 ​编辑 前向传播 反向传播 ​编辑 整体架构 正则化的作用 数据预处理 ​过拟合解决方法 得分函数 得分函数是在机器学习和自然语言处理中常用的一种函数&#xff0c;用于评估模型对输入数据的预测结果的准确性或匹配程度。…

【Python---六大数据结构】

&#x1f680; 作者 &#xff1a;“码上有前” &#x1f680; 文章简介 &#xff1a;Python &#x1f680; 欢迎小伙伴们 点赞&#x1f44d;、收藏⭐、留言&#x1f4ac; Python---六大数据结构 往期内容前言概述一下可变与不可变 Number四种不同的数值类型Number类型的创建i…

2024年【天津市安全员B证】新版试题及天津市安全员B证复审考试

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 天津市安全员B证新版试题参考答案及天津市安全员B证考试试题解析是安全生产模拟考试一点通题库老师及天津市安全员B证操作证已考过的学员汇总&#xff0c;相对有效帮助天津市安全员B证复审考试学员顺利通过考试。 1、…

人工智能学习与实训笔记(七):神经网络之模型压缩与知识蒸馏

人工智能专栏文章汇总&#xff1a;人工智能学习专栏文章汇总-CSDN博客 本篇目录 七、模型压缩与知识蒸馏 7.1 模型压缩 7.2 知识蒸馏 7.2.1 知识蒸馏的原理 7.2.2 知识蒸馏的种类 7.2.3 知识蒸馏的作用 七、模型压缩与知识蒸馏 出于对响应速度&#xff0c;存储大小和能…

(07)Hive——窗口函数详解

一、 窗口函数知识点 1.1 窗户函数的定义 窗口函数可以拆分为【窗口函数】。窗口函数官网指路&#xff1a; LanguageManual WindowingAndAnalytics - Apache Hive - Apache Software Foundationhttps://cwiki.apache.org/confluence/display/Hive/LanguageManual%20Windowing…

【Redis实战】有MQ为啥不用?用Redis作消息队列!?Redis作消息队列使用方法及底层原理高级进阶

&#x1f389;&#x1f389;欢迎光临&#x1f389;&#x1f389; &#x1f3c5;我是苏泽&#xff0c;一位对技术充满热情的探索者和分享者。&#x1f680;&#x1f680; &#x1f31f;特别推荐给大家我的最新专栏《Redis实战与进阶》 本专栏纯属为爱发电永久免费&#xff01;&a…

致敬新春“不回家”的厨师,李锦记让厨师的年味更有滋味

“新春饭市万家团圆&#xff0c;致敬千万坚守岗位的厨师” 新春团圆饭向来是餐饮行业最为关注的节点&#xff0c;但过去几年&#xff0c;在疫情与后疫情时期&#xff0c;新年团圆饭市不免冷清。而今年餐饮行业真正迎来“龙抬头”&#xff0c;龙年除夕夜的团圆饭市终于重迎来了…

腾讯云4核8G服务器能支持多少人访问?

腾讯云4核8G服务器支持多少人在线访问&#xff1f;支持25人同时访问。实际上程序效率不同支持人数在线人数不同&#xff0c;公网带宽也是影响4核8G服务器并发数的一大因素&#xff0c;假设公网带宽太小&#xff0c;流量直接卡在入口&#xff0c;4核8G配置的CPU内存也会造成计算…

挑战杯 Yolov安全帽佩戴检测 危险区域进入检测 - 深度学习 opencv

1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; Yolov安全帽佩戴检测 危险区域进入检测 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;3分工作量&#xff1a;3分创新点&#xff1a;4分 该项目较为新颖&am…

Quantitative Analysis: PIM Chip Demands for LLAMA-7B inference

1 Architecture 如果将LLAMA-7B模型参数量化为4bit&#xff0c;则存储模型参数需要3.3GB。那么&#xff0c;至少PIM chip 的存储至少要4GB。 AiM单个bank为32MB&#xff0c;单个die 512MB&#xff0c;至少需要8个die的芯片。8个die集成在一个芯片上。 提供816bank级别的访存带…

C++入门学习(二十九)goto语句

在C中&#xff0c;goto语句是一种控制流语句&#xff0c;用于无条件地转移到程序中指定的行。goto语句的使用通常是不推荐的&#xff0c;因为它可能导致代码结构变得混乱、不易理解和维护。然而&#xff0c;在某些特殊情况下&#xff0c;goto语句可能是一种有效的解决方法。 示…

php switch、for、foreach、while、do...while

php switch 1. switch2. for循环3. foreach4. while、do...while 1. switch <?php$height 190;switch ($height) {case 160:echo 太矮了;break; //跳出本次循环case 170:echo 还行吧;break; //跳出本次循环case 180:echo 帅哥;break; //跳出本次循环default:echo 迷; }2.…