[oneAPI] 手写数字识别-BiLSTM

[oneAPI] 手写数字识别-BiLSTM

  • 手写数字识别
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手写数字识别

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用MNIST数据集,该数据集包含了一系列以黑白图像表示的手写数字,每个图像的大小为28x28像素,数据集组成如下:

  • 训练集:包含60,000个图像和标签,用于训练模型。
  • 测试集:包含10,000个图像和标签,用于测试模型的性能。

每个图像都被标记为0到9之间的一个数字,表示图像中显示的手写数字。这个数据集常常被用来验证图像分类模型的性能,特别是在计算机视觉领域。

参数与包

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transformsimport intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003

加载数据

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',train=True,transform=transforms.ToTensor(),download=True)test_dataset = torchvision.datasets.MNIST(root='../../data/',train=False,transform=transforms.ToTensor())# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)

模型

# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(BiRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)self.fc = nn.Linear(hidden_size * 2, num_classes)  # 2 for bidirectiondef forward(self, x):# Set initial statesh0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)  # 2 for bidirectionc0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)# Forward propagate LSTMout, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)# Decode the hidden state of the last time stepout = self.fc(out[:, -1, :])return out

训练过程

model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 100 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))# Test the model
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

结果

在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# 模型
model = ConvNet(num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/45490.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python rtsp 硬件解码 二

上次使用了python的opencv模块 述说了使用PyNvCodec 模块,这个模块本身并没有rtsp的读写,那么读写rtsp是可以使用很多方法的,我们为了输出到pytorch直接使用AI程序,简化rtsp 输入,可以直接使用ffmpeg的子进程 方法一 …

STM8遇坑[EEPROM读取debug不正常release正常][ STVP下载成功单运行不成功][定时器消抖莫名其妙的跑不通流程]

EEPROM读取debug不正常release正常 这个超级无语,研究和半天,突然发现调到release就正常了,表现为写入看起来正常读取不正常,这个无语了,不想研究了 STVP下载不能够成功运行 本文摘录于:https://blog.csdn.net/qlexcel/article/details/71270780只是做学习备份之…

GEEMAP 中如何拉伸图像

图像拉伸是最基础的图像增强显示处理方法,主要用来改善图像显示的对比度,地物提取流程中往往首先要对图像进行拉伸处理。图像拉伸主要有三种方式:线性拉伸、直方图均衡化拉伸和直方图归一化拉伸。 GEE 中使用 .sldStyle() 的方法来进行图像的…

8.5.tensorRT高级(3)封装系列-基于生产者消费者实现的yolov5封装

目录 前言1. yolov5封装总结 前言 杜老师推出的 tensorRT从零起步高性能部署 课程,之前有看过一遍,但是没有做笔记,很多东西也忘了。这次重新撸一遍,顺便记记笔记。 本次课程学习 tensorRT 高级-基于生产者消费者实现的yolov5封装…

postgresql中基础sql查询

postgresql中基础sql查询 创建表插入数据创建索引删除表postgresql命令速查简单查询计算查询结果 利用查询条件过滤数据模糊查询 创建表 -- 部门信息表 CREATE TABLE departments( department_id INTEGER NOT NULL -- 部门编号,主键, department_name CHARACTE…

CentOS6.8图形界面安装Oracle11.2.0.1.0

Oracle11下载地址 https://edelivery.oracle.com/osdc/faces/SoftwareDelivery 一、环境 CentOS release 6.8 (Final),测试环境:内存2G,硬盘20G,SWAP空间4G Oracle版本:Release 11.2.0.1.0 安装包:V175…

Lookup Singularity

1. 引言 Lookup Singularity概念 由Barry WhiteHat在2022年11月在zkResearch论坛 Lookup Singularity中首次提出: 其主要目的是:让SNARK前端生成仅需做lookup的电路。Barry预测这样有很多好处,特别是对于可审计性 以及 形式化验证&#xff…

【学习FreeRTOS】第8章——FreeRTOS列表和列表项

1.列表和列表项的简介 列表是 FreeRTOS 中的一个数据结构,概念上和链表有点类似,列表被用来跟踪 FreeRTOS中的任务。列表项就是存放在列表中的项目。 列表相当于链表,列表项相当于节点,FreeRTOS 中的列表是一个双向环形链表列表的…

微软Win11 Dev预览版Build23526发布

近日,微软Win11 Dev预览版Build23526发布,修复了不少问题。牛比如斯Microsoft,也有这么多bug,所以你写再多bug也不作为奇啊。 主要更新问题 [开始菜单] 修复了在高对比度主题下,打开开始菜单中的“所有应…

Spring Boot通过企业邮箱发件被Gmail退回的解决方法

这两天给我们开发的Chrome插件:Youtube中文配音 增加了账户注册和登录功能,其中有一步是邮箱验证,所以这边会在Spring Boot后台给用户的邮箱发个验证信息。如何发邮件在之前的文章教程里就有,这里就不说了,着重说说这两…

通过 kk 创建 k8s 集群和 kubesphere

官方文档:多节点安装 确保从正确的区域下载 KubeKey export KKZONEcn下载 KubeKey curl -sfL https://get-kk.kubesphere.io | VERSIONv3.0.7 sh -为 kk 添加可执行权限: chmod x kk创建 config 文件 KubeSphere 版本:v3.3 支持的 Kuber…

Linux 安全技术和防火墙

目录 1 安全技术 2 防火墙 2.1 防火墙的分类 2.1.1 包过滤防火墙 2.1.2 应用层防火墙 3 Linux 防火墙的基本认识 3.1 iptables & netfilter 3.2 四表五链 4 iptables 4.2 数据包的常见控制类型 4.3 实际操作 4.3.1 加新的防火墙规则 4.3.2 查看规则表 4.3.…

企事业数字培训及知识库平台

前言 随着信息化的进一步推进,目前各行各业都在进行数字化转型,本人从事过医疗、政务等系统的研发,和客户深入交流过日常办公中“知识”的重要性,再加上现在倡导的互联互通、数据安全、无纸化办公等概念,所以无论是企业…

打家劫舍 II——力扣213

动规 int robrange(vector<int>& nums, int start, int end){int first=nums[start]

CountDownLatch和CyclicBarrie

前置提要 什么是闭锁对象 闭锁对象&#xff08;Latch Object&#xff09;是一种同步工具&#xff0c;用于控制线程的等待和执行顺序。闭锁对象可以让一个或多个线程等待&#xff0c;直到特定的条件满足后才能继续执行。 在Java中&#xff0c;CountDownLatch就是一种常见的闭锁对…

STC15单片机PM2.5空气质量检测仪

一、系统方案 本设计采用STC15单片机作为主控制器&#xff0c;PM2.5传感器、按键设置&#xff0c;液晶1602显示&#xff0c;蜂鸣器报警。 二、硬件设计 原理图如下&#xff1a; 三、单片机软件设计 1、首先是系统初始化&#xff1a; void lcd_init()//液晶初始化设置 { de…

SQLite数据库实现数据增删改查

当前文章介绍的设计的主要功能是利用 SQLite 数据库实现宠物投喂器上传数据的存储&#xff0c;并且支持数据的增删改查操作。其中&#xff0c;宠物投喂器上传的数据包括投喂间隔时间、水温、剩余重量等参数。 实现功能&#xff1a; 创建 SQLite 数据库表&#xff0c;用于存储宠…

第一讲:BeanFactory和ApplicationContext接口

BeanFactory和ApplicationContext接口 1. 什么是BeanFactory?2. BeanFactory能做什么&#xff1f;3.ApplicationContext对比BeanFactory的额外功能?3.1 MessageSource3.2 ResourcePatternResolver3.3 EnvironmentCapable3.4 ApplicationEventPublisher 4.总结 1. 什么是BeanF…

解决C#报“MSB3088 未能读取状态文件*.csprojAssemblyReference.cache“问题

今天在使用vscode软件C#插件&#xff0c;编译.cs文件时&#xff0c;发现如下warning: 图(1) C#报cache没有更新 出现该warning的原因&#xff1a;当前.cs文件修改了&#xff0c;但是其缓存文件*.csprojAssemblyReference.cache没有更新&#xff0c;需要重新清理一下工程&#x…

【机器学习实战】朴素贝叶斯:过滤垃圾邮件

【机器学习实战】朴素贝叶斯&#xff1a;过滤垃圾邮件 0.收集数据 这里采用的数据集是《机器学习实战》提供的邮件文件&#xff0c;该文件有ham 和 spam 两个文件夹&#xff0c;每个文件夹中各有25条邮件&#xff0c;分别代表着 正常邮件 和 垃圾邮件。 这里需要注意的是需要…