【智能计算系统】神经网络基础代码实现

参考课程:智能计算系统

神经网络中常见的组成部分有:全连接层,激活函数,Softmax层。

全连接层

全连接层输入为一维向量(维度为m),输出为一维向量(维度为n)。前向传播的公式为:

y = W T x + b y=W^Tx+b y=WTx+b

其中,W为二维权重矩阵,维度为 m × n m×n m×n;偏置b是一维向量,维度为n。

以手写数字识别为例,输入的手写数字图像信息,经过第一轮的正向传播后,会产生一个预测结果(结果以概率的形式体现),但这个概率与真实情况的概率(比如数字7的手写图像,理想情况下它为7的概率为1,为其他数字的概率为0)有差别,这样便产生了差值。利用差值来调整网络中的参数(权重、偏置等),得到较为准确的预测结果。

反向传播在理论上通过求偏导的方式来解决。实际应用中通常使用批量随机梯度下降的方法进行反向传播计算。

梯度下降法:通过不断地沿着函数的负梯度方向更新参数,从而逼近函数的局部最小值。(是一种基于一阶导数的优化算法)

代码实现:

# 要引入numpy, os, time等包
class FullyConnectedLayer(object):def __init__(self, num_input, num_output):  # 全连接层初始化self.num_input = num_input # 输入的一维向量维度self.num_output = num_output # 输出的一维向量维度print('\tFully connected layer with input %d, output %d.' % (self.num_input, self.num_output))def init_param(self, std=0.01):  # 参数初始化# 正态分布随机生成数# loc=0.0 正态分布的均值# scale=std:正态分布的标准差# size:生成随机数的形状,这里是二维数组,行数为self.num_input,列数为self.num_output# 对应于全连接层的权重矩阵self.weight = np.random.normal(loc=0.0, scale=std, size=(self.num_input, self.num_output)) self.bias = np.zeros([1, self.num_output]) # 生成全零数组def forward(self, input):  # 前向传播计算start_time = time.time()self.input = input#全连接层的前向传播,计算输出结果#Y = WX + bself.output = np.dot(self.input, self.weight) + self.bias # np.dot:求两个数组的内积return self.outputdef backward(self, top_diff):  # 反向传播的计算# 全连接层的反向传播,计算参数梯度和本层损失# top_diff是“从上一层传递下来的梯度”(损失函数关于上一层输出的偏导数,这些偏导数用于计算当前层的参数梯度(即权重和偏置的更新量)以及进一步传递给下一层)self.d_weight = np.dot(self.input.T, top_diff)self.batch_size = top_diff.shape[0]self.d_bias = np.dot(np.ones(shape=(1,self.batch_size)),top_diff)bottom_diff = np.dot(top_diff, self.weight.T)return bottom_diffdef update_param(self, lr):  # 参数更新# 对全连接层参数利用参数进行更新self.weight = self.weight - lr * self.d_weightself.bias = self.bias - lr * self.d_biasdef load_param(self, weight, bias):  # 参数加载assert self.weight.shape == weight.shapeassert self.bias.shape == bias.shapeself.weight = weightself.bias = biasdef save_param(self):  # 参数保存return self.weight, self.bias

激活函数

为什么要使用激活函数:因为神经网络中每一层的输入输出都是一个线性求和的过程,下一层的输出只是承接了上一层输入函数的线性变换,所以如果没有激活函数,那么无论你构造的神经网络多么复杂,有多少层,最后的输出都是输入的线性组合,纯粹的线性组合并不能够解决更为复杂的问题。而引入激活函数之后,我们会发现常见的激活函数都是非线性的,因此也会给神经元引入非线性元素,使得神经网络可以逼近其他的任何非线性函数,这样可以使得神经网络应用到更多非线性模型中。

参考资料:知乎

ReLU激活函数:x<0时y=0,x>0时y=x。

代码实现:

