[pytorch、学习] - 3.13 丢弃法

参考

3.13 丢弃法

过拟合问题的另一种解决办法是丢弃法。当对隐藏层使用丢弃法时,隐藏单元有一定概率被丢弃。

3.12.1 方法

在这里插入图片描述

3.13.2 从零开始实现

import torch
import torch.nn as nn
import numpy as np
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2ldef dropout(X, drop_prob):X = X.float()assert 0 <= drop_prob <= 1keep_prob = 1 - drop_prob# 这种情况下把全部元素都丢弃if keep_prob == 0:return torch.zeros_like(X)mask = (torch.rand(X.shape) < keep_prob).float()return mask * X / keep_prob
X = torch.arange(16).view(2, 8)
X

在这里插入图片描述

dropout(X, 0.5)

在这里插入图片描述

dropout(X, 1)

在这里插入图片描述

3.13.2.1 定义模型参数

num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256W1 = torch.tensor(np.random.normal(0, 0.01, size=(num_inputs, num_hiddens1)), dtype=torch.float, requires_grad=True)
b1 = torch.zeros(num_hiddens1, requires_grad=True)
W2 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens1, num_hiddens2)), dtype=torch.float, requires_grad=True)
b2 = torch.zeros(num_hiddens2, requires_grad=True)
W3 = torch.tensor(np.random.normal(0, 0.01, size=(num_hiddens2, num_outputs)), dtype=torch.float, requires_grad=True)
b3 = torch.zeros(num_outputs, requires_grad=True)params = [W1, b1, W2, b2, W3, b3]

3.13.2.2 定义模型

drop_prob1, drop_prob2 = 0.2, 0.5def net(X, is_training=True):X = X.view(-1, num_inputs)H1 = (torch.matmul(X, W1) + b1).relu()if is_training:  # 只在训练模型时使用丢弃法H1 = dropout(H1, drop_prob1)  # 在第一层全连接后添加丢弃层H2 = (torch.matmul(H1, W2) + b2).relu()if is_training:H2 = dropout(H2, drop_prob2)  # 在第二层全连接后添加丢弃层return torch.matmul(H2, W3) + b3# 本函数已保存在d2lzh_pytorch
def evaluate_accuracy(data_iter, net):acc_sum, n = 0.0, 0for X, y in data_iter:if isinstance(net, torch.nn.Module):net.eval() # 评估模式, 这会关闭dropoutacc_sum += (net(X).argmax(dim=1) == y).float().sum().item()net.train() # 改回训练模式else: # 自定义的模型if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数# 将is_training设置成Falseacc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() else:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0]return acc_sum / n

3.13.2.3 训练和测试模型

num_epochs, lr, batch_size = 5, 100.0, 256
loss = torch.nn.CrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

在这里插入图片描述

3.13.3 简洁实现

net = nn.Sequential(d2l.FlattenLayer(),nn.Linear(num_inputs, num_hiddens1),nn.ReLU(),nn.Dropout(drop_prob1),nn.Linear(num_hiddens1, num_hiddens2),nn.ReLU(),nn.Dropout(drop_prob2),nn.Linear(num_hiddens2, 10)
)for param in net.parameters():nn.init.normal_(param, mean=0, std= 0.01)optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/250168.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

springboot---request 中Parameter,Attribute区别

HttpServletRequest类既有getAttribute()方法&#xff0c;也由getParameter()方法&#xff0c;这两个方法有以下区别&#xff1a; &#xff08;1&#xff09;HttpServletRequest类有setAttribute()方法&#xff0c;而没有setParameter()方法 &#xff08;2&#xff09;当两个…

Python之令人心烦意乱的字符编码与转码

ASC-II码&#xff1a;英文1个字节&#xff08;8 byte&#xff09;&#xff0c;不支持中文&#xff1b; 高大上的中国&#xff0c;扩展出自己的gbk、gb2312、gb2318等字符编码。 由于各个国家都有自己的编码&#xff0c;于是就需要统一的编码形式用于国际流传&#xff0c;防止乱…

[pytorch、学习] - 4.1 模型构造

参考 4.1 模型构造 让我们回顾以下多重感知机的简洁实现中包含单隐藏层的多重感知机的实现方法。我们首先构造Sequential实例,然后依次添加两个全连接层。其中第一层的输出大小为256,即隐藏层单元个数是256;第二层的输出大小为10,即输出层单元个数是10. 4.1.1 继承Module类来…

springboot---基本模块详解

概述 1.基于Spring框架的“约定优先于配置&#xff08;COC&#xff09;”理念以及最佳实践之路。 2.针对日常企业应用研发各种场景的Spring-boot-starter自动配置依赖模块&#xff0c;且“开箱即用”&#xff08;约定spring-boot-starter- 作为命名前缀&#xff0c;都位于org.…

第二课 运算符(day10)

第二课 运算符(day10) 一、运算符 结果是值 算数运算 a 10 * 10 赋值运算 a a 1 a1 结果是布尔值 比较运算 a 1 > 5 逻辑运算 a 1>6 or 11 成员运算 a "蚊" in "郑建文" 二、基本数据类型 1、数值…

[pytorch、学习] - 4.2 模型参数的访问、初始化和共享

