搭建一个简单的深度神经网络

目录

一、引入所需要的库

二、制作数据集

三、搭建神经网络

四、训练网络

五、测试网络

本博客实验环境为jupyter

一、引入所需要的库

torch库是核心,其中torch.nn 提供了搭建网络所需的所有组件,nn即神经网络。matplotlib类似与matlab,其中pyplot用于进行数据可视化,如绘制图表、曲线等。%matplotlib inline: 这是IPython(Jupyter Notebook)的魔法命令,用于在Notebook中直接显示Matplotlib绘制的图表,而不是弹出一个新窗口显示。

import torch 
import torch.nn as nn 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
%matplotlib inline# 展示高清图 
from matplotlib_inline import backend_inline #导入Matplotlib库中的backend_inline模块,用于控制图表的显示方式。
backend_inline.set_matplotlib_formats('svg') #设置Matplotlib图表的显示格式为SVG格式,SVG格式的图表在显示时具有高清晰度,适合用于展示精细的图形。

二、制作数据集

主要任务是读取数据集,划分为训练集和测试集,一定要随机划分。

读取的数据集中共760组数据,共8个输入特征,1个输出特征。

其中第一列是索引,从0开始,70%为训练集,30%为测试集。

#读取数据
df = pd.read_csv('Data.csv', index_col=0)#之前的pandas库中有介绍到,即df为读取后的对象,以第一列为索引    
arr = df.values #转化为numpy数组              
arr = arr.astype(np.float32)#转化为深度学习常用的单精度浮点类型    
ts = torch.tensor(arr)#转化为张量tensor         
ts = ts.to('cuda')#送到cuda设备上即gpu上计算             # 划分训练集与测试集 
train_size = int(len(ts) * 0.7) #训练集的大小为百分之七十          
test_size = len(ts) - train_size #测试集的大小为百分之三十         
ts = ts[ torch.randperm( ts.size(0) ) , : ] #随机打乱数据集中样本的顺序    
train_Data = ts[ : train_size , : ] #将前百分之七十行给训练集       
test_Data = ts[ train_size : , : ]  #将百分之七十后的行给测试集        

三、搭建神经网络

主要是构建DNN类,需要对python的类定义有较为深入的理解能力。

class DNN(nn.Module): #定义了一个名为DNN的PyTorch模型类,该类继承自nn.Module类,表示这是一个神经网络模型。def __init__(self): #定义了模型的初始化方法''' 搭建神经网络各层 ''' super(DNN,self).__init__() #调用父类的初始化方法,确保模型的其他部分也能够被正确初始化。self.net = nn.Sequential(            # 按顺序搭建各层 nn.Linear(8, 32), nn.Sigmoid(),   # 第1层:全连接层 ,是一个包含32个神经元的全连接层,输入特征数为8(表示输入数据维度为8),并使用Sigmoid激活函数。nn.Linear(32, 8), nn.Sigmoid(),   # 第2层:全连接层 ,是一个包含8个神经元的全连接层,输入特征数为32,同样使用Sigmoid激活函数。nn.Linear(8, 4), nn.Sigmoid(),    # 第3层:全连接层 ,是一个包含4个神经元的全连接层,输入特征数为8,同样使用Sigmoid激活函数。nn.Linear(4, 1), nn.Sigmoid()    # 第4层:全连接层 ,是一个包含1个神经元的全连接层,输入特征数为4,同样使用Sigmoid激活函数。这是模型的输出层。) def forward(self, x): ''' 前向传播 ''' y = self.net(x)    # 将输入数据x通过模型定义的神经网络结构self.net进行前向传播计算,得到输出y。return y        # y即输出数据model = DNN().to('cuda:0')    # 创建子类的实例,并搬到GPU上 

这个代码可以当做模板,其中需要修改的部分为网络层的搭建,输入特征,中间层,输出特征一般都要为2的n次幂。

这就是该实例的各层。

四、训练网络

通过前向传播,反向传播等操作,本质上是不断调整权重和偏置。

loss_fn = nn.BCELoss(reduction='mean')#选择二元交叉熵损失函数作为模型的损失函数,其中reduction='mean'表示采用平均损失值作为最终的损失值。
learning_rate = 0.005    # 设置学习率 
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 
# 训练网络 
epochs = 5000     #设置训练的总轮数为5000。
losses = []        # 记录损失函数变化的列表 # 给训练集划分输入与输出 
X = train_Data[ : , : -1 ]                # 前8列为输入特征 
Y = train_Data[ : , -1 ].reshape((-1,1))    # 后1列为输出特征
for epoch in range(epochs):     #对于每个epoch进行循环。Pred = model(X)              #通过模型进行一次前向传播,得到模型的预测结果Pred。loss = loss_fn(Pred, Y)        #计算模型预测结果与实际标签之间的损失值。losses.append(loss.item())     #将当前轮次的损失值记录到losses列表中。optimizer.zero_grad()        #清空上一轮的梯度信息。loss.backward()             #进行反向传播,计算梯度。optimizer.step()             #根据优化算法更新模型的参数,完成一轮训练。#绘制损失函数随训练轮次变化的图像,用于可视化训练过程中损失值的变化。
Fig = plt.figure() 
plt.plot(range(epochs), losses) 
plt.ylabel('loss') 
plt.xlabel('epoch') 
plt.show() 

