[oneAPI] 手写数字识别-LSTM

[oneAPI] 手写数字识别-LSTM

  • 手写数字识别
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手写数字识别

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用MNIST数据集,该数据集包含了一系列以黑白图像表示的手写数字,每个图像的大小为28x28像素,数据集组成如下:

  • 训练集:包含60,000个图像和标签,用于训练模型。
  • 测试集:包含10,000个图像和标签,用于测试模型的性能。

每个图像都被标记为0到9之间的一个数字,表示图像中显示的手写数字。这个数据集常常被用来验证图像分类模型的性能,特别是在计算机视觉领域。

参数与包

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transformsimport intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.01

加载数据

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',train=True,transform=transforms.ToTensor(),download=True)test_dataset = torchvision.datasets.MNIST(root='../../data/',train=False,transform=transforms.ToTensor())# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)

模型

# Recurrent neural network (many-to-one)
class RNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, num_classes)def forward(self, x):# Set initial hidden and cell states h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward propagate LSTMout, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)# Decode the hidden state of the last time stepout = self.fc(out[:, -1, :])return out

训练过程

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))# Test the model
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

结果

在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# 模型
model = ConvNet(num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/40387.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Curson 编辑器

Curson 汉化与vacode一样 Curson 自带chat功能 1、快捷键ctrlk(代码中编辑) 2、快捷键ctrll 右侧打开窗口

为什么hive会出现_HIVE_DEFAULT_PARTITION分区

问题: 为什么hive表中出现_HIVE_DEFAULT_PARTITION分区? 解答: 因为在业务sql中使用的是动态分区,并且hive启用动态分区时,对于指定的分区键如果存在空值时,会对空值部分创建一个默认分区用于存储该部分…

小程序项目组件的基本应用

宿主环境:程序运行必须依赖的环境 小程序的宿主环境 ---->手机微信(定位、扫码、支付等) 小程序的通信模型: 渲染层和逻辑层之间的通信(微信客户端转发)逻辑层和第三方服务器之间的通信(微信客户端转发) 小程序的运行机制: 启动&#xff1…

c#实现工厂模式

可以使用以下代码实现C#中的工厂模式: 首先,定义一个接口作为产品的抽象: public interface IProduct {void Operation(); }然后,创建具体的产品类: public class ConcreteProductA : IProduct {public void Operat…

vue基础知识五:请描述下你对vue生命周期的理解?在created和mounted这两个生命周期中请求数据有什么区别呢?

一、生命周期是什么 生命周期(Life Cycle)的概念应用很广泛,特别是在政治、经济、环境、技术、社会等诸多领域经常出现,其基本涵义可以通俗地理解为“从摇篮到坟墓”(Cradle-to-Grave)的整个过程在Vue中实…

41 | 京东商家书籍评论数据分析

京东作为中国领先的电子商务平台,积累了大量商品评论数据,这些数据蕴含了丰富的信息。通过文本数据分析,我们可以了解用户对产品的态度、评价的关键词、消费者的需求等,从而有助于商家优化产品和服务,以及消费者作出更明智的购买决策。 本文将详细阐述如何获取京东商家评…

Python opennsfw/opennsfw2 图片/视频 鉴黄 笔记

nsfw&#xff08; Not Suitable for Work&#xff09;直接翻译就是 工作的时候不适合看&#xff0c;真文雅 nsfw效果&#xff0c;注意底部的分数 大体流程&#xff0c;输入图片/视频&#xff0c;输出0-1之间的数字&#xff0c;一般情况下&#xff0c;Scores < 0.2 认为是非…

7zip分卷压缩

前言 有些项目上传文件大小有限制 压缩包大了之后传输也会比较慢 解决方案 我们可以利用7zip压缩工具对文件进行分卷压缩 利用7zip压缩工具进行分卷压缩 查看待压缩文件大小 压缩完成之后有300多M&#xff0c;我们用100M去进行分卷压缩 选择待压缩的文件夹&#xff0c;右…

网络安全 Day30-运维安全项目-容器架构上

容器架构上 1. 什么是容器2. 容器 vs 虚拟机(化) :star::star:3. Docker极速上手指南1&#xff09;使用rpm包安装docker2) docker下载镜像加速的配置3) 载入镜像大礼包&#xff08;老师资料包中有&#xff09; 4. Docker使用案例1&#xff09; 案例01&#xff1a;:star::star::…

《内网穿透》无需公网IP,公网SSH远程访问家中的树莓派

文章目录 前言 如何通过 SSH 连接到树莓派步骤1. 在 Raspberry Pi 上启用 SSH步骤2. 查找树莓派的 IP 地址步骤3. SSH 到你的树莓派步骤 4. 在任何地点访问家中的树莓派4.1 安装 Cpolar内网穿透4.2 cpolar进行token认证4.3 配置cpolar服务开机自启动4.4 查看映射到公网的隧道地…

【JavaEE基础学习打卡02】是时候了解Java EE了!

目录 前言一、为什么要学习Java EE二、Java EE规范介绍1.什么是规范&#xff1f;2.什么是Java EE规范&#xff1f;3.Java EE版本 三、Java EE应用程序模型1.模型前置说明2.模型具体说明 总结 前言 &#x1f4dc; 本系列教程适用于 Java Web 初学者、爱好者&#xff0c;小白白。…

java接口导出csv

1、背景介绍 项目中需要导出数据质检结果&#xff0c;本来使用Excel&#xff0c;但是质检结果数据行数过多&#xff0c;导致用hutool报错&#xff0c;因此转为导出csv格式数据。 2、参考文档 https://blog.csdn.net/ityqing/article/details/127879556 工程环境&#xff1a;…

Redis-分布式锁!

分布式锁&#xff0c;顾名思义&#xff0c;分布式锁就是分布式场景下的锁&#xff0c;比如多台不同机器上的进程&#xff0c;去竞争同一项资源&#xff0c;就是分布式锁。 分布式锁特性 互斥性:锁的目的是获取资源的使用权&#xff0c;所以只让一个竞争者持有锁&#xff0c;这…

PyTorch: clamp函数与梯度的关系

本文主要以下探究这一点&#xff1a;梯度反向传播过程中&#xff0c;测试强行修改后的预测结果是否还会传递loss&#xff1f; clamp应用场景&#xff1a;在深度学习计算损失函数的过程中&#xff0c;会有这样一个问题&#xff0c;如果Label是1.0&#xff0c;而预测结果是0.0&a…

【算法】排序+双指针——leetcode三数之和、四数之和

三数之和 &#xff08;1&#xff09;排序双指针 算法思路&#xff1a; 和之前的两数之和类似&#xff0c;我们对暴力枚举进行了一些优化&#xff0c;利用了排序双指针的思路&#xff1a; 我们先排序&#xff0c;然后固定⼀个数 a &#xff0c;接着我们就可以在这个数后面的区间…

Mybatis Plus Interceptor

Mybatis Plus Interceptor 1 获取表名2 获取SQL 1 获取表名 Component public class MybatisInterceptor implements Interceptor {private static final List<String> EXCLUDE_TABLE new ArrayList<>();static {EXCLUDE_TABLE.add("test");}private s…

OpenCV实例(九)基于深度学习的运动目标检测(一)YOLO运动目标检测算法

基于深度学习的运动目标检测&#xff08;一&#xff09; 1.YOLO算法检测流程2.YOLO算法网络架构3.网络训练模型3.1 训练策略3.2 代价函数的设定 2012年&#xff0c;随着深度学习技术的不断突破&#xff0c;开始兴起基于深度学习的目标检测算法的研究浪潮。 2014年&#xff0c;…

电脑突然黑屏的解决办法

记录一次电脑使用问题 问题描述 基本情况&#xff1a;雷神游戏笔记本 windows10操作系统 64位 使用时间 4年 日期&#xff1a;2023年8月11日 当时 电脑充着电 打开了两个浏览器&#xff1a;edge[页面加载5个左右]&#xff0c;火狐[页面加载1个左右] 两个文件夹 一个百度网盘…

Davinci 报表工具 0.3.0-rc release 文本框模糊查询不生效问题

背景: 在使用过程中发现davinci 的控制器配置中, 取值配置的对应关系设置 包含 或 不包含时 不生效, 不能实现模糊匹配效果, 只能精确查询; 问题分析: 通过跟踪接口及相应代码, 发现在sql 拼接时没有对 like 和 not like 类型的值两侧添加百分号, 导致模糊查询失败 调用过程…

CentOS系统环境搭建(七)——Centos7安装MySQL

centos系统环境搭建专栏&#x1f517;点击跳转 坦诚地说&#xff0c;本文中百分之九十的内容都来自于该文章&#x1f517;Linux&#xff1a;CentOS7安装MySQL8&#xff08;详&#xff09;&#xff0c;十分佩服大佬文章结构合理&#xff0c;文笔清晰&#xff0c;我曾经在这篇文章…