前馈神经网络正则化例子

直接看代码:

import torch  
import numpy as np  
import random  
from IPython import display  
from matplotlib import pyplot as plt  
import torchvision  
import torchvision.transforms as transforms   mnist_train = torchvision.datasets.MNIST(root='/MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  batch_size = 256 train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  num_inputs,num_hiddens,num_outputs =784, 256,10def init_param():W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  b1 = torch.zeros(1, dtype=torch.float32)  W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  b2 = torch.zeros(1, dtype=torch.float32)  params =[W1,b1,W2,b2]for param in params:param.requires_grad_(requires_grad=True)  return W1,b1,W2,b2def relu(x):x = torch.max(input=x,other=torch.tensor(0.0))  return xdef net(X):  X = X.view((-1,num_inputs))  H = relu(torch.matmul(X,W1.t())+b1)  #myrelu =((matmal x,w1)+b1),return  matmal(myrelu,w2 )+ b2return relu(torch.matmul(H,W2.t())+b2 )return torch.matmul(H,W2.t())+b2def SGD(paras,lr):  for param in params:  param.data -= lr * param.grad  def l2_penalty(w):return (w**2).sum()/2def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None,mylambda=0):  train_ls, test_ls = [], []for epoch in range(num_epochs):ls, count = 0, 0for X,y in train_iter :X = X.reshape(-1,num_inputs)l=loss(net(X),y)+ mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)optimizer.zero_grad()l.backward()optimizer.step()ls += l.item()count += y.shape[0]train_ls.append(ls)ls, count = 0, 0for X,y in test_iter:X = X.reshape(-1,num_inputs)l=loss(net(X),y) + mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)ls += l.item()count += y.shape[0]test_ls.append(ls)if(epoch)%2==0:print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))return train_ls,test_lslr = 0.01num_epochs = 20Lamda = [0,0.1,0.2,0.3,0.4,0.5]Train_ls, Test_ls = [], []for lamda in Lamda:print("current lambda is %f"%lamda)W1,b1,W2,b2 = init_param()loss = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer,lamda)   Train_ls.append(train_ls)Test_ls.append(test_ls)x = np.linspace(0,len(Train_ls[1]),len(Train_ls[1]))plt.figure(figsize=(10,8))for i in range(0,len(Lamda)):plt.plot(x,Train_ls[i],label= f'L2_Regularization:{Lamda [i]}',linewidth=1.5)plt.xlabel('different epoch')plt.ylabel('loss')plt.legend(loc=2, bbox_to_anchor=(1.1,1.0),borderAxesPad = 0.)plt.title('train loss with L2_penalty')plt.show()

运行结果:

在这里插入图片描述

疑问和心得:

  1. 画图的实现和细节还是有些模糊。
  2. 正则化系数一般是一个可以根据算法有一定变动的常数。
  3. 前馈神经网络中,二分类最后使用logistic函数返回,多分类一般返回softmax值,若是一般的回归任务,一般是直接relu返回。
  4. 前馈神经网络的实现,从物理层上应该是全连接的,但是网上的代码一般都是两层单个神经元,这个容易产生误解。个人感觉,还是要使用nn封装的函数比较正宗。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/42759.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker:Windows container和Linux container

点击"Switch to Windows containers"菜单时: 提示 然后 实际上是运行:com.docker.admin.exe start-service

成集云 | 乐享问题邀请同步企微提醒 | 解决方案

源系统成集云目标系统 方案介绍 腾讯乐享是腾讯公司开发的一款企业社区化知识管理平台,它提供了包括知识库、问答、课堂、考试、活动、投票和论坛等核心应用。这个平台凝聚了腾讯10年的管理经验,可以满足政府、企业和学校在知识管理、学习培训、文化建…

【gitkraken】gitkraken自动更新问题

GitKraken 会自动升级&#xff01;一旦自动升级&#xff0c;你的 GitKraken 自然就不再是最后一个免费版 6.5.1 了。 在安装 GitKraken 之后&#xff0c;在你的安装目录&#xff08;C:\Users\<用户名>\AppData\Local\gitkraken&#xff09;下会有一个名为 Update.exe 的…

Linux环境变量

环境变量 一.基本概念二.常见的环境变量1.PATH&#xff1a;指令搜索路径2.HOME&#xff1a; 指定用户的主工作目录3.SHELL&#xff1a;当前Shell,它的值通常是/bin/bash 三.查看环境变量的方法四.命令行参数五.环境变量增加和删除六.本地变量 一个问题&#xff1a;我们在写一段…

Kotlin~Bridge桥接模式

概念 抽象和现实之间搭建桥梁&#xff0c;分离实现和抽象。 抽象&#xff08;What&#xff09;实现&#xff08;How&#xff09;用户可见系统正常工作的底层代码产品付款方式定义数据类型的类。处理数据存储和检索的类 角色介绍 Abstraction&#xff1a;抽象 定义抽象接口&…

《Go 语言第一课》课程学习笔记(五)

入口函数与包初始化&#xff1a;搞清 Go 程序的执行次序 main.main 函数&#xff1a;Go 应用的入口函数 Go 语言中有一个特殊的函数&#xff1a;main 包中的 main 函数&#xff0c;也就是 main.main&#xff0c;它是所有 Go 可执行程序的用户层执行逻辑的入口函数。 Go 程序在…

一起创建Vue脚手架吧

目录 一、安装Vue CLI1.1 配置 npm 淘宝镜像1.2 全局安装1.3 验证是否成功 二、创建vue_test项目2.1 cmd进入桌面2.2 创建项目2.3 运行项目2.4 查看效果 三、脚手架结构分析3.1 文件目录结构分析3.2 vscode终端打开项目 一、安装Vue CLI CLI&#xff1a;command-line interface…

日常BUG——微信小程序提交代码报错

&#x1f61c;作 者&#xff1a;是江迪呀✒️本文关键词&#xff1a;日常BUG、BUG、问题分析☀️每日 一言 &#xff1a;存在错误说明你在进步&#xff01; 一、问题描述 在使用微信小程序开发工具进行提交代码时&#xff0c;报出如下错误&#xff1a; Invalid a…

Git提交规范指南

在开发过程中&#xff0c;Git每次提交代码&#xff0c;都需要写Commit message&#xff08;提交说明&#xff09;&#xff0c;规范的Commit message有很多好处&#xff1a; 方便快速浏览查找&#xff0c;回溯之前的工作内容可以直接从commit 生成Change log(发布时用于说明版本…

5、flink任务中可以使用哪些转换算子(Transformation)

1、什么是Flink中的转换算子 在使用 Flink DataStream API 开发流式计算任务时&#xff0c;可以将一个或多个 DataStream 转换成新的 DataStream&#xff0c;在应用程序中可以将多个数据转换算子合并成一个复杂的数据流拓扑图。 2、常用的转换算子 Flink提供了功能各异的转换算…

[论文笔记]ON LAYER NORMALIZATION IN THE TRANSFORMER ARCHITECTURE

引言 这是论文ON LAYER NORMALIZATION IN THE TRANSFORMER ARCHITECTURE的阅读笔记。本篇论文提出了通过Pre-LN的方式可以省掉Warm-up环节,并且可以加快Transformer的训练速度。 通常训练Transformer需要一个仔细设计的学习率warm-up(预热)阶段:在训练开始阶段学习率需要设…

JDK 1.6与JDK 1.8的区别

ArrayList使用默认的构造方式实例 jdk1.6默认初始值为10jdk1.8为0,第一次放入值才初始化&#xff0c;属于懒加载 Hashmap底层 jdk1.6与jdk1.8都是数组链表 jdk1.8是链表超过8时&#xff0c;自动转为红黑树 静态方式不同 jdk1.6是先初始化static后执行main方法。 jdk1.8是懒加…

设置PHP的fpm的系统性能参数pm.max_children

1 介绍 PHP从Apache module换成了Fpm&#xff0c;跑了几天突然发现网站打不开了。 页面显示超时&#xff0c;检查MySQL、Redis一众服务都正常。 进入Fpm容器查看日志&#xff0c;发现了如下的错误信息&#xff1a; server reached pm.max_children setting (5), consider r…

python中的svm:介绍和基本使用方法

python中的svm&#xff1a;介绍和基本使用方法 支持向量机&#xff08;Support Vector Machine&#xff0c;简称SVM&#xff09;是一种常用的分类算法&#xff0c;可以用于解决分类和回归问题。SVM通过构建一个超平面&#xff0c;将不同类别的数据分隔开&#xff0c;使得正负样…

2023全国大学生数学建模竞赛A题B题C题D题E题思路+模型+代码+论文

目录 一. 2023国赛数学建模思路&#xff1a; 赛题发布后会第一时间发布选题建议&#xff0c;思路&#xff0c;模型代码等 详细思路获取见文末名片&#xff0c;9.7号第一时间更新 二.国赛常用的模型算法&#xff1a; 三、算法简介 四.超重要&#xff01;&#xff01;&…

msvcp140.dll丢失的解决方法,如何预防msvcp140.dll丢失

在电脑操作系统中经常会弹出类似msvcp140.dll丢失的错误提示窗口&#xff0c;导致软件无法正常运行。为什么会出现msvcp140.dll丢失的情况呢&#xff1f;出现这种情况应该如何解决呢&#xff1f;小编有三种解决方法分享给大家。 一.msvcp140.dll丢失的原因 1.安装过程中受损:在…

前端框架学习-ES6新特性(尚硅谷web笔记)

ECMASript是由 Ecma 国际通过 ECMA-262 标准化的脚本程序设计语言。javaScript也是该规范的一种实现。 新特性目录 笔记出处&#xff1a;b站ES6let 关键字const关键字变量的解构赋值模板字符串简化对象写法箭头函数rest参数spread扩展运算符Promise模块化 ES8async 和 await E…

云原生周刊:Kubernetes v1.28 新特性一览 | 2023.8.14

推荐一个 GitHub 仓库&#xff1a;Fast-Kubernetes。 Fast-Kubernetes 是一个涵盖了 Kubernetes 的实验室&#xff08;LABs&#xff09;的仓库。它提供了关于 Kubernetes 的各种主题和组件的详细内容&#xff0c;包括 Kubectl、Pod、Deployment、Service、ConfigMap、Volume、…

CF1013B And 题解

题目传送门 题目意思&#xff1a; 给你一个长度为 n n n 的序列 a i a_i ai​&#xff0c;再给一个数 x x x。每一步你可以将序列中的一个数与上 x x x。请问最少要多少步才可以使得序列中出现两个相同的数&#xff0c;如果无解输出 − 1 -1 −1。 思路&#xff1a; 首…

Vue页面刷新常用的4种方法

Vue项目里,有时候我们需要刷新页面,重新加载页面数据,常用方法如下: 方法一:location.reload() 方法全局刷新 使用 location.reload() 方法可以简单地实现当前页面的刷新,这个方法会重新加载当前页面,类似于用户点击浏览器的刷新按钮。 在 Vue 中,可以将该方法绑定到…