001.从0开始实现线性回归(pytorch)

000动手从0实现线性回归

0. 背景介绍

我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。
设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R1000×2
X∈R 1000×2 ,我们使用线性回归模型真实权重 w=[2,−3.4]⊤ 和偏差 b=4.2以及一个随机噪声项 ϵϵ 来生成标签
在这里插入图片描述

# 需要导入的包
import numpy as np
import torch
import random
from d2l import torch as d2l
from IPython import display
from matplotlib import pyplot as plt

1. 生成数据集合(待拟合)

使用python生成待拟合的数据

num_input = 2
num_example = 1000
w_true = [2,-3.4]
b_true = 4.2
features = torch.randn(num_example,num_input)
print('features.shape = '+ str(features.shape) )
labels =  w_true[0] * features[:,0] + w_true[1] * features[:,1] + b_true
labels += torch.tensor(np.random.normal(0,0.01 , size = labels.size() ),dtype = torch.float32)
print(features[0],labels[0])

2.数据的分批量处理

def data_iter(batch_size, features, labels):num_example = len(labels)indices = list(range(num_example))random.shuffle(indices)for i in range(0, num_example, batch_size):j = torch.tensor( indices[i:min(i+ batch_size,num_example)])yield features.index_select(0,j) ,labels.index_select(0,j)

3. 模型构建及训练

3.1 定义模型:

def linreg(X, w, b):return torch.mm(X,w)+b

3.2 定义损失函数

def square_loss(y, y_hat):return (y_hat - y.view(y_hat.size()))**2/2

3.3 定义优化算法

def sgd(params , lr ,batch_size):for param in params:param.data  -= lr * param.grad / batch_size

3.4 模型训练

# 设置超参数
lr = 0.03
num_epochs =5
net = linreg
loss = square_loss
batch_size = 10
for epoch in range(num_epochs):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):l = loss(net(X,w,b),y).sum()l.backward()sgd([w,b],lr,batch_size=batch_size)#梯度清零避免梯度累加w.grad.data.zero_()b.grad.data.zero_()train_l = loss(net(features,w,b),labels)print('epoch %d, loss %f' %(epoch +1 ,train_l.mean().item()))

epoch 1, loss 0.032550
epoch 2, loss 0.000133
epoch 3, loss 0.000053
epoch 4, loss 0.000053
epoch 5, loss 0.000053


基于pytorch的线性模型的实现

  1. 相关数据和初始化与上面构建相同
  2. 定义模型
import torch
from torch import nn
class LinearNet(nn.Module):def __init__(self, n_feature):# 调用父类的初始化super(LinearNet,self).__init__()# Linear(输入特征数,输出特征的数量,是否含有偏置项)self.linera = nn.Linear(n_feature,1)def forward(self,x):y = self.linera(x)return y
#打印模型的结构:
net = LinearNet(num_input)
print(net) 
# LinearNet( (linera): Linear(in_features=2, out_features=1, bias=True)
)
  1. 初始化模型的参数
from torch.nn import init
init.normal_(net.linera.weight,mean=0,std= 0.1)
init.constant_(net.linera.bias ,val=0)
  1. 定义损失函数
loss = nn.MSELoss()

5.定义优化算法

import torch.optim as optim
optimizer =  optim.SGD(net.parameters(),lr = 0.03)
print(optimizer)
  1. 训练模型:
num_epochs = 3
for epoch in range(1,num_epochs+1):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):output= net(X)l = loss(output,y.view(-1,1))optimizer.zero_grad()l.backward()optimizer.step()print('epoch %d ,loss: %f' %(epoch,l.item()) )

epoch 1 ,loss: 0.000159
epoch 2 ,loss: 0.000089
epoch 3 ,loss: 0.000066

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/54582.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Delphi】扩展现有组件创建新的 FireMonkey 组件(步骤二)

实现指定格式的属性 步骤 1 中创建的 TClockLabel 组件需要在显示当前时间时定义日期时间格式作为属性,以便组件用户可以指定。 一、实现指定格式的属性 要实现格式属性,请在 TClockLabel class 的发布部分添加以下一行: property Form…

Mybatis-Plus使用例过程中BaseMapper中的一些selectList方法找不到的问题

如果需要用到Mybatis-Plus里面BaseMapper的默认实现的一些方法,而不通过mapper.xml的方式自定义方法,那么在构建SqlSessionFactory的时候需要注意应该通过Mybatis-Plus中的MybatisSqlSessionFactoryBean来获取SqlSessionFactory,使用Mybatis中…

CST电磁仿真77GHz汽车雷达保险杠

77G毫米波雷达仿真时,要考虑天线罩和保险杠的影响。通常保险杠都是多层结构,有的层非常薄。如果采用传统的3D建模方法,会导致网格数量巨大,进而影响到求解效率。 三维保险杠(bumper)模型如下图所示&…

vue3.5新特性

vue在2024.09.03发布了3.5正式版本,其中包含多方面的升级和优化 性能优化 响应式系统重构优化,在内存占用、性能方面均有收益 Memory Usage Improvements Given a test case with 1000 refs 2000 computeds (1000 chained pairs) 1000 effects sub…

【C++篇】探寻C++ STL之美:从string类的基础到高级操作的全面解析

文章目录 C string 类详解:从入门到精通前言第一章:C 语言中的字符串 vs C string 类1.1 C 语言中的字符串1.2 C string 类的优势 第二章:string 类的构造与基础操作2.1 string 类的构造方法2.1.1 示例代码:构造字符串 2.2 string…

二期 1.4 Nacos安装部署 - Window版

