Pytorch线性回归教程

import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt

生成测试数据

# 长期趋势
def trend(time, slope=0):return slope * time# 季节趋势
def seasonal_pattern(season_time):return np.where(season_time < 0.4,np.cos(season_time * 2 * np.pi),1 / np.exp(3 * season_time))
def seasonality(time, period, amplitude=1, phase=0):season_time = ((time + phase) % period) / periodreturn amplitude * seasonal_pattern(season_time)# 噪声
def noise(time, noise_level=1):return np.random.randn(len(time)) * noise_level
X = torch.arange(1, 1001)
# Y = 0.7 * X + 100 + torch.randn(X.size())
Y = trend(X, 0.3) + seasonality(X, period=365, amplitude=30) + noise(X, 15) + 200
X.shape, Y.shape
(torch.Size([1000]), torch.Size([1000]))
plt.plot(X.numpy(), Y.numpy());

对测试数据进行处理

# 模型的数据的类型需要是32位浮点型
X = X.type(torch.float32)
Y = Y.type(torch.float32)
X.dtype, Y.dtype
(torch.float32, torch.float32)
# 模型的数据需要进行归一化或者标准化,下面是归一化
X = (X - X.min()) / (X.max() - X.min())
Y = (Y - Y.min()) / (Y.max() - Y.min())
plt.plot(X.numpy(), Y.numpy());

定义模型和模型参数

# 线性模型只有两个参数斜率k,和偏置b
# 线性模型的方程为y = k * x + b
k = nn.Parameter(torch.rand(1, dtype=torch.float32))
b = nn.Parameter(torch.rand(1, dtype=torch.float32))
# 下面输出中的requires_grad=True 表示该参数需要计算梯度
# 梯度用于在反向传播中对参数进行优化,优化方法即梯度下降
k, b 
(Parameter containing:tensor([0.6231], requires_grad=True),Parameter containing:tensor([0.0044], requires_grad=True))
def linear_model(x):return k * x + b

梯度下降优化参数

# 可以通过改变学习率lr和epoch_num学习各自的用途
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()# 每个epoch表示把全部的数据过一遍
epoch_num = 2000
for epoch in range(epoch_num):# 获取模型预测结果y_pred = linear_model(X)# 计算损失值loss = loss_func(y_pred, Y)# 将梯度设为0optimizer.zero_grad()# 反向传播,计算梯度loss.backward()# 执行梯度下降,优化参数optimizer.step()
k, b
(Parameter containing:tensor([0.8825], requires_grad=True),Parameter containing:tensor([0.0419], requires_grad=True))
# detach()函数用于将参数设置为不需要梯度
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

优化模型

class LinearModel(nn.Module):def __init__(self):super().__init__()self.k = nn.Parameter(torch.rand(1, dtype=torch.float32))self.b = nn.Parameter(torch.rand(1, dtype=torch.float32))def forward(self, x):return self.k * x + self.b
model = LinearModel()
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()epoch_num = 2000
for epoch in range(epoch_num):y_pred = model(X)loss = loss_func(y_pred, Y)optimizer.zero_grad()loss.backward()optimizer.step()
k, b
(Parameter containing:tensor([0.8825], requires_grad=True),Parameter containing:tensor([0.0419], requires_grad=True))
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

随机梯度下降

# 前面执行梯度下降时,我们是一次将全部的数据都传入模型
# 但在实际应用中,可能会由于数据太大,没法全部传入模型
# 因此,可以一次传入一部分数据,这便是随机梯度下降
# 随机梯度下降的核心是,梯度是期望。期望可使用小规模的样本近似估计。
model = LinearModel()
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()# 每个epoch表示把全部的数据过一遍
epoch_num = 2000 
# iter_step表示在一个epoch内抽取几个小规模样本
iter_step = 10
# batch_size表示小规模样本的大小
batch_size = 100
for epoch in range(epoch_num):for i in range(iter_step):random_samples = torch.randint(X.size()[0], (batch_size, ))X_i, Y_i = X[random_samples], Y[random_samples]y_pred = model(X_i)loss = loss_func(y_pred, Y_i)optimizer.zero_grad()loss.backward()optimizer.step()
k, b
(Parameter containing:tensor([0.8825], requires_grad=True),Parameter containing:tensor([0.0419], requires_grad=True))
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/204613.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3 + mark.js 实现文字标注功能

