Pytorch 实现简单的 线性回归 算法

Pytorch实现简单的线性回归算法

简单 tensor的运算

Pytorch涉及的基本数据类型是tensor(张量)和Autograd(自动微分变量)

import torch
x = torch.rand(5, 3) #产生一个5*3的tensor,在 [0,1) 之间随机取值
y = torch.ones(5, 3) #产生一个5*3的Tensor,元素都是1  z = x + y                    #两个tensor可以直接相加
q = x.mm(y.transpose(0, 1))  #x乘以y的转置  mm为矩阵的乘法,矩阵相乘必须某一个矩阵的行与另一个矩阵的列相等

Tensor与numpy.ndarray之间的转换

import numpy as np        #导入numpy包
a = np.ones([5, 3])       #建立一个5*3全是1的二维数组(矩阵)
b = torch.from_numpy(a)   #利用from_numpy将其转换为tensor
c = torch.FloatTensor(a)  #另外一种转换为tensor的方法,类型为FloatTensor,还可以使LongTensor,整型数据类型
b.numpy()                 #从一个tensor转化为numpy的多维数组
from torch.autograd import Variable                  # 导入自动梯度的运算包,主要用Variable这个类
x = Variable(torch.ones(2, 2), requires_grad=True)   # 创建一个Variable,包裹了一个2*2张量,将需要计算梯度属性置为True

用pytorch做一个简单的线性关系预测

线性关系是一种非常简单的变量之间的关系,因变量和自变量在线性关系的情况下,可以使用线性回归算法对一个或多个因变量和自变量间的线性关系进行建模,该模型的系数可以用最小二乘法进行求解。生活中的场景往往会比较复杂,需要考虑多元线性关系和非线性关系,用其他的回归分析方法求解。


x = Variable(torch.linspace(0, 100, 100).type(torch.FloatTensor))  # 生成一些样本点作为原始数据
rand = Variable(torch.randn(100)) * 10                             # 随机生成100个满足标准正态分布的随机数,均值为0,方差为1.将这个数字乘以10,标准方差变为10
y = x + rand                                                       # 将x和rand相加,得到伪造的标签数据y。所以(x,y)应能近似地落在y=x这条直线上import matplotlib.pyplot as plt  
plt.figure(figsize=(10,8))                    #设定绘制窗口大小为10*8 inch
plt.plot(x.data.numpy(), y.data.numpy(), 'o') #绘制数据,考虑到x和y都是Variable,需要用data获取它们包裹的Tensor,并专成numpy
plt.xlabel('X') 
plt.ylabel('Y') 
plt.show() 

在这里插入图片描述

构建模型

#a,b就是要构建的线性函数的系数
a = Variable(torch.rand(1), requires_grad = True) #创建a变量,并随机赋值初始化
b = Variable(torch.rand(1), requires_grad = True) #创建b变量,并随机赋值初始化
print('Initial parameters:', [a, b])learning_rate = 0.0001 #设置学习率
for i in range(1000):### 增加了这部分代码,清空存储在变量a,b中的梯度信息,以免在backward的过程中会反复不停地累加if (a.grad is not None) and (b.grad is not None):  a.grad.data.zero_() b.grad.data.zero_() predictions = a.expand_as(x) * x+ b.expand_as(x)  #计算在当前a、b条件下的模型预测数值# 在 PyTorch 中,a.expand_as(x) 用于将张量 a 扩展(expand)为与张量 x 具有相同的形状loss = torch.mean((predictions - y) ** 2)         #通过与标签数据y比较,计算误差print('loss:', loss)loss.backward() #对损失函数进行梯度反传,backward的方向传播算法a.data.add_(- learning_rate * a.grad.data)  #利用上一步计算中得到的a的梯度信息更新a中的data数值b.data.add_(- learning_rate * b.grad.data)  #利用上一步计算中得到的b的梯度信息更新b中的data数值

绘制结果


x_data = x.data.numpy()                       # 将tensor 转为 numpy
plt.figure(figsize = (10, 7))
xplot = plt.plot(x_data, y.data.numpy(), 'o') # 绘制原始数据
yplot = plt.plot(x_data, a.data.numpy() * x_data + b.data.numpy())  #绘制拟合数据
plt.xlabel('X') 
plt.ylabel('Y') 
str1 = str(a.data.numpy()[0]) + 'x +' + str(b.data.numpy()[0]) # 图例信息 拟合的直线
plt.legend(['Obs', 'Model']) #绘制图例
plt.show()

