001.从0开始实现线性回归(pytorch)

000动手从0实现线性回归

0. 背景介绍

我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。
设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R1000×2
X∈R 1000×2 ,我们使用线性回归模型真实权重 w=[2,−3.4]⊤ 和偏差 b=4.2以及一个随机噪声项 ϵϵ 来生成标签
在这里插入图片描述

# 需要导入的包
import numpy as np
import torch
import random
from d2l import torch as d2l
from IPython import display
from matplotlib import pyplot as plt

1. 生成数据集合(待拟合)

使用python生成待拟合的数据

num_input = 2
num_example = 1000
w_true = [2,-3.4]
b_true = 4.2
features = torch.randn(num_example,num_input)
print('features.shape = '+ str(features.shape) )
labels =  w_true[0] * features[:,0] + w_true[1] * features[:,1] + b_true
labels += torch.tensor(np.random.normal(0,0.01 , size = labels.size() ),dtype = torch.float32)
print(features[0],labels[0])

2.数据的分批量处理

def data_iter(batch_size, features, labels):num_example = len(labels)indices = list(range(num_example))random.shuffle(indices)for i in range(0, num_example, batch_size):j = torch.tensor( indices[i:min(i+ batch_size,num_example)])yield features.index_select(0,j) ,labels.index_select(0,j)

3. 模型构建及训练

3.1 定义模型:

def linreg(X, w, b):return torch.mm(X,w)+b

3.2 定义损失函数

def square_loss(y, y_hat):return (y_hat - y.view(y_hat.size()))**2/2

3.3 定义优化算法

def sgd(params , lr ,batch_size):for param in params:param.data  -= lr * param.grad / batch_size

3.4 模型训练

# 设置超参数
lr = 0.03
num_epochs =5
net = linreg
loss = square_loss
batch_size = 10
for epoch in range(num_epochs):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):l = loss(net(X,w,b),y).sum()l.backward()sgd([w,b],lr,batch_size=batch_size)#梯度清零避免梯度累加w.grad.data.zero_()b.grad.data.zero_()train_l = loss(net(features,w,b),labels)print('epoch %d, loss %f' %(epoch +1 ,train_l.mean().item()))

epoch 1, loss 0.032550
epoch 2, loss 0.000133
epoch 3, loss 0.000053
epoch 4, loss 0.000053
epoch 5, loss 0.000053


基于pytorch的线性模型的实现

  1. 相关数据和初始化与上面构建相同
  2. 定义模型
import torch
from torch import nn
class LinearNet(nn.Module):def __init__(self, n_feature):# 调用父类的初始化super(LinearNet,self).__init__()# Linear(输入特征数,输出特征的数量,是否含有偏置项)self.linera = nn.Linear(n_feature,1)def forward(self,x):y = self.linera(x)return y
#打印模型的结构:
net = LinearNet(num_input)
print(net) 
# LinearNet( (linera): Linear(in_features=2, out_features=1, bias=True)
)
  1. 初始化模型的参数
from torch.nn import init
init.normal_(net.linera.weight,mean=0,std= 0.1)
init.constant_(net.linera.bias ,val=0)
  1. 定义损失函数
loss = nn.MSELoss()

5.定义优化算法

import torch.optim as optim
optimizer =  optim.SGD(net.parameters(),lr = 0.03)
print(optimizer)
  1. 训练模型:
num_epochs = 3
for epoch in range(1,num_epochs+1):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):output= net(X)l = loss(output,y.view(-1,1))optimizer.zero_grad()l.backward()optimizer.step()print('epoch %d ,loss: %f' %(epoch,l.item()) )

epoch 1 ,loss: 0.000159
epoch 2 ,loss: 0.000089
epoch 3 ,loss: 0.000066

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/54582.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Delphi】扩展现有组件创建新的 FireMonkey 组件(步骤二)

实现指定格式的属性 步骤 1 中创建的 TClockLabel 组件需要在显示当前时间时定义日期时间格式作为属性,以便组件用户可以指定。 一、实现指定格式的属性 要实现格式属性,请在 TClockLabel class 的发布部分添加以下一行: property Form…

CST电磁仿真77GHz汽车雷达保险杠

77G毫米波雷达仿真时,要考虑天线罩和保险杠的影响。通常保险杠都是多层结构,有的层非常薄。如果采用传统的3D建模方法,会导致网格数量巨大,进而影响到求解效率。 三维保险杠(bumper)模型如下图所示&…

【C++篇】探寻C++ STL之美:从string类的基础到高级操作的全面解析

文章目录 C string 类详解:从入门到精通前言第一章:C 语言中的字符串 vs C string 类1.1 C 语言中的字符串1.2 C string 类的优势 第二章:string 类的构造与基础操作2.1 string 类的构造方法2.1.1 示例代码:构造字符串 2.2 string…

部署自己的对话大模型,使用Ollama + Qwen2 +FastGPT 实现

部署资源 AUTODL 使用最小3080Ti 资源,cuda > 12.0使用云服务器,部署fastGPT oneAPI,M3E 模型 操作步骤 配置代理 export HF_ENDPOINThttps://hf-mirror.com下载qwen2模型 - 如何下载huggingface huggingface-cli download Qwen/Qwen2-…

flutter遇到问题及解决方案

目录 1、easy_refresh相关问题 2、 父子作用域关联问题 3. 刘海屏底部安全距离 4. 了解保证金弹窗 iOS端闪退 (待优化) 5. loading无法消失 6. dialog蒙版问题 7. 倒计时优化 8. scrollController.offset报错 9. 断点不走 10.我的出价报红 11…

