李沐深度学习-多层感知机从零开始

!!!梯度的产生是由于反向传播,在自定义从零开始编写代码时,第一次反向传播前应该对params参数的梯度进行判断

import torch
import numpy as np
import torch.utils.data as Data
import torchvision.datasets
import torchvision.transforms as transforms
import syssys.path.append("路径")
import d2lzh_pytorch as d2l'''
--------------------------------------------------获取和读取数据
'''
batch_size = 256
train_mnist = torchvision.datasets.FashionMNIST(root='路径',download=True, train=True, transform=transforms.ToTensor())
test_mnist = torchvision.datasets.FashionMNIST(root='路径',download=True, train=False, transform=transforms.ToTensor())
train_iter = Data.DataLoader(train_mnist, batch_size=batch_size, shuffle=True)
test_iter = Data.DataLoader(test_mnist, batch_size=batch_size, shuffle=False)'''
--------------------------------------------------定义模型参数
'''
num_inputs = 784
num_outputs = 10
num_hidden = 256
# 有几个隐藏层就要设置几个参数,简洁实现中,linear网络会自动配置初始参数,自己可以使用init.normal_()设置参数初始值
w1 = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_hidden)), dtype=torch.float)
b1 = torch.zeros(num_hidden, dtype=torch.float)
w2 = torch.tensor(np.random.normal(0, 0.1, (num_hidden, num_outputs)), dtype=torch.float)
b2 = torch.zeros(num_outputs, dtype=torch.float)
params = [w1, b1, w2, b2]
for param in params:param.requires_grad_(requires_grad=True)
'''
---------------------------------------------定义激活函数
'''def relu(X):return torch.max(input=X, other=torch.tensor(0.0))'''
---------------------------------------------------定义模型
'''# 使用view函数将输入的样本转换成inputs特征数大小的图像
def net(X):X = X.view((-1, num_inputs))H = relu(torch.matmul(X, w1) + b1)  # torch.mm(X, w1) + b1得到隐藏层输出# 对隐藏层变量进行激活函数变换,然后作为下一个全连接层的输入# 第一层不是隐藏层,直接线性计算,隐藏层输出作为输出层输入的时候,对隐藏层进行非线性变换,然后传入输入层return torch.matmul(H, w2) + b2  # 隐藏层作为输出层的输入   n层layer有最多n-2个激活函数'''
-----------------------------------------------------定义损失函数
'''
loss = torch.nn.CrossEntropyLoss()  # 包含了softmax运算和交叉熵运算'''
------------------------------------------------------softmax操作,用于训练模型中训练集准确率调用
'''def softmax(X):X_exp = X.exp()  # 幂指数化partition = X_exp.sum(dim=1, keepdim=True)  # 求和每行的元素值return X_exp / partition  # 做比值得预测概率'''
----------------------------------------------------测试集准确率函数,训练模型中测试集准确率调用
'''def evaluate_accuracy(test_data):acc_num, num = 0.0, 0for X, y in test_data:  # X,y分别是一个元组acc_num += (softmax(net(X)).argmax(dim=1) == y).float().sum().item()num += y.shape[0]return acc_num / num'''
------------------------------------------------------训练模型
'''
num_epochs, lr = 5, 100def train():for epoch in range(num_epochs):train_acc, train_l, test_acc, n, num = 0.0, 0.0, 0.0, 0, 0for X, y in train_iter:  #l = loss(net(X), y)  # CrossEntropyLoss 函数已经是对一个批次内所有样本的平均损失计算了if params[0].grad is not None:  # 第一次训练迭代前是没有梯度产生的,梯度是由于反向传播才产生的for param in params:  # 参数梯度清零param.grad.data.zero_()l.backward()  # 反向传播d2l.sgd(params, lr, batch_size)  # 梯度下降操作train_l += l.item()# net(X)返回每个样本各个类别的预测值,有n个样本返回train_acc += (softmax(net(X)).argmax(dim=1) == y).float().sum().item()  # 累加预测正确个数n += y.shape[0]num += 1test_acc = evaluate_accuracy(test_iter)print(f'epoch %d, loss %.4f, train_acc %.3f, test_acc %.3f'% (epoch + 1, train_l / num, train_acc / n, test_acc))train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/636985.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

对#多种编程语言 性能的研究和思考 go/c++/rust java js ruby python

对#多种编程语言 性能的研究和思考 打算学习一下rust 借着这个契机 简单的写了计算圆周率代码的各种语言的版本 比较了一下性能 只比拼单线程简单计算能力 计算十亿次循环 不考虑多线程 go/c/rust java js ruby python 耗时秒数 1:1:1:22:3:250:450 注:能启用则启…

SpringBoot ES 重建 Mapping

SpringBoot ES 重建 Mapping 1 复制数据2 删除老索引3 重建索引4 复制回数据 1 复制数据 POST http://elastic:123456127.0.0.1:9200/_reindex{"source": {"index": "老索引名称"},"dest": {"index": "备份索引名称&q…

web蓝桥杯真题--11、蓝桥知识网

介绍 蓝桥为了帮助大家学习,开发了一个知识汇总网站,现在想设计一个简单美观的首页。本题请根据要求来完成一个首页布局。 准备 开始答题前,需要先打开本题的项目代码文件夹,目录结构如下: ├── css │ └──…

Stream toList不能滥用以及与collect(Collectors.toList())的区别

Stream toList()返回的是只读List原则上不可修改,collect(Collectors.toList())默认返回的是ArrayList,可以增删改查 1. 背景 在公司看到开发环境突然发现了UnsupportedOperationException 报错,想到了不是自己throw的应该就是操作collection不当。 发…

spawn_group | spawn_group_template | linked_respawn

字段介绍 spawn_group | spawn_group_template 用来记录与脚本事件或boss战斗有关的 creatures | gameobjects 的刷新数据linked_respawn 用来将 creatures | gameobjects 和 boss 联系起来,这样如果你杀死boss, creatures | gameobjects 在副本重置之前…

测试覆盖与矩阵

4. Coverage - 衡量测试的覆盖率 我们已经掌握了如何进行单元测试。接下来,一个很自然的问题浮现出来,我们如何知道单元测试的质量呢?这就提出了测试覆盖率的概念。覆盖率测量通常用于衡量测试的有效性。它可以显示您的代码的哪些部分已被测…

【网络安全】【密码学】【北京航空航天大学】实验五、古典密码(中)【C语言实现】

实验五、古典密码(中) 实验目的和原理简介参见博客:古典密码(上) 一、实验内容 1、弗纳姆密码(Vernam Cipher) (1)、算法原理 加密原理: 加密过程可以用…

【跳槽面试】Redis中分布式锁的实现

分布式锁常见的三种实现方式: 数据库乐观锁;基于Redis的分布式锁;基于ZooKeeper的分布式锁。 本地面试考点是,你对Redis使用熟悉吗?Redis中是如何实现分布式锁的。 在Redis中,分布式锁的实现主要依赖于R…

对比一下HelpLook和Bloomfire知识库软件:谁更胜一筹?

在当今知识经济的浪潮中,知识库工具作为企业不可或缺的利器,对于提高工作效率、加强团队协作和优化员工培训等方面起着至关重要的作用。HelpLook和Bloomfire是众多知识库工具中的两款佼佼者,它们各自拥有独特的优势和特点。 一、HelpLook&…

解决 java.lang.NoClassDefFoundError: org/apache/poi/POIXMLTypeLoader 报错

在使用POI导出Excel表格的时候&#xff0c;本地运行导出没问题&#xff0c;但是发布到服务器后提示 “java.lang.NoClassDefFoundError: org/apache/poi/POIXMLTypeLoader” 下面是pom.xml中的配置 <dependency><groupId>org.apache.poi</groupId><art…

Linux查找二进制文件命令——whereis

whereis whereis 命令是一个 Linux/Unix 系统下的命令行命令&#xff0c;用于查询指定命令或程序的二进制文件、源代码文件和帮助文件的位置。 whereis 命令的语法如下&#xff1a; whereis [options] command其中&#xff0c;command 为要查询的命令或程序名称&#xff0c;…

【算法详解】力扣162.寻找峰值

​ 目录 一、题目描述二、思路分析 一、题目描述 力扣链接&#xff1a;力扣162.寻找峰值 峰值元素是指其值严格大于左右相邻值的元素。 给你一个整数数组 nums&#xff0c;找到峰值元素并返回其索引。数组可能包含多个峰值&#xff0c;在这种情况下&#xff0c;返回 任何一个…

大创项目推荐 深度学习验证码识别 - 机器视觉 python opencv

文章目录 0 前言1 项目简介2 验证码识别步骤2.1 灰度处理&二值化2.2 去除边框2.3 图像降噪2.4 字符切割2.5 识别 3 基于tensorflow的验证码识别3.1 数据集3.2 基于tf的神经网络训练代码 4 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x…

Webpack5入门到原理15:提取 Css 成单独文件

提取 Css 成单独文件 Css 文件目前被打包到 js 文件中&#xff0c;当 js 文件加载时&#xff0c;会创建一个 style 标签来生成样式 这样对于网站来说&#xff0c;会出现闪屏现象&#xff0c;用户体验不好 我们应该是单独的 Css 文件&#xff0c;通过 link 标签加载性能才好 …

gin介绍及helloworld

1. 介绍 Gin是一个golang的微框架&#xff0c;封装比较优雅&#xff0c;API友好&#xff0c;源码注释比较明确&#xff0c;具有快速灵活&#xff0c;容错方便等特点 对于golang而言&#xff0c;web框架的依赖要远比Python&#xff0c;Java之类的要小。自身的net/http足够简单&…

未来 AI 可能给哪些产业带来哪些进步与帮助?

AI时代如何要让公司在创新领域领先吗&#xff1f;拥抱这5种创新技能&#xff0c;可以帮助你的公司应对不断变化。包括人工智能、云平台应用、数据分析、 网络安全和体验设计。这些技能可以帮助你提高业务效率、保护公司知识资产、明智决策、满足客户需求并提高销售额。 现在就加…

Redis 面试题 | 01.精选Redis高频面试题

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

Java 代理模式详解

1. 代理模式 代理模式是一种比较好理解的设计模式。简单来说就是 我们使用代理对象来代替对真实对象(real object)的访问&#xff0c;这样就可以在不修改原目标对象的前提下&#xff0c;提供额外的功能操作&#xff0c;扩展目标对象的功能。 代理模式的主要作用是扩展目标对象…

低代码技术杂谈

一、探讨低代码的定义 “Low-Code”是什么&#xff1f;身为技术人员听到这种技术名词&#xff0c;咱们第一反应就是翻看维基百科 或者其他相关技术论文&#xff0c;咱们想看维基百科的英文介绍&#xff1a; A low-code development platform (LCDP) provides a development env…

AI辅助编程工具—Github Copilot

一、概述 Copilot是一种基于Transformer模型的神经网络&#xff0c;具有12B个参数。是GitHub和OpenAPI共同开发的编程辅助工具。GitHubCopilot是一款由人工智能驱动的结对编程编辑器&#xff0c;旨在帮助开发人员更加高效地工作。它利用OpenAICodex技术&#xff0c;将开发…