在这里插入图片描述

x_test = Variable(torch.FloatTensor([1, 2, 10, 100, 1000])) #随便选择一些点1,2,……,1000
predictions = a.expand_as(x_test) * x_test + b.expand_as(x_test) #计算模型的预测结果
predictions  #输出预测的数值

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/26097.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OJ刷题——2080.夹角有多大II和2082.找单词、2085.核反应堆

2080.夹角有多大II 题目描述 Problem - 2080 运行代码 #include <iostream> #include <math.h> using namespace std; int main() {int T;double x1, y1, x2, y2;double res;scanf_s("%d", &T);while (T--) {scanf_s("%lf%lf%lf%lf", &…

机器学习算法 —— 贝叶斯分类之模拟离散数据集

&#x1f31f;欢迎来到 我的博客 —— 探索技术的无限可能&#xff01; &#x1f31f;博客的简介&#xff08;文章目录&#xff09; 目录 实战&#xff08;贝叶斯分类&#xff09;莺尾花数据模拟离散数据集库函数导入数据导入和分析模型训练和预测 总结 实战&#xff08;贝叶斯…

群体优化算法---水波优化算法介绍以及应用于聚类数据挖掘代码示例

介绍 水波优化算法&#xff08;Water Wave Optimization, WWO&#xff09;是一种新兴的群智能优化算法&#xff0c;灵感来自水波在自然环境中的传播和衰减现象。该算法模拟了水波在水面上传播和碰撞的行为&#xff0c;通过这些行为来寻找问题的最优解。WWO算法由三种主要的操作…

打工人的福利,NewspaceGpt使用新体验

使用地址&#xff1a;https://newspace.ai0.cn/ NewspaceGpt大体所有功能一览(​​newspace.ai0.cn​​) 使用体验与官网完全一致&#xff0c;可在第一时间体验到官网所有新功能。无需特殊上网。内置多个Plus账号&#xff0c;不用担心次数不够。支持所有GPTS功能:DALLE-3模型(…

CTFHUB-SQL注入-时间盲注

本题用到sqlmap工具&#xff0c;没有sqlmap工具点击&#x1f680;&#x1f680;&#x1f680;直达下载安装使用教程 理论简述 时间盲注概述 时间盲注是一种SQL注入技术的变种&#xff0c;它依赖于页面响应时间的不同来确定SQL注入攻击的成功与否。在某些情况下&#xff0c;攻…

攻防世界---misc---embarrass

1、下载附件是一个数据包 2、用wireshark分析 3、ctrlf查找字符 4、 flag{Good_b0y_W3ll_Done}

angular2网页前端执行流程

示例代码版本&#xff1a; http://192.168.102.9/jas-paas/cloudlink-front-framework/tree/045f4811da782c107eca72f9bdea39ebaa086a7d 命令行运行命令启动服务 在开发环境下&#xff0c;打开项目目录&#xff0c;运行命令npm start,这个命令会进入package.json文件中&#x…

cve_2014_3120-Elasticsearch-rce-vulfocus靶场

1.背景 来源&#xff1a;ElasticSearch&#xff08;CVE-2014-3120&#xff09;命令执行漏洞复现_mvel 漏洞-CSDN博客 参考&#xff1a;https://www.cnblogs.com/huangxiaosan/p/14398307.html 老版本ElasticSearch支持传入动态脚本&#xff08;MVEL&#xff09;来执行一些复…

Windows11上安装docker(WSL2后端)和使用docker安装MySQL和达梦数据库

Windows11上安装docker&#xff08;WSL2后端&#xff09;和使用docker安装MySQL和达梦数据库 1. 操作系统环境2. 首先安装wsl2.1 关于wsl2.2 安装wsl2.3 查看可用的wsl2.4 安装ubuntu-22.042.5 查看、启动ubuntu-22.04应用2.6 上面安装开了daili2.7 wsl的更多参考 3. 下载Docke…

Springboot 开发之任务调度框架(一)Quartz 简介

