pytorch 实现线性回归(深度学习)

一 查看原始函数

        y=2x+4.2

初始化

%matplotlib inline
import random
import torch
from d2l import torch as d2l

1.1 生成原始数据

def synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape)  # 噪声return x, y.reshape((-1 , 1))
true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)

1.2 数据转换

def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')

1.3 初始化权重

随机初始化,w使用 均值0,方差 0.01 的随机值, b 初始化为1

w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

二 执行训练

查看训练过程中的 参数变化:

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:# print('param:', param, 'param.grad:', param.grad)param -= lr * param.grad / batch_sizeparam.grad.zero_()lr = 0.03
num_epochs = 1000
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print('w:', w, 'b:', b)  # l:', l, '\nl.sum().backward()sgd([w, b], lr, batch_size)

 


三 测试梯度更新

初始化数据

%matplotlib inline
import random
import torch
from d2l import torch as d2ldef synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape)  # 噪声return x, y.reshape((-1 , 1))true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

3.1 测试更新

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\nl.sum().backward()  # 计算更新梯度sgd([w, b], lr, batch_size)

使用 l.sum().backward()  # 计算更新梯度:

不使用更新时:

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\n# l.sum().backward()  # 计算更新梯度sgd([w, b], lr, batch_size)#     break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/687001.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

阿里云ECS香港服务器性能强大、cn2高速网络租用价格表

阿里云香港服务器中国香港数据中心网络线路类型BGP多线精品,中国电信CN2高速网络高质量、大规格BGP带宽,运营商精品公网直连中国内地,时延更低,优化海外回中国内地流量的公网线路,可以提高国际业务访问质量。阿里云服务…

免费chatgpt使用

基本功能如下: https://go.aigcplus.cc/auth/register?inviteCode3HCULH2UD

中科大计网学习记录笔记(十二):TCP 套接字编程

前前言:大家看到这一章节的时候一定不要跳过,虽然标题是编程,但实际上是对 socket 的运行机制做了详细的讨论,对理解 TCP 有很大的帮助;但是由于本节涉及到了大量的编程知识,对于一些朋友来说不是很好理解&…

Nginx (window)2024版 笔记 下载 安装 配置

前言 Nginx (engine x) 是一款轻量级的 Web 服务器 、反向代理(Reverse Proxy)服务器及电子邮件(IMAP/POP3)代理服务器。 反向代理方式是指以代理服务器来接受 internet 上的连接请求,然后将请求转发给内部网络上的服…

[AIGC_coze] Kafka 的主题分区之间的关系

Kafka 的主题分区之间的关系 在 Kafka 中,主题(Topics)和分区(Partitions)是两个重要的概念,它们之间存在着密切的关系。 主题是 Kafka 中用于数据发布和订阅的逻辑单元。每个主题可以包含多个分区&#x…

BUGKU-WEB eval

题目描述 题目截图如下&#xff1a; 进入场景看看&#xff1a; <?phpinclude "flag.php";$a $_REQUEST[hello];eval( "var_dump($a);");show_source(__FILE__); ?>解题思路 PHP代码审计咯 相关工具 百度搜索PHP相关知识 解题步骤 分析脚…

OpenAI全新发布文生视频模型:Sora!

OpenAI官网原文链接&#xff1a;https://openai.com/research/video-generation-models-as-world-simulators#fn-20 我们探索视频数据生成模型的大规模训练。具体来说&#xff0c;我们在可变持续时间、分辨率和宽高比的视频和图像上联合训练文本条件扩散模型。我们利用对视频和…

【C++初阶】第三站:类和对象(中) -- 日期计算器

目录 前言 日期类的声明.h 日期类的实现.cpp 获取某年某月的天数 全缺省的构造函数 拷贝构造函数 打印函数 日期 天数 日期 天数 日期 - 天数 日期 - 天数 前置 后置 前置 -- 后置-- 日期类中比较运算符的重载 <运算符重载 运算符重载 ! 运算符重载 …

SG5032EAN规格书

SG5032EAN 晶体振荡器结合了相位锁定环&#xff08;PLL&#xff09;技术和AT切割晶体单元&#xff0c;提供了73.5 MHz至700 MHz的广泛频率范围&#xff0c;以满足高速数字应用的需求。高性能的LV-PECL输出&#xff0c;2.5V和3.3V电源电压&#xff0c;可灵活适配不同设计的电源需…