参考 4.2 模型参数的访问、初始化和共享 在3.3节(线性回归的简洁实现)中,我们通过init模块来初始化模型的参数。我们也介绍了访问模型参数的简单方法。本节将深入讲解如何访问和初始化模型参数,以及如何在多个层之间共享同一份模型参数。 import torch from torch import nn…

spring-boot注解详解(三)

1.SpringBoot/spring SpringBootApplication: 包含Configuration、EnableAutoConfiguration、ComponentScan通常用在主类上&#xff1b; Repository: 用于标注数据访问组件&#xff0c;即DAO组件&#xff1b; Service: 用于标注业务层组件&#xff1b; RestController: 用于…

IEnumerableT和IQueryableT区分

哎&#xff0c;看了那么多&#xff0c;这个知识点还是得开一个文章 IQueryable和IEnumerable都是延时执行(Deferred Execution)的&#xff0c;而IList是即时执行(Eager Execution) IQueryable和IEnumerable在每次执行时都必须连接数据库读取&#xff0c;而IList读取一次后&…

表的转置 行转列: DECODE(Oracle) 和 CASE WHEN 的异同点

异同点 都可以对表行转列&#xff1b;DECODE功能上和简单Case函数比较类似&#xff0c;不能像Case搜索函数一样&#xff0c;进行更复杂的判断在Case函数中&#xff0c;可以使用BETWEEN, LIKE, IS NULL, IN, EXISTS等等&#xff08;也可以使用NOT IN和NOT EXISTS&#xff0c;但是…

[pytorch、学习] - 4.4 自定义层

参考 4.4 自定义层 深度学习的一个魅力在于神经网络中各式各样的层,例如全连接层和后面章节将要用介绍的卷积层、池化层与循环层。虽然PyTorch提供了大量常用的层,但有时候我们依然希望自定义层。本节将介绍如何使用Module来自定义层,从而可以被重复调用。 4.4.1 不含模型参…

树的存储

父亲表示法 顾名思义&#xff0c;就是只记录每个结点的父结点。 int n; int p[MAX_N]; // 指向每个结点的父结点 孩子表示法 如上&#xff0c;就是只记录每个结点的子结点。 int n; int cnt[MAX_N]; // 记录每个结点的子结点的数量 int p[MAX_N][MAX_CNT]; // 指向每个结点的子…

spring-boot注解详解(四)

repository repository跟Service,Compent,Controller这4种注解是没什么本质区别,都是声明作用,取不同的名字只是为了更好区分各自的功能.下图更多的作用是mapper注册到类似于以前mybatis.xml中的mappers里. 也是因为接口没办法在spring.xml中用bean的方式来配置实现类吧(接口…

令人叫绝的EXCEL函数功能

http://club.excelhome.net/thread-166725-1-1.html https://wenku.baidu.com/view/db319da0bb0d4a7302768e9951e79b8969026864.html转载于:https://www.cnblogs.com/cqufengchao/articles/9150401.html

[pytorch、学习] - 4.5 读取和存储

参考 4.5 读取和存储 到目前为止,我们介绍了如何处理数据以及如何构建、训练和测试深度学习模型。然而在实际中,我们有时需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。 4.5.1 读写tensor 我们可以直…

JAVA排序的方法

//冒泡排序法&#xff1a; package fuxi;public class Bubble { public static void main(String[] args) { int a[] { 10,23,11,56,45,26,59,28,84,79 }; int i,temp; System.out.println("输出原始数组数据&#xff1a;"); for (i…

spring-boot注解详解(五)

AutoWired 首先要知道另一个东西&#xff0c;default-autowire&#xff0c;它是在xml文件中进行配置的&#xff0c;可以设置为byName、byType、constructor和autodetect&#xff1b;比如byName&#xff0c;不用显式的在bean中写出依赖的对象&#xff0c;它会自动的匹配其它bea…

什么是p12证书?ios p12证书怎么获取?

.cer是苹果的默认证书&#xff0c;在xcode开发打包可以使用&#xff0c;如果在lbuilder、phonegap、HBuilder、AppCan、APICloud这些跨平台开发工具打包&#xff0c;就需要用到p12文件。 .cer证书仅包含公钥&#xff0c;.p12证书可能既包含公钥也包含私钥&#xff0c;这就是他们…

[pytorch、学习] - 4.6 GPU计算

参考 4.6 GPU计算 到目前为止,我们一直使用CPU进行计算。对复杂的神经网络和大规模数据来说,使用CPU来计算可能不够高效。 在本节中,将要介绍如何使用单块NIVIDA GPU进行计算 4.6.1 计算设备 PyTorch可以指定用来存储和计算的设备,如果用内存的CPU或者显存的GPU。默认情况下…

adb connect 192.168.1.10 failed to connect to 192.168.1.10:5555

adb connect 192.168.1.10 输出 failed to connect to 192.168.1.10:5555 关闭安卓端Wi-Fi&#xff0c;重新打开连接即可 转载于:https://www.cnblogs.com/sea-stream/p/10020995.html

创建oracle数据库表空间并分配用户

我们在本地的oracle上或者virtualbox的oracle上 创建新的数据库表空间操作&#xff1a;通过system账号来创建并授权/*--创建表空间create tablespace YUJKDATAdatafile c:\yujkdata200.dbf --指定表空间对应的datafile文件的具体的路径size 100mautoextend onnext 10m*/ /*--创…