Python3爬虫教程-HTTP基本原理

HTTP基本原理 1,URL组成部分详解2,HTTP和HTTPS3,HTTP请求过程4,请求(Request)请求方法(Request Method)请求的网址(Request URL)请求头(Request H…

Redmi Note 7 Pro(violet)免授权9008文件分享及刷机教程

获取文件 关注微信公众号 heStudio Community回复 violet_9008 获取下载链接。 刷机教程 下载搞机助手(可以从上方文件中获取)并安装。手机按音量减键和电源键进入 Fastboot 模式, 打开搞机助手,点击进入 9008 模式 等待手机…

IDEA 关闭自动补全功能(最新版本)

文章目录 一、前言二、关闭自动补全三、最终效果 一、前言 在最新的 IDEA 中发布了自动补全功能,当你输入代码时,IDEA 会自动显示你可能想输入的代码,减少手动输入的工作量,它会根据上下文提供正确的选项,提高代码的准…

Java-数据结构-二叉树-习题(三)  ̄へ ̄

文本目录: ❄️一、习题一(前序遍历非递归): ▶ 思路: ▶ 代码: ❄️二、习题二(中序遍历非递归): ▶ 思路: ▶ 代码: ❄️三、习题三(后序遍历非递归): ▶ 思路: …

vue使用PDF.JS踩的坑--部署到服务器上显示pdf.mjs viewer.mjs找不到资源

之前项目使用的pdf.js 是2.15.349版本,最近换了一个4.6.82的版本,在本地上浏览文件运行的好好的,但是发布到服务器(IIS)上打不开文件,控制台提示找不到pdf.mjs viewer.mjs。 之前使用的2.15.349pdf和viewer…

Git使用手册

1、初识Git 概述:Git 是一个开源的分布式版本控制系统,可以有效、高速地处理项目版本管理。 知识点补充: 版本控制:一种记录一个或若干文件内容变化,以便将来查阅特定版本修订情况的系统。 分布式:每个人…

M9410A VXT PXI 矢量收发信机,300/600/1200MHz带宽

M9410A PXI 矢量收发信机 -300/600/1200MHz带宽- M9410A VXT PXI 矢量收发信机,300/600/1200MHz带宽支持 5G 的 PXI 矢量收发信机(VXT)是一个 2 插槽模块,具有 1.2 GHz 的瞬时带宽 主要特点 Keysight M9410A VXT PXIe 矢量收发…

Leetcode 1039. 多边形三角形剖分的最低得分 枚举型区间dp C++实现

问题:Leetcode 1039. 多边形三角形剖分的最低得分 你有一个凸的 n 边形,其每个顶点都有一个整数值。给定一个整数数组 values ,其中 values[i] 是第 i 个顶点的值(即 顺时针顺序 )。 假设将多边形 剖分 为 n - 2 个三…

【QML】Button图标设置透明颜色,会变模糊有阴影

原图效果 1. 透明 1.1 效果 1.2 代码 Button{id: _mBtnwidth: parent.widthheight: parent.heightbackground: Rectangle{id: _mBgradius: 5antialiasing: truecolor: "white"}icon{source: _mRoot._mIconSourcecache: falsecolor: "transparent" //透明…

[spring]MyBatis介绍 及 用MyBatis操作简单数据库

文章目录 一. 什么是MyBatis二. MyBatis操作数据库步骤创建工程创建数据库创建对应实体类配置数据库连接字符串写持久层代码单元测试 三. MyBatis基础操作打印日志参数传递增删改查 四. MyBatis XML配置文件配置链接字符串和MyBatis写持久层代码方法定义Interface方法实现xml测…

JavaWeb纯小白笔记02:Tomcat的使用:发布项目的三种方式、配置虚拟主机、配置用户名和密码

通过Tomcat进行发布项目的目的是为了提供项目的访问能力:Tomcat作为Web服务器,能够处理HTTP请求和响应,将项目的内容提供给用户进行访问和使用。 一.Tomcat发布项目的三种方式: 第一种:直接在Tomcat文件夹里的webapp…

开源RK3588 AI Module7,并与Jetson Nano生态兼容的低功耗AI模块

RK3588 AI Module7 搭载瑞芯微 RK3588,提供强大的 64 位八核处理器,最高时钟速度为 2.4 GHz,6 TOPS NPU,并支持高达 32 GB 的内存。它与 Nvidia 的 Jetson Nano 接口兼容,具有升级和改进的 PCIe 连接。由于该模块的多功…

Leetcode面试经典150题-39.组合总数进阶:40.组合总和II

本题是扩展题,真实考过,看这个题之前先看一下39题 Leetcode面试经典150题-39.组合总数-CSDN博客 给定一个候选人编号的集合 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合。 candidates 中的每个数…

9.23 My_string.cpp

my_string.h #ifndef MY_STRING_H #define MY_STRING_H#include <iostream> #include <cstring>using namespace std;class My_string { private:char *ptr; //指向字符数组的指针int size; //字符串的最大容量int len; //字符串当前…

【十八】MySQL 8.0 新特性

MySQL 8.0 新特性 目录 MySQL 8.0 新特性 概述 简述 1、数据字典 2、原子数据定义语句 3、升级过程 4、会话重用 5、安全和账户管理 6、资源管理 7、表加密管理 8、InnoDB增强功能 9、字符集支持 10、增强JSON功能 11、数据类型的支持 12、查询的优化 13、公用…