神经网络极简入门

神经网络是深度学习的基础,正是深度学习的兴起,让停滞不前的人工智能再一次的取得飞速的发展。

其实神经网络的理论由来已久,灵感来自仿生智能计算,只是以前限于硬件的计算能力,没有突出的表现,直至谷歌的AlphaGO的出现,才让大家再次看到神经网络相较于传统机器学习的优异表现。

本文主要介绍神经网络中的重要基础概念,然后基于这些概念手工实现一个简单的神经网络。希望通过理论结合实践的方式让大家更容易的理解神经网络。

1. 神经网络是什么

神经网络就像人脑一样,整体看上去非常复杂,但是其基础组成部分并不复杂。其组成部分中最重要的就是神经元neural),sigmod函数layer)。

1.1. 神经元

神经元(neural)是神经网络最基本的元素,一个神经元包含3个部分:

  • 获取输入:获取多个输入的数据

  • 数学处理:对输入的数据进行数学计算

  • 产生输出:计算后多个输入数据变成一个输出数据

image.png

从上图中可以看出,神经元中的处理有2个步骤。第一个步骤:从蓝色框变成红色框,是对输入的数据进行加权计算后合并为一个值(N)。N=x1w1+x2w2𝑁=𝑥1𝑤1+𝑥2𝑤2 其中,w1,w2𝑤1,𝑤2分别是输入数据x1,x2𝑥1,𝑥2的权重。一般在计算N𝑁的过程中,除了权重,还会加上一个偏移参数b𝑏,最终得到:N=x1w1+x2w2+b𝑁=𝑥1𝑤1+𝑥2𝑤2+𝑏

第二个步骤:从红色框变成绿色框,通过sigmoid函数是对N进一步加工得到的神经元的最终输出(M)。

1.2. sigmoid函数

sigmoid函数也被称为S函数,因为的形状类似S形

image.png

它是神经元中的重要函数,能够将输入数据的值映射到(0,1)(0,1)之间。最常用的sigmoid函数是 f(x)=11+e−x𝑓(𝑥)=11+𝑒−𝑥,当然,不是只有这一种sigmoid函数。

至此,神经元通过两个步骤,就把输入的多个数据,转换为一个(0,1)(0,1)之间的值。

1.3. 层

多个神经元可以组合成一层,一个神经网络一般包含一个输入层和一个输出层,以及多个隐藏层。

image.png

比如上图中,有2个隐藏层,每个隐藏层中分别有4个和2个神经元。实际的神经网络中,隐藏层数量和其中的神经元数量都是不固定的,根据模型实际的效果来进行调整。

1.4. 网络

通过神经元和层的组合就构成了一个网络,神经网络的名称由此而来。神经网络可大可小,可简单可复杂,不过,太过简单的神经网络模型效果一般不会太好。

因为一只果蝇就有10万个神经元,而人类的大脑则有大约1000亿个神经元,这就是为什么训练一个可用的神经网络模型需要庞大的算力,这也是为什么神经网络的理论1943年就提出了,但是基于深度学习的AlphaGO却诞生于2015年

2. 实现一个神经网络

了解上面的基本概念只能形成一个感性的认知。下面通过自己动手实现一个最简单的神经网络,来进一步认识神经元sigmoid函数以及隐藏层是如何发挥作用的。

2.1. 准备数据

数据使用sklearn库中经典的鸢尾花数据集,这个数据集中有3个分类的鸢尾花,每个分类50条数据。为了简化,只取其中前100条数据来使用,也就是取2个分类的鸢尾花数据。

from sklearn.datasets import load_irisds = load_iris(as_frame=True, return_X_y=True)
data = ds[0].iloc[:100]
target = ds[1][:100]print(data)
print(target)

image.png

变量data100条数据,每条数据包含4个属性,分别是花萼的宽度和长度,花瓣的宽度和长度。

image.png

变量target中也是100条数据,只有0和1两种值,表示两种不同种类的鸢尾花。

2.2. 实现神经元

准备好了数据,下面开始逐步实现一个简单的神经网络。首先,实现最基本的单元--神经元。本文第一节中已经介绍了神经元中主要的2个步骤,分别计算出N𝑁和M𝑀。

image.png

计算N𝑁时,依据每个输入元素的权重(w1,w2𝑤1,𝑤2)和整体的偏移b𝑏;计算M𝑀时,通过sigmoid函数。

