神经网络模型预训练

根据神经网络各个层的计算逻辑用程序实现相关的计算,主要是:前向传播计算、反向传播计算、损失计算、精确度计算等,并提供保存超参数到文件中。

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
from DeepLearn_Base.common.functions import *
from DeepLearn_Base.common.gradient import numerical_gradient
import pickle# 三层神经网络处理类(两层隐藏层+1层输出层)
class ThreeLayerNet:# input_size:输入层神经元数量,灰度图像的三维表示: 1 * 28 * 28 = 784# output_size: 输出层神经元数量,10,表示10个数字# hidden_size:第一层隐藏层神经元数量,50# second_hidden_size:第二层隐藏层神经元数量,100# weight_init_std:权重初始化def __init__(self, input_size, hidden_size, output_size, second_hidden_size, weight_init_std=0.01):# 初始化权重self.params = {}self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)self.params['b1'] = np.zeros(hidden_size)self.params['W2'] = weight_init_std * np.random.randn(hidden_size, second_hidden_size)self.params['b2'] = np.zeros(second_hidden_size)self.params['W3'] = weight_init_std * np.random.randn(second_hidden_size, output_size)self.params['b3'] = np.zeros(output_size)# 执行预测def predict(self, x):W1, W2, W3 = self.params['W1'], self.params['W2'], self.params['W3']b1, b2, b3 = self.params['b1'], self.params['b2'], self.params['b3']# 隐藏层第一层a1 = np.dot(x, W1) + b1z1 = sigmoid(a1)# 隐藏层第二层a2 = np.dot(z1, W2) + b2z2 = sigmoid(a2)# 输出层a3 = np.dot(z2, W3) + b3y = softmax(a3)return y# x:输入数据, t:监督数据def loss(self, x, t):y = self.predict(x)return cross_entropy_error(y, t)# 精确度计算def accuracy(self, x, t):y = self.predict(x)y = np.argmax(y, axis=1)t = np.argmax(t, axis=1)accuracy = np.sum(y == t) / float(x.shape[0])return accuracy# 梯度计算def gradient(self, x, t):W1, W2, W3 = self.params['W1'], self.params['W2'], self.params['W3']b1, b2, b3 = self.params['b1'], self.params['b2'], self.params['b3']grads = {}batch_num = x.shape[0]# forward# 隐藏层第一层a1 = np.dot(x, W1) + b1z1 = sigmoid(a1)# 隐藏层第二层a2 = np.dot(z1, W2) + b2z2 = sigmoid(a2)# 输出层a3 = np.dot(z2, W3) + b3y = softmax(a3)# backward# 两层隐藏层计算梯度# 输出层梯度: Loss与输出的导数,分类场景下,等于预测值-真实值# 权重梯度: 隐藏层输出的转置 * 损失函数梯度dy = (y - t) / batch_numgrads['W3'] = np.dot(z2.T, dy)grads['b3'] = np.sum(dy, axis=0)# 反向传播到隐藏层# 隐藏层梯度:Loss与输出的导数 * 输出层权重的转置da2 = np.dot(dy, W3.T)dz2 = sigmoid_grad(a2) * da2grads['W2'] = np.dot(z1.T, dz2)grads['b2'] = np.sum(dz2, axis=0)da1 = np.dot(da2, W2.T)dz1 = sigmoid_grad(a1) * da1grads['W1'] = np.dot(x.T, dz1)grads['b1'] = np.sum(dz1, axis=0)return grads# 保存参数到文件def save_params(self, file_name="params.pkl"):params = {}for key, val in self.params.items():params[key] = valwith open(file_name, 'wb') as f:pickle.dump(params, f)

预训练实现

