MLP感知机python实现

from torch import nn
from softmax回归 import train_ch3
import torch
import torchvision
from torch.utils import data
from torchvision import transforms# ①准备数据集
def load_data_fashion_mnist(batch_size, resize=None):# PyTorch中的一个转换函数,它的作用是将一个PIL Image或numpy.ndarray图像转换为一个Tensor数据类型。trans = [transforms.ToTensor()]# 是否需要改变大小if resize:trans.insert(0, transforms.Resize(resize))# 函数compose将这些转换操作组合起来trans = transforms.Compose(trans)# 训练数据mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)# 测试数据mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)# 返回值return (torch.utils.data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=4),torch.utils.data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=4))
# 批量大小为256
batch_size = 256
# 获取训练数据集和测试数据集
train_iter, test_iter = load_data_fashion_mnist(batch_size)# ②实现一个具有单隐藏层的多层感知机,它包含256个隐藏单元
# 定义输入,输出,隐藏层大小
num_inputs, num_outputs, num_hiddens = 784, 10, 256
# 定义W1、b1、W2、b2
W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))params = [W1, b1, W2, b2]# ③实现ReLU激活函数
def relu(X):a = torch.zeros_like(X)return torch.max(X, a)# ④实现模型
def net(X):# x=256*784X = X.reshape((-1, num_inputs))# torch.matmul(X,W1)= 256*256H = relu(torch.matmul(X,W1) + b1)# torch.matmul(H,W2) = 256*10return (torch.matmul(H,W2) + b2)# ⑤定义损失函数
loss = nn.CrossEntropyLoss()# ⑥训练
# 定义学习率
lr =  0.1
# 优化函数
updater = torch.optim.SGD(params, lr=lr)# 训练
if __name__ == '__main__':num_epochs = 10train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

训练结果

训练损失:0.0015049066459139188
训练精度:0.86405
测试精度:0.8453

貌似是比softmax十次好一些

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/152079.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

037、目标检测-算法速览

之——常用算法速览 目录 之——常用算法速览 杂谈 正文 1.区域卷积神经网络 - R-CNN 2.单发多框检测SSD,single shot detection 3.yolo 杂谈 快速过一下目标检测的各类算法。 正文 1.区域卷积神经网络 - R-CNN region_based CNN,奠基性的工作。…

AI自动写代码:GitHub copilot插件在Idea的安装和使用教程

GitHub Copilot 是微软与OpenAI共同推出的一款AI编程工具,基于GitHub及其他网站的源代码,根据上文提示为程序员自动编写下文代码,可以极大地提高编写代码的效率。 先看看ChatGpt是怎么回答Copilot的功能特点: 给大家简单提取一…

什么是缓存雪崩、击穿、穿透?

背景 数据一般是存储于数据库中,数据库中的数据都是存在磁盘上的,磁盘读写的速度相较于内存或者CPU中的寄存器来说是非常慢的了。 如果用户的请求都直接访问数据库的话,请求数量一上来,数据库很容易就崩溃了,所以为了…

Visio免费版!Visio国产平替软件,终于被我找到啦!

作为一个职场人士,我经常需要绘制各种流程图和图表,而Visio一直是我使用的首选工具。但是,随着公司的发展和工作的需要,我逐渐发现了Visio的优点和不足。 首先,让我们来看看Visio的优点。Visio是一个专业的流程图和图…

C#可空类型

在C#中,可空类型(Nullable types)允许值类型(比如int, double, bool等)接受null值。这是特别有用的,因为在很多应用程序中,如数据库交互和数据解析,值类型的字段可能需要表示没有值&…

注册表单mvc 含源代码

总结 jsp给我们的ControllerServlet.java,ControllerServlet.java获取参数,信息封装给RegisterFormBean.java的对象看是否符合格式,符合格式再信息封装给UserBean对象,调用Dbutil插入方法查重.]]要创建一个user集合成功跳哪个界面,打印信息注意什么时候要加getsession失败跳哪…

VS+Qt+C++ Yolov8物体识别窗体程序onnx模型

程序示例精选 VSQtC Yolov8物体识别窗体程序onnx模型 如需安装运行环境或远程调试,见文章底部个人QQ名片,由专业技术人员远程协助! 前言 这篇博客针对《VSQtC Yolov8物体识别窗体程序onnx模型》编写代码,代码整洁,规…

C++ 与Qt之间的关系

