pytorch实现Dropout与正则化防止过拟合

numpy实现dropout与L1,L2正则化请参考我另一篇博客

https://blog.csdn.net/fanzonghao/article/details/81079757

pytorch使用dropout与L2 

import torch
import matplotlib.pyplot as plt
torch.manual_seed(1)    # Sets the seed for generating random numbers.reproducibleN_SAMPLES = 20
N_HIDDEN = 300# training data
x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)
print('x.size()',x.size())# torch.normal(mean, std, out=None) → Tensor
y = x + 0.3*torch.normal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))
print(y.shape)
print(y)
# test data
test_x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)
test_y = test_x + 0.3*torch.normal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))# show data
plt.scatter(x.numpy(), y.numpy(), c='red', s=50, alpha=0.5, label='train')
plt.scatter(test_x.numpy(), test_y.numpy(), c='blue', s=50, alpha=0.5, label='test')
plt.legend(loc='upper left')
plt.ylim((-2.5, 2.5))
plt.show()net_overfitting = torch.nn.Sequential(torch.nn.Linear(1,N_HIDDEN),torch.nn.ReLU(),torch.nn.Linear(N_HIDDEN,N_HIDDEN),torch.nn.ReLU(),torch.nn.Linear(N_HIDDEN,1),
)net_dropped = torch.nn.Sequential(torch.nn.Linear(1,N_HIDDEN),torch.nn.Dropout(0.5), # 0.5的概率失活torch.nn.ReLU(),torch.nn.Linear(N_HIDDEN,N_HIDDEN),torch.nn.Dropout(0.5),torch.nn.ReLU(),torch.nn.Linear(N_HIDDEN,1),
)#no dropout
optimizer_ofit = torch.optim.Adam(net_overfitting.parameters(), lr=0.001)
#add dropout
optimizer_drop = torch.optim.Adam(net_dropped.parameters(), lr=0.01)
#add l2 penalty weight_decay
# optimizer_ofit = torch.optim.Adam(net_overfitting.parameters(), lr=0.001,weight_decay=0.001)
loss = torch.nn.MSELoss()for epoch in range(500):pred_ofit = net_overfitting(x)loss_ofit = loss(pred_ofit, y)optimizer_ofit.zero_grad()loss_ofit.backward()optimizer_ofit.step()#DROP OUTpred_drop = net_dropped(x)loss_drop = loss(pred_drop, y)optimizer_drop.zero_grad()loss_drop.backward()optimizer_drop.step()if epoch % 250 == 0:net_overfitting.eval()  # 将神经网络转换成测试形式,此时不会对神经网络dropoutnet_dropped.eval()  # 此时不会对神经网络dropouttest_pred_ofit = net_overfitting(test_x)test_pred_drop = net_dropped(test_x)# show dataplt.scatter(x.numpy(), y.numpy(), c='red', s=50, alpha=0.5, label='train')plt.scatter(test_x.numpy(), test_y.numpy(), c='blue', s=50, alpha=0.5, label='test')plt.plot(test_x.numpy(), test_pred_ofit.detach().numpy(), 'r-', lw=3, label='overfitting')plt.plot(test_x.numpy(), test_pred_drop.detach().numpy(), 'b--', lw=3, label='L2')plt.text(0, -1.2, 'overfitting loss=%.4f' % loss(test_pred_ofit, test_y).detach().numpy(),fontdict={'size': 20, 'color': 'red'})plt.text(0, -1.5, 'L2 loss=%.4f' % loss(test_pred_drop, test_y).detach().numpy(),fontdict={'size': 20, 'color': 'blue'})plt.legend(loc='upper left')plt.ylim((-2.5, 2.5))plt.pause(0.1)net_overfitting.train()net_dropped.train()plt.ioff()
plt.show()

数据:

使用dropout对比:可看出使用dropout具有防止过拟合的作用。

使用L2对比:可看出使用L2也具有防止过拟合作用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/493268.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

“蚁人”不再是科幻!MIT最新研究,能把任何材料物体缩小1000倍 | Science

来源:量子位科学加速,科幻成真也在加速。漫威世界中,蚁人是蚂蚁大小的超级英雄,靠一件“变身服”,人类就能在更微观的世界里大干一场。现在,类似的科幻想象,被MIT变成现实。丨小小小&#xff0c…

Android ARM指令学习

在逆向分析Android APK的时候,往往需要分析它的.so文件。这个.so文件就是Linux的动态链接库,只不过是在ARM-cpu下编译的。所以学习Android下的ARM指令很重要。目前,市面上的ARM-cpu基本都支持一种叫做THUMB的指令集模式。这个THUMB指令集可以…

cuda基础知识

nvidia-cuda 手册:https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#kernels nvidia cuda 教学视频 https://www.nvidia.cn/object/cuda_education_cn_old.html 介绍: CUDA编程模型是一个异构模型,需要CPU和GPU协同工作。在CUDA中,…

苹果着手自研调制解调器,以应对高通天价专利费

来源:DeepTech深科技近日,苹果官方发布一份招聘信息,其中有一个职位就非常惹人注意,根据信息,苹果准备招募两名蜂窝调制解调器系统架构师,一名构架师的工作地点在圣克拉拉,另一名构架师的工作地…

labelme标注文件转coco json,coco json转yolo txt格式,coco json转xml, labelme标注文件转分割,boxes转labelme json

参考:https://github.com/wkentaro/labelme 一.labelme标注文件转coco json 1.标注时带图片ImageData信息,将一个文件夹下的照片和labelme的标注文件,分成了train和val的coco json文件和照片, (COCO的格式: [x1,y1,…