读取MNIST训练数据集,总共有60000个。每次从60000个训练数据中随机取出100个数据 (图像数据和正确解标签数据)。然后,对这个包含100笔数据的批数据求梯度,使用随机梯度下降法(SGD)更新参数。这里,梯度法的更新次数(循环的次数)为10000。每更新一次,都对训练数据计算损失函数的值,并把该值添加到数组中。

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from DeepLearn_Base.dataset.mnist import load_mnist
from three_layer_net import ThreeLayerNet# 读入数据
# x_train.sharp 60000 * 784
# t_train.sharp 60000 * 10
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)network = ThreeLayerNet(input_size=784, hidden_size=50, second_hidden_size=100, output_size=10)iters_num = 10000  # 适当设定循环的次数
# 训练集大小 60000
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.1train_loss_list = []
train_acc_list = []
test_acc_list = []# 每批次迭代数量:600
iter_per_epoch = max(train_size / batch_size, 1)for i in range(iters_num):# 从训练集中选取100个为一批次进行训练batch_mask = np.random.choice(train_size, batch_size)x_batch = x_train[batch_mask]t_batch = t_train[batch_mask]# 更新超参数梯度grad = network.gradient(x_batch, t_batch)# 更新超参数W,b# 基于SGD算法更新梯度,上面是随机选择的批数据处理,因此更新时,也是随即更新梯度for key in ('W1', 'b1', 'W2', 'b2', 'W3', 'b3'):network.params[key] -= learning_rate * grad[key]loss = network.loss(x_batch, t_batch)train_loss_list.append(loss)if i % iter_per_epoch == 0:train_acc = network.accuracy(x_train, t_train)test_acc = network.accuracy(x_test, t_test)train_acc_list.append(train_acc)test_acc_list.append(test_acc)print("train acc, test acc | " + str(train_acc) + ", " + str(test_acc))# 绘制图形
markers = {'train': 'o', 'test': 's'}
x = np.arange(len(train_acc_list))
plt.plot(x, train_acc_list, label='train acc')
plt.plot(x, test_acc_list, label='test acc', linestyle='--')
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()# 输出到文件保存参宿后
network.save_params("E:\\workcode\\code\\DeepLearn_Base\\ch04\\myparams.pkl")

用图像来表示这个损失函数的值的推移,如图所示;并保存最终的超参数到pkl文件

应用自训练超参数

将之前用于预测图像文字中使用的超参数文件替换为自己预训练生成的pkl参数文件,并执行代码,打印出精确度。
这是基于默认的超参数进行推理后的精确度:

替换超参数文件,进行图像识别推理

 

精确度上涨了0.01,因此选择合适的梯度更新超参数,是保证推理精确度好坏的关键。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/192710.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python百练——第3练】矩形类及操作

💐作者:insist-- 💐个人主页:insist-- 的个人主页 理想主义的花,最终会盛开在浪漫主义的土壤里,我们的热情永远不会熄灭,在现实平凡中,我们终将上岸,阳光万里 ❤️欢迎点…

Golang 原生Rpc Server实现

Golang 原生Rpc Server实现 引言源码解析服务端数据结构服务注册请求处理 客户端数据结构建立连接请求调用 延伸异步调用定制服务名采用TPC协议建立连接自定义编码格式自定义服务器 参考 引言 本文我们来看看golang原生rpc库的实现 , 首先来看一下golang rpc库的demo案例: 服…

python的制图

测试数据示例: day report_user_cnt report_user_cnt_2 label 2023-10-01 3 3 欺诈 2023-10-02 2 4 欺诈 2023-10-03 6 5 欺诈 2023-10-04 2 1 正常 2023-10-05 4 3 正常 2023-10-06 4 4 正常 2023-10-07 2 6 正常 2023-10-08 3 7 正常 2023-10-09 3 12 正常 2023-…

找不到DNS地址的解决方案

找不到DNS地址的解决方案 第一种解决方案:刷新DNS缓存第二种解决方案: 配置Internet协议版本4(TCP/IPv4)配置IP地址配置DNS地址 如何查看本机IPv4地址、子网掩码与默认网关 第一种解决方案:刷新DNS缓存 WINR输入cmd回…

基于SSH三大框架的员工管理系统

基于SSH三大框架的员工管理系统 摘要 本系统为本人学习SSH三大框架时所做的整合实例,系统角色包括普通用户和管理员两种,首页有管理员登录入口链接。系统功能主要包括管理员对用户的基本增、删、改、查和分页显示用户信息等。 系统环境 本系统使用ec…

【C++练级之路】【Lv.1】C++,启动!(命名空间,缺省参数,函数重载,引用,内联函数,auto,范围for,nullptr)

目录 引言入门须知一、命名空间1.1 作用域限定符1.2 命名空间的意义1.3 命名空间的定义1.4 命名空间的使用 二、C输入&输出2.1 cout输出2.2 cin输入2.3 std命名空间的使用惯例 三、缺省参数3.1 缺省参数概念3.2 缺省参数分类 四、函数重载4.1 函数重载概念4.2 函数重载分类…

BUUCTF 间谍启示录 1

BUUCTF:https://buuoj.cn/challenges 题目描述: 在城际公路的小道上,罪犯G正在被警方追赶。警官X眼看他正要逃脱,于是不得已开枪击中了罪犯G。罪犯G情急之下将一个物体抛到了前方湍急的河流中,便头一歪突然倒地。警官X接近一看&…