效果图 安装依赖 npm install mark.js --save-dev npm i nanoid代码块 <template><!-- 文档标注 --><header><el-buttontype"primary":disabled"selectedTextList.length 0 ? true : false"ghostclick"handleAllDelete"…

Linux学习笔记2

web服务器部署&#xff1a; 1.装包&#xff1a; [rootlocalhost ~]# yum -y install httpd 2.配置一个首页&#xff1a; [rootlocalhost ~]# echo i love yy > /var/www/html/index.html 启动服务&#xff1a;[rootlocalhost ~]# systemctl start httpd Ctrl W以空格为界…

String、StringBuffer、StringBuilder

String类 特点&#xff1a; 1.类由final关键字修饰&#xff0c;不可被继承&#xff1b; 2.value是一个由final修饰的字符数组&#xff0c;即字符串的长度不可修改&#xff1b; 3.实现了Comparable<T>接口&#xff0c;可进行比较&#xff1b; StringBuffer 特点&#x…

almaLinux centos8 下载ffmpeg离线安装包、离线安装

脚本 # 添加RPMfusion仓库 sudo yum install https://download1.rpmfusion.org/free/el/rpmfusion-free-release-8.noarch.rpm wget -ymkdir -p /root/ffmpeg cd /root/ffmpegwget http://rpmfind.net/linux/epel/7/x86_64/Packages/s/SDL2-2.0.14-2.el7.x86_64.rpmyum instal…

排序算法介绍(一)插入排序

0. 简介 插入排序&#xff08;Insertion Sort&#xff09; 是一种简单直观的排序算法&#xff0c;它的工作原理是通过构建有序序列&#xff0c;对于未排序数据&#xff0c;在已排序序列中从后向前扫描&#xff0c;找到相应位置并插入。插入排序在实现上&#xff0c;通常…

算法刷题之数组篇

题目一&#xff1a;两数之和 给出一个整型数组 numbers 和一个目标值 target&#xff0c;请在数组中找出两个加起来等于目标值的数的下标&#xff0c;返回的下标按升序排列。 &#xff08;注&#xff1a;返回的数组下标从1开始算起&#xff0c;保证target一定可以由数组里面2…

mybatis数据输出-驼峰命名规则设置

1、建库建表 CREATE DATABASE mybatis-example;USE mybatis-example;CREATE TABLE t_emp(emp_id INT AUTO_INCREMENT,emp_name CHAR(100),emp_salary DOUBLE(10,5),PRIMARY KEY(emp_id) );INSERT INTO t_emp(emp_name,emp_salary) VALUES("tom",200.33); INSERT INTO…

持续集成交付CICD:GitLabCI 实现Sonarqube代码扫描

目录 一、实验 1.GitLabCI 代码扫描 二、问题 1.GitLab 执行sonar-scanner命令报错 一、实验 1.GitLabCI 代码扫描 &#xff08;1&#xff09;打开maven项目 &#xff08;2&#xff09;maven项目流水线调用公共库 &#xff08;3&#xff09;项目组添加token认证 &#xf…

【C++】手撕string思路梳理

目录 基本思路 代码实现 1.构建框架&#xff1a; 2.构建函数重载 3.迭代器&#xff1a; 4.遍历string 5.resetve 开空间&#xff0c;insert任意位置插入push_back,append,(按顺序依次实现) 6.erase删除&#xff0c;clear清除&#xff0c;resize缩容 7.流插入&#xff0…

Java集合使用注意事项

