pytorch 实现线性回归(深度学习)

一 查看原始函数

        y=2x+4.2

初始化

%matplotlib inline
import random
import torch
from d2l import torch as d2l

1.1 生成原始数据

def synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape)  # 噪声return x, y.reshape((-1 , 1))
true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)

1.2 数据转换

def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')

1.3 初始化权重

随机初始化,w使用 均值0,方差 0.01 的随机值, b 初始化为1

w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

二 执行训练

查看训练过程中的 参数变化:

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:# print('param:', param, 'param.grad:', param.grad)param -= lr * param.grad / batch_sizeparam.grad.zero_()lr = 0.03
num_epochs = 1000
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print('w:', w, 'b:', b)  # l:', l, '\nl.sum().backward()sgd([w, b], lr, batch_size)

 


三 测试梯度更新

初始化数据

%matplotlib inline
import random
import torch
from d2l import torch as d2ldef synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape)  # 噪声return x, y.reshape((-1 , 1))true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

3.1 测试更新

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\nl.sum().backward()  # 计算更新梯度sgd([w, b], lr, batch_size)

使用 l.sum().backward()  # 计算更新梯度:

不使用更新时:

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\n# l.sum().backward()  # 计算更新梯度sgd([w, b], lr, batch_size)#     break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/687001.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

springboot第55集:思维导图Sharding-JDBC,事务,微服务分布式架构周刊

事务相关知识,你知道多少? 事务定义 在数据库管理系统中,事务是单个逻辑或工作单元,有时由多个操作组成,在数据库中以一致模式完成的逻辑处理称为事务。一个例子是从一个银行账户转账到另一个账户:完整的交…

【ChatGPT】的定价模式:免费还是收费?

ChatGPT的定价模式:免费还是收费? 人工智能技术的快速发展正在为我们的生活带来巨大的变化,而OpenAI最近推出的ChatGPT模型引发了人们对它的定价模式的关注。这篇文章将探讨ChatGPT是免费还是收费的问题,并对这个话题进行深入分析…

阿里云ECS香港服务器性能强大、cn2高速网络租用价格表

阿里云香港服务器中国香港数据中心网络线路类型BGP多线精品,中国电信CN2高速网络高质量、大规格BGP带宽,运营商精品公网直连中国内地,时延更低,优化海外回中国内地流量的公网线路,可以提高国际业务访问质量。阿里云服务…

免费chatgpt使用

基本功能如下: https://go.aigcplus.cc/auth/register?inviteCode3HCULH2UD

[嵌入式系统-25]:RT-Thread -12- 内核组件编程接口 - 网络组件 - HTTP编程

目录 一、HTTP编程概述 1.1 概述 1.2 HTTP 服务器和 HTTP 客户端 二、HTTP Client 2.1 如何配置HTTP Client 2.2 HTTP Client代码实例1:socket发送http报文 2.3 HTTP Client代码实例2:httpc_xx接口收发HTTP报文 2.3.1 接口函数描述 2.3.2 代码实…

中科大计网学习记录笔记(十二):TCP 套接字编程

前前言:大家看到这一章节的时候一定不要跳过,虽然标题是编程,但实际上是对 socket 的运行机制做了详细的讨论,对理解 TCP 有很大的帮助;但是由于本节涉及到了大量的编程知识,对于一些朋友来说不是很好理解&…

【深度学习】S3 线性神经网络 P1 线性回归(未完)

目录 线性回归基本元素基本名词线性模型 机器学习领域,大多数任务最终的目标都是预测。而预测的结果大致分为两大类,一种是需要估计连续数值的回归预测,另一种是确定离散类别的分类预测。本节博文将围绕线性回归内容。 线性回归基本元素 基…

Nginx (window)2024版 笔记 下载 安装 配置

前言 Nginx (engine x) 是一款轻量级的 Web 服务器 、反向代理(Reverse Proxy)服务器及电子邮件(IMAP/POP3)代理服务器。 反向代理方式是指以代理服务器来接受 internet 上的连接请求,然后将请求转发给内部网络上的服…

[AIGC_coze] Kafka 的主题分区之间的关系

