李沐深度学习-线性回归从零开始

# 核心Tensor,autograd
import torch
from IPython import display
import numpy as np
import random
from matplotlib import pyplot as pltimport syssys.path.append('路径')
from d2lzh_pytorch import *'''
backward()函数:一次小批量执行完在进行反向传播
线性回归模型步骤;1.数据处理2.模型定义:根据矩阵形式运算,模型可以一次计算多个样本,比如X:1000x2, w:2x1  则模型可以一次计算1000个样本3.损失函数:4.优化算法:sgd则是小批量中每个样本loss运行完后,对应参数的梯度进行了累加,得到一个小批量的代表梯度 w1,w2,b然后将每个小批量的参数梯度进行梯度下降5.模型预测
'''
# ------------------------------------------------------------------------
# 生成数据集
'''
样本X=1000,特征=2,w=2,-3.4;b=4.2   随机噪声ξ     y=Xw+b+ξ
噪声服从均值为0,标准差为0.01的正态分布  噪声代表了数据集中无意义的干扰
'''
num_inputs = 2  # 特征数
num_examples = 1000  # 样本数量
true_w = [2, -3.4]  # w
true_b = 4.3  # b
# 生成所有包含特征  1000x2的样本 向量
features = torch.randn(num_examples, num_inputs, dtype=torch.float32)
# 下列运算属于矢量运算  预测y 表达式  是个向量,1000x1
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
# 添加符合正态分布的噪声
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float32)# set_figsize()
# plt.scatter(features[:, 1].numpy(), labels.numpy(), 1)
# plt.savefig('/home/eilab2023/yml/project/limu/picture/picture.png')# 读取数据集# ---------------------------------------------------------------------------------------------# ---------------------------------------------------------------------------------------------
# 定义模型# ---------------------------------------------------------------------------------------------# 损失函数
# ---------------------------------------------------------------------------------------------# 优化算法
# ---------------------------------------------------------------------------------------------# 模型训练
# ---------------------------------------------------------------------------------------------
lr = 0.03
num_epoch = 3
net = linreg
loss = squared_loss
'''
每次返回一个batch-size大小随机样本的特征和标签
'''
batch_size = 10
# 初始化模型参数 w b  都是列矩阵  上面的是确定的公式,x,w,b都是确定的,label确定。这里的参数是初始化模拟的
# 一般是,X确定,label确定,然后初始化w,b,在模型训练寻找最优解,这里提前确定是为了方便
w = torch.tensor((np.random.normal(0, 0.01, (num_inputs, 1))), dtype=torch.float32)
b = torch.tensor(1, dtype=torch.float32)
w.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)
# 外部定义了变量w后若在方法内有改变w,则该变量值会随着改变for epoch in range(num_epoch):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y).sum()  # 小批量的损失计算完l.backward()  # 计算小批量的样本参数梯度,这里每个样本的参数梯度会自动累加sgd([w, b], lr, batch_size)  # 一个小批量出一个w,b   梯度下降算法# 一个小批量的参数更新后就要对参数的梯度进行清零操作w.grad.data.zero_()b.grad.data.zero_()train_l = loss(net(features, w, b), labels)  # 这里的w,b是一轮所有的批量更新完成之后得到的最新的值,然后用于所有的样本进行损失计算print('epoch %d, loss %f' % (epoch + 1, train_l.mean().item()))  # 因为loss是一个1000x1的一个张量

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/633726.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java如何做到无感知刷新token含示例代码(值得珍藏)

1. 前言 在系统页面进行业务操作时,有时会突然遇到应用闪退,并被重定向至登录页面,要求重新登录。此问题的出现,通常与系统中用于存储用户ID和token信息的Redis缓存有关。具体来说,这可能是由于token过期所导致的身份…

容器部署的nextcloud配置onlyoffice时开启密钥

容器部署的nextcloud配置onlyoffice时开启密钥 配置 进入onlyoffice容器 docker exec -it 容器id bash编辑配置vi /etc/onlyoffice/documentserver/local.json enable设置为true,并配置secret 重启容器,并将配置的密钥填入nextcloud密钥页面 docker r…

复杂字幕特效SDK,重塑视频字幕新体验

字幕特效已经成为了提升视频品质、增强观众体验的重要手段。美摄科技作为行业领先的技术提供商,近期推出的复杂字幕特效SDK,更是引领了这一领域的创新潮流。 美摄科技复杂字幕特效SDK,不仅具备了电影级别的字幕功能,更实现了众多…

系统学习Python——警告信息的控制模块warnings:常用函数-[warnings.warn]

分类目录:《系统学习Python》总目录 函数 warnings.warn(message, categoryNone, stacklevel1, sourceNone, \*, skip_file_prefixesNone) 常被用于引发警告、忽略或者触发异常。 参数 如果给出category参数,则必须是警告类别类 ;默认为U…