目录 1. 集合判空 2. 集合转 Map 3. 集合遍历 4. 集合去重 5. 集合转数组 6. 数组转集合 1. 集合判空 《阿里巴巴 Java 开发手册》的描述如下&#xff1a; 判断所有集合内部的元素是否为空&#xff0c;使用 isEmpty() 方法&#xff0c;而不是 size()0 的方式。 这是因为…

FreeRTOS系统延时函数分析

一、概述 FreeRTOS提供了两个系统延时函数&#xff0c;相对延时函数vTaskDelay()和绝对延时函数vTaskDelayUntil()。相对延时是指每次延时都是从任务执行函数vTaskDelay()开始&#xff0c;延时指定的时间结束&#xff0c;绝对延时是指每隔指定的时间&#xff0c;执行一次调用vT…

项目分析:解决类的复杂设计中遇到的问题

1.问题1&#xff1a;析构函数乱码问题 【样例输入】 -3 1 3 -1 -3 2 3 -2 【样例输出】 gouzao 1 -3 1 3 -1 gouzao 2 -3 2 3 -2 -3 1 3 -1 -3 2 3 -2 9.4245 18.849 Ellipse xigou 3 -2 Point xigou 3 -2 Point xigou -3 2 Point xigou 3 -2 Point xigou -3 2…

2023年阿里云云栖大会-核心PPT资料下载

一、峰会简介 历经14届的云栖大会&#xff0c;是云计算产业的建设者、推动者、见证者。2023云栖大会以“科技、国际、年轻”为基调&#xff0c;以“计算&#xff0c;为了无法计算的价值”为主题&#xff0c;发挥科技平台汇聚作用&#xff0c;与云计算全产业链上下游的先锋代表…

关于mysql高版本使用groupby导致的报错

在开发时&#xff0c;遇到mysql版本在5.7.X及以上版本时使用group by 语句会报以下的错误 Caused by: com.mysql.jdbc.exceptions.jdbc4.MySQLSyntaxErrorException: Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggregated column business_typ…

使用pandas制作图表

数据可视化对于数据分析的重要性不言而喻&#xff0c;一个优秀的图表有足以一眼就看出关键所在。pandas利用matplotlib实现绘图。能够提供各种各样的图表功能&#xff0c;包括: 单折线图多折线图柱状图叠加柱状图水平叠加柱状图直方图拆分直方图箱型图区域块图形散点图饼图多子…

13年测试老鸟总结,性能测试常遇问题+解决方案+分析...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、内存溢出 1&a…

【高数:1 映射与函数】

【高数&#xff1a;1 映射与函数】 例2.1 绝对值函数例2.2 符号函数例2.3 反函数表示例2.4 双曲正弦sinh&#xff0c;双曲余弦cosh&#xff0c;双曲正切tanh 参考书籍&#xff1a;毕文斌, 毛悦悦. Python漫游数学王国[M]. 北京&#xff1a;清华大学出版社&#xff0c;2022. 例2…

设计之初,成就AI创作的非凡之路——AI绘画

一.官方活动 活动链接&#xff1a;| 2023腾讯云 AI 绘画有奖征文大赛&#xff0c;秀出你的AI新质生产力 https://cloud.tencent.com/developer/article/2367375 二.产品体验 1.产品链接:https://cloud.tencent.com/act/pro/AIhuihua?from20421&from_column20421 2.产品…

vue实现父子传参

1.父向子进行传值 在父组件中&#xff0c;使用子组件的标签&#xff0c;并通过属性将数据传递给子组件。 在子组件中&#xff0c;定义props选项来接收父组件传递的数据。 父组件的数据会通过props选项传递给子组件&#xff0c;子组件可以直接使用这些数据。 父组件&#xf…

Maven-高效的Java项目构建与管理工具(含Maven详细安装与配置过程)

Maven 什么是Maven&#xff1f; 正如题目所说&#xff0c;Maven就是一款高效的Java项目构建与管理工具&#xff0c;基于项目对象模型&#xff08;POM&#xff09;概念&#xff0c;利用一个中央信息片断能管理一个项目的构建、报告和文档等步骤。是Apache软件基金会的一个开源…