从0开始深度学习(12)——多层感知机的逐步实现

依然以Fashion-MNIST图像分类数据集为例,手动实现多层感知机和激活函数的编写,大部分代码均在从0开始深度学习(9)——softmax回归的逐步实现中实现过

1 读取数据

import torch
from torchvision import transforms
import torchvision
from torch.utils import data# 读取数据
def load_data_fashion_mnist(batch_size, resize=None):  #@savetrans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="D:/DL_Data/", train=True, transform=trans, download=False)mnist_test = torchvision.datasets.FashionMNIST(root="D:/DL_Data/", train=False, transform=trans, download=False)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=12),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=12))train_iter, test_iter = load_data_fashion_mnist(256, resize=28)

2 初始化模型参数

以单隐藏层的多层感知机为例,选择使用256个隐藏单元

from torch import nn# 初始化模型参数
num_inputs=784      # 28*28
num_outputs=10
num_hiddens=256     # 我们选择使用256个隐藏单元,注意,一般选择使用2的若干次幂,因为内存的特殊性,可以在计算上更高效w1 = nn.Parameter(torch.randn(num_inputs,num_hiddens,requires_grad=True)*0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens,requires_grad=True))w2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))params = [w1, b1, w2, b2]

3 激活函数、损失函数、建立模型

# 激活函数
def relu(x):a=torch.zeros_like(x) # 保证全零张量和x的形状一致,利于广播计算return torch.max(x,a)# 损失函数
loss = nn.CrossEntropyLoss(reduction='none')#建立模型
def net(x):x=x.reshape((-1,num_inputs))#展开H=relu(x@w1+b1)# @表示矩阵乘法return (H@w2+b2)

4 训练模型

优化器使用SGD

#训练,优化器使用sgd
num_epochs=5
lr=00.1
updater=torch.optim.SGD(params,lr=lr)def train_epoch(net, train_iter, loss, updater):if isinstance(net, torch.nn.Module):net.train()  # 将模型设置为训练模式metric = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()updater.step()else:l.backward()updater([w, b], lr, batch_size)metric.add(float(l) * y.numel(), compute_accuracy(y_hat, y), y.numel())return metric[0] / metric[2], metric[1] / metric[2]def train(net, train_iter, test_iter, loss, num_epochs, updater):for epoch in range(num_epochs):train_metrics = train_epoch(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)print(f'Epoch {epoch + 1}: Train Loss {train_metrics[0]:.3f}, Train Acc {train_metrics[1]:.3f}, Test Acc {test_acc:.3f}')class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]def compute_accuracy(y_hat, y):  # 预测值、真实值if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)  # 找到一个样本中,对应的最大概率的类别cmp = y_hat.type(y.dtype) == y  # 将预测值 y_hat 与真实标签 y 进行比较,生成一个布尔张量 cmpreturn float(cmp.type(y.dtype).sum())# 计算在指定数据集上模型的准确率
def evaluate_accuracy(net, data_iter):  if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 累加多个变量的总和。这里初始化了一个包含两个元素的累加器,分别用来存储正确预测的数量和总的预测数量。with torch.no_grad():for X, y in data_iter:metric.add(compute_accuracy(net(X), y), y.numel())return metric[0] / metric[1]train(net, train_iter, test_iter, loss, num_epochs, updater)

在这里插入图片描述

5 预测

import matplotlib.pyplot as plt
# 定义 Fashion-MNIST 标签的文本描述
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# 预测并显示结果
def predict(net, test_iter, n=6):for X, y in test_iter:break  # 只取一个批次的数据trues = get_fashion_mnist_labels(y)preds = get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true + '\n' + pred for true, pred in zip(trues, preds)]n = min(n, X.shape[0])fig, axs = plt.subplots(1, n, figsize=(12, 3))for i in range(n):axs[i].imshow(X[i].permute(1, 2, 0).squeeze().numpy(), cmap='gray')axs[i].set_title(titles[i])axs[i].axis('off')plt.show()# 调用预测函数
predict(net, test_iter, n=6)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/55878.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

查找与排序-交换排序

交换排序是基于“比较”和“交换”两种操作来实现的排序方法 。 由于选择“比较”的基准元素不同,可将交换排序分为以下两种: 冒泡排序快速排序 一、冒泡排序 1.冒泡排序基本思想 因为其实现与气泡从水中往上冒的过程类似而得名。 每一趟的…

基于SpringBoot+Vue+uniapp微信小程序的垃圾分类系统的详细设计和实现(源码+lw+部署文档+讲解等)

项目运行截图 技术框架 后端采用SpringBoot框架 Spring Boot 是一个用于快速开发基于 Spring 框架的应用程序的开源框架。它采用约定大于配置的理念,提供了一套默认的配置,让开发者可以更专注于业务逻辑而不是配置文件。Spring Boot 通过自动化配置和约…

Redux (八) 路由React-router、嵌套路由、路由传参、路由懒加载

文章目录 一、React-Router的基本使用1. 安装及基本使用(路由映射配置)2. 路由跳转Link与NavLink3. Navigate导航4. 处理路径不存在的情况 二、嵌套路由三、手动跳转 (类似编程式路由导航)1. 函数式组件2. 类组件实现手动跳转 四、路由传参1. 路径设置占位符(params)2. search传…

Java面试指南:Java基础介绍

