softmax回归的从零开始实现

 1.初始化模型参数

import torch
from IPython import display
from d2l import torch as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
num_inputs = 784
num_outputs = 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

2. 定义模型

def softmax(X):
X_exp = torch.exp(X)
partition = X_exp.sum(1, keepdim=True)
return X_exp / partition 

注意,虽然这在数学上看起来是正确的,但我们在代码实现中有点草率。矩阵中的非常大或非常小的元素可
能造成数值上溢或下溢,但我们没有采取措施来防止这点。

def net(X):
return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

先线性模型得出值,softmax转换为(0,1)之间。正向传播

3 定义损失函数

def cross_entropy(y_hat, y):
return - torch.log(y_hat[range(len(y_hat)), y])

4.梯度下降

updater核心作用是更新参数 

lr = 0.1
def updater(batch_size):
return d2l.sgd([W, b], lr, batch_size)
def sgd(params, lr, batch_size): #@save
"""小批量随机梯度下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()

5.训练函数

单epoch的

def train_epoch_ch3(net, train_iter, loss, updater): #@save
"""训练模型一个迭代周期(定义见第3章)"""
# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()
# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:
# 计算梯度并更新参数y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):
# 使用PyTorch内置的优化器和损失函数,这个没有batchsize,传入的是meanupdater.zero_grad()l.mean().backward() updater.step()else:
# 使用定制的优化器和损失函数  传入batchsize就行,会/batchsize,直接sum更新就好了l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]

 多epoch的

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save
"""训练模型(定义见第3章)"""animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):train_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)animator.add(epoch + 1, train_metrics + (test_acc,))#把原来的两个元素的元组变成三个的
train_loss, train_acc = train_metrics
assert train_loss < 0.5, train_loss
assert train_acc <= 1 and train_acc > 0.7, train_acc
assert test_acc <= 1 and test_acc > 0.7, test_acc

 6.训练

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

 

7.预测

def predict_ch3(net, test_iter, n=6): #@save
"""预测标签(定义见第3章)"""for X, y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true +'\n' + pred for true, pred in zip(trues, preds)]d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
predict_ch3(net, test_iter)

 

附加:计算分类精度

y_hat是矩阵,那么假定第二个维度存储每个类的预测分数。我们使用argmax获得每行中最大元素的索引来获得预测类别。最后得到正确预测的数量

def accuracy(y_hat, y): #@save
"""计算预测正确的数量"""
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = y_hat.argmax(axis=1)
cmp = y_hat.type(y.dtype) == y
return float(cmp.type(y.dtype).sum())

 实用程序类Accumulator,用于对多个变量进行累加。

Accumulator实例中创建了2个变量,分别用于存储正确预测的数量和预测的总数量。当我们遍历数据集时,两者都将随着时间的推移而累加。

class Accumulator: #@save
"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]
#        这一步,zip并行加载两个数组,a为原数据,b为要添加的数据
#        如:[1,1]+[1,2]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]

 评估在任意模型net的精度

def evaluate_accuracy(net, data_iter): #@save
"""计算在指定数据集上模型的精度"""if isinstance(net, torch.nn.Module):net.eval() # 将模型设置为评估模式metric = Accumulator(2) # 正确预测数、预测总数
with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())
return metric[0] / metric[1]

调包实现

1 初始化模型参数
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
def init_weights(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights);2.损失函数
loss = nn.CrossEntropyLoss(reduction='none')3.优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.1)4 训练
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/979.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构】复习题(二)

Hello&#xff01;大家好&#xff0c;这一篇数据结构复习题是我上个学期复习的时候写的&#xff08;刚刚在草稿箱发现了&#xff01;&#xff09;有一些题目过程都是配了图片的&#xff0c;希望对正在复习数据结构的宝宝们有帮助哦&#xff01;(还有一个数据结构复习题(一)可以…

代码随想录算法训练营第60天|84.柱状图中最大的矩形

代码随想录算法训练营第60天|84.柱状图中最大的矩形 |有了之前单调栈的铺垫&#xff0c;这道题目就不难了。 84.柱状图中最大的矩形 https://programmercarl.com/0084.%E6%9F%B1%E7%8A%B6%E5%9B%BE%E4%B8%AD%E6%9C%80%E5%A4%A7%E7%9A%84%E7%9F%A9%E5%BD%A2.html class Soluti…

Ant Design Pro + springboot实现文件上传功能

前端代码 <a-upload:fileList"fileList":beforeUpload"beforeUpload":customRequest"customRequest" ><a-button style"margin-left: 50px" type"primary" ref"btn">导入配置文件 </a-button>…

代码随想录算法训练营第三十二天| 122.买卖股票的最佳时机II ,55. 跳跃游戏,45.跳跃游戏II

目录 题目链接&#xff1a;122.买卖股票的最佳时机II 思路 代码 题目链接&#xff1a;55. 跳跃游戏 思路 代码 题目链接&#xff1a;45.跳跃游戏II 思路 代码 总结 题目链接&#xff1a;122.买卖股票的最佳时机II 思路 每天可以重复买卖&#xff0c;所以只需要计算每…

Spring JdbcTemplate基本使用

1. JdbcTemplate概述 它是spring框架中提供的一个对象&#xff0c;是对原始繁琐的JdbcAPI对象的简单封装。spring框架为我们提供了很多的操作模板类。例如:操作关系型数据的JdbcTemplate和HibermateTemplate&#xff0c;操作nosql数据库的RedisTemplate&#xff0c;操作消息队…

p2949(简单反悔贪心)

题目链接 #include<bits/stdc.h> using namespace std; const int N2e520; using ll long long; long long ans; struct node{int x,y;bool operator < (node p) const{return x<p.x; //按截至时间从小到大排序 } }a[N]; long long ti; int main() {map<int,in…

Hadoop学习总结(Hive的远程服务、数据模型操作、数据操作)

在启动hive时要先启动Hadoop。 在SecurityCRT 或者在 Xshell 进行虚拟机链接 &#xff08;这里使用Xshell &#xff09; 一、Hive 的管理 1、CLI 方式 &#xff08;1&#xff09;启动 Hive 直接输入 hive &#xff08;2&#xff09;退出 直接输入以下一条命令&#xff0…

预付费电表管理系统:WEB端的高效解决方案

1.系统概述 预付费电表管理系统&#xff0c;尤其是基于WEB端的版本&#xff0c;是一种现代化的电力管理工具&#xff0c;旨在提高能源效率&#xff0c;优化电费支付流程&#xff0c;并提供实时的用电数据监控。它通过互联网技术&#xff0c;使得用户能够在线充值、查询电量、远…

如何排查oracle连接数不足问题

最近oracle数据库莫名其妙的连接不上&#xff0c;plsql连接报错&#xff0c;sqlplus终端打开时提示ora-00020错误&#xff0c;下面记录一下本次问题的解决过程。 1.sqlplus 登录数据库 show parameter processes;–当前默认配置的process是多少。 select count(*) from v$pr…

开源全方位运维监控工具:HertzBeat

HertzBeat&#xff1a;实时监控系统性能&#xff0c;精准预警保障业务稳定- 精选真开源&#xff0c;释放新价值。 概览 HertzBeat是一款深受广大开发者喜爱的开源实时监控解决方案。它以其简洁直观的设计理念和免安装Agent的特性&#xff0c;实现了对各类服务器、数据库及应用…

一次性找出数组中的最小值和次小值

一次性找出数组中的最小值和次小值 代码&#xff1a; #include <stdio.h> int main() {int arr[] {5, 4, 6, 3, 9, 12, 35, 42, 18, 29, 30};int min arr[0], submin arr[0];for (int i 0; i < sizeof(arr) / sizeof(arr[0]); i){if (arr[i] < min){submin m…

Youtube DNN

目录 1. 挑战 2. 系统整体结构 3.召回 4. 排序 5. 训练和测试样本的处理 1. 挑战 &#xff08;1&#xff09;规模。很多现有的推荐算法在小规模上效果好&#xff0c;但Youtobe规模很大。 &#xff08;2&#xff09;新颖度。Youtobe语料库是动态的&#xff0c;每秒都会有…

前后端连接完后的各种安全问题

前后端连接完后的各种安全问题&#xff1a; 当我们完成前后端链接后&#xff0c;这只是第一步&#xff0c;接下来各种安全问题才是前后端交互的重中之重。 后端&#xff1a; 一.管理员 当前端把账号密码之类的用户信息传来后端后&#xff0c;我们还需要一个管理员来保存这些…

芯片数字后端设计入门书单推荐(可下载)

数字后端设计&#xff0c;作为数字集成电路设计的关键环节&#xff0c;承担着将逻辑设计转化为物理实现的重任。它不仅要求设计师具备深厚的电路理论知识&#xff0c;还需要对EDA工具有深入的理解和熟练的操作技能。尽管数字后端工作不像前端设计那样频繁涉及代码编写&#xff…

VTK----VTK数据结构详解(几何篇)

在讲VTK的数据结构之前,我们先了解可视化数据的两个特征:离散性、有规则或无规则。 离散性。当我们使用计算机去表示我们的数据时,一般都是基于有限数量的点做信息的采样(或插值),因此可视化的数据是以一种离散的方式表示的。有规则或无规则(也叫结构化或非结构化)。针…

从0开始深入理解Spring(1)--SpringApplication构造

从0开始深入理解Spring-启动类构造 引言: 从本篇开始&#xff0c;打算深入理解Spring执行过程及原理&#xff0c;个人理解极有可能有偏差&#xff0c;评论区欢迎指正错误&#xff0c;下面开始进入正片内容。 ps: springboot版本使用2.4.8 Springboot项目启动时&#xff0c;是通…

PMSM MATLAB

// s-function搭建变参数PMSM模型 永磁同步电机dq轴电感和其内部结构有何关系&#xff1f;​​​​​​​ 矢量控制&#xff0c;SVPWM开关频率一般20kHZ&#xff0c;是不是开关频率提越高控制效果越好&#xff1f;频率提高有没有意义&#xff1f; 一般来说&#xff0c;电机电…

【Leetcode】链表专题

leetcode链表专题 主要根据CSview一个校招刷题准备网站 做这个网站的人真的很厉害&#xff01;进行整理 太困了&#xff0c;一学习就困&#xff0c;来刷刷题 文章目录 leetcode链表专题前言一、leetcode 206.反转链表1.题目描述&#xff1a;2.主要有两种方法&#xff0c;迭代法…

发送钉钉、邮件、手机信息

其中下列部分用到了Hutool中的工具,可先导入Hutool依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.8.16</version></dependency>钉钉 public void sendDingDing(PoMaster poMa…

Ugee手写板Ex08 S在不同软件中的设置

手写笔的结构 功能对应于鼠标的作用笔尖鼠标左键上面第一个键鼠标右键&#xff08;效果有时候也不完全等同&#xff09;上面第二个键鼠标中键 以下测试的软件版本 软件版本windows10WPS2024春季16729Office2007SimpleTex0.2.5Ex08 S驱动版本4.2.4.231109 WPS-word ①点击审…