深度学习之pytorch实现线性回归

度学习之pytorch实现线性回归

  • pytorch用到的函数
    • torch.nn.Linearn()函数
    • torch.nn.MSELoss()函数
    • torch.optim.SGD()
  • 代码实现
  • 结果分析

pytorch用到的函数

torch.nn.Linearn()函数

torch.nn.Linear(in_features, # 输入的神经元个数out_features, # 输出神经元个数bias=True # 是否包含偏置)

在这里插入图片描述

作用j进行线性变换
Linear(1, 1) : 表示一维输入,一维输出

torch.nn.MSELoss()函数

在这里插入图片描述

torch.optim.SGD()

优化器对象
在这里插入图片描述

代码实现

import torchx_data = torch.tensor([[1.0], [2.0], [3.0]])  # 将x_data设置为tensor类型数据
y_data = torch.tensor([[2.0], [4.0], [6.0]])class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()  # 继承父类self.linear = torch.nn.Linear(1, 1)# 用torch.nn.Linear来构造对象  (y = w * x + b)def forward(self, x):y_pred = self.linear(x) #调用之前的构造的对象(调用构造函数),计算 y = w * x + breturn y_predmodel = LinearModel()criterion = torch.nn.MSELoss(size_average=False)  # 定义损失函数,不求平均损失(为False)#优化器对象
# #model.parameters()会扫描module中的所有成员,如果成员中有相应权重,那么都会将结果加到要训练的参数集合上
# #类似权重的更新
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 定义梯度优化器为随机梯度下降for epoch in range(10000):  # 训练过程y_pred = model(x_data)  # 向前传播,求y_predloss = criterion(y_pred, y_data)  # 根据y_pred和y_data求损失print(epoch, loss)# 记住在backward之前要先梯度归零optimizer.zero_grad()  # 将优化器数值清零loss.backward()  # 反向传播,计算梯度optimizer.step()  # 根据梯度更新参数#打印权重和b
print("w = ", model.linear.weight.item())
print("b = ", model.linear.bias.item())#检测模型
x_test = torch.tensor([4.0])
y_test = model(x_test)
print('y_pred = ', y_test.data)  # 测试

结果分析

9961 tensor(4.0927e-12, grad_fn=)
9962 tensor(4.0927e-12, grad_fn=)
9963 tensor(4.0927e-12, grad_fn=)
9964 tensor(4.0927e-12, grad_fn=)
9965 tensor(4.0927e-12, grad_fn=)
9966 tensor(4.0927e-12, grad_fn=)
9967 tensor(4.0927e-12, grad_fn=)
9968 tensor(4.0927e-12, grad_fn=)
9969 tensor(4.0927e-12, grad_fn=)
9970 tensor(4.0927e-12, grad_fn=)
9971 tensor(4.0927e-12, grad_fn=)
9972 tensor(4.0927e-12, grad_fn=)
9973 tensor(4.0927e-12, grad_fn=)
9974 tensor(4.0927e-12, grad_fn=)
9975 tensor(4.0927e-12, grad_fn=)
9976 tensor(4.0927e-12, grad_fn=)
9977 tensor(4.0927e-12, grad_fn=)
9978 tensor(4.0927e-12, grad_fn=)
9979 tensor(4.0927e-12, grad_fn=)
9980 tensor(4.0927e-12, grad_fn=)
9981 tensor(4.0927e-12, grad_fn=)
9982 tensor(4.0927e-12, grad_fn=)
9983 tensor(4.0927e-12, grad_fn=)
9984 tensor(4.0927e-12, grad_fn=)
9985 tensor(4.0927e-12, grad_fn=)
9986 tensor(4.0927e-12, grad_fn=)
9987 tensor(4.0927e-12, grad_fn=)
9988 tensor(4.0927e-12, grad_fn=)
9989 tensor(4.0927e-12, grad_fn=)
9990 tensor(4.0927e-12, grad_fn=)
9991 tensor(4.0927e-12, grad_fn=)
9992 tensor(4.0927e-12, grad_fn=)
9993 tensor(4.0927e-12, grad_fn=)
9994 tensor(4.0927e-12, grad_fn=)
9995 tensor(4.0927e-12, grad_fn=)
9996 tensor(4.0927e-12, grad_fn=)
9997 tensor(4.0927e-12, grad_fn=)
9998 tensor(4.0927e-12, grad_fn=)
9999 tensor(4.0927e-12, grad_fn=)

w = 1.9999985694885254
b = 2.979139480885351e-06
y_pred = tensor([8.0000])

因为轮数过多,这里展示后面几轮
模型的准确性,跟轮数的多少有关系 ,如果轮数为100,最后测试结果的y_pred肯定不为8.00,这里轮数为10000,预测结果跟实际结果基本一样

这里是轮数为100,结果是 7点多,有一定误差
0 tensor(101.4680, grad_fn=)
1 tensor(45.8508, grad_fn=)
2 tensor(21.0819, grad_fn=)
3 tensor(10.0458, grad_fn=)
4 tensor(5.1234, grad_fn=)
5 tensor(2.9227, grad_fn=)
6 tensor(1.9338, grad_fn=)
7 tensor(1.4844, grad_fn=)
8 tensor(1.2754, grad_fn=)
9 tensor(1.1736, grad_fn=)
10 tensor(1.1195, grad_fn=)
11 tensor(1.0869, grad_fn=)
12 tensor(1.0639, grad_fn=)
13 tensor(1.0453, grad_fn=)
14 tensor(1.0288, grad_fn=)
15 tensor(1.0134, grad_fn=)
16 tensor(0.9985, grad_fn=)
17 tensor(0.9841, grad_fn=)
18 tensor(0.9699, grad_fn=)
19 tensor(0.9559, grad_fn=)
20 tensor(0.9421, grad_fn=)
21 tensor(0.9286, grad_fn=)
22 tensor(0.9153, grad_fn=)
23 tensor(0.9021, grad_fn=)
24 tensor(0.8891, grad_fn=)
25 tensor(0.8764, grad_fn=)
26 tensor(0.8638, grad_fn=)
27 tensor(0.8513, grad_fn=)
28 tensor(0.8391, grad_fn=)
29 tensor(0.8271, grad_fn=)
30 tensor(0.8152, grad_fn=)
31 tensor(0.8034, grad_fn=)
32 tensor(0.7919, grad_fn=)
33 tensor(0.7805, grad_fn=)
34 tensor(0.7693, grad_fn=)
35 tensor(0.7582, grad_fn=)
36 tensor(0.7474, grad_fn=)
37 tensor(0.7366, grad_fn=)
38 tensor(0.7260, grad_fn=)
39 tensor(0.7156, grad_fn=)
40 tensor(0.7053, grad_fn=)
41 tensor(0.6952, grad_fn=)
42 tensor(0.6852, grad_fn=)
43 tensor(0.6753, grad_fn=)
44 tensor(0.6656, grad_fn=)
45 tensor(0.6561, grad_fn=)
46 tensor(0.6466, grad_fn=)
47 tensor(0.6373, grad_fn=)
48 tensor(0.6282, grad_fn=)
49 tensor(0.6192, grad_fn=)
50 tensor(0.6103, grad_fn=)
51 tensor(0.6015, grad_fn=)
52 tensor(0.5928, grad_fn=)
53 tensor(0.5843, grad_fn=)
54 tensor(0.5759, grad_fn=)
55 tensor(0.5676, grad_fn=)
56 tensor(0.5595, grad_fn=)
57 tensor(0.5514, grad_fn=)
58 tensor(0.5435, grad_fn=)
59 tensor(0.5357, grad_fn=)
60 tensor(0.5280, grad_fn=)
61 tensor(0.5204, grad_fn=)
62 tensor(0.5129, grad_fn=)
63 tensor(0.5056, grad_fn=)
64 tensor(0.4983, grad_fn=)
65 tensor(0.4911, grad_fn=)
66 tensor(0.4841, grad_fn=)
67 tensor(0.4771, grad_fn=)
68 tensor(0.4703, grad_fn=)
69 tensor(0.4635, grad_fn=)
70 tensor(0.4569, grad_fn=)
71 tensor(0.4503, grad_fn=)
72 tensor(0.4438, grad_fn=)
73 tensor(0.4374, grad_fn=)
74 tensor(0.4311, grad_fn=)
75 tensor(0.4250, grad_fn=)
76 tensor(0.4188, grad_fn=)
77 tensor(0.4128, grad_fn=)
78 tensor(0.4069, grad_fn=)
79 tensor(0.4010, grad_fn=)
80 tensor(0.3953, grad_fn=)
81 tensor(0.3896, grad_fn=)
82 tensor(0.3840, grad_fn=)
83 tensor(0.3785, grad_fn=)
84 tensor(0.3730, grad_fn=)
85 tensor(0.3677, grad_fn=)
86 tensor(0.3624, grad_fn=)
87 tensor(0.3572, grad_fn=)
88 tensor(0.3521, grad_fn=)
89 tensor(0.3470, grad_fn=)
90 tensor(0.3420, grad_fn=)
91 tensor(0.3371, grad_fn=)
92 tensor(0.3322, grad_fn=)
93 tensor(0.3275, grad_fn=)
94 tensor(0.3228, grad_fn=)
95 tensor(0.3181, grad_fn=)
96 tensor(0.3136, grad_fn=)
97 tensor(0.3091, grad_fn=)
98 tensor(0.3046, grad_fn=)
99 tensor(0.3002, grad_fn=)
w = 1.6352288722991943
b = 0.8292105793952942
y_pred = tensor([7.3701])

Process finished with exit code 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/689138.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Jlink+OpenOCD+STM32 Vscode 下载和调试环境搭建

对于 Mingw 的安装比较困难,国内的网无法正常在线下载组件, 需要手动下载 x86_64-8.1.0-release-posix-seh-rt_v6-rev0.7z 版本的软件包,添加环境变量,并将 mingw32-make.exe 名字改成 make.exe。 对于 OpenOCD,需要…

中标麒麟桌面操作系统软件及安装方法

中标麒麟桌面操作系统软件及安装方法 安装力法 TeXworks编辑器 1、打开终端 输入“ sudo yum install texworks "命令 2、打开软件中心,搜索“TeX 编辑器”下载安装编程软件Anjuta集成开发环境 1、打开终端 输入“ sudo yum install anjuta ” 命令 2、打开软…

Python爬虫知识图谱

下面是一份详细的Python爬虫知识图谱,涵盖了从基础入门到进阶实战的各个环节,涉及网络请求、页面解析、数据提取、存储优化、反爬策略应对以及法律伦理等多个方面,并配以关键点解析和代码案例,以供读者深入学习和实践。 一、Pyth…

django中session值的数据类型是dict,需要手动save(),更新才会传递到其他页面。

django 项目在一个页面中删除了session中的某一个成员(del request.session["test"]["a"]),切换到另外一个页面的时候,session中的那个成员居然还在。让我一阵莫名其妙。 # 对session["test"]进行初…

mysql 2-18

加密与解密函数 其他函数 聚合函数 三者效率 GROUP BY HAVING WHERE和HAVING的区别 子查询 单行子查询和多行子查询 单行比较操作符 多行比较操作符 把平均工资生成的结果当成一个新表 相关子查询 EXISTS 一条数据的存储过程 标识符命名规则 创建数据库 MYSQL的数据类型 创建表…

目标检测一般性问题

Precision(查准率/精确率) 所有预测为正样本的结果中,预测正确的比率。 Precision TP / (TP FP)Recall (查全率/召回率) 所有正样本中被正确预测的比率。 Recall TP / (TP FN)正样本负样本预测为正True Positive(TP)False Positive(FP)预测为负False Negati…

利用ChatGPT进行数据分析——如何提出一个好的prompt

利用ChatGPT进行数据分析——如何提出一个好的prompt ​ 使用ChatGPT时,能否得到一个好的输出结果,关键在于能否提出好的prompt。 1.1 prompt是什么 ​ 现在大家把向ChatGPT输入的内容称作prompt(提示),它的作用是引…

Docker部署Java项目

打包 添加 <!-- 打包插件--> <build><plugins><plugin><groupId>org.springframework.boot</groupId><artifactId>spring-boot-maven-plugin</artifactId></plugin></plugins> </build> 通过执行以下命令进…

MySQL数据库基础(九):SQL约束

文章目录 SQL约束 一、主键约束 二、非空约束 三、唯一约束 四、默认值约束 五、外键约束&#xff08;了解&#xff09; 六、总结 SQL约束 一、主键约束 PRIMARY KEY 约束唯一标识数据库表中的每条记录。主键必须包含唯一的值。主键列不能包含 NULL 值。每个表都应该有…

String为什么是不可变的?

原因一、String字符串类型的数据结构 字符串在String类的内部是用一个char[]数组表示的,而这个数组使用final关键字修饰的&#xff0c;所以不能修改。 举例说明&#xff1a; String ip"127.0.0.1"; String retip.replace(".","#"); System.out.…

spring通过类名称获取名字

在Spring中&#xff0c;可以使用反射来根据类的全限定名获取其对应的Bean名称。 下面是示例代码&#xff1a; import org.springframework.beans.factory.BeanFactory; import org.springframework.context.support.ClassPathXmlApplicationContext;public class Main {publi…

BUGKU-WEB 留言板1

题目描述 题目截图如下&#xff1a; 进入场景看看&#xff1a; 解题思路 之间写过一题类似的&#xff0c;所以这题应该是有什么不同的那就按照之前的思路进行测试试试提示说&#xff1a;需要xss平台接收flag&#xff0c;这个和之前说的提示一样 相关工具 xss平台&#xf…

银河麒麟操作系统自动同步时间更新

1、银河麒麟操作系统基于Centos8的&#xff0c;因centos8取消了ntp服务器&#xff0c;所以导致之前使用ntpdate命令无法同步时间 2、centos默认使用chrony模块来进行同步时间 3、修改chrony配置同步时间服务器 vim /etc/chrony.conf 4、目前使用的是阿里云的时间服务器&…

Postgresql源码(122)Listen / Notify与事务的联动机制

前言 Notify和Listen是Postgresql提供的不同会话间异步消息通信功能&#xff0c;例子&#xff1a; LISTEN virtual; NOTIFY virtual; Asynchronous notification "virtual" received from server process with PID 8448. NOTIFY virtual, This is the payload; Asy…

Unity笔记:数据持久化的几种方式

正文 主要方法&#xff1a; ScriptableObjectPlayerPrefsJSONXML数据库&#xff08;如Sqlite&#xff09; 1. PlayerPerfs PlayerPrefs 存储的数据是全局共享的&#xff0c;它们存储在用户设备的本地存储中&#xff0c;并且可以被应用程序的所有部分访问。这意味着&#xf…

深入浅出熟悉OpenAI最新大作Sora文生视频大模型

蠢蠢欲动&#xff0c;惴惴不安&#xff0c;朋友们我又来了&#xff0c;这个春节真的过的是像过山车&#xff0c;Gemini1.5 PRO还没过劲&#xff0c;OpenAI又放大招&#xff0c;人类真的要认输了吗&#xff0c;让我忍不住想要再探究竟&#xff0c;到底是什么让文生视频发生了质的…

头歌C++语言之选择排序练习题

目录 第1关:第二统计数字 任务描述 相关知识 数组声明: 初始化数组: 访问数组元素 选择排序 编程要求 第2关:运动会排名 任务描述 相关知识 多维数组 访问二维数组 编程要求 第3关:单词排序 任务描述 相关知识 strcmp()函数 编程要求

流星蝴蝶剑之七夜听雪中文版下载

软件介绍&#xff1a; 中文名称: 流星蝴蝶剑七夜听雪 英文名称: Meteor 游戏类型: 3D武侠格斗 发行时间: 2002年08月 制作发行: 流星江湖悠悠客栈 语言 :中文 配置要求: 操作系统&#xff1a;Windows 95 / 98 / Me / 2000 / XP 最低配置 CPU&#xff1a;Pentium II 450MHz 以上…

记录 | git win C://User/Administrator/.ssh下没有id_rsa.pub找不到

在用 ssh-keygen -t rsa -C "xxx163.com”生成后&#xff0c;在 C://User/Administrator/.ssh 下找不到 id_rsa.pub 文件 在这个下面找找&#xff1a; C:\Users\Administrator\AppData\Roaming\SPB_Data\.ssh 或者直接看 ssh-keygen 生成的终端日志&#xff0c;上面有说…

单向/双向V2G环境下分布式电源与电动汽车充电站联合配置方法(matlab代码)

目录 1 主要内容 目标函数 电动汽车负荷建模 算例系统图 程序亮点 2 部分代码 3 程序结果 4 下载链接 1 主要内容 该程序复现博士文章《互动环境下分布式电源与电动汽车充电站的优化配置方法研究》第五章《单向/双向V2G环境下分布式电源与电动汽车充电站联合配置方法》…