生成结果为

可以发现随着训练的进行,loss开始减少。

五、测试网络

通过用训练好的模型对测试集进行测试,由于只有一个输出特征为0或者1,将大于0.5的置为1,小于0.5的置为0,可以类比成可能性从0到1。

# 测试网络 
# 给测试集划分输入与输出 
X = test_Data[ : , : -1 ]                
Y = test_Data[ : , -1 ].reshape((-1,1))    
with torch.no_grad():    #进入上下文管理器,表示接下来的计算不会被记录在计算图中,因此不会影响梯度的计算。Pred = model(X)     Pred[Pred>=0.5] = 1 Pred[Pred<0.5] = 0 correct = torch.sum( (Pred == Y).all(1) )    #统计预测正确的样本数,使用.all(1)表示在第1维度(即行)上进行比较,得到一个布尔张量,再进行求和操作。total = Y.size(0)   #获取试集样本总数。print(f'测试集精准度: {100*correct/total} %') 

一般精准度得百分之八九十才合格哦,所以精度不高很有可能是训练集或者环境的问题,所以训练前一定要做好准备工作,因为训练一个模型要花费很久时间。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/26255.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

透平油氧化安定性检测 发动机油运动粘度40℃检测

透平油氧化安定性检测 透平油&#xff0c;也称为涡轮机油或汽轮机油&#xff0c;是专门用于汽轮机的润滑油。它具有良好的抗氧化安定性和抗乳化性能&#xff0c;主要用于发电厂蒸气轮机、水电站水轮发电机以及其他需要深度精细润滑的场合。透平油的氧化安定性是衡量其在高温条件…

服务器防漏扫,主机加固方案来解决

什么是漏扫&#xff1f; 漏扫是漏洞扫描的简称。漏洞扫描是一种安全测试方法&#xff0c;用于发现计算机系统、网络或应用程序中的潜在漏洞和安全弱点。通过使用自动化工具或软件&#xff0c;漏洞扫描可以检测系统中存在的已知漏洞&#xff0c;并提供相关的报告和建议&#xf…

YOLOv10原理与实战训练自己的数据集

课程链接&#xff1a;YOLOv10原理与实战训练自己的数据集_在线视频教程-CSDN程序员研修院 YOLOv10是最近提出的YOLO的改进版本。在后处理方面&#xff0c;提出了一致性双重分配策略用于无NMS训练&#xff0c;从而实现了高效的端到端检测。在模型架构方面&#xff0c;引入了全面…

ubuntu如何查看ip地址

ubuntu如何查看ip地址 方法一&#xff1a;使用ifconfig方法二&#xff1a;使用ip命令 方法一&#xff1a;使用ifconfig 命令行输入ifconfig&#xff1a; 这里inet后跟的内容就是IP地址。 方法二&#xff1a;使用ip命令 命令行输入&#xff1a;ipa ddr&#xff1a; 这里ine…

轮到国产游戏统治Steam榜单

6月10日晚8点&#xff0c;《黑神话:悟空》实体版正式开启全款预售,预售开启不到5分钟,所有产品即宣告售罄。 Steam上&#xff0c;《黑神话:悟空》持续占据着热销榜榜首的位置。 但在《黑神话:悟空》傲人的光环下&#xff0c;还有一款国产游戏取得出色的成绩。 6月10日&#…

RK3568笔记三十二:PaddleSeg训练部署

一、环境 1、Autodl配置 PyTorch 1.7.0Python 3.8(ubuntu18.04)Cuda 11.02、所需环境需求 - OS: 64-bit - Python 3(3.6/3.7/3.8/3.9/3.10)&#xff0c;64-bit version - pip/pip3(9.0.1)&#xff0c;64-bit version - CUDA > 10.2 - cuDNN > 7.6 - PaddlePaddle (the…

“树莓派” 成为上市公司

“树莓派” 成为上市公司 树莓派公司昨日已在伦敦证券交易所首次亮相&#xff08;Raspberry Pi Holdings plc&#xff09;。早盘交易中&#xff0c;该公司股价大涨&#xff0c;这为伦敦首次公开发行&#xff08;IPO&#xff09;市场带去了一些动力。 Stable Diffusion 3 开源倒…

SaaS产品运营 | 千万不能踏入的PLG模式的六大误区

随着科技的迅速发展和市场竞争的日益激烈&#xff0c;越来越多的公司开始尝试采用PLG&#xff08;Product Led Growth&#xff0c;即产品驱动增长&#xff09;模式来推动其业务的发展。然而&#xff0c;尽管PLG模式在促进增长方面具有显著优势&#xff0c;但在实践中也容易出现…