def sigmoid(x):return 1 / (1 + np.exp(-1 * x))@dataclass
class Neuron:weights: list[float] = field(default_factory=lambda: [])bias: float = 0.0N: float = 0.0M: float = 0.0def compute(self, inputs):self.N = np.dot(self.weights, inputs) + self.biasself.M = sigmoid(self.N)return self.M

上面的代码中,Neuron类表示神经元,这个类有4个属性:其中属性weightsbias是计算N𝑁时的权重和偏移;属性NM分别是神经元中两步计算的结果。

Neuron类的compute方法根据输入的数据计算神经元的输出。

2.3. 实现神经网络

神经元实现之后,下面就是构建神经网络。我们的输入数据是带有4个属性(花萼的宽度和长度,花瓣的宽度和长度)的鸢尾花数据,所以神经网络的输入层有4个值(x1,x2,x3,x4𝑥1,𝑥2,𝑥3,𝑥4)。

为了简单起见,我们的神经网络只构建一个隐藏层,其中包含3个神经元。最后就是输出层,输出层最后输出一个值,表示鸢尾花的种类。

由此形成的简单神经网络如下图所示:

image.png

实现的代码:

@dataclass
class MyNeuronNetwork:HL1: Neuron = field(init=False)HL2: Neuron = field(init=False)HL3: Neuron = field(init=False)O1: Neuron = field(init=False)def __post_init__(self):self.HL1 = Neuron()self.HL1.weights = np.random.dirichlet(np.ones(4))self.HL1.bias = np.random.normal()self.HL2 = Neuron()self.HL2.weights = np.random.dirichlet(np.ones(4))self.HL2.bias = np.random.normal()self.HL3 = Neuron()self.HL3.weights = np.random.dirichlet(np.ones(4))self.HL3.bias = np.random.normal()self.O1 = Neuron()self.O1.weights = np.random.dirichlet(np.ones(3))self.O1.bias = np.random.normal()def compute(self, inputs):m1 = self.HL1.compute(inputs)m2 = self.HL2.compute(inputs)m3 = self.HL3.compute(inputs)output = self.O1.compute([m1, m2, m3])return output

MyNeuronNetwork类是自定义的神经网络,其中的属性是4个神经元HL1HL2HL3隐藏层的3个神经元;O1输出层的神经元。

__post__init__函数是为了初始化各个神经元。因为输入层是4个属性(x1,x2,x3,x4𝑥1,𝑥2,𝑥3,𝑥4),所以神经元HL1HL2HL3weights初始化为4个随机数组成的列表,偏移(bias)用一个随时数来初始化。

对于神经元O1,它的输入是隐藏层的3个神经元,所以它的weights初始化为3个随机数组成的列表,偏移(bias)还是用一个随时数来初始化。

最后还有一个compute函数,这个函数描述的就是整个神经网络的计算过程。首先,根据输入层(x1,x2,x3,x4𝑥1,𝑥2,𝑥3,𝑥4)的数据计算隐藏层的神经元(HL1HL2HL3);然后,以隐藏层的神经元(HL1HL2HL3)的输出作为输出层的神经元(O1)的输入,并将O1的计算结果作为整个神经网络的输出。

2.4. 训练模型

上面的神经网络中各个神经元的中的参数(主要是weightsbias)都是随机生成的,所以直接使用这个神经网络,效果一定不会很好。所以,我们需要给神经网络(MyNeuronNetwork类)加一个训练函数,用来训练神经网络中各个神经元的参数(也就是个各个神经元中的weightsbias)。

@dataclass
class MyNeuronNetwork:# 略...def train(self, data: pd.DataFrame, target: pd.Series):## 使用 随机梯度下降算法来训练pass

上面的train函数有两个参数data(训练数据)和target(训练数据的标签)。我们使用随机梯度下降算法来训练模型的参数。这里略去了具体的代码,完整的代码可以在文章的末尾下载。

此外,再实现一个预测函数predict,传入测试数据集,然后用我们训练好的神经网络模型来预测测试数据集的标签。

@dataclass
class MyNeuronNetwork:# 略...def predict(self, data: pd.DataFrame):results = []for idx, row in enumerate(data.values):pred = self.compute(row)results.append(round(pred))return results

2.5. 验证模型效果

最后就是验证模型的效果。

def main():# 加载数据ds = load_iris(as_frame=True, return_X_y=True)# 只用前100条数据data = ds[0].iloc[:100]target = ds[1][:100]# 划分训练数据,测试数据# test_size=0.2 表示80%作为训练数据,20%作为测试数据X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)# 创建神经网络nn = MyNeuronNetwork()# 用训练数据集来训练模型nn.train(X_train, y_train)# 检验模型# 用训练过的模型来预测测试数据的标签results = nn.predict(X_test)df = pd.DataFrame()df["预测值"] = resultsdf["实际值"] = y_test.valuesprint(df)

