CNN的小体验

用的pytorch。

训练代码cnn.py:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F# 定义超参数
num_epochs = 10
batch_size = 100
learning_rate = 0.001# 数据预处理和加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# 定义卷积神经网络
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2)self.fc1 = nn.Linear(32 * 8 * 8, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 32 * 8 * 8)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')# 测试模型
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the 10000 test images: {100 * correct / total}%')# 保存模型
torch.save(model.state_dict(), 'cnn.pth')

推断代码cnn2.py

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F# 定义卷积神经网络(与之前的定义保持一致)
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2) # 第一个卷积层self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) # 池化层self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2) # 第二个卷积层self.fc1 = nn.Linear(32 * 8 * 8, 128) # 全连接层self.fc2 = nn.Linear(128, 10) # 输出层def forward(self, x):x = self.pool(F.relu(self.conv1(x))) # 通过第一个卷积层和池化层x = self.pool(F.relu(self.conv2(x))) # 通过第二个卷积层和池化层x = x.view(-1, 32 * 8 * 8) # 展平x = F.relu(self.fc1(x)) # 通过全连接层x = self.fc2(x) # 通过输出层return x# 加载模型
model = SimpleCNN()
model.load_state_dict(torch.load('cnn.pth'))
model.eval() # 设置模型为评估模式# 预处理图片
def preprocess_image(image_path):transform = transforms.Compose([transforms.Resize((32, 32)), # 调整图像大小到32x32transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])image = Image.open(image_path)image = transform(image)image = image.unsqueeze(0) # 增加批量维度return image# 加载并预处理图片
image_path = 'test.jpg' # 替换为你要分析的图片路径
image = preprocess_image(image_path)# 使用模型进行推理
with torch.no_grad():outputs = model(image)_, predicted = torch.max(outputs.data, 1)class_index = predicted.item()# CIFAR-10类别标签
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']# 输出预测结果
print(f'Predicted class: {classes[class_index]}')

可惜出来的东西跟弱智一般。

python3 cnn2.py
Predicted class: horse

python3 cnn2.py
Predicted class: bird

几个小点:

1 使用的数据集是CIFAR10

2 训练真的挺耗时的,我用的阿里云,一共搞了差不多10分钟(训练一个弱智)。

3 环境依然麻烦,python,numpy的版本都不能太高。否则要出问题。。。

4 最后实事求是的说,我不太懂的一点是怎么分类出来的。。。晚点再看看。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/864632.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Python绘制彩虹效果:动态彩虹动画

文章目录 引言准备工作前置条件 代码实现与解析导入必要的库初始化Pygame定义绘制彩虹函数定义颜色列表主循环 完整代码 引言 彩虹是自然界中最美丽的现象之一。通过编程,我们可以将这一奇妙的景象带到屏幕上。在这篇博客中,我们将使用Python来创建一个…

聊聊 golang 的 map

1、哈希表 哈希表是一个很常见的数据结构,用来存储无序的 key/value 对,给定的 key 可以在 O(1) 时间复杂度内查找、更新或删除对应的 value。 设计一个好的哈希表,需要着重关注两个关键点:哈希函数、冲突处理。 1.1 哈希函数 …

Redis 高级数据结构业务实践

0、前言 本文所有代码可见 > 【gitee code demo】 本文会涉及 hyperloglog 、GEO、bitmap、布隆过滤器的介绍和业务实践 1、HyperLogLog 1.1、功能 基数统计(去重) 1.2、redis api 命令作用案例PFADD key element [element ...]添加元素到keyPF…

【Linux】systemctl系统和服务管理命令

systemctl 是 systemd 系统和服务管理器的主命令行工具,用于启动、停止、重启、启用、禁用和检查服务状态,以及管理系统状态。systemd 是现代 Linux 发行版中广泛使用的初始化系统(init system),取代了旧的 SysV 和 Upstart 系统。以下是一些常见的 systemctl 命令和用途的…

力扣 用队列实现栈(Java)

核心思想:因为队列都是一端进入另一端出(先进先出,后进后出),因此一个队列肯定是不能实现栈的功能的,这里就创建两个队列来模拟栈的先进后出,后进先出。 比如说如果是push操作我们肯定是要弹出栈…

中英双语介绍美国最早的13个殖民地( Thirteen Original Colonies)

中文版 美国最早的十三个殖民地简介 美国最早的十三个殖民地是英格兰在北美洲建立的殖民地,它们后来成为了美国最初的十三个州。这些殖民地是美国历史的重要组成部分,它们在美国革命期间起到了关键作用。以下是对这十三个殖民地的详细介绍,…

STM32自己从零开始实操08:电机电路原理图

一、LC滤波电路 其实以下的滤波都可以叫低通滤波器。 1.1倒 “L” 型 LC 滤波电路 1.1.1定性分析 1.1.2仿真实验 电感:通低频阻高频的。仿真中高频信号通过电感,因为电感会阻止电流发生变化,故说阻止高频信号 电容:隔直通交。…

65、基于卷积神经网络的调制分类(matlab)

1、基于卷积神经网络的调制分类的原理及流程 基于卷积神经网络(CNN)的调制分类是一种常见的信号处理任务,用于识别或分类不同调制方式的信号。下面是基于CNN的调制分类的原理和流程: 原理: CNN是一种深度学习模型&a…

通过nginx上传大文件失败

调用nginx的接口上传文件的时候加载到最后就异常了. 发现是在nginx限制了body大小 解决方法 在这个server设置client_max_body_size大小 #设置NGINX能处理的最大请求主体大小client_max_body_size 2028M; ##

TG群发机器人:高效自动化消息分发指南

在这个信息爆炸的时代,TG作为一个流行的即时通讯平台,其机器人功能为自动化消息分发提供了强大的支持。本文将引导您从零开始搭建一个TG群发机器人,实现高效的消息管理和分发。 引言 TG群发机器人能够自动向多个用户或群组发送消息&#xf…

SpringBoot学习06-[SpringBoot与AOP、SpringBoot自定义starter]

SpringBoot自定义starter SpringBoot与AOP SpringBoot与AOP 使用AOP实现用户接口访问日志功能 添加AOP场景启动器 <!--添加AOP场景启动器--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-aop</…

都2024了,现在搞本HCIE真的还来得及?

信息技术的迅猛发展&#xff0c;网络的飞速进步&#xff0c;网络工程师这一职业的需求也在不断增加。 作为华为认证体系中的顶级认证&#xff0c;HCIE一直以来都是网络工程师追求的目标之一。 都2024了&#xff0c;厂商认证都火了十几年来&#xff0c;很多人犹犹豫豫&#xff0…

Redis Stream 作为消息队列的详尽笔记

概述 Redis Stream 是 Redis 5.0 版本引入的数据结构&#xff0c;用于消息传递。 基础概念 结构&#xff1a;消息链表&#xff0c;每个消息有唯一 ID 和内容。命名&#xff1a;每个 Stream 有唯一名称&#xff0c;对应 Redis Key。消费组&#xff08;ConsumerGroup&#xff…

Mybatis1(JDBC编程和ORM模型 MyBatis简介 实现增删改查 MyBatis生命周期)

目录 一、JDBC编程和ORM模型 1. JDBC回顾 2. JDBC的弊端 3. ORM模型 Mybatis和hibernate 区别: 4. mybatis 解决了jdbc 的问题 二、MyBatis简介 1. MyBatis快速开始 1.1 导入jar包 1.2 引入 mybatis-config.xml 配置文件 1.3 引入 Mapper 映射文件 1.3 测试 …

Ubuntu Server 和 Ubuntu Desktop 组合使用

1.常见的组合使用方式 Ubuntu Server 和 Ubuntu Desktop 确实可以组合使用&#xff0c;但具体要看你的需求和使用场景。以下是一些常见的组合使用方式&#xff1a; 单一设备上安装&#xff1a;你可以在一台设备上同时安装 Ubuntu Server 和 Ubuntu Desktop。这样&#xff0c;你…

【ARM系列】1of N SPI

1 of N模式 SPI 概述配置流程 概述 GIC-600AE支持1 of N模式SPI。在此模式下可以将SPI target到多个core&#xff0c;并且GIC-600AE可以选择哪些内核接收SPI。 GIC-600AE只向处于powered up 并且使能中断组的core发送SPI。 GIC-600AE会优先考虑那些被认为是active的核&#xf…

OOCL东方海外不定位置旋转验证码识别代码

样例图如下 这款验证码的识别最大难度在于&#xff0c;旋转的位置不固定&#xff0c;需要识别旋转图片的位置。 第二大难点就是旋转角度的识别。所以我们采集了大量样例图片进行训练&#xff0c;如下图所示 最终训练得到的模型需要两张图片输入&#xff0c;才能完成旋转角度识…

阿里 Mobile-Agent-v2:基于大模型的安卓鸿蒙自动化工具

与之前介绍的 DigiRL类似, Mobile-Agent-v2是一个支持安卓和鸿蒙系统的自动化工具&#xff0c;它使用视觉模型理解手机屏幕&#xff0c;并利用 ADB 来实现操作手机&#xff0c;你可以在本地运行&#xff0c;或者通过手机截图在线体验 Mobile-Agent-v2 从演示来看&#xff0c;可…

短信接口平台的核心功能有哪些?如何使用?

短信接口平台怎么有效集成&#xff1f;选择短信接口平台的技巧&#xff1f; 短信接口平台作为一种重要的通信工具&#xff0c;广泛应用于各种企业和组织。通过短信接口平台&#xff0c;企业能够高效、便捷地与客户进行互动和沟通。AoKSend将详细介绍短信接口平台的核心功能。 …

C++中的C++中的虚析构函数的作用和重要性

在C中&#xff0c;虚析构函数&#xff08;virtual destructor&#xff09;的作用和重要性主要体现在多态和继承的上下文中。了解这一点之前&#xff0c;我们先简要回顾一下多态和继承的基本概念。 继承与多态 继承&#xff1a;允许我们定义一个基类&#xff08;也称为父类或超…