先导小型五轴联动数控加工中心

先导小型五轴联动加工中心可以作为学校或培训机构的教学工具&#xff0c;帮助学生了解数控加工的基本原理和操作方法。它特别适用于机械、自动化、工业设计等相关专业的学生进行实践操作和课程项目。 小型五轴联动加工中心是一种能够同时控制五个自由度进行联动的加工设备。这五…

上午接到被裁员的通知,下午就收到涨薪30%的offer,我生怕公司反悔,当天就找HR签了离职协议,拿到了N+1赔偿!

大家好&#xff0c;我是瑶琴呀。 昨天看到一位网友分享自己被裁的经历&#xff1a;最近这段时间在面试&#xff0c;没成想上午刚被 HR 约谈裁员的事情&#xff0c;下午就收到下家公司涨薪 30% 的offer&#xff0c;这可真是天时人和&#xff0c;当天下午就找 HR 签了离职协议&a…

mysql索引B+树可视化演示地址

https://www.cs.usfca.edu/~galles/visualization/BPlusTree.html

【产品经理】ERP订单处理2

本次讲解订单初始化成功到ERP系统过程中的后续环节。 一、根据客服备注更新订单信息 初始化订单过程中&#xff0c;若订单中的客服备注信息对订单进行更新&#xff0c;包括可能改收货信息、改商品、加赠品、指定快递等。 注意&#xff1a;更新订单的过程中要注意订单当前状…

【云原生】Kubernetes----Helm包管理器

目录 引言 一、Helm概述 1.Helm价值概述 2.Helm的基本概念 3.Helm名词介绍 二、安装Helm 1.下载二进制包 2.部署Helm环境 3.添加补全信息 三、使用Helm部署服务 1.创建chart 2.查看文件信息 3.安装chart 4.卸载chart 5.自定义chart服务部署 6.版本升级 7.版本…

数字孪生技术及其广泛应用场景探讨

通过将实际物理世界中的物体或系统建模、模拟和分析&#xff0c;数字孪生技术可以提供更精确、更可靠、更高效的解决方案。数字孪生技术在智能制造、城市建设、智慧物流等众多领域中得到了广泛的应用。 通过将数据可视化呈现在虚拟环境中&#xff0c;我们可以更清晰地观察和理…

CodeArts Snap:辅助你编程的神器

CodeArts Snap - Visual Studio Marketplace 文心一言 CodeArts Snap&#xff1a;辅助你编程的神器 一、简介 CodeArts Snap是一款基于华为云研发大模型开发的智能开发助手&#xff0c;旨在覆盖软件开发的全生命周期&#xff0c;为开发者提供端到端的智能支持。自2023年7月…

网络编程2----UDP简单客户端服务器的实现

首先我们要知道传输层提供的协议主要有两种&#xff0c;TCP协议和UDP协议&#xff0c;先来介绍一下它们的区别&#xff1a; 1、TCP是面向连接的&#xff0c;UDP是无连接的。 连接的本质是双方分别保存了对方的关键信息&#xff0c;而面向连接并不意味着数据一定能正常传输到对…

【NLP】给Transformer降降秩,通过分层选择性降阶提高语言模型的推理能力

【NLP】给Transformer降降秩&#xff0c;通过分层选择性降阶提高语言模型的推理能力 文章目录 【自然语言处理-论文翻译与学习】序1、导论2、相关工作3、相关工具4、方案5、实验5.1 使用 GPT-J 对 CounterFact 数据集进行彻底分析5.1.1 数据集中的哪些事实是通过降阶恢复的&…

使用谷歌 Gemini API 构建自己的 ChatGPT(一)

AI领域一直由OpenAI和微软等公司主导&#xff0c;而Gemini则崭露头角&#xff0c;以更大的规模和多样性脱颖而出。它被设计用于无缝处理文本、图像、音频和视频&#xff1b;这些基础模型重新定义了人工智能交互的边界。随着谷歌在人工智能领域强势回归&#xff0c;了解Gemini如…

17.路由配置与页面创建

路由配置与页面创建 官网&#xff1a;https://router.vuejs.org/zh/ Vue Router 和 组合式 API | Vue Router (vuejs.org) 1. 修改index.ts import { RouteRecordRaw, createRouter, createWebHistory } from "vue-router"; import Layout from /layout/Index.vueco…

中国版Sora?快手「可灵」到底行不行?

“可灵”与Sora有相似的技术架构&#xff0c;生成的视频动作流畅、幅度大&#xff0c;对物理世界理解力与还原度很高。可生成120秒、每秒30帧的高清视频&#xff0c;分辨率高达1080p&#xff0c;并且支持多种不同的屏幕比例。 “中国版SORA”到底是不是名副其实&#xff1f;能…