本文目录 Nacos支持三种部署模式环境准备下载Nacos启动登录服务注册与查看Nacos支持三种部署模式 单机模式 - 用于测试和单机试用。集群模式 - 用于生产环境,确保高可用。多集群模式 - 用于多数据中心场景。以 Window单机模式 抛转引玉,其它部署方式参考官方文档: https://n…

使用Python实现深度学习模型:智能电影制作与剪辑

随着人工智能技术的飞速发展,深度学习在各个领域的应用越来越广泛。在电影制作与剪辑领域,深度学习技术也展现出了巨大的潜力。本文将介绍如何使用Python实现一个简单的深度学习模型,用于智能电影制作与剪辑。我们将使用TensorFlow和Keras库来构建和训练模型,并展示如何应用…

部署自己的对话大模型,使用Ollama + Qwen2 +FastGPT 实现

部署资源 AUTODL 使用最小3080Ti 资源,cuda > 12.0使用云服务器,部署fastGPT oneAPI,M3E 模型 操作步骤 配置代理 export HF_ENDPOINThttps://hf-mirror.com下载qwen2模型 - 如何下载huggingface huggingface-cli download Qwen/Qwen2-…

Python编程入门指南

本文为初学者提供了全面的Python编程入门指南,涵盖了基础语法、控制结构、函数、数据结构、文件操作、异常处理、模块与包、面向对象编程以及一些高级特性,帮助读者快速掌握Python编程的核心知识。通过详细解释Python编程的各个方面,文章旨在…

flutter遇到问题及解决方案

目录 1、easy_refresh相关问题 2、 父子作用域关联问题 3. 刘海屏底部安全距离 4. 了解保证金弹窗 iOS端闪退 (待优化) 5. loading无法消失 6. dialog蒙版问题 7. 倒计时优化 8. scrollController.offset报错 9. 断点不走 10.我的出价报红 11…

Python3爬虫教程-HTTP基本原理

HTTP基本原理 1,URL组成部分详解2,HTTP和HTTPS3,HTTP请求过程4,请求(Request)请求方法(Request Method)请求的网址(Request URL)请求头(Request H…

Redmi Note 7 Pro(violet)免授权9008文件分享及刷机教程

获取文件 关注微信公众号 heStudio Community回复 violet_9008 获取下载链接。 刷机教程 下载搞机助手(可以从上方文件中获取)并安装。手机按音量减键和电源键进入 Fastboot 模式, 打开搞机助手,点击进入 9008 模式 等待手机…

IDEA 关闭自动补全功能(最新版本)

文章目录 一、前言二、关闭自动补全三、最终效果 一、前言 在最新的 IDEA 中发布了自动补全功能,当你输入代码时,IDEA 会自动显示你可能想输入的代码,减少手动输入的工作量,它会根据上下文提供正确的选项,提高代码的准…

Java-数据结构-二叉树-习题(三)  ̄へ ̄

文本目录: ❄️一、习题一(前序遍历非递归): ▶ 思路: ▶ 代码: ❄️二、习题二(中序遍历非递归): ▶ 思路: ▶ 代码: ❄️三、习题三(后序遍历非递归): ▶ 思路: …

C++——打印以下图案:用字符数组方法。

没注释的源代码 #include <iostream> using namespace std; int main() { char a[5]{*,*,*,*,*}; for(int i0;i<5;i) { for(int j0;j<i;j) { cout<<" "; } for(int k0;k<5;k) …

vue使用PDF.JS踩的坑--部署到服务器上显示pdf.mjs viewer.mjs找不到资源

之前项目使用的pdf.js 是2.15.349版本&#xff0c;最近换了一个4.6.82的版本&#xff0c;在本地上浏览文件运行的好好的&#xff0c;但是发布到服务器&#xff08;IIS&#xff09;上打不开文件&#xff0c;控制台提示找不到pdf.mjs viewer.mjs。 之前使用的2.15.349pdf和viewer…

Git使用手册

1、初识Git 概述&#xff1a;Git 是一个开源的分布式版本控制系统&#xff0c;可以有效、高速地处理项目版本管理。 知识点补充&#xff1a; 版本控制&#xff1a;一种记录一个或若干文件内容变化&#xff0c;以便将来查阅特定版本修订情况的系统。 分布式&#xff1a;每个人…

oracle 多表查询

3.6多表查询 当查询的数据并不是来源一个表时&#xff0c;需要使用多表连接操作完成查询。多表连接查询通过表之间的关联字段&#xff0c;一次查询出多个表的数据。 3.6.1等值连接 等值连接也称为简单连接(Simple Joins)或者内连接(Inner Join)。通过等号来判断连接条件中的数据…

M9410A VXT PXI 矢量收发信机,300/600/1200MHz带宽

M9410A PXI 矢量收发信机 -300/600/1200MHz带宽- M9410A VXT PXI 矢量收发信机&#xff0c;300/600/1200MHz带宽支持 5G 的 PXI 矢量收发信机&#xff08;VXT&#xff09;是一个 2 插槽模块&#xff0c;具有 1.2 GHz 的瞬时带宽 主要特点 Keysight M9410A VXT PXIe 矢量收发…

切换笔记本键盘的启用与禁用状态

使用批处理脚本切换笔记本键盘的启用与禁用状态 背景描述详细步骤及代码解释1. 在管理员模式下运行脚本2. 脚本内容3. 解释 背景描述 在笔记本电脑中&#xff0c;在外接键盘的时候&#xff0c;有时我们希望禁用内置键盘&#xff0c;防止意外按键。Windows 系统中&#xff0c;键…