动手学深度学习v2笔记 —— 线性回归 + 基础优化算法

二 动手学深度学习v2 —— 线性回归 + 基础优化算法


目录:

  1. 线性回归
  2. 基础优化方法

1. 线性回归
总结

  • 线性回归是对n维输入的加权,外加偏差
  • 使用平方损失来衡量预测值和真实值的差异
  • 线性回归有显示解
  • 线性回归可以看作是单层神经网络

2. 基础优化方法
梯度下降
在这里插入图片描述
小批量随机梯度下降
在这里插入图片描述

3. 总结

  • 梯度下降通过不断沿着反梯度方向更新参数求解
  • 小批量随机梯度下降是深度学习默认的求解算法
  • 两个重要的超参数是批量大小和学习率

4. 线性回归的从零开始实现

import torch
import random
def synthetic_data(w, b, num):"生成 y = Xw + b + 噪声"''' 根据带有噪声的线性模型构造一个人造数据集。 我们使用线性模型参数𝐰 = [2,−3.4]⊤w=[2,−3.4]⊤、𝑏 = 4.2和噪声项𝜖ϵ生成数据集及其标签:𝐲 = 𝐗𝐰 + 𝑏 + 𝜖'''X = torch.normal(0, 1, (num, len(w)))y = torch.matmul(X, w) + by += torch.normal(0, 0.01, y.shape)return X, y.reshape(-1, 1)true_w = torch.tensor([2, -3.4])
true_b = 4.2
num_examples = 1000
batch_size = 10
features, labels = synthetic_data(true_w, true_b, num_examples)def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):num_indices = torch.tensor(indices[i: min(i +batch_size, num_examples)])yield features[num_indices], labels[num_indices]def linreg(X, w, b):return torch.matmul(X, w) + bdef squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape))**2/2def sgd(params, lr, batch_size):with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)loss = squared_loss
net = linreg
lr = 0.03
epoches = 3for epoch in range(epoches):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)l.sum().backward()sgd([w, b], lr, batch_size)with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch: {epoch + 1}, loss {float(train_l.mean()):f}')

5. 线性回归的简洁实现

import torch
from d2l import torch as d2l
from torch.utils import datatrue_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)def load_array(batch_size, data_array, is_train=True):dataset = data.TensorDataset(*data_array)return data.DataLoader(dataset, batch_size, shuffle=is_train)batch_size = 10
data_iter = load_array(batch_size, (features, labels))from torch import nnnet = nn.Sequential(nn.Linear(2, 1))
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)loss = nn.MSELoss()trainer = torch.optim.SGD(net.parameters(), lr = 0.03)num_epoches = 3iter_temp = iter(data_iter)for epoch in range(num_epoches):for X, y in next(iter_temp):l = loss(net(X), y)l.backward()trainer.step()trainer.zero_grad()l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/15028.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring的创建及使用

文章目录 什么是SpringSpring项目的创建存储Bean对象读取Bean对象getBean()方法 更简单的读取和存储对象的方式路径配置使用类注解存储Bean对象关于五大类注解使用方法注解Bean存储对象Bean重命名 Bean对象的读取 使用Resource注入对象Resource VS Autowired同一类型多个bean对…

echart折线图,调节折线点和y轴的间距(亲测可用)