C 和 Qt 之间有着密切的关系。Qt 是一个用于开发 GUI 和应用程序的跨平台 C 库,它提供了许多功能和工具,可以帮助开发者创建高质量的桌面应用程序和移动应用程序。 以下是 C 和 Qt 之间的一些主要关系: Qt 是用 C 编写的:Qt 的核…

Linux远程工具专家推荐(二)

8. Apache Guacamole Apache Guacamole 是一款免费开源的无客户端远程桌面网关,支持 VNC、RDP 和 SSH 等标准协议。无需插件或客户端软件;只需使用 HTML5 Web 应用程序(例如 Web 浏览器)即可。 这意味着您的计算机的使用不受任何一…

【开源】基于Vue和SpringBoot的民宿预定管理系统

项目编号: S 058 ,文末获取源码。 \color{red}{项目编号:S058,文末获取源码。} 项目编号:S058,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 用例设计2.2 功能设计2.2.1 租客角色…

【监控系统】日志可视化监控体系ELK搭建

1.ELK架构是什么 ELK是ElasticsearchLogstashKibana的简称。 Elasticsearch是一个开源的分布式搜索和分析引擎,可以用于全文检索、结构化检索和分析,它构建在Lucene搜索引擎库之上,是当前使用较为广泛的开源搜索引擎之一。 Logstash是一个…

CISP模拟试题(二)

免责声明 文章仅做经验分享用途,利用本文章所提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,作者不为此承担任何责任,一旦造成后果请自行承担!!! 1.中国信息安全测评中心对CISP注册信息安全专业人员有保持认证要求,在证书有效期内,应完成至少6…

雪花算法的使用

雪花算法的使用(工具类utils) import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component;// 雪花算法 Component public class SnowflakeUtils { // Generated ID: 1724603634882318341; // Generated ID: 1724603…

Databend 开源周报第 120 期

Databend 是一款现代云数仓。专为弹性和高效设计,为您的大规模分析需求保驾护航。自由且开源。即刻体验云服务:https://app.databend.cn 。 Whats On In Databend 探索 Databend 本周新进展,遇到更贴近你心意的 Databend 。 使用自定义 CON…

常见Web安全

一.Web安全概述 以下是百度百科对于web安全的解释: Web安全,计算机术语,随着Web2.0、社交网络、微博等等一系列新型的互联网产品的诞生,基于Web环境的互联网应用越来越广泛,企业信息化的过程中各种应用都架设在Web平台…

java mybatisplus generator 修改字段类型

最新版生成代码的指定字段类型 FastAutoGenerator .dataSourceConfig(config -> {config.typeConvertHandler(new ITypeConvertHandler() {Overridepublic NotNull IColumnType convert(GlobalConfig globalConfig, TypeRegistry typeRegistry, TableField.MetaInfo metaIn…

【框架整合】Redis限流方案

1、Redis实现限流方案的核心原理&#xff1a; redis实现限流的核心原理在于redis 的key 过期时间&#xff0c;当我们设置一个key到redis中时&#xff0c;会将key设置上过期时间&#xff0c;这里的实现是采用lua脚本来实现原子性的。2、准备 引入相关依赖 <dependency>…

MySQL 之多版本并发控制 MVCC

MySQL 之多版本并发控制 MVCC 1、MVCC 中的两种读取方式1.1、快照读1.2、当前读 2、MVCC实现原理之 ReadView2.1、隐藏字段2.2、ReadView2.3、读已提交和可重复读隔离级别下&#xff0c;产生 ReadView 时机的区别 3、MVCC 解决幻读4、总结 MVCC&#xff08;多版本并发控制&…

springboot引入第三方jar包放到项目目录中,添加web.xml

参考博客&#xff1a;https://www.cnblogs.com/mask-xiexie/p/16086612.html https://zhuanlan.zhihu.com/p/587605618 1、在resources目录下新建lib文件夹&#xff0c;将jar包放到lib文件夹中 2、修改pom.xml文件 <dependency><groupId>com.lanren312</grou…

android 点击EdiText 禁止弹出系统键盘

找了很多都不好用&#xff0c;最后使用 editText.setShowSoftInputOnFocus(false); editText.getViewTreeObserver().addOnGlobalLayoutListener(new ViewTreeObserver.OnGlobalLayoutListener() {Overridepublic void onGlobalLayout() {InputMethodManager imm (InputMetho…