【全网最全】2024华数杯国际赛B题成品论文50页+1-4问高质量代码+完整数据集+建模过程+保姆级教学

基于数据分析下的光伏发电 摘 要(完整版在文末) 根据最新数据,中国的总发电量超过20万亿千瓦时,总体排名世界第一,而光伏发电是一种重要的可再生能源,可以将太阳能转化为电能可以减少对传统能源的依赖&…

VBA窗体跟随活动单元格【简易版】(2/2)

上一篇博客(文章连接如下)中使用工作表事件Worksheet_SelectionChange实现了窗体跟随活动单元格的动态效果。 VBA窗体跟随活动单元格【简易版】(1/2) 为了在用户滚动工作表窗体之后仍能够实现跟随效果,需要使用Application.Windows(1).Visibl…

Nginx 常用的基础配置(前端相关方面)

Nginx配置前端 web 服务这篇文章;希望能够帮助更多的朋友。 基础配置 user root; worker_processes 1;events {worker_connections 10240; }http {log_format $remote_addr - $remote_u…

归并排序详解

目录 ​💡基本思想 💡图文介绍 💡动图演示 💡过程解释 💡代码实现 💡递归实现 💡非递归实现 💡总结 💡基本思想 归并排序(MERGE-SORT)是…

数据结构--串

本文为复习的草稿笔记,,,有点乱 1. 串的基本概念和基本操作 串是由零个或多个字符组成的有限序列 2. 串的存储结构 3.串的应用 模式匹配 BF算法(简单匹配算法 穷举法 算法思路:从子串的每一个字符开始依次与主串…

深耕文档型数据库12载,SequoiaDB再开源

1月15日,巨杉数据库举行SequoiaDB新特性及开源项目发布活动。本次活动回顾了巨杉数据库深耕JSON文档型数据库12年的发展历程与技术演进,全面解读了SequoiaDB包括在高可用、安全、实时、易用性四个方向的技术特性,宣布了2024年面向技术社区的开…

无法打开浏览器开发者工具的可能解决方法

网页地址: https://jx.xyflv.cc/?url视频地址url 我在抖音里面抓了一个视频地址, 获取到响应的json数据, 找到里面的视频地址信息 这个网站很好用: https://www.jsont.run/ 可以使用js语法对json对象操作, 找到所有视频的url地址 打开网页: https://jx.xyflv.cc/?urlhttps:…

【Linux C | 文件操作】目录相关操作 | mkdir、rmdir、opendir、readdir、closedir、getcwd、chdir

😁博客主页😁:🚀https://blog.csdn.net/wkd_007🚀 🤑博客内容🤑:🍭嵌入式开发、Linux、C语言、C、数据结构、音视频🍭 🤣本文内容🤣&a…

【LeetCode】栈精选9题

目录 1. 删除字符串中的所有相邻重复项(简单) 2. 逆波兰表达式(中等) 3. 基本计算器 II(中等) 4. 字符串解码(中等) 5. 验证栈序列(中等) 6. 小行星碰撞…

1月18日课前练习题

调整一个三位的百位,十位,个位 的数字让调整后的数字最大 //参数num:进行调整的整数 //返回值:调整后的最大整数 package com.ztt.Demo06Exercise;public class test { //1月18日public static void main(String[] args) {int n234;int retto…

新能源汽车智慧充电桩方案:基于视频监控的可视化智能监管平台

一、方案概述 TSINGSEE青犀&触角云新能源汽车智慧充电桩方案围绕互联网、物联网、车联网、人工智能、视频技术、大数据、4G/5G等技术,结合云计算、移动支付等,实现充电停车一体化、充电桩与站点管理等功能,达到充电设备与站点的有效监控…

有效防范网络风险的关键措施

在数字化时代,企业面临着日益复杂和频繁的网络风险。提高员工的网络安全意识是防范网络威胁的关键一步。本文将探讨企业在提升网络安全意识方面可以采取的措施,以有效预防潜在的网络风险。 1. 开展网络安全培训:企业应定期组织网络安全培训&…

面板小程序命令行工具介绍

Ray 体系提供配套的工程化解决方案。 由于多端构建的一些客观原因,在构建流程的设计上,必须将工程套件安装在项目内。 项目内的依赖至少包含以下内容: {"dependencies": {"ray-js/ray": "latest"},"de…

GDB 调用无符号的任意函数

我们知道有符号的函数调用很简单了,直接像写c语言一样传参调用即可。但是无符号的就不知道怎么弄了,查遍了整个网络我都没有查到怎么做。只好自己想办法了。总体的思路如下 1. 保存好所有的现场,如寄存器,当前pc, 返回地址&#…

Django migration 新增外键的坑

TL;DR 永远不要相信 makemigrations! migrate 之前一定好好看看 migrate 了啥东西,必要时手动修改生成的 migrate 文件。 最好把db的更新与服务代码更新解耦 场景 先描述下场景: 现在有两个表,一个是 question,一…