深度学习中简易FC和CNN搭建

  • TensorFlow是由谷歌开发的
  • PyTorch是由Facebook人工智能研究院(Facebook AI Research)开发的

Torch和cuda版本的对应,手动安装较好

全连接FC(Batch*Num)

搭建建议网络:

from torch import nnclass Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out  = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x

封装数据

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)

训练模型:

import numpy as npdef fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))

一般在训练模型时加上model.train(),这样会正常使用Batch NormalizationDropout;测试的时候一般选择model.eval(),这样就不会使用Batch NormalizationDropout

批量损失函数

from torch import optim
def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)

优化器
SGD是一种简单且易于实现的优化算法,但在大规模数据集和复杂模型上收敛缓慢。
Adam是一种自适应学习率调整的优化算法,能够更快地收敛,但可能会占用更多的内存。
在实践中,根据具体问题和数据集的特点,选择适合的优化算法可以提高训练效果。

卷积神经网络CNN(Batch * C * H * W)

Channel First
引入py库

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms 
import matplotlib.pyplot as plt
import numpy as np

预处理

# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data',  train=True,   transform=transforms.ToTensor(),  download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

构建CNN

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1,              # 灰度图out_channels=16,            # 要得到几多少个特征图kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),                      # relu层nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.conv3 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(32, 64, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),             # 输出 (32, 7, 7))self.out = nn.Linear(64 * 7 * 7, 10)   # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)output = self.out(x)return output

定义准确率

def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels) 

训练网络模型

# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = [] for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环net.train()                             output = net(data) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() right = accuracy(output, target) train_rights.append(right) if batch_idx % 100 == 0: net.eval() val_rights = [] for (data, target) in test_loader:output = net(data) right = accuracy(output, target) val_rights.append(right)#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1], 100. * val_r[0].numpy() / val_r[1]))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/14604.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

力扣 968. 监控二叉树

题目来源:https://leetcode.cn/problems/binary-tree-cameras/description/ C题解(来源代码随想录):节点可以分为3个状态:0无覆盖;1有摄像头;2有覆盖。 要想放的摄像头最少,应当叶子…

SOC FPGA之HPS模型设计(一)

目录 一、建立HPS硬件系统模型 1.1 GHRD 1.2 从0开始搭建HPS 1.2.1 FPGA Interfaces 1.2.1.1 General 1.2.1.2 AXI Bridge 1.2.1.3 FPGA-to-HPS SDRAM Interface 1.2.1.4 DMA Peripheral Request 1.2.1.5 Interrupts 1.2.1.6 EMAC ptp interface 1.2.2 Peripheral P…

seata组件使用期间,获取全局事务状态

GlobalStatus枚举类展示全局事务状态 官网链接:http://seata.io/zh-cn/docs/user/appendix/global-transaction-status.html 获得全局事务状态 // 开启全局事务地方获取全局事务xid String xid RootContext.getXID(); // 通过全局事务xid获得GlobalStatus枚举类 …

Unity游戏源码分享-2.5D塔防类游戏

Unity游戏源码分享-2.5D塔防类游戏 项目地址: https://download.csdn.net/download/Highning0007/88118947

电子元器件选型与实战应用—01 电阻选型

大家好, 我是记得诚。 这是《电子元器件选型与实战应用》专栏的第一篇文章,今天的主角是电阻,在每一个电子产品中,都少不了电阻的身影,其重要性不言而喻。 文章目录 1. 入门知识1.1 基础1.2 常用品牌1.3 电阻的种类2. 贴片电阻标识2.1 三位数标注法2.2 四位数标注法2.3 小…

操作系统_进程与线程(二)

目录 2. 处理机调度 2.1 调度的基本概念 2.2 调度的层次 2.3 三级调度的联系 2.4 调度的目标 2.5 调度的实现 2.5.1 调度程序(调度器) 2.5.2 调度的时机、切换与过程 2.5.3 进程调度方式 2.5.4 闲逛进程 2.5.5 两种线程的调度 2.6 典型的调度…

多旋翼物流无人机节能轨迹规划(Python代码实现)

目录 💥1 概述 📚2 运行结果 🌈3 Python代码实现 🎉4 参考文献 💥1 概述 多旋翼物流无人机的节能轨迹规划是一项重要的技术,可以有效减少无人机的能量消耗,延长飞行时间,提高物流效率…

了解Unity编辑器之组件篇Layout(八)

Layout:用于管理和控制UI元素的排列和自动调整一、Aspect Ratio Fitter:用于根据宽高比自动调整UI元素的大小 Aspect Mode:用于定义纵横比适配的行为方式。Aspect Mode属性有以下几种选项: (1)None&#xf…

【etcd】解决 go-zero 注册 etcd 出现 “Auto sync endpoints failed.” 的问题