这是《Java面试指南》系列的第1篇,本篇主要是介绍Java的一些基础内容: 1、Java语言的起源 2、Java EE、Java SE、Java ME介绍 3、Java语言的特点 4、Java和C的区别和联系? 5、面向对象和面向过程的比较 6、Java面向对象的三大特性&#xff1a…

leetcode30:串联所有单词的字串

给定一个字符串 s 和一个字符串数组 words。 words 中所有字符串 长度相同。 s 中的 串联子串 是指一个包含 words 中所有字符串以任意顺序排列连接起来的子串。 例如,如果 words ["ab","cd","ef"], 那么 "abcdef…

1. 解读DLT698.45-2017通信规约--预连接响应

国家电网有限公司企业标准,面向对象的用电信息数据交换协议DLT698.45-2017 为提高用电信息采集系统的业务适应性、采集效率、安全性和数据溯源性,规范用电信息数据交换协议的通信架构、数据链路层、应用层、接口类与对象标识,制定本标准。 …

Linux系统:(Linux系统概述与安装)

硬件计算机硬件是指计算机系统中所有物理部件的总称。包括计算机主机、显示器、键盘、鼠标、内存、硬盘、处理器、主板等等。这些硬件部件是计算机系统运行的基础 不管是电脑系统(个人电脑、服务器等)、还是移动端操作系统(手机、平板等)。它的功能就是做为用户和硬件之间的桥梁…

前端求职简历-待补充

当然可以,针对大厂的前端岗位,一个吸引人的简历应该突出你的技术能力、项目经验、教育背景以及任何能体现你学习能力和团队协作能力的证明。以下是一个简历大纲示例,你可以根据自己的实际情况进行调整: 个人信息 姓名联系方式&a…

图文深入介绍oracle资源管理(续)

1. 引言: 本文将承接上篇继续深入介绍oracle资源管理。本文重点介绍如何使用oracle资源管理器管理好DB。 2. 资源管理器: 可以使用图形界面 OEM$或命令行调用 DBMS RESOURCE MANAGER 程序包的过程进行数据库资源管理。 调用资源管理器的先决条件&…

瑞数后缀加密怎么处理

前言: 瑞数我们经常补环境通过,但是遇到瑞数后缀不知道怎么处理 就拿瑞数4来讲 解决方法: (1)传明文加密参数 一般情况,我们传明文加密参数也能访问 (2)再补环境基础调用open …

基于stm32的4G模块点灯实验

led模块功能封装 #include "led.h" #include "sys.h"//初始化GPIO函数 void led_init(void) {GPIO_InitTypeDef gpio_initstruct;//打开时钟__HAL_RCC_GPIOB_CLK_ENABLE();//调用GPIO初始化函数gpio_initstruct.Pin GPIO_PIN_8 | GPIO_PIN_9;gpio_inits…

排序算法 —— 直接插入排序

目录 1.直接插入排序的思想 2.直接插入排序的实现 实现分析 实现代码 3.直接插入排序的分析 时间复杂度分析 空间复杂度分析 稳定性 1.直接插入排序的思想 直接插入排序的思想就是把待排序的元素按其关键码值的大小依次插入到一个已经排好序的有序序列中&#xff0c…

pycharm调试带参数命令行的程序

点击 run configuration 点击加号,选择python name填写程序名字,script填写程序路径,下一行填写传入的参数 点击apply,再点ok,即可debug 参考: pycharm 调试模式下命令行参数的传递_pycharm debug传参 ya…

项目实战:构建 effet.js 人脸识别交互系统的实战之路

📝个人主页🌹:Eternity._ 🌹🌹期待您的关注 🌹🌹 ❀构建 effet.js 📒1. 什么是effet.js📜2. 为什么需要使用effet.js📝3. effet.js的功能📚4. 使用…

【项目案例】-音乐播放器-Android前端实现-Java后端实现

精品专题: 01.C语言从不挂科到高绩点 https://blog.csdn.net/yueyehuguang/category_12753294.html?spm1001.2014.3001.5482https://blog.csdn.net/yueyehuguang/category_12753294.html?spm1001.2014.3001.5482 02. SpringBoot详细教程 https://blog.csdn.ne…

HTML之表单设计

1、HTML表单 HTML表单是用于收集用户输入的信息,并将用户输入的内容信息传到后台服务器中。 表单是通过form标签实现。 特别注意:如果一些内容提交后,没有将内容提交给后台服务器,那么需要添加一个name属性,语法&am…

NC 单据模板自定义项 设置参照(自定义参照)

NC 单据模板自定义项 设置参照(自定义参照) 如图下图,NC 单据模板自定义项 设置参照: 1、选择需要设置参照的自定义字段,选择高级属性页签,在类型设置中,数据类型选择参照信息,即bd…

【热门主题】000004 案例 Vue.js组件开发

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 【热…

JavaWeb合集11-Maven高级

十一、Maven高级 1、分模块设计与开发 为什么?将项目按照功能拆分成若干个子模块,方便项目的管理维护、扩展,也方便模块间的相互调用,资源共享。 分模块开发需要先针对模块功能进行设计,再进行编码。不会先将工程开发完毕,然后进行拆分。 实现步骤&…

RabbitMQ下载与配置

安装Erlang Erlang 下载地址如下: https://erlang.org/download/otp_versions_tree.html 安装 RabbitMQ RabbitMQ 下载地址如下: https://www.rabbitmq.com/install-windows.html 查看服务,服务已经正常启动 打开Command Prompt 输入rabb…