softmax回归的从零开始实现

 1.初始化模型参数

import torch
from IPython import display
from d2l import torch as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
num_inputs = 784
num_outputs = 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

2. 定义模型

def softmax(X):
X_exp = torch.exp(X)
partition = X_exp.sum(1, keepdim=True)
return X_exp / partition 

注意,虽然这在数学上看起来是正确的,但我们在代码实现中有点草率。矩阵中的非常大或非常小的元素可
能造成数值上溢或下溢,但我们没有采取措施来防止这点。

def net(X):
return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

先线性模型得出值,softmax转换为(0,1)之间。正向传播

3 定义损失函数

def cross_entropy(y_hat, y):
return - torch.log(y_hat[range(len(y_hat)), y])

4.梯度下降

updater核心作用是更新参数 

lr = 0.1
def updater(batch_size):
return d2l.sgd([W, b], lr, batch_size)
def sgd(params, lr, batch_size): #@save
"""小批量随机梯度下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()

5.训练函数

单epoch的

def train_epoch_ch3(net, train_iter, loss, updater): #@save
"""训练模型一个迭代周期(定义见第3章)"""
# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()
# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:
# 计算梯度并更新参数y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):
# 使用PyTorch内置的优化器和损失函数,这个没有batchsize,传入的是meanupdater.zero_grad()l.mean().backward() updater.step()else:
# 使用定制的优化器和损失函数  传入batchsize就行,会/batchsize,直接sum更新就好了l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]

 多epoch的

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save
"""训练模型(定义见第3章)"""animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):train_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)animator.add(epoch + 1, train_metrics + (test_acc,))#把原来的两个元素的元组变成三个的
train_loss, train_acc = train_metrics
assert train_loss < 0.5, train_loss
assert train_acc <= 1 and train_acc > 0.7, train_acc
assert test_acc <= 1 and test_acc > 0.7, test_acc

 6.训练

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

 

7.预测

def predict_ch3(net, test_iter, n=6): #@save
"""预测标签(定义见第3章)"""for X, y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true +'\n' + pred for true, pred in zip(trues, preds)]d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
predict_ch3(net, test_iter)

 

附加:计算分类精度

y_hat是矩阵,那么假定第二个维度存储每个类的预测分数。我们使用argmax获得每行中最大元素的索引来获得预测类别。最后得到正确预测的数量

def accuracy(y_hat, y): #@save
"""计算预测正确的数量"""
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = y_hat.argmax(axis=1)
cmp = y_hat.type(y.dtype) == y
return float(cmp.type(y.dtype).sum())

 实用程序类Accumulator,用于对多个变量进行累加。

Accumulator实例中创建了2个变量,分别用于存储正确预测的数量和预测的总数量。当我们遍历数据集时,两者都将随着时间的推移而累加。

class Accumulator: #@save
"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]
#        这一步,zip并行加载两个数组,a为原数据,b为要添加的数据
#        如:[1,1]+[1,2]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]

 评估在任意模型net的精度

def evaluate_accuracy(net, data_iter): #@save
"""计算在指定数据集上模型的精度"""if isinstance(net, torch.nn.Module):net.eval() # 将模型设置为评估模式metric = Accumulator(2) # 正确预测数、预测总数
with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())
return metric[0] / metric[1]

调包实现

1 初始化模型参数
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
def init_weights(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights);2.损失函数
loss = nn.CrossEntropyLoss(reduction='none')3.优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.1)4 训练
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/979.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构】复习题(二)

Hello&#xff01;大家好&#xff0c;这一篇数据结构复习题是我上个学期复习的时候写的&#xff08;刚刚在草稿箱发现了&#xff01;&#xff09;有一些题目过程都是配了图片的&#xff0c;希望对正在复习数据结构的宝宝们有帮助哦&#xff01;(还有一个数据结构复习题(一)可以…

Spring JdbcTemplate基本使用

1. JdbcTemplate概述 它是spring框架中提供的一个对象&#xff0c;是对原始繁琐的JdbcAPI对象的简单封装。spring框架为我们提供了很多的操作模板类。例如:操作关系型数据的JdbcTemplate和HibermateTemplate&#xff0c;操作nosql数据库的RedisTemplate&#xff0c;操作消息队…

Hadoop学习总结(Hive的远程服务、数据模型操作、数据操作)

在启动hive时要先启动Hadoop。 在SecurityCRT 或者在 Xshell 进行虚拟机链接 &#xff08;这里使用Xshell &#xff09; 一、Hive 的管理 1、CLI 方式 &#xff08;1&#xff09;启动 Hive 直接输入 hive &#xff08;2&#xff09;退出 直接输入以下一条命令&#xff0…

预付费电表管理系统:WEB端的高效解决方案

1.系统概述 预付费电表管理系统&#xff0c;尤其是基于WEB端的版本&#xff0c;是一种现代化的电力管理工具&#xff0c;旨在提高能源效率&#xff0c;优化电费支付流程&#xff0c;并提供实时的用电数据监控。它通过互联网技术&#xff0c;使得用户能够在线充值、查询电量、远…

开源全方位运维监控工具:HertzBeat

HertzBeat&#xff1a;实时监控系统性能&#xff0c;精准预警保障业务稳定- 精选真开源&#xff0c;释放新价值。 概览 HertzBeat是一款深受广大开发者喜爱的开源实时监控解决方案。它以其简洁直观的设计理念和免安装Agent的特性&#xff0c;实现了对各类服务器、数据库及应用…

Youtube DNN

目录 1. 挑战 2. 系统整体结构 3.召回 4. 排序 5. 训练和测试样本的处理 1. 挑战 &#xff08;1&#xff09;规模。很多现有的推荐算法在小规模上效果好&#xff0c;但Youtobe规模很大。 &#xff08;2&#xff09;新颖度。Youtobe语料库是动态的&#xff0c;每秒都会有…

芯片数字后端设计入门书单推荐(可下载)

数字后端设计&#xff0c;作为数字集成电路设计的关键环节&#xff0c;承担着将逻辑设计转化为物理实现的重任。它不仅要求设计师具备深厚的电路理论知识&#xff0c;还需要对EDA工具有深入的理解和熟练的操作技能。尽管数字后端工作不像前端设计那样频繁涉及代码编写&#xff…

【Leetcode】链表专题

leetcode链表专题 主要根据CSview一个校招刷题准备网站 做这个网站的人真的很厉害&#xff01;进行整理 太困了&#xff0c;一学习就困&#xff0c;来刷刷题 文章目录 leetcode链表专题前言一、leetcode 206.反转链表1.题目描述&#xff1a;2.主要有两种方法&#xff0c;迭代法…

Ugee手写板Ex08 S在不同软件中的设置

手写笔的结构 功能对应于鼠标的作用笔尖鼠标左键上面第一个键鼠标右键&#xff08;效果有时候也不完全等同&#xff09;上面第二个键鼠标中键 以下测试的软件版本 软件版本windows10WPS2024春季16729Office2007SimpleTex0.2.5Ex08 S驱动版本4.2.4.231109 WPS-word ①点击审…

《R语言与农业数据统计分析及建模》学习——创建与访问数据框

1、数据框的概念和特点 数据框是二维的表格形式数据结构&#xff0c;是R语言中最常用的数据结构之一。有如下特点&#xff1a; &#xff08;1&#xff09;异质性&#xff1a;各列不同的数据类型 &#xff08;2&#xff09;命名索引&#xff1a;每列都有一个名称 &#xff08;3&…

开源Windows12网页版HTML源码

源码介绍 开源Windows12网页版HTML源码&#xff0c;无需安装就能用的Win12网页版来了Windows12概念版&#xff08;PoweredbyPowerPoint&#xff09;后深受启发&#xff0c;于是通过使用HTML、CSS、js等技术做了这样一个模拟板的Windows12系统&#xff0c;并已发布至github进行…

蓝桥杯2024年第十五届省赛真题-小球反弹

以下两个解法感觉都靠谱&#xff0c;并且网上的题解每个人答案都不一样&#xff0c;目前无法判断哪个是正确答案。 方法一&#xff1a;模拟 代码参考博客 #include <iostream> #include <cmath> #include <vector>using namespace std;int main() {const i…

(二十)C++自制植物大战僵尸游戏僵尸进攻控制实现

植物大战僵尸游戏开发教程专栏地址http://t.csdnimg.cn/8UFMs 文件位置 实现功能的代码文件位置在Class\Scenes\GameScene文件夹中&#xff0c;具体如下图所示。 ZombiesAppearControl.h /* 僵尸出现波数控制 */ class ZombiesAppearControl { public:/***对于进攻的不同波数…

【吊打面试官系列】Java高并发篇 - 如何停止一个正在运行的线程?

大家好&#xff0c;我是锋哥。今天分享关于 【如何停止一个正在运行的线程&#xff1f;】面试题&#xff0c;希望对大家有帮助&#xff1b; 如何停止一个正在运行的线程&#xff1f; java如何停止一个正在运行的线程? 在Java中&#xff0c;可以使用Thread.stop()方法来停止一…

Android自带模拟器如何获得ROOT权限

如果在模拟器中不能切换到root权限&#xff0c;很可能是镜像使用的不对。 一.选择镜像标准&#xff1a; 1.运行在PC端选X86_64镜像&#xff0c;才能流畅运行 2.不带google api的镜像 二.步骤 在虚拟机管理器中新建AVD&#xff0c;并下载符合要求的镜像文件 三.验证

【MATLAB】App 设计 (入门)

设计APP 主界面 函数方法 定时器 classdef MemoryMonitorAppExample < matlab.apps.AppBase% Properties that correspond to app componentsproperties (Access public)UIFigure matlab.ui.FigureStopButton matlab.ui.control.ButtonStartButton matlab.ui.cont…

大模型ChatGPT里面的一些技术和发展方向

文章目录 如何炼成ChatGPT如何调教ChatGPT如何武装ChatGPT一些大模型的其他方向 这个是基于视频 https://www.bilibili.com/video/BV17t4218761&#xff0c;可以了解一下大模型里面的一些技术和最近的发展&#xff0c;基本都是2022你那以来的发展&#xff0c;比较新。然后本文…

最新版的GPT-4.5-Turbo有多强

OpenAI再次用实力证明了&#xff0c;GPT依然是AI世界最强的玩家&#xff01;在最新的AI基准测试中&#xff0c;OpenAI几天前刚刚发布的GPT-4-Turbo-2024-04-09版本&#xff0c;大幅超越了Claude3 Opus&#xff0c;重新夺回了全球第一的AI王座&#xff1a; 值得一提的是&#xf…

C++ 模板详解——template<class T>

一. 前言 在我们学习C时&#xff0c;常会用到函数重载。而函数重载&#xff0c;通常会需要我们编写较为重复的代码&#xff0c;这就显得臃肿&#xff0c;且效率低下。重载的函数仅仅只是类型不同&#xff0c;代码的复用率比较低&#xff0c;只要有新类型出现时&#xff0c;就需…

文章解读与仿真程序复现思路——电力自动化设备EI\CSCD\北大核心《考虑碳市场风险的热电联产虚拟电厂低碳调度》

本专栏栏目提供文章与程序复现思路&#xff0c;具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…