Pytorch线性回归教程

import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt

生成测试数据

# 长期趋势
def trend(time, slope=0):return slope * time# 季节趋势
def seasonal_pattern(season_time):return np.where(season_time < 0.4,np.cos(season_time * 2 * np.pi),1 / np.exp(3 * season_time))
def seasonality(time, period, amplitude=1, phase=0):season_time = ((time + phase) % period) / periodreturn amplitude * seasonal_pattern(season_time)# 噪声
def noise(time, noise_level=1):return np.random.randn(len(time)) * noise_level
X = torch.arange(1, 1001)
# Y = 0.7 * X + 100 + torch.randn(X.size())
Y = trend(X, 0.3) + seasonality(X, period=365, amplitude=30) + noise(X, 15) + 200
X.shape, Y.shape
(torch.Size([1000]), torch.Size([1000]))
plt.plot(X.numpy(), Y.numpy());

对测试数据进行处理

# 模型的数据的类型需要是32位浮点型
X = X.type(torch.float32)
Y = Y.type(torch.float32)
X.dtype, Y.dtype
(torch.float32, torch.float32)
# 模型的数据需要进行归一化或者标准化,下面是归一化
X = (X - X.min()) / (X.max() - X.min())
Y = (Y - Y.min()) / (Y.max() - Y.min())
plt.plot(X.numpy(), Y.numpy());

定义模型和模型参数

# 线性模型只有两个参数斜率k,和偏置b
# 线性模型的方程为y = k * x + b
k = nn.Parameter(torch.rand(1, dtype=torch.float32))
b = nn.Parameter(torch.rand(1, dtype=torch.float32))
# 下面输出中的requires_grad=True 表示该参数需要计算梯度
# 梯度用于在反向传播中对参数进行优化,优化方法即梯度下降
k, b 
(Parameter containing:tensor([0.6231], requires_grad=True),Parameter containing:tensor([0.0044], requires_grad=True))
def linear_model(x):return k * x + b

梯度下降优化参数

# 可以通过改变学习率lr和epoch_num学习各自的用途
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()# 每个epoch表示把全部的数据过一遍
epoch_num = 2000
for epoch in range(epoch_num):# 获取模型预测结果y_pred = linear_model(X)# 计算损失值loss = loss_func(y_pred, Y)# 将梯度设为0optimizer.zero_grad()# 反向传播,计算梯度loss.backward()# 执行梯度下降,优化参数optimizer.step()
k, b
(Parameter containing:tensor([0.8825], requires_grad=True),Parameter containing:tensor([0.0419], requires_grad=True))
# detach()函数用于将参数设置为不需要梯度
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

优化模型

class LinearModel(nn.Module):def __init__(self):super().__init__()self.k = nn.Parameter(torch.rand(1, dtype=torch.float32))self.b = nn.Parameter(torch.rand(1, dtype=torch.float32))def forward(self, x):return self.k * x + self.b
model = LinearModel()
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()epoch_num = 2000
for epoch in range(epoch_num):y_pred = model(X)loss = loss_func(y_pred, Y)optimizer.zero_grad()loss.backward()optimizer.step()
k, b
(Parameter containing:tensor([0.8825], requires_grad=True),Parameter containing:tensor([0.0419], requires_grad=True))
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

随机梯度下降

# 前面执行梯度下降时,我们是一次将全部的数据都传入模型
# 但在实际应用中,可能会由于数据太大,没法全部传入模型
# 因此,可以一次传入一部分数据,这便是随机梯度下降
# 随机梯度下降的核心是,梯度是期望。期望可使用小规模的样本近似估计。
model = LinearModel()
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()# 每个epoch表示把全部的数据过一遍
epoch_num = 2000 
# iter_step表示在一个epoch内抽取几个小规模样本
iter_step = 10
# batch_size表示小规模样本的大小
batch_size = 100
for epoch in range(epoch_num):for i in range(iter_step):random_samples = torch.randint(X.size()[0], (batch_size, ))X_i, Y_i = X[random_samples], Y[random_samples]y_pred = model(X_i)loss = loss_func(y_pred, Y_i)optimizer.zero_grad()loss.backward()optimizer.step()
k, b
(Parameter containing:tensor([0.8825], requires_grad=True),Parameter containing:tensor([0.0419], requires_grad=True))
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/204613.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3 + mark.js 实现文字标注功能

