CNN的小体验

用的pytorch。

训练代码cnn.py:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F# 定义超参数
num_epochs = 10
batch_size = 100
learning_rate = 0.001# 数据预处理和加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# 定义卷积神经网络
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2)self.fc1 = nn.Linear(32 * 8 * 8, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 32 * 8 * 8)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')# 测试模型
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the 10000 test images: {100 * correct / total}%')# 保存模型
torch.save(model.state_dict(), 'cnn.pth')

推断代码cnn2.py

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F# 定义卷积神经网络(与之前的定义保持一致)
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2) # 第一个卷积层self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) # 池化层self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2) # 第二个卷积层self.fc1 = nn.Linear(32 * 8 * 8, 128) # 全连接层self.fc2 = nn.Linear(128, 10) # 输出层def forward(self, x):x = self.pool(F.relu(self.conv1(x))) # 通过第一个卷积层和池化层x = self.pool(F.relu(self.conv2(x))) # 通过第二个卷积层和池化层x = x.view(-1, 32 * 8 * 8) # 展平x = F.relu(self.fc1(x)) # 通过全连接层x = self.fc2(x) # 通过输出层return x# 加载模型
model = SimpleCNN()
model.load_state_dict(torch.load('cnn.pth'))
model.eval() # 设置模型为评估模式# 预处理图片
def preprocess_image(image_path):transform = transforms.Compose([transforms.Resize((32, 32)), # 调整图像大小到32x32transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])image = Image.open(image_path)image = transform(image)image = image.unsqueeze(0) # 增加批量维度return image# 加载并预处理图片
image_path = 'test.jpg' # 替换为你要分析的图片路径
image = preprocess_image(image_path)# 使用模型进行推理
with torch.no_grad():outputs = model(image)_, predicted = torch.max(outputs.data, 1)class_index = predicted.item()# CIFAR-10类别标签
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']# 输出预测结果
print(f'Predicted class: {classes[class_index]}')

可惜出来的东西跟弱智一般。

python3 cnn2.py
Predicted class: horse

python3 cnn2.py
Predicted class: bird

几个小点:

1 使用的数据集是CIFAR10

2 训练真的挺耗时的,我用的阿里云,一共搞了差不多10分钟(训练一个弱智)。

3 环境依然麻烦,python,numpy的版本都不能太高。否则要出问题。。。

4 最后实事求是的说,我不太懂的一点是怎么分类出来的。。。晚点再看看。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/864632.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Python绘制彩虹效果:动态彩虹动画

文章目录 引言准备工作前置条件 代码实现与解析导入必要的库初始化Pygame定义绘制彩虹函数定义颜色列表主循环 完整代码 引言 彩虹是自然界中最美丽的现象之一。通过编程,我们可以将这一奇妙的景象带到屏幕上。在这篇博客中,我们将使用Python来创建一个…

聊聊 golang 的 map

1、哈希表 哈希表是一个很常见的数据结构,用来存储无序的 key/value 对,给定的 key 可以在 O(1) 时间复杂度内查找、更新或删除对应的 value。 设计一个好的哈希表,需要着重关注两个关键点:哈希函数、冲突处理。 1.1 哈希函数 …

Redis 高级数据结构业务实践

0、前言 本文所有代码可见 > 【gitee code demo】 本文会涉及 hyperloglog 、GEO、bitmap、布隆过滤器的介绍和业务实践 1、HyperLogLog 1.1、功能 基数统计(去重) 1.2、redis api 命令作用案例PFADD key element [element ...]添加元素到keyPF…

力扣 用队列实现栈(Java)

核心思想:因为队列都是一端进入另一端出(先进先出,后进后出),因此一个队列肯定是不能实现栈的功能的,这里就创建两个队列来模拟栈的先进后出,后进先出。 比如说如果是push操作我们肯定是要弹出栈…

STM32自己从零开始实操08:电机电路原理图

一、LC滤波电路 其实以下的滤波都可以叫低通滤波器。 1.1倒 “L” 型 LC 滤波电路 1.1.1定性分析 1.1.2仿真实验 电感:通低频阻高频的。仿真中高频信号通过电感,因为电感会阻止电流发生变化,故说阻止高频信号 电容:隔直通交。…

65、基于卷积神经网络的调制分类(matlab)

1、基于卷积神经网络的调制分类的原理及流程 基于卷积神经网络(CNN)的调制分类是一种常见的信号处理任务,用于识别或分类不同调制方式的信号。下面是基于CNN的调制分类的原理和流程: 原理: CNN是一种深度学习模型&a…

SpringBoot学习06-[SpringBoot与AOP、SpringBoot自定义starter]

SpringBoot自定义starter SpringBoot与AOP SpringBoot与AOP 使用AOP实现用户接口访问日志功能 添加AOP场景启动器 <!--添加AOP场景启动器--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-aop</…

都2024了,现在搞本HCIE真的还来得及?

信息技术的迅猛发展&#xff0c;网络的飞速进步&#xff0c;网络工程师这一职业的需求也在不断增加。 作为华为认证体系中的顶级认证&#xff0c;HCIE一直以来都是网络工程师追求的目标之一。 都2024了&#xff0c;厂商认证都火了十几年来&#xff0c;很多人犹犹豫豫&#xff0…

Mybatis1(JDBC编程和ORM模型 MyBatis简介 实现增删改查 MyBatis生命周期)

目录 一、JDBC编程和ORM模型 1. JDBC回顾 2. JDBC的弊端 3. ORM模型 Mybatis和hibernate 区别: 4. mybatis 解决了jdbc 的问题 二、MyBatis简介 1. MyBatis快速开始 1.1 导入jar包 1.2 引入 mybatis-config.xml 配置文件 1.3 引入 Mapper 映射文件 1.3 测试 …

Ubuntu Server 和 Ubuntu Desktop 组合使用

1.常见的组合使用方式 Ubuntu Server 和 Ubuntu Desktop 确实可以组合使用&#xff0c;但具体要看你的需求和使用场景。以下是一些常见的组合使用方式&#xff1a; 单一设备上安装&#xff1a;你可以在一台设备上同时安装 Ubuntu Server 和 Ubuntu Desktop。这样&#xff0c;你…

【ARM系列】1of N SPI

1 of N模式 SPI 概述配置流程 概述 GIC-600AE支持1 of N模式SPI。在此模式下可以将SPI target到多个core&#xff0c;并且GIC-600AE可以选择哪些内核接收SPI。 GIC-600AE只向处于powered up 并且使能中断组的core发送SPI。 GIC-600AE会优先考虑那些被认为是active的核&#xf…

OOCL东方海外不定位置旋转验证码识别代码

样例图如下 这款验证码的识别最大难度在于&#xff0c;旋转的位置不固定&#xff0c;需要识别旋转图片的位置。 第二大难点就是旋转角度的识别。所以我们采集了大量样例图片进行训练&#xff0c;如下图所示 最终训练得到的模型需要两张图片输入&#xff0c;才能完成旋转角度识…

阿里 Mobile-Agent-v2:基于大模型的安卓鸿蒙自动化工具

与之前介绍的 DigiRL类似, Mobile-Agent-v2是一个支持安卓和鸿蒙系统的自动化工具&#xff0c;它使用视觉模型理解手机屏幕&#xff0c;并利用 ADB 来实现操作手机&#xff0c;你可以在本地运行&#xff0c;或者通过手机截图在线体验 Mobile-Agent-v2 从演示来看&#xff0c;可…

短信接口平台的核心功能有哪些?如何使用?

短信接口平台怎么有效集成&#xff1f;选择短信接口平台的技巧&#xff1f; 短信接口平台作为一种重要的通信工具&#xff0c;广泛应用于各种企业和组织。通过短信接口平台&#xff0c;企业能够高效、便捷地与客户进行互动和沟通。AoKSend将详细介绍短信接口平台的核心功能。 …

Android --- 新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死,鼠标和键盘都操作不了

新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死&#xff0c;鼠标和键盘都操作不了 大概原因就是,初始化默认Google的安卓模拟器占用的RAM内存是2048&#xff0c;如果电脑的性能和内存一般的话就可能卡死&#xff0c;解决方案是手动修改安卓模拟器的config文件&…

Python酷库之旅-第三方库openpyxl(20)

目录 一、 openpyxl库的由来 1、背景 2、起源 3、发展 4、特点 4-1、支持.xlsx格式 4-2、读写Excel文件 4-3、操作单元格 4-4、创建和修改工作表 4-5、样式设置 4-6、图表和公式 4-7、支持数字和日期格式 二、openpyxl库的优缺点 1、优点 1-1、支持现代Excel格式…

架构练习题目

【2022下架构真题第24题&#xff1a;红色】 24.在分布式系统中&#xff0c;中间件通常提供两种不同类型的支持&#xff0c;即&#xff08;27) A.数据支持和交互支持 B.交互支持和提供公共服务 C.数据支持和提供公共服务 D.安全支持和提供公共服务 解答&#xff1a;答案选择B。…

【知识图谱系列】(实例)python操作neo4j构建企业间的业务往来的知识图谱

本章节通过聚焦于"金额"这一核心属性&#xff0c;构建了一幅知识图谱&#xff0c;旨在揭示"销售方"与"购买方"间的商业互动网。在这张图谱中&#xff0c;绿色节点象征着购买方&#xff0c;而红色节点则代表了销售方。这两类节点间的紧密连线&…

苹果手机+AI手机概念股名单一览表

苹果智能将成为AI手机引领者&#xff0c;推动原生智能加速渗透&#xff0c;据Canlys预计2025年iOS操作系统将占据全球AI手机出货的55%。 AI手机端侧算力提升&#xff0c;将带动产业链部件升级创新 端侧算力提升或带动手机芯片及零部件升级&#xff0c;如 1&#xff09;SoC芯片&…

无人机智能追踪反制系统技术详解

随着无人机技术的飞速发展&#xff0c;无人机在各个领域的应用越来越广泛。然而&#xff0c;无人机的无序飞行和非法使用也带来了一系列安全隐患和威胁。因此&#xff0c;无人机智能追踪反制系统应运而生&#xff0c;成为维护公共安全和防止无人机滥用的重要工具。本文将详细介…