class ReLULayer(object):def __init__(self):print('\tReLU layer.')def forward(self, input):  # 前向传播的计算start_time = time.time()self.input = input# ReLU层的前向传播,计算输出结果# y(i) = max(0, x(i))output = np.maximum(0, self.input)return outputdef backward(self, top_diff):  # 反向传播的计算# ReLU层的反向传播,计算本层损失(计算损失函数对本层的第i个输入的偏导)# 公式:x(i)>=0时为:根据损失函数对输出的偏导▽yL计算损失函数对输入的偏导▽xL# x(i)<0时取0bottom_diff = np.zeros_like(self.input)mask = self.input >= 0np.putmask(bottom_diff, mask, top_diff)#bottom_diff[mask] = top_diff #TypeError: NumPy boolean array indexing assignment requires a 0 or 1-dimensional input, input has 2 dimensions. changereturn bottom_diff

Softmax损失层

Softmax损失层是目前多分类问题中最常用的损失函数层。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/763981.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ThingsBoard初始化数据库Postgres+Cassandra

本章将介绍ThingsBoard初始化数据PostgresCassandra&#xff0c;两种数据库结合使用&#xff0c;以及源码的编译安装。本机环境&#xff1a;Centos7、Docker、Postgres、Cassandra 环境安装 开发环境要求&#xff1a; docker &#xff1b;Docker&#xff1b;Postgres:Cassandr…

Qwen及Qwen-audio大模型微调项目汇总

Qwen及Qwen-audio可微调项目调研 可用来微调方法/项目汇总ps.大语言模型基础资料 可用来微调方法/项目汇总 Qwen github 项目自带的finetune脚本 可以参考https://blog.csdn.net/qq_45156060/article/details/135153920PAI-DSW中微调千问大模型&#xff08;阿里云的一个产品&a…

maven archetype 和普通的maven有啥区别

Maven是一个项目管理和构建自动化工具&#xff0c;主要用于Java项目&#xff0c;它基于项目对象模型&#xff08;POM&#xff09;。Maven可以通过其POM文件管理项目的构建、报告和文档。 而Maven Archetype是Maven的一个插件&#xff0c;主要用于为用户创建基于某个模版的新项…

Django日志(一)

一、概念与配置 1.1、概述 日志是程序员经常在代码中使用快速和方便的调试工具。它在调试方面比print更加的优雅和灵活 而且日志记录对于调试很有用,可以提供更多,更好的结构化,有关应用程序的状态和运行状况的信息 Django框架的日志通过python内置的logging模块实现的,可…

TCP TLS

TCP&#xff08;传输控制协议&#xff09;是一种面向连接的协议&#xff0c;用于在网络上可靠地传输数据。它提供了数据分段、重传、流量控制和拥塞控制等功能&#xff0c;以确保数据的可靠传输。TCP在传输层上工作&#xff0c;它使用IP&#xff08;Internet协议&#xff09;作…

LeetCode2671. Frequency Tracker

文章目录 一、题目二、题解 一、题目 Design a data structure that keeps track of the values in it and answers some queries regarding their frequencies. Implement the FrequencyTracker class. FrequencyTracker(): Initializes the FrequencyTracker object with …

前端视角如何理解“时间复杂度O(n)”

定义 时间复杂度是O(n) 意味着算法的执行时间与输入数据的大小成正比。 这里的n表示输入数据的数量。 假设有一个数组&#xff0c;需要遍历这个数组并打印出每个元素的值。 这个操作的时间复杂度就是O(n)&#xff0c;因为你需要执行n次操作&#xff0c;其中n是数组的长度。 …

力扣由浅至深 每日一题.11 加一

少年气&#xff0c;是历经千帆举重若轻地沉淀&#xff0c;也是乐观淡然笑对生活的豁达 —— 24.3.22 加一 给定一个由 整数 组成的 非空 数组所表示的非负整数&#xff0c;在该数的基础上加一。 最高位数字存放在数组的首位&#xff0c; 数组中每个元素只存储单个数字。 你可以…

IPC通信--socket

1.windows环境 在C中&#xff0c;Windows环境下实现socket通信的客户端与服务端的流程如下&#xff1a; 创建套接字&#xff1a;使用socket()函数创建一个套接字。绑定套接字&#xff1a;使用bind()函数将套接字与一个地址&#xff08;IP和端口&#xff09;绑定在一起。监听连…