效果图 安装依赖 npm install mark.js --save-dev npm i nanoid代码块 <template><!-- 文档标注 --><header><el-buttontype"primary":disabled"selectedTextList.length 0 ? true : false"ghostclick"handleAllDelete"…

Linux学习笔记2

web服务器部署&#xff1a; 1.装包&#xff1a; [rootlocalhost ~]# yum -y install httpd 2.配置一个首页&#xff1a; [rootlocalhost ~]# echo i love yy > /var/www/html/index.html 启动服务&#xff1a;[rootlocalhost ~]# systemctl start httpd Ctrl W以空格为界…

排序算法介绍(一)插入排序

0. 简介 插入排序&#xff08;Insertion Sort&#xff09; 是一种简单直观的排序算法&#xff0c;它的工作原理是通过构建有序序列&#xff0c;对于未排序数据&#xff0c;在已排序序列中从后向前扫描&#xff0c;找到相应位置并插入。插入排序在实现上&#xff0c;通常…

mybatis数据输出-驼峰命名规则设置

1、建库建表 CREATE DATABASE mybatis-example;USE mybatis-example;CREATE TABLE t_emp(emp_id INT AUTO_INCREMENT,emp_name CHAR(100),emp_salary DOUBLE(10,5),PRIMARY KEY(emp_id) );INSERT INTO t_emp(emp_name,emp_salary) VALUES("tom",200.33); INSERT INTO…

持续集成交付CICD:GitLabCI 实现Sonarqube代码扫描

目录 一、实验 1.GitLabCI 代码扫描 二、问题 1.GitLab 执行sonar-scanner命令报错 一、实验 1.GitLabCI 代码扫描 &#xff08;1&#xff09;打开maven项目 &#xff08;2&#xff09;maven项目流水线调用公共库 &#xff08;3&#xff09;项目组添加token认证 &#xf…

FreeRTOS系统延时函数分析

一、概述 FreeRTOS提供了两个系统延时函数&#xff0c;相对延时函数vTaskDelay()和绝对延时函数vTaskDelayUntil()。相对延时是指每次延时都是从任务执行函数vTaskDelay()开始&#xff0c;延时指定的时间结束&#xff0c;绝对延时是指每隔指定的时间&#xff0c;执行一次调用vT…

项目分析:解决类的复杂设计中遇到的问题

1.问题1&#xff1a;析构函数乱码问题 【样例输入】 -3 1 3 -1 -3 2 3 -2 【样例输出】 gouzao 1 -3 1 3 -1 gouzao 2 -3 2 3 -2 -3 1 3 -1 -3 2 3 -2 9.4245 18.849 Ellipse xigou 3 -2 Point xigou 3 -2 Point xigou -3 2 Point xigou 3 -2 Point xigou -3 2…

2023年阿里云云栖大会-核心PPT资料下载

一、峰会简介 历经14届的云栖大会&#xff0c;是云计算产业的建设者、推动者、见证者。2023云栖大会以“科技、国际、年轻”为基调&#xff0c;以“计算&#xff0c;为了无法计算的价值”为主题&#xff0c;发挥科技平台汇聚作用&#xff0c;与云计算全产业链上下游的先锋代表…

关于mysql高版本使用groupby导致的报错

在开发时&#xff0c;遇到mysql版本在5.7.X及以上版本时使用group by 语句会报以下的错误 Caused by: com.mysql.jdbc.exceptions.jdbc4.MySQLSyntaxErrorException: Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggregated column business_typ…

使用pandas制作图表

数据可视化对于数据分析的重要性不言而喻&#xff0c;一个优秀的图表有足以一眼就看出关键所在。pandas利用matplotlib实现绘图。能够提供各种各样的图表功能&#xff0c;包括: 单折线图多折线图柱状图叠加柱状图水平叠加柱状图直方图拆分直方图箱型图区域块图形散点图饼图多子…