options代码: options {tooltip: {trigger: axis, //坐标轴触发,主要在柱状图,折线图等会使用类目轴的图表中使用。},xAxis: {type: category,//类目轴,适用于离散的类目数据,为该类型时必须通过 data 设置类目数据。…

iOS开发-启动页广告实现

iOS开发-启动页广告实现 启动页广告实现是一个非常常见的广告展示模式。 就是在启动时候显示广告,之后点击跳转到广告页面或者其他APP。 一、实现启动页广告 启动页广告控件实现,将View放置在keyWindow上,显示广告图片,点击广告…

Pytorch(二)

一、分类任务 构建分类网络模型 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播Module中的可学习参数可以通过named_parameters()返回迭代器 from torch import nn import torch.nn.f…

C++部署学习

gcc -E src/main.c -o src/main.i gcc -S src/main.c -o src/main.s gcc -C src/main.c -o src/main.o gcc src/main.c -o exec ./exec

RabbitMQ 教程 | 第3章 客户端开发向导

👨🏻‍💻 热爱摄影的程序员 👨🏻‍🎨 喜欢编码的设计师 🧕🏻 擅长设计的剪辑师 🧑🏻‍🏫 一位高冷无情的编码爱好者 大家好,我是 DevO…

排序算法汇总

每日一句:你的日积月累终会成为别人的望尘莫及 目录 常数时间的操作 选择排列 冒泡排列 【异或运算】 面试题: 1)在一个整形数组中,已知只有一种数出现了奇数次,其他的所有数都出现了偶数次,怎么找到…

面试之CurrentHashMap的底层原理

首先回答HashMap的底层原理? HashMap是数组链表组成。数字组是HashMap的主体,链表则是主要为了解决哈希冲突而存在的。要将key 存储到(put)HashMap中,key类型实现必须计算hashcode方法,默认这个方法是对象的地址。接…

【应用层】Http协议总结

文章目录 一、续->Http协议的学习 1.http请求中的get方法和post方法 2.http的状态码 3.http的报头 4.长链接 5.cookie(会话保持)总结 继续上一篇的内容: 上一篇的最后我们讲到了web根目录,知道…

使用Docker部署EMQX

原文链接:http://www.ibearzmblog.com/#/technology/info?id9dd5bf4159d07f6a4e69a6b379ce4244 前言 在物联网中,大多通信协议使用的都是MQTT,而EMQX是基于 Erlang/OTP 平台开发的 MQTT 消息服务器,它的优点很多,我…

《吐血整理》进阶系列教程-拿捏Fiddler抓包教程(12)-Fiddler设置IOS手机抓包,你知多少???

1.简介 Fiddler不但能截获各种浏览器发出的 HTTP 请求,也可以截获各种智能手机发出的HTTP/ HTTPS 请求。 Fiddler 能捕获Android 和 Windows Phone 等设备发出的 HTTP/HTTPS 请求。同理也可以截获iOS设备发出的请求,比如 iPhone、iPad 和 MacBook 等苹…

【BMC】OpenBMC使用基础(WSL2版本)

代码准备 OpenBMC是一个开源的项目,用于开发BMC固件。官网是https://www.openbmc.org/,不过里面似乎没有什么内容,所以还需要依赖其它的网站,https://github.com/openbmc,在这里可以下载到需要的代码和文档。其主体部…

C#,数值计算——对数正态分布(logarithmic normal distribution)的计算方法与源程序

对数正态分布(logarithmic normal distribution)是指一个随机变量的对数服从正态分布,则该随机变量服从对数正态分布。对数正态分布从短期来看,与正态分布非常接近。但长期来看,对数正态分布向上分布的数值更多一些。 …

Tailwind CSS:基础使用/vue3+ts+Tailwind

一、理解Tailwind 安装 - TailwindCSS中文文档 | TailwindCSS中文网 Installation - Tailwind CSS 1.1、词义 我们简单理解就是搭上CSS的顺风车,事半功倍。 1.2、Tailwind CSS有以下优势 1. 快速开发:Tailwind CSS 提供了一些现成的 class / 可复用…

ARM裸机-4

1、什么是交叉编译 1.1、两种开发模式 非嵌入式开发,A(类)机编写(源代码)、编译得到可执行程序,发布给A(类)机运行。 嵌入式开发,A(类)机编写&am…

Spring源码(三)Spring Bean生命周期

Bean的生命周期就是指:在Spring中,一个Bean是如何生成的,如何销毁的 Bean生命周期流程图 1、生成BeanDefinition Spring启动的时候会进行扫描,会先调用org.springframework.context.annotation.ClassPathScanningCandidateCompo…

Qt C++实现Excel表格的公式计算

用Qt的QTableViewQStandardItemModelQStyledItemDelegate实现类似Excel表格的界面,在parser 模块中提供解析表格单元格输入的公式。单元格编辑结束后按回车进行计算和更新显示。 效果如下: 支持的公式计算可以深度嵌套,目前parser模块中仅提…

【Java】零基础上手SpringBoot学习日记(day1)

前言 此帖为本人学习Springboot时的笔记,由于是个接触计算机一年左右的新手,也没有网站开发经验,所以有些地方的理解会比较浅显并且可能会出现错误,望大佬们多多包涵和指正。 Web应用开发 在我的理解中,Web应用的开发…

测试|测试分类

测试|测试分类 文章目录 测试|测试分类1.按照测试对象分类(部分掌握)2.是否查看代码:黑盒、白盒灰盒测试3.按开发阶段分:单元、集成、系统及验收测试4.按实施组织分:α、β、第三方测试5.按是否运行代码:静…