Kafka 的主题分区之间的关系 在 Kafka 中,主题(Topics)和分区(Partitions)是两个重要的概念,它们之间存在着密切的关系。 主题是 Kafka 中用于数据发布和订阅的逻辑单元。每个主题可以包含多个分区&#x…

【退役之重学前端】关于在控制台得到undefined的事

在浏览器控制台中,undefined 会时不时地,在我不想看到的地方出现。如果你遇到相同的问题,在这篇博客中你会得到答案。 先来看代码块 function test(){} test()//undefined再看下一个代码块 function test(){return 1; } test()//1再来看一个…

BUGKU-WEB eval

题目描述 题目截图如下&#xff1a; 进入场景看看&#xff1a; <?phpinclude "flag.php";$a $_REQUEST[hello];eval( "var_dump($a);");show_source(__FILE__); ?>解题思路 PHP代码审计咯 相关工具 百度搜索PHP相关知识 解题步骤 分析脚…

OpenAI全新发布文生视频模型:Sora!

OpenAI官网原文链接&#xff1a;https://openai.com/research/video-generation-models-as-world-simulators#fn-20 我们探索视频数据生成模型的大规模训练。具体来说&#xff0c;我们在可变持续时间、分辨率和宽高比的视频和图像上联合训练文本条件扩散模型。我们利用对视频和…

【C++初阶】第三站:类和对象(中) -- 日期计算器

目录 前言 日期类的声明.h 日期类的实现.cpp 获取某年某月的天数 全缺省的构造函数 拷贝构造函数 打印函数 日期 天数 日期 天数 日期 - 天数 日期 - 天数 前置 后置 前置 -- 后置-- 日期类中比较运算符的重载 <运算符重载 运算符重载 ! 运算符重载 …

【Webpack】生产模式

生产模式介绍 生产模式是开发完成代码后&#xff0c;我们需要得到代码将来部署上线。 这个模式下我们主要对代码进行优化&#xff0c;让其运行性能更好。 优化主要从两个角度出发: 优化代码运行性能优化代码打包速度 生产模式准备 我们分别准备两个配置文件来放不同的配置…

用c语言做一个心算小游戏

有加减和乘法3种运算&#xff0c;由于除法涉及到浮点数存储有误差&#xff0c;所以比较难实现&#xff0c;改程序还有判定分数机制&#xff0c;根据难度给定合适的分数&#xff0c;随机抽取运算题目和符号。下面的代码适合Linux和安卓上的编译器&#xff0c;因为用了ANSI转义字…

SG5032EAN规格书

SG5032EAN 晶体振荡器结合了相位锁定环&#xff08;PLL&#xff09;技术和AT切割晶体单元&#xff0c;提供了73.5 MHz至700 MHz的广泛频率范围&#xff0c;以满足高速数字应用的需求。高性能的LV-PECL输出&#xff0c;2.5V和3.3V电源电压&#xff0c;可灵活适配不同设计的电源需…

layui表格中使用cascader后导致表格滚动条消失

修改前&#xff0c;受影响页面 修改后最终想要的效果 修改方法

《Go 简易速速上手小册》第8章:网络编程(2024 最新版)

文章目录 8.1 HTTP 客户端与服务端编程 - Go 语言的网络灯塔与探航船8.1.1 基础知识讲解服务端编程客户端编程 8.1.2 重点案例&#xff1a;简易博客服务服务端实现客户端实现运行示例 8.1.3 拓展案例 1&#xff1a;增加文章评论功能功能描述服务端实现客户端实现 8.1.4 拓展案例…

Python爬虫之Splash详解

爬虫专栏&#xff1a;http://t.csdnimg.cn/WfCSx Splash 的使用 Splash 是一个 JavaScript 渲染服务&#xff0c;是一个带有 HTTP API 的轻量级浏览器&#xff0c;同时它对接了 Python 中的 Twisted 和 QT 库。利用它&#xff0c;我们同样可以实现动态渲染页面的抓取。 1. 功…

sql常用语句小结

创建表&#xff1a; create table 表名&#xff08; 字段1 字段类型 【约束】【comment 字段1注释】&#xff0c; //【】里面的东西可以不用加上去 字段2 字段类型 【约束】【comment 字段2注释】 &#xff09;【comment 表注释】 约束&#xff1a;作用于表中字段上的规则…