layui表格中使用cascader后导致表格滚动条消失

修改前&#xff0c;受影响页面 修改后最终想要的效果 修改方法

《Go 简易速速上手小册》第8章:网络编程(2024 最新版)

文章目录 8.1 HTTP 客户端与服务端编程 - Go 语言的网络灯塔与探航船8.1.1 基础知识讲解服务端编程客户端编程 8.1.2 重点案例&#xff1a;简易博客服务服务端实现客户端实现运行示例 8.1.3 拓展案例 1&#xff1a;增加文章评论功能功能描述服务端实现客户端实现 8.1.4 拓展案例…

Python爬虫之Splash详解

爬虫专栏&#xff1a;http://t.csdnimg.cn/WfCSx Splash 的使用 Splash 是一个 JavaScript 渲染服务&#xff0c;是一个带有 HTTP API 的轻量级浏览器&#xff0c;同时它对接了 Python 中的 Twisted 和 QT 库。利用它&#xff0c;我们同样可以实现动态渲染页面的抓取。 1. 功…

代码随想录|day 18

Day 18 哎&#xff0c;日子越来越近了&#xff0c;干什么都干不下去&#xff0c;但又必须要坚持。前途渺茫… 一、理论学习 1)自己误打误撞的时候&#xff0c;学习函数 int partitionmax_element(nums.begin(),nums.end())-nums.begin();也记录一下我的错误做法&#xff0c…

【Anaconda】conda创建、删除、查看虚拟环境,安装pytorch

1.删除环境 首先退出现有的环境 conda deactivate然后查看要删除的环境名称与路径 conda env list接下来就可以删除环境了 有两种方法 方法1&#xff1a; conda env remove -p 要删除的虚拟环境路径对我来说就是&#xff1a; conda env remove -p D:\Anaconda3\envs\MVDet…

【Unity】【VR开发】针对VR项目的优化版Unity Build Settings

【背景】 编辑器中做了功能后,打包后却总会画面不满意,所以到处学习,总结成本篇,希望有用。 【准备】 本篇总结基于Unity 2021 LTS。 模板选择3D(URP) 如果URP不支持所用的部分Assets,那么也可以选择Built-in管线,不过URP肯定画面效果上要胜过Built-in。 HDRP不适用…

智能摄像头prv文件恢复案例

家用智能摄像头一般采用的是mp4或者mov视频方案&#xff0c;常见的是mp4&#xff0c;对于部分有开发能力的厂商可能会采用自定义方案&#xff08;如360的bin文件&#xff09;,今天我们来看一个小厂的PRV自定义文件的恢复案例。 故障存储: 32G TF卡/fat32/ 簇&#xff08;块)大…

没钱、没资源、没团队、没商业模式,该怎么创业成功?

很多人为什么要去创业呢&#xff1f;大多还是万般无奈去创业的。如果人人都有王思聪的条件&#xff0c;天台你享受岂不是最好&#xff1f;谁还愿意苦哈哈创业呢&#xff1f; 对于很多创业者而言&#xff0c;创业初期就是会面对没钱、没资源、没团队、没商业模式的窘境&#xff…

vmware-17虚拟机安装教程及版本密钥(保姆级,包含图文讲解,不需注册账户)

文章目录 vmware安装教程一、下载vmware二、安装三、破解密匙 vmware安装教程 一、下载vmware 1.进入VMware官网&#xff1a;https://www.vmware.com/sg/products/workstation-pro.html 2.向下翻找到&#xff0c;如下界面并点击“现在安装” 3.稍事等待以下直到出现以下界面…

004 - Hugo, 分类

004 - Hugo, 分类content文件夹 004 - Hugo, 分类 content文件夹 ├─.obsidian ├─categories │ ├─Python │ └─Test ├─page │ ├─about │ ├─archives │ ├─links │ └─search └─post├─chinese-test├─emoji-support├─Git教程├─Hugo分类├─…

基于飞腾ARM+FPGA国产化计算模块联合解决方案

联合解决方案概述 随着特殊领域电子信息系统对自主创新需求的日益提升&#xff0c;需不断开展国产抗恶劣环境计算整机及模块产 品的研制和升级。特殊领域电子信息系统的自主创新&#xff0c;是指依靠自身技术手段和安全机制&#xff0c;实现信息系统从硬 件到软件的自主研发…