基于Gabor滤波器的指纹图像识别,Matlab实现

博主简介&#xff1a; 专注、专一于Matlab图像处理学习、交流&#xff0c;matlab图像代码代做/项目合作可以联系&#xff08;QQ:3249726188&#xff09; 个人主页&#xff1a;Matlab_ImagePro-CSDN博客 原则&#xff1a;代码均由本人编写完成&#xff0c;非中介&#xff0c;提供…

2024年【山东省安全员C证】考试试卷及山东省安全员C证复审模拟考试

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 山东省安全员C证考试试卷是安全生产模拟考试一点通生成的&#xff0c;山东省安全员C证证模拟考试题库是根据山东省安全员C证最新版教材汇编出山东省安全员C证仿真模拟考试。2024年【山东省安全员C证】考试试卷及山东省…

一个线程可以有几个Handler,几个Looper,几个MessageQueue对象?

一个线程可以有多个Handler&#xff0c;但是只有一个Looper对象,只有一个MessageQueue对象。 在Looper的prepare方法中创建了Looper对象&#xff0c;并放入到ThreadLocal中&#xff0c;并通过ThreadLocal来获取looper的对象, ThreadLocal的内部维护了一个ThreadLocalMap类, 里…

【Linux】信号的处理{信号处理的时机/了解寄存器/内核态与用户态/信号操作函数}

文章目录 0.对于信号捕捉的理解1.信号处理的时机1.1 何时处理信号&#xff1f;1.2 内核态和用户态1.3 内核态和用户态的切换 2.了解寄存器3.信号捕捉的原理4.信号操作函数4.1sighandler_t signal(int signum, sighandler_t handler);4.2int sigaction(int signum, const struct…

express+mysql+vue,从零搭建一个商城管理系统15--快递查询(对接快递100)

提示&#xff1a;学习express&#xff0c;搭建管理系统 文章目录 前言一、安装md5&#xff0c;axios二、新建config/logistics.js三、修改routes/order.js四、查询物流信息五、试错与误区总结 前言 需求&#xff1a;主要学习express&#xff0c;所以先写service部分 快递100API…

工业项目中你连DCS系统都没见过?

什么是DCS DCS&#xff0c;即分散控制系统&#xff0c;是一种用于监控和控制工业过程的系统。它通过连接多个控制器、传感器和执行器实现自动化控制&#xff0c;提高生产效率和安全性。在中国&#xff0c;随着工业化和自动化水平的提高&#xff0c;DCS技术得到了广泛应用和快速…

创建Message对象的方式及区别?Message.obtain()怎么维护消息池 ?Handler 有哪些发送消息的方法?

Message对象创建的方式有哪些&#xff0c; 区别&#xff1f; 直接new一个obtain&#xff08;&#xff09;方法获取handler.obtainMessage()方法获取。 下面两个方式是从对象池中获取&#xff0c;可以避免message对象重复的创建。 Message.obtain()怎么维护消息池的Handler &…

外包干了10天,技术倒退明显

先说情况&#xff0c;大专毕业&#xff0c;18年通过校招进入湖南某软件公司&#xff0c;干了接近6年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了四年的功能测试&#xf…

推免保研夏令营/预推免面试记录—北大软微

0x00简述 0x01 面试经历 0x02 相关资料下载 0x00简述 0x01 面试经历 0x02 相关资料下载 挖坑待写

SpringMVC 简介及入门级的快速搭建详细步骤

MVC 回顾 MVC&#xff0c;即Model-View-Controller&#xff08;模型-视图-控制器&#xff09;设计模式&#xff0c;是一种广泛应用于软件工程中&#xff0c;特别是Web应用开发中的架构模式。它将应用程序分为三个核心组件&#xff1a; Model&#xff08;模型&#xff09;&#…

面试问答示范

文章目录 请做个自我介绍您的学历是统招吗&#xff1f;可以在学信网查询吗是全日制吗是双证吗&#xff1f;请介绍一下你上家公司的情况。介绍一下你们公司的服务器架构&#xff08;网络架构&#xff09;。说说你在工作中处理过的最棘手的技术问题讲一讲上家公司做过的项目为什么…