公平锁和非公平锁以及他们的实现原理是什么

文章目录 什么是非公平锁和公平锁呢?我们来看看acquire(1)的源码如下:这里的判断条件主要做两件事:在tryAcquire()方法中,主要是做了以下几件事:公平锁的tryAcquire(),实现的原理图如下:我们来看…

ORA-00257: archiver error. Connect internal only, until freed 的解决方法

归档文件存储空间不足,导致出现该问题。 当我们将数据库的模式修改为归档模式的时候,如果没有指定归档目录,默认的归档文件就会放到Flash Recovery Area的目录,但是这个目录是有大小限制的,如果超过了这个大小&#x…

C#基础学习--命名空间和程序集

引用其他程序集 编译器接受源代码文件并生成一个名为程序集的输出文件。 在许多项目中,会想使用来自其他程序集的类或类型。这些程序集可能来自BCL或第三方供应商,或者自己创建的。这些程序集称为类库,而且它们的程序集文件的名称通常以dll…

微信小程序组件与插件有啥区别?怎么用?

目录 一、微信小程序介绍 二、微信小程序组件 三、微信小程序插件 四、微信小程序组件与插件有啥区别 一、微信小程序介绍 微信小程序是一种基于微信平台的应用程序,它可以在微信客户端内直接运行,无需下载和安装。微信小程序具有轻量、便捷、跨平台…

对比ProtoBuf和JSON的序列化和反序列化能力

1.序列化能力对比验证 在这里让我们分别使用PB与JSON的序列化与反序列化能力,对值完全相同的一份结构化数据进行不同次数的性能测试。 为了可读性,下面这一份文本使用JSON格式展示了需要被进行测试的结构化数据内容: {"age" : 20,"name…

线程安全的问题以及解决方案

线程安全 线程安全的定义 线程安全:某个代码无论是在单线程上运行还是在多线程上运行,都不会产生bug. 线程不安全:单线程上运行正常,多线程上运行会产生bug. 观察线程不安全 看看下面的代码: public class ThreadTest1 {public static int count 0;public static void main…

数据结构和算法-树与二叉树的存储结构以及树和二叉树和森林的遍历

文章目录 二叉树的存储结构二叉树的顺序存储二叉树的链式存储小结 二叉树的先中后序遍历例题小结 二叉树的层次遍历小结 由遍历序列构造二叉树一个遍历序列即使给定了前中后序,也不能确定该二叉树的形态可以确定的序列组合前序中序后序中序层序中序 小结若前序&…

算力基础设施领域国家标准发布

2023 年 11 月 27 日,国家标准 GB/T 43331-2023《互联网数据中心(IDC)技术和分级要求》正式发布。这一国家标准由中国信息通信研究院(简称“中国信通院”)联合多家企事业单位编制,旨在满足当前国家算力基础…

强化学习(一)——基本概念及DQN

1 基本概念 智能体 agent ,做动作的主体,(大模型中的AI agent) 环境 environment:与智能体交互的对象 状态 state ;当前所处状态,如围棋棋局 动作 action:执行的动作,…

C#——Delegate(委托)与Event(事件)

C#——Delegate(委托)与Event(事件) 前言一、Delegate(委托)1.是什么?2.怎么用?Example 1:无输入无返回值Example 2:有输入Example 3:有返回值Exa…

【C#】接口定义和使用知多少

给自己一个目标,然后坚持一段时间,总会有收获和感悟! 最近在封装和参考sdk时,看到一个不错的写法,并且打破自己对接口和实现类固定的观念,这也充分说明自己理解掌握的知识点还不够深。 目录 前言一、什么是…

Kubernetes(K8s)_16_CSI

Kubernetes(K8s)_16_CSI CSICSI实现CSI接口CSI插件 CSI CSI(Container Storage Interface): 实现容器存储的规范 本质: Dynamic Provisioning、Attach/Detach、Mount/Unmount等功能的抽象CSI功能通过3个gRPC暴露服务: IdentityServer、ControllerServe…

C++二维数组名到底代表个啥

题目先导 int a[3][4]; 则对数组元素a[i][j]正确的引用是*(*(ai)j)先翻译一下这个*(*(ai)j),即a后移i解引用,再后移j再解引用,这么看来a就应该是个二维数组,第一层存储行向量,一次解引用获得行向量的地址,…