一、引言 常见的定时任务框架有 Quartz、elastic-job、xxl-job等等&#xff0c;本文主要介绍 Spirng Boot 集成 Quartz 定时任务框架。 二、Quartz 简介 Quartz 是一个功能强大且灵活的开源作业调度库&#xff0c;广泛用于 Java 应用中。它允许开发者创建复杂的调度任务&…

【Jenkins+K8s】持续集成与交付 (二十):K8s集群通过Deployment方式部署安装Jenkins

🟣【Jenkins+K8s】持续集成与交付 (二十):K8s集群通过Deployment方式部署安装Jenkins 一、 准备工作二、安装 Jenkins2.1 设置NFS共享目录2.2 创建名称空间2.3 创建持久化卷和声明2.4 创建sa账号2.5 对sa账号授权2.6 通过Deployment方式部署Jenkins2.7 查看Jenkins是否创建…

AdroitFisherman模块测试日志(2024/6/10)

测试内容 测试AdroitFisherman分发包中SHAUtil模块。 测试用具 Django5.0.3框架&#xff0c;AdroitFisherman0.0.31 项目结构 路由设置 总路由 from django.contrib import admin from django.urls import path,include from Base64Util import urls urlpatterns [path(ad…

SCRM的全面了解

一、什么是SCRM SCRM&#xff08;Social CRM&#xff0c;社会化客户关系管理&#xff09;&#xff0c;是以用户为中心&#xff0c;通过社交平台与用户建立联系&#xff0c;以内容、活动、客服、商城等服务吸引用户注意力&#xff0c;并不断与用户产生互动&#xff0c;实现用户…

【Oracle篇】rman时间点异机恢复:从RAC环境到单机测试环境的转移(第六篇,总共八篇)

&#x1f4ab;《博主介绍》&#xff1a;✨又是一天没白过&#xff0c;我是奈斯&#xff0c;DBA一名✨ &#x1f4ab;《擅长领域》&#xff1a;✌️擅长Oracle、MySQL、SQLserver、阿里云AnalyticDB for MySQL(分布式数据仓库)、Linux&#xff0c;也在扩展大数据方向的知识面✌️…

谷歌AI助力软件工程的进展及未来展望

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

OpenAI 宕机事件:GPT 停摆的影响与应对

引言 2024年6月4日&#xff0c;OpenAI 的 GPT 模型发生了一次全球性的宕机&#xff0c;持续时间长达8小时。此次宕机不仅影响了OpenAI自家的服务&#xff0c;还导致大量用户涌向竞争对手平台&#xff0c;如Claude和Gemini&#xff0c;结果也导致这些平台出现故障。这次事件的广…

lua网站开发中如何制作自定义模块

自定义模块是FastWeb框架的重要拓展功能&#xff0c;用来扩展和增强服务的能力。通过自定义模块&#xff0c;开发者可以轻松添加特定的功能和特性&#xff0c;使得网站开发更加灵活和高效。本文将演示如何添加自己的模块作为FastWeb的拓展&#xff0c;为框架的壮大与支持提供重…

在 Word 中,如何有效调整文字与下划线之间的距离

&#x1f349; CSDN 叶庭云&#xff1a;https://yetingyun.blog.csdn.net/ 如果你在使用 Word 时&#xff0c;希望调整文字和下划线之间的距离&#xff0c;让它们看起来更加美观&#xff0c;可以按照以下步骤操作&#xff1a; 1. 在你想要加下划线的文字前后各加一个空格&…

c++【入门】米老鼠偷糖果

限制 时间限制 : 1 秒 内存限制 : 128 MB 题目 米老鼠发现了厨房放了n颗糖果&#xff0c;它一次可以背走a颗&#xff0c;请问米老鼠背了x次之后还剩多少颗&#xff1f;&#xff08;假设x次之后一定有糖果剩下&#xff09; 输入 三个整数n、a、x分别代表总共有n颗糖果&…

在windows10 安装子系统linux(WSL安装方式)

在 windows 10 平台采用了WSL安装方式安装linux子系统 1 查找自己想要安装的linux子系统 wsl --list --online 2 在线安装 个人用Debian比较多&#xff0c;这里选择Debian&#xff0c;如下图&#xff1a; wsl --install -d Debian 安装过程中有一步要求输入用户名与密码&…