前馈神经网络正则化例子

直接看代码:

import torch  
import numpy as np  
import random  
from IPython import display  
from matplotlib import pyplot as plt  
import torchvision  
import torchvision.transforms as transforms   mnist_train = torchvision.datasets.MNIST(root='/MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  batch_size = 256 train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  num_inputs,num_hiddens,num_outputs =784, 256,10def init_param():W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  b1 = torch.zeros(1, dtype=torch.float32)  W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  b2 = torch.zeros(1, dtype=torch.float32)  params =[W1,b1,W2,b2]for param in params:param.requires_grad_(requires_grad=True)  return W1,b1,W2,b2def relu(x):x = torch.max(input=x,other=torch.tensor(0.0))  return xdef net(X):  X = X.view((-1,num_inputs))  H = relu(torch.matmul(X,W1.t())+b1)  #myrelu =((matmal x,w1)+b1),return  matmal(myrelu,w2 )+ b2return relu(torch.matmul(H,W2.t())+b2 )return torch.matmul(H,W2.t())+b2def SGD(paras,lr):  for param in params:  param.data -= lr * param.grad  def l2_penalty(w):return (w**2).sum()/2def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None,mylambda=0):  train_ls, test_ls = [], []for epoch in range(num_epochs):ls, count = 0, 0for X,y in train_iter :X = X.reshape(-1,num_inputs)l=loss(net(X),y)+ mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)optimizer.zero_grad()l.backward()optimizer.step()ls += l.item()count += y.shape[0]train_ls.append(ls)ls, count = 0, 0for X,y in test_iter:X = X.reshape(-1,num_inputs)l=loss(net(X),y) + mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)ls += l.item()count += y.shape[0]test_ls.append(ls)if(epoch)%2==0:print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))return train_ls,test_lslr = 0.01num_epochs = 20Lamda = [0,0.1,0.2,0.3,0.4,0.5]Train_ls, Test_ls = [], []for lamda in Lamda:print("current lambda is %f"%lamda)W1,b1,W2,b2 = init_param()loss = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer,lamda)   Train_ls.append(train_ls)Test_ls.append(test_ls)x = np.linspace(0,len(Train_ls[1]),len(Train_ls[1]))plt.figure(figsize=(10,8))for i in range(0,len(Lamda)):plt.plot(x,Train_ls[i],label= f'L2_Regularization:{Lamda [i]}',linewidth=1.5)plt.xlabel('different epoch')plt.ylabel('loss')plt.legend(loc=2, bbox_to_anchor=(1.1,1.0),borderAxesPad = 0.)plt.title('train loss with L2_penalty')plt.show()

运行结果:

在这里插入图片描述

疑问和心得:

  1. 画图的实现和细节还是有些模糊。
  2. 正则化系数一般是一个可以根据算法有一定变动的常数。
  3. 前馈神经网络中,二分类最后使用logistic函数返回,多分类一般返回softmax值,若是一般的回归任务,一般是直接relu返回。
  4. 前馈神经网络的实现,从物理层上应该是全连接的,但是网上的代码一般都是两层单个神经元,这个容易产生误解。个人感觉,还是要使用nn封装的函数比较正宗。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/42759.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker:Windows container和Linux container

点击"Switch to Windows containers"菜单时: 提示 然后 实际上是运行:com.docker.admin.exe start-service

成集云 | 乐享问题邀请同步企微提醒 | 解决方案

源系统成集云目标系统 方案介绍 腾讯乐享是腾讯公司开发的一款企业社区化知识管理平台,它提供了包括知识库、问答、课堂、考试、活动、投票和论坛等核心应用。这个平台凝聚了腾讯10年的管理经验,可以满足政府、企业和学校在知识管理、学习培训、文化建…

【gitkraken】gitkraken自动更新问题

GitKraken 会自动升级&#xff01;一旦自动升级&#xff0c;你的 GitKraken 自然就不再是最后一个免费版 6.5.1 了。 在安装 GitKraken 之后&#xff0c;在你的安装目录&#xff08;C:\Users\<用户名>\AppData\Local\gitkraken&#xff09;下会有一个名为 Update.exe 的…

Linux环境变量

环境变量 一.基本概念二.常见的环境变量1.PATH&#xff1a;指令搜索路径2.HOME&#xff1a; 指定用户的主工作目录3.SHELL&#xff1a;当前Shell,它的值通常是/bin/bash 三.查看环境变量的方法四.命令行参数五.环境变量增加和删除六.本地变量 一个问题&#xff1a;我们在写一段…

Kotlin~Bridge桥接模式

概念 抽象和现实之间搭建桥梁&#xff0c;分离实现和抽象。 抽象&#xff08;What&#xff09;实现&#xff08;How&#xff09;用户可见系统正常工作的底层代码产品付款方式定义数据类型的类。处理数据存储和检索的类 角色介绍 Abstraction&#xff1a;抽象 定义抽象接口&…

一起创建Vue脚手架吧

目录 一、安装Vue CLI1.1 配置 npm 淘宝镜像1.2 全局安装1.3 验证是否成功 二、创建vue_test项目2.1 cmd进入桌面2.2 创建项目2.3 运行项目2.4 查看效果 三、脚手架结构分析3.1 文件目录结构分析3.2 vscode终端打开项目 一、安装Vue CLI CLI&#xff1a;command-line interface…

日常BUG——微信小程序提交代码报错

&#x1f61c;作 者&#xff1a;是江迪呀✒️本文关键词&#xff1a;日常BUG、BUG、问题分析☀️每日 一言 &#xff1a;存在错误说明你在进步&#xff01; 一、问题描述 在使用微信小程序开发工具进行提交代码时&#xff0c;报出如下错误&#xff1a; Invalid a…

5、flink任务中可以使用哪些转换算子(Transformation)

1、什么是Flink中的转换算子 在使用 Flink DataStream API 开发流式计算任务时&#xff0c;可以将一个或多个 DataStream 转换成新的 DataStream&#xff0c;在应用程序中可以将多个数据转换算子合并成一个复杂的数据流拓扑图。 2、常用的转换算子 Flink提供了功能各异的转换算…

msvcp140.dll丢失的解决方法,如何预防msvcp140.dll丢失

在电脑操作系统中经常会弹出类似msvcp140.dll丢失的错误提示窗口&#xff0c;导致软件无法正常运行。为什么会出现msvcp140.dll丢失的情况呢&#xff1f;出现这种情况应该如何解决呢&#xff1f;小编有三种解决方法分享给大家。 一.msvcp140.dll丢失的原因 1.安装过程中受损:在…

Elasticsearch 查询之Function Score Query

前言 ES 的主查询评分模式分为两种&#xff0c;是信息检索领域的重要算法&#xff1a; TF-IDF 算法 和 BM25 算法。 Elasticsearch 从版本 5.0 开始引入了 BM25 算法作为默认的文档评分&#xff08;relevance scoring&#xff09;算法。在此之前&#xff0c;Elasticsearch 使…

sip网络号角喇叭 sip音柱 POE供电广播音箱 ip网络防水对讲终端 sip网络功放

SV-7042TP网络号角喇叭 一、描述 SV-7042TP是我司的一款SIP网络号角喇叭&#xff0c;具有10/100M以太网接口&#xff0c;内置有一个高品质扬声器&#xff0c;将网络音源通过自带的功放和喇叭输出播放&#xff0c;可达到功率30W。SV-7042TP作为SIP系统的播放终端&#xff0c;可…

【脚踢数据结构】常见树总结(图码结和版)

(꒪ꇴ꒪ )&#xff0c;Hello我是祐言QAQ我的博客主页&#xff1a;C/C语言&#xff0c;Linux基础&#xff0c;ARM开发板&#xff0c;软件配置等领域博主&#x1f30d;快上&#x1f698;&#xff0c;一起学习&#xff0c;让我们成为一个强大的攻城狮&#xff01;送给自己和读者的…

如何构造不包含字母和数字的webshell

利用不含字母与数字进行绕过 1.异或进行绕过 2.取反进行绕过 3.利用php语法绕过 利用不含字母与数字进行绕过 基本代码运行思路理解 <?php echo "A"^""; ?> 运行结果为! 我们可以看到&#xff0c;输出的结果是字符"!"。之所以会…

对比 VPN 与远程桌面软件,为什么远程桌面更优越

数字格局不断演变&#xff0c;我们的工作和连接方式也在不断变化。企业纷纷转向远程运营&#xff0c;有关推进向远程过渡的最佳技术的争论从未停止。争论的焦点通常是虚拟专用网络&#xff08;VPN&#xff09;和远程桌面软件。 长期以来&#xff0c;VPN 一直被用作访问公司网络…

【C++】函数指针

2023年8月18日&#xff0c;周五上午 今天在B站看Qt教学视频的时候遇到了 目录 语法和typedef或using结合我的总结 语法 返回类型 (*指针变量名)(参数列表)以下是一些示例来说明如何声明不同类型的函数指针&#xff1a; 声明一个不接受任何参数且返回void的函数指针&#xf…

【Flink】Flink窗口触发器

数据进入到窗口的时候,窗口是否触发后续的计算由窗口触发器决定,每种类型的窗口都有对应的窗口触发机制。WindowAssigner 默认的 Trigger通常可解决大多数的情况。我们通常使用方式如下,调用trigger()方法把我们想执行触发器传递进去: SingleOutputStreamOperator<Produ…

kubernetes--技术文档--基本概念--《10分钟快速了解》

官网主页&#xff1a; Kubernetes 什么是k8s Kubernetes 也称为 K8s&#xff0c;是用于自动部署、扩缩和管理容器化应用程序的开源系统。 它将组成应用程序的容器组合成逻辑单元&#xff0c;以便于管理和服务发现。Kubernetes 源自Google 15 年生产环境的运维经验&#xff0c…

《一个操作系统的实现》windows用vm安装CentOS——从bochs环境搭建到第一个demo跑通

vm安装CentOS虚拟机带有桌面的版本。su输入密码123456。更新yum -y update 。一般已经安装好后面这2个工具&#xff1a;yum install -y net-tools wget。看下ip地址ifconfig&#xff0c;然后本地终端连接ssh root192.168.249.132输入密码即可&#xff0c;主要是为了复制网址方便…

Netty+springboot开发即时通讯系统笔记(四)终

实时性 1.线程池多线程&#xff0c;把消息同步给其他端和对方用户&#xff0c;其中数据持久化往往是最浪费时间的操作&#xff0c;可以使用mq异步存储&#xff0c;因为其他业务不需要拿着整条数据&#xff0c;只需要这条数据的id进行操作。 2。消息校验前置&#xff0c;放在t…

Vim的插件管理器之Vundle

1、安装Vundle插件管理器 Vim可以安装插件&#xff0c;但是需要手动安装比较麻烦&#xff0c;Vim本身没有提供插件管理器&#xff0c;所以会有很多的第三方的插件管理器&#xff0c;有一个vim的插件叫做 “vim-easymotion”&#xff0c;在它的github的安装说明里有列出对于不同…