“深度学习之父”大谈AI:寒冬不会出现,论文评审机制有损创新

来源: AI科技大本营整理:琥珀近日《连线》杂志发表了一篇文章,记录了与“深度学习之父” Geoffrey Hinton 围绕人工智能伦理、技术、学术等领域的采访实录。当被问到如今人工智能是否将走进寒冬时,Hinton 的回答非常坚决&#xff…

GDataXML解析XML文档

一、GDataXMLNode说明GDataXMLNode是Google提供的用于XML数据处理的类集。该类集对libxml2--DOM处理方式进行了封装,能对较小或中等的xml文档进行读写操作且支持XPath语法。 使用方法:1、获取GDataXMLNode.h/m文件,将GDataXMLNode.h/m文件添加…

RetinaNet+focal loss

one stage 精度不高,一个主要原因是正负样本的不平衡,以YOLO为例,每个grid cell有5个预测,本来正负样本的数量就有差距,再相当于进行5倍放大后,这种数量上的差异更会被放大。 文中提出新的分类损失函数Foca…

真实用户首次披露Waymo无人车服务体验: 为避开左转, 故意绕路

来源 :Ars Technica编译 :机器之能 高璇外国网友炸了:「就像看了一部大导演导的烂片一样。」在过去的 18 个月里,Waymo 的汽车一直在凤凰城的东南角运送乘客。该公司在合同中明确规定禁止乘客讨论用户体验,对项目信息进…

“横平竖直”进行连线+将相邻框进行合并

一.横平竖直”进行连线 解法1.将一些坐标点按照x相等,y相等连起来 解法1.根据 x或y总有一个相等的,用np.sum来找出和为1的点,然后在连起来,存在重复连线的问题. import numpy as npcoord np.array([[10, 60],[10, 20],[20, 20],[40, 40],[40, 60],[20, 40]])img np.zeros(…

一文看透汽车芯片!巨头布局技术路线全解密【附下载】| 智东西内参

来源:智东西摘要:一文看透汽车芯片!巨头布局技术路线全解密智能驾驶涉及人机交互、视觉处理、智能决策等,核心是 AI 算法和芯片。伴随汽车电子化提速,汽车半导体加速成长,2017 年全球市场规模 288 亿美元&a…

详细介绍软件架构设计的三个维度

如果你对项目管理、系统架构有兴趣,请加微信订阅号“softjg”,加入这个PM、架构师的大家庭 架构设计是一个非常大的话题,不管写几篇文章,接触到的始终只是冰山一角,更多的是实践中去体会。这篇文章主要介绍面向对象OO、…

中国智能语音行业研究

报告来源:中信证券作者:刘雯蜀 杨泽原 张若海智能语音作为人机交互的新型方式,有望大规模推广,中国市场是更适合语音交互的市场。2017年中国人工智能市场规模达约220亿元,智能语音占中国人工智能市场份额的22%&#…

SQL2012 附加数据库提示5120错误解决方法

在win8.1 x64系统上使用sql2012进行附加数据库(包括在x86系统正在使用的数据库文件,直接拷贝附加在X64系统中)时,提示无法打开文件,5120错误。 这个错误是因为没有操作权限,所以附加的时候出错,…

pytorch利用rnn通过sin预测cos 利用lstm预测手写数字

一.利用rnn通过sin预测cos 1.首先可视化一下数据 import numpy as np from matplotlib import pyplot as plt def show(sin_np,cos_np):plt.figure()plt.title(Sin and Cos, fontsize18)plt.plot(steps, sin_np, r-, labelsin)plt.plot(steps, cos_np, b-, labelcos)plt.lege…

高德纳咨询公司(Gartner)预测:2019年七大人工智能科技趋势

来源:创新研究摘要:人工智能技术对我们的工作环境、工作种类等等正在产生日益深刻的影响,其结果或好或坏都有可能。为应对这种改变,特别是负面的变化,高德纳咨询公司(Gartner)于2018年12月13日发…

美爆!《自然》公布2018年19张最震撼的科学图片

来源:前瞻网 摘要:2018年注定将载入科学史册:气候上,从加利福尼亚烧到开普敦的致命野火和极端干旱、历史罕见;医学上,克隆和成像技术的进步既带来希望,也产生了争议;生物上,一系列事件让人们意识…

python实现Trie 树+朴素匹配字符串+RK算法匹配字符串+kmp算法匹配字符串

一.trie树应用: 相应leetcode 常用于搜索提示,如当输入一个网址,可以自动搜索出可能的选择。当没有完全匹配的搜索结果,可以返回前缀最相似的可能。 例如三个单词app, apple, add,我们按照以下规则创建了一颗Trie树.对于从树的根…

天才也勤奋!DeepMind哈萨比斯自述:领导400名博士向前,每天工作至凌晨4点

来源:量子位你见过凌晨4点的伦敦吗?哈萨比斯天天见。这位DeepMind创始人、AlphaGo之父,一直是全球赞颂的当世天才,但每天要到凌晨4点,才能睡下。这是哈萨比斯最新采访中透露的作息时间,他告诉《星期日泰晤士…

RNN知识+LSTM知识+encoder-decoder+ctc+基于pytorch的crnn网络结构

一.基础知识: 下图是一个循环神经网络实现语言模型的示例,可以看出其是基于当前的输入与过去的输入序列,预测序列的下一个字符. 序列特点就是某一步的输出不仅依赖于这一步的输入,还依赖于其他步的输入或输…