go: v1.20.3 go-zero: v1.5.4 etcd: v3.5.9 问题描述 在 go-zero 中用 etcd 去实现服务注册发现,rpc 服务可以注册到 etcd,同时其他服务可以发现注册的微服务,也可以访问。但是,注册的 rpc 服务的日志,就是一直报以…

bat一键批量、有序启动jar

将脚本文件后缀改为 bat,脚本文件和 jar 包放在同一个目录 echo offstart cmd /c "java -jar register.jar " ping 192.0.2.2 -n 1 -w 10000 > nulstart cmd /c "java -jar admin.jar " ping 192.0.2.2 -n 1 -w 30000 > nulstart cmd /c…

基于ARM+FPGA (STM32+ Cyclone 4)的滚动轴承状态监测系统

状态监测系统能够在故障早期及时发现机械设备的异常状态,避免故障的 进一步恶化造成不必要的损失,滚动轴承是机械设备的易损部件,本文对以滚动 轴承为研究对象的状态监测系统展开研究。现有的监测技术多采用定时上传监 测数据,…

Spring MVC学习笔记,包含mvc架构使用,过滤器、拦截器、执行流程等等

😀😀😀创作不易,各位看官点赞收藏. 文章目录 Spring MVC 习笔记1、Spring MVC demo2、Spring MVC 中常见注解3、数据处理3.1、请求参数处理3.2、响应数据处理 4、RESTFul 风格5、静态资源处理6、HttpMessageConverter 转换器7、过…

Open3D(C++) 根据索引提取点云

目录 一、功能概述1、主要函数2、源码二、代码实现三、结果展示本文由CSDN点云侠原创,原文链接。爬虫网站自重,把自己当个人 一、功能概述 1、主要函数 std::shared_ptr<PointCloud> SelectByIn

spring boot 2 配置上传文件大小限制

一、起因&#xff1a;系统页面上传一个文件超过日志提示的文件最大100M的限制&#xff0c;需要更改配置文件 二、经过&#xff1a; 1、在本地代码中找到配置文件&#xff0c;修改相应数值后交给运维更新生产环境配置&#xff0c;但是运维说生产环境没有这行配置&#xff0c;遂…

MODBUS-TCP转Ethernet IP 网关连接空压机 配置案例

本案例是工业现场应用捷米特JM-EIP-TCP的Ethernet/IP转Modbus-TCP网关连接欧姆龙PLC与空压机的配置案例。使用设备&#xff1a;欧姆龙PLC&#xff0c;捷米特JM-EIP-TCP网关&#xff0c; ETHERNET/IP 的电气连接 ETHERNET/IP 采用标准的 T568B 接法&#xff0c;支持直连和交叉接…

[个人笔记] Linux配置NTP时间同步

Linux - 运维篇 第四章 Linux配置NTP时间同步 Linux - 运维篇系列文章回顾Linux配置NTP时间同步Linux配置CST时区 参考来源 系列文章回顾 第一章 Linux扩容LVM分区 第二章 Linux虚拟机安装VMware Tools插件 第三章 ssh-keygen和openssl工具的使用 Linux配置NTP时间同步 仅实验…

Ubuntu通用镜像加速配置

备份 cp -rf /etc/apt/sources.list /etc/apt/sources.list.bak开始配置 阿里云 sed -i shttp://archive.ubuntu.comhttps://mirrors.aliyun.comg /etc/apt/sources.listsed -i shttp://security.ubuntu.comhttps://mirrors.aliyun.comg /etc/apt/sources.list清华源 sed -i …

【etcd】docker 启动单点 etcd

etcd: v3.5.9 etcd-browser: rustyx/etcdv3-browser:latest 本文档主要描述用 docker 部署单点的 etcd&#xff0c; 用 etcd-browser 来查看注册到 etcd 的 key 默认配置启动 docker run -d --name ai-etcd --networkhost --restart always \-v $PWD/etcd.conf.yml:/opt/bitn…

【Linux】线程池

1 线程池的介绍 1.1 线程池 一种线程使用模式。线程过多会带来调度开销&#xff0c;进而影响局部性和整体性能。而线程池维护多个线程&#xff0c;等待着监督管理者分配可并发执行的任务。这避免了在处理短时间任务创建与销毁线程的代价。线程池不仅能够保证内核的充分利用&am…

生命在于学习——指纹混淆技术学习

一、前言 本篇文章仅为学习笔记记录&#xff0c;不得用于违规用途。 本篇文章为安全社公众号的Poker安全所发&#xff0c;本文仅为学习复现。 二、介绍 指纹混淆技术&#xff0c;顾名思义&#xff0c;就是迷惑指纹扫描识别技术。 三、思路 作者的思路&#xff1a; 1、伪…