运行结果可以看出,模型的效果还不错,20条测试数据的预测结果都正确。

image.png

3. 总结

本文中的的神经网络示例是为了介绍神经网络的一些基本概念,所以对神经网络做了尽可能的简化,为了方便去手工实现。

而实际环境中的神经网络,不仅神经元的个数,隐藏层的数量极其庞大,而且其计算和训练的方式也很复杂,手工去实现不太可能,一般会借助TensorFlowKerasPyTorch等等知名的python深度学习库来帮助我们实现。

文章转载自:wang_yb

原文链接:https://www.cnblogs.com/wang_yb/p/18176563

体验地址:引迈 - JNPF快速开发平台_低代码开发平台_零代码开发平台_流程设计器_表单引擎_工作流引擎_软件架构

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/8032.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql workbench如何导出insert语句?

进行导出设置 导出的sql文件 CREATE DATABASE IF NOT EXISTS jeesite /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci */ /*!80016 DEFAULT ENCRYPTIONN */; USE jeesite; -- MySQL dump 10.13 Distrib 8.0.28, for Win64 (x86_64) -- -- Host: 127.0…

如何使用dockerfile文件将项目打包成镜像

要根据Dockerfile文件来打包一个Docker镜像,你需要遵循以下步骤。这里假设你已经安装了Docker环境。 1. 准备Dockerfile 确保你的Dockerfile文件已经准备就绪,并且位于你希望构建上下文的目录中。Dockerfile是一个文本文件,包含了用户可以调…

顺序表的实现(迈入数据结构的大门)(1)

什么是数据结构 数据结构是由:“数据”与“结构”两部分组成 数据与结构 数据:如我们所看见的广告、图片、视频等,常见的数值,教务系统里的(姓名、性别、学号、学历等等); 结构:当…

线性表--数据结构设计与操作

单链表 1.单链表的定义&#xff1a; typedef struct LNode{Elemtype data;struct Lnode *next; }LNode ,*LinkList;//单链表的数据结构&#xff08;手写&#xff09; #include<iostream> #include<vector> #include<algorithm>typedef int TypeElem; //单链表…

OpenAI API搭建的智能家居助手;私密大型语言模型(LLM)聊天机器人;视频和音频文件的自动化识别和翻译工具

✨ 1: GPT Home 基于Raspberry Pi和OpenAI API搭建的智能家居助手 GPT Home是一个基于Raspberry Pi和OpenAI API搭建的智能家居助手&#xff0c;功能上类似于Google Nest Hub或Amazon Alexa。通过详细的设置指南和配件列表&#xff0c;用户可以自行组装和配置这个设备&#x…

Ansible自动运维工具之playbook

一.inventory主机清单 1.定义 Inventory支持对主机进行分组&#xff0c;每个组内可以定义多个主机&#xff0c;每个主机都可以定义在任何一个或多个主机组内。 2.变量 &#xff08;1&#xff09;主机变量 [webservers] 192.168.10.14 ansible_port22 ansible_userroot ans…

使用sqlmodel实现唯一性校验

代码&#xff1a; from sqlmodel import Field, Session, SQLModel, create_engine# 声明模型 class User(SQLModel, tableTrue):id: int | None Field(defaultNone, primary_keyTrue)# 不能为空&#xff0c;必须唯一name: str Field(nullableFalse, uniqueTrue)age: int | …

Flutter弹窗链-顺序弹出对话框

效果 前言 弹窗的顺序执行在App中是一个比较常见的应用场景。比如进入App首页&#xff0c;一系列的弹窗就会弹出。如果不做处理就会导致弹窗堆积的全部弹出&#xff0c;严重影响用户体验。 如果多个弹窗中又有判断逻辑&#xff0c;根据点击后需要弹出另一个弹窗&#xff0c;这…

大数据Scala教程从入门到精通第五篇:Scala环境搭建

一&#xff1a;安装步骤 1&#xff1a;scala安装 1&#xff1a;首先确保 JDK1.8 安装成功: 2&#xff1a;下载对应的 Scala 安装文件 scala-2.12.11.zip 3&#xff1a;解压 scala-2.12.11.zip 4&#xff1a;配置 Scala 的环境变量 在Windows上安装Scala_windows安装scala…

docker搭建代码审计平台sonarqube