13年测试老鸟总结,性能测试常遇问题+解决方案+分析...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、内存溢出 1&a…

【高数:1 映射与函数】

【高数&#xff1a;1 映射与函数】 例2.1 绝对值函数例2.2 符号函数例2.3 反函数表示例2.4 双曲正弦sinh&#xff0c;双曲余弦cosh&#xff0c;双曲正切tanh 参考书籍&#xff1a;毕文斌, 毛悦悦. Python漫游数学王国[M]. 北京&#xff1a;清华大学出版社&#xff0c;2022. 例2…

设计之初,成就AI创作的非凡之路——AI绘画

一.官方活动 活动链接&#xff1a;| 2023腾讯云 AI 绘画有奖征文大赛&#xff0c;秀出你的AI新质生产力 https://cloud.tencent.com/developer/article/2367375 二.产品体验 1.产品链接:https://cloud.tencent.com/act/pro/AIhuihua?from20421&from_column20421 2.产品…

Maven-高效的Java项目构建与管理工具(含Maven详细安装与配置过程)

Maven 什么是Maven&#xff1f; 正如题目所说&#xff0c;Maven就是一款高效的Java项目构建与管理工具&#xff0c;基于项目对象模型&#xff08;POM&#xff09;概念&#xff0c;利用一个中央信息片断能管理一个项目的构建、报告和文档等步骤。是Apache软件基金会的一个开源…

网站测试都要测试哪些?如何进行测试?

1 UI测试 看页面是否美观养眼(包括页面的布局是否合理&#xff0c;策划是否舒服美观&#xff0c;页面长度是否合理&#xff0c;前景色与背景色是否搭配&#xff0c;页面风格是否统一&#xff0c;色调是否适合人眼&#xff0c;会不会太刺眼&#xff0c;字体大小是否合适&#x…

Java多线程:代码不只是在‘Hello World‘

Java线程好书推荐 概述01 多线程对于Java的意义02 为什么Java工程师必须掌握多线程03 Java多线程使用方式04 如何学好Java多线程写在末尾&#xff1a; 主页传送门&#xff1a;&#x1f4c0; 传送 概述 摘要&#xff1a;互联网的每一个角落&#xff0c;无论是大型电商平台的秒杀…

IntelliJ IDEA图形安装教程

IntelliJ IDEA图形安装教程 之前开始Java程序&#xff0c;一直用的eclipse&#xff0c;觉得还可以。一直听说IntelliJ IDEA比eclipse好用很多&#xff0c;但因为比较懒&#xff0c;也没有学习使用。机缘巧合下&#xff0c;尝试用了下&#xff0c;顿时有种相见恨晚的感觉&#…

【问题思考】泰勒公式证明题如何选展开点?【对称美】

我的证明题水平很烂&#xff0c;这个纯属让自己有一个初步的理解&#xff0c;恳请指正&#xff01; 问题 我们可以看到这里有两种展开方式&#xff08;注意&#xff1a;x0叫展开点&#xff09;&#xff0c;分别是正确的做法&#xff0c;在x0展开&#xff0c;然后将0和a代入fx中…

Windows系统上如何搭建Linux操作系统

一、准备工作 1&#xff0c;VMware安装包 2&#xff0c;Centos IOS镜像 3&#xff0c;finalshell安装包 阿里云盘下载地址&#xff1a; https://www.alipan.com/s/uSQsWn15E3W 二&#xff0c;VMware安装 1&#xff0c;新建虚拟机 2&#xff0c;选择下一步 3&#xff0c;…

如何在Linux上部署1Panel运维管理面板并远程访问内网Web端管理界面

文章目录 前言1. Linux 安装1Panel2. 安装cpolar内网穿透3. 配置1Panel公网访问地址4. 公网远程访问1Panel管理界面5. 固定1Panel公网地址 前言 1Panel 是一个现代化、开源的 Linux 服务器运维管理面板。高效管理,通过 Web 端轻松管理 Linux 服务器&#xff0c;包括主机监控、…