docker搭建代码审计平台sonarqube 一、代码审计关注的质量指标二、静态分析技术分类三、sonarqube流程四、快速搭建sonarqube五、sonarqube scanner的安装和使用 一、代码审计关注的质量指标 代码坏味道 代码规范技术债评估 bug和漏洞代码重复度单测与集成 测试用例数量覆盖率…

node报错——解决Error: error:0308010C:digital envelope routines::unsupported——亲测可用

今天在打包vue2项目时&#xff0c;遇到一个报错&#xff1a; 最关键的代码如下&#xff1a; Error: error:0308010C:digital envelope routines::unsupportedat new Hash (node:internal/crypto/hash:80:19)百度后发现是node版本的问题。 在昨天我确实操作了一下node&…

Ansible——Playbook剧本

目录 一、Playbook概述 1.Playbook定义 2.Playbook组成 3.Playbook配置文件详解 4.运行Playbook 4.1Ansible-Playbook相关命令 4.2运行Playbook启动httpd服务 4.3变量的定义和引用 4.4指定远程主机sudo切换用户 4.5When——条件判断 4.6迭代 4.6.1创建文件夹 4.6.2…

[Linux][网络][TCP][四][流量控制][拥塞控制]详细讲解

目录 1.流量控制2.拥塞控制0.为什么要有拥塞控制&#xff0c;不是有流量控制么&#xff1f;1.什么是拥塞窗口&#xff1f;和发送窗口有什么关系呢&#xff1f;2.怎么知道当前网络是否出现了拥塞呢&#xff1f;3.拥塞控制有哪些算法&#xff1f;4.慢启动5.拥塞避免6.拥塞发生7.快…

劝退计算机?CS再过几年会没落!?

事实上&#xff0c;未来计算机不仅不会没落&#xff0c;国家还会大力发展 只不过大家认为的计算机就是什么Java web&#xff0c;真正的计算机行业是老美那样的&#xff0c;涉及到方方面面&#xff0c;比如&#xff1a; web&#xff0c;图形学&#xff0c;Linux系统开发&#…

2024DCIC海上风电出力预测Top方案 + 光伏发电出力高分方案学习记录

海上风电出力预测 赛题数据 海上风电出力预测的用电数据分为训练组和测试组两大类&#xff0c;主要包括风电场基本信息、气象变量数据和实际功率数据三个部分。风电场基本信息主要是各风电场的装机容量等信息&#xff1b;气象变量数据是从2022年1月到2024年1月份&#xff0c;…

Skywalking数据持久化与自定义链路追踪

学习本篇文章之前首先要了解一下Sky walking的基础知识 分布式链路追踪工具Skywalking详解 一&#xff0c;Sky walking数据持久化 Sky walking提供了es&#xff0c;MySQL等数据持久化方案&#xff0c;默认使用h2基于内存的数据库&#xff0c;重启之后数据即会丢失。 在实际工…

【Git】Git学习-16:git merge,且解决合并冲突

学习视频链接&#xff1a; 【GeekHour】一小时Git教程_哔哩哔哩_bilibili​编辑https://www.bilibili.com/video/BV1HM411377j/?vd_source95dda35ac10d1ae6785cc7006f365780 1 创建分支dev&#xff0c;并用merge合并master分支&#xff0c;使dev分支合并上master分支中内容为…

【学习笔记】HarmonyOS 4.0 鸿蒙Next 应用开发--安装开发环境

开发前的准备 首先先到官网去下载Devco Studio 这个开发工具&#xff0c;https://developer.harmonyos.com/cn/develop/deveco-studio/#download 提供了WIndows和Mac的开发环境&#xff0c;我自己是Windows的开发环境。 所以下载之后直接点击exe进行安装即可。 如果之前安装过…

Eplan带你做项目——如何实现项目的交付

前言 Eplan作为一款专业的电气工程设计软件&#xff0c;不仅在设计阶段为电气工程师提供了强大的绘图、计算、仿真等功能&#xff0c;还具备丰富的数据管理与交换能力&#xff0c;能够便捷、准确地导出软件设计、生产制造所需的数据&#xff0c;实现电气设计与软件设计、生产制…

反汇编一个ARM64的机器码

文章目录 使用objdump直接阅读ARM64手册使用反汇编网站 有下面一个机器码&#xff1a;0x929ffee9&#xff0c;如何翻译成汇编呢&#xff1f; 下面介绍几种做法&#xff1a; 使用objdump 将这个机器码写到文件中&#xff0c;然后使用objdump去反汇编 创建一个二进制文件 dd…