pytoch M2芯片测试

今天才发现我的新片是M2芯片,而不是M1芯片,有点尴尬
在这里插入图片描述
参考网址
https://www.oldcai.com/ai/pytorch-train-MNIST-with-gpu-on-mac/

测试结果如下

M2_cpu.py

# https://www.oldcai.com/ai/pytorch-train-MNIST-with-gpu-on-mac/
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensorprint(f"PyTorch version: {torch.__version__}")# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")# Set the device
device = "cpu"
device = torch.device(device)
print(f"Using device: {device}")# Define the CNN model
class HandwritingRecognitionModel(nn.Module):def __init__(self):super().__init__()# Define the convolutional layersself.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)# Define the pooling and dropout layersself.pool = nn.MaxPool2d(2, 2)self.dropout1 = nn.Dropout(0.25)self.dropout2 = nn.Dropout(0.5)# Define the fully connected layersself.fc1 = nn.Linear(32 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):# Pass the input through the convolutional layersx = self.conv1(x)x = self.pool(x)x = self.dropout1(x)x = self.conv2(x)x = self.pool(x)x = self.dropout2(x)# Reshape the output for the fully connected layersx = x.view(-1, 32 * 7 * 7)# Pass the output through the fully connected layersx = self.fc1(x)x = self.fc2(x)# Return the final outputreturn x# Load the MNIST dataset
train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
test_dataset = MNIST("./data", train=False, download=True, transform=ToTensor())# Define the data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# Define the model
model = HandwritingRecognitionModel().to(device)# Define the loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)from time import time 
t0 = time()
# Train the model for 10 epochs
for epoch in range(10):# Set the model to training modemodel.train()# Iterate over the training datafor images, labels in train_loader:images, labels = images.to(device), labels.to(device)# Pass the input through the modeloutputs = model(images)# Compute the lossloss = loss_fn(outputs, labels)# Backpropagate the errorloss.backward()# Update the model parametersoptimizer.step()# Set the model to evaluation modemodel.eval()# Evaluate the model on the validation setwith torch.no_grad():correct = 0total = 0for images, labels in test_loader:images, labels = images.to(device), labels.to(device)# Pass the input through the modeloutputs = model(images)# Get the predicted labels_, predicted = torch.max(outputs.data, 1)# Update the total and correct countstotal += labels.size(0)correct += (predicted == labels).sum()# Print the accuracyprint(f"Epoch {epoch + 1}: Accuracy = {100 * correct / total:.2f}%")t1 =time()
print("10 epoch cost {}s".format(t1-t0))

结果如下
在这里插入图片描述

M2_MPS.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensorprint(f"PyTorch version: {torch.__version__}")# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")# Set the device
device = "mps" if torch.backends.mps.is_available() else "cpu"
device = torch.device(device)
print(f"Using device: {device}")# Define the CNN model
class HandwritingRecognitionModel(nn.Module):def __init__(self):super().__init__()# Define the convolutional layersself.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)# Define the pooling and dropout layersself.pool = nn.MaxPool2d(2, 2)self.dropout1 = nn.Dropout(0.25)self.dropout2 = nn.Dropout(0.5)# Define the fully connected layersself.fc1 = nn.Linear(32 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):# Pass the input through the convolutional layersx = self.conv1(x)x = self.pool(x)x = self.dropout1(x)x = self.conv2(x)x = self.pool(x)x = self.dropout2(x)# Reshape the output for the fully connected layersx = x.view(-1, 32 * 7 * 7)# Pass the output through the fully connected layersx = self.fc1(x)x = self.fc2(x)# Return the final outputreturn x# Load the MNIST dataset
train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
test_dataset = MNIST("./data", train=False, download=True, transform=ToTensor())# Define the data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# Define the model
model = HandwritingRecognitionModel().to(device)# Define the loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)from time import time 
t0 = time()
# Train the model for 10 epochs
for epoch in range(10):# Set the model to training modemodel.train()# Iterate over the training datafor images, labels in train_loader:images, labels = images.to(device), labels.to(device)# Pass the input through the modeloutputs = model(images)# Compute the lossloss = loss_fn(outputs, labels)# Backpropagate the errorloss.backward()# Update the model parametersoptimizer.step()# Set the model to evaluation modemodel.eval()# Evaluate the model on the validation setwith torch.no_grad():correct = 0total = 0for images, labels in test_loader:images, labels = images.to(device), labels.to(device)# Pass the input through the modeloutputs = model(images)# Get the predicted labels_, predicted = torch.max(outputs.data, 1)# Update the total and correct countstotal += labels.size(0)correct += (predicted == labels).sum()# Print the accuracyprint(f"Epoch {epoch + 1}: Accuracy = {100 * correct / total:.2f}%")t1 =time()
print("10 epoch cost {}s".format(t1-t0))

结果如下
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/101725.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

antd Form shouldUpdate 关联展示 form 数组赋值

form 数组中嵌套数值更新 注意:数组是引用类型 项目需求,表单中包含多个产品信息,使用form.list 数组嵌套,提货方式如果是邮寄展示地址,如果是自提,需要在该条目中增加两项 代码如下:// An hi…

Nacos(替代Eureka)注册中心

Nacos初步学习 Nacos 是一个开源的服务注册和配置中心,它允许您注册、注销和发现服务实例,并提供了配置管理的功能。下面是Nacos的最基础用法: 1. 服务注册和发现: 首先,您需要将您的应用程序或服务注册到Nacos中。…

黑马JVM总结(三十一)

(1)类加载器-概述 启动类加载器-扩展类类加载器-应用程序类加载器 双亲委派模式: 类加载器,加载类的顺序是先依次请问父级有没有加载,没有加载自己才加载,扩展类加载器在getParent的时候为null 以为Boots…

《设计一款2轮车充电桩系统》

以深圳为例,深圳有400万台电动2轮车,以每个月电费20元计算,深圳每个月用在2轮车充电上的费用为8000万左右。1年10个亿的市场规模。 前景可观,竞争也非常激烈。 本文主要讨论技术实现方案。 方法: 24v/36v直流输出 需…

接口自动化测试 —— 协议、请求流程

一、架构 CRM客户关系管理系统 SAAS Software As A Service 软件即服务 PAAS Platform AS A Service 平台即服务 快速交付→ 快:自己去干、有结果、事事有回音、持续改进 单体架构——》垂直架构——》面向服务架构——》微服务架构(分布式&#xf…

C#(Csharp)我的基础教程(四)(我的菜鸟教程笔记)-Windows项目结构分析、UI设计和综合事件应用的探究与学习

目录 windows项目是我们.NET学习一开始必备的内容。 1、窗体类(主代码文件窗体设计器后台代码文件) 主窗体对象的创建:在Program类里面: Application.Run(new FrmMain());这句代码就决定了,当前窗体是项目的主窗体。…

Vuex基础使用存取值+异步请求

一.Vuex简介 vuex是什么? Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式。它采用集中式存储管理应用的所有组件的状态,并以相应的规则保证状态以一种可预测的方式发生变化。 什么情况使用 Vuex? Vuex 可以帮助我们管理共享状态&#…

阿里云域名免费配置HTTPS

阿里云域名配置HTTPS - 知乎

vue原生实现element上传多张图片浏览删除

vue原生实现element上传多张图片浏览删除 <div class"updata-component" style"width:100%;"><div class"demo-upload-box clearfix"><div class"demo-upload-image-box" v-if"imageUrlArr && imageUrlAr…

算法-动态规划-编辑距离

算法-动态规划-编辑距离 1 题目概述 1.1 题目出处 https://leetcode.cn/problems/longest-increasing-subsequence/ 1.2 题目描述 2 动态规划 2.1 思路 dp[i][j] 表示 word1[0,i) 变换为 word2[0,j)的最少步数&#xff0c;那么转移表达式&#xff1a; i和j上的字符相同时…

[网鼎杯 2018]Comment git泄露 / 恢复 二次注入 .DS_Store bash_history文件查看

首先我们看到账号密码有提示了 我们bp爆破一下 我首先对数字爆破 因为全字符的话太多了 爆出来了哦 所以账号密码也出来了 zhangwei zhangwei666 没有什么用啊 扫一下吧 有git git泄露 那泄露看看 真有 <?php include "mysql.php"; session_start(); if(…

子层连接结构

目录 1、子层连接结构介绍 2、子层连接结构 3、代码实现 1、子层连接结构介绍 输入到每个子层以及规范化层的过程中&#xff0c;还使用了残差连接&#xff08;跳跃连接&#xff09;&#xff0c;因此我们把这一部分整体结构叫子层连接&#xff08;代表子层及其连接结构&#xf…

常见Http请求形式

一、请求参数的类型 我们在做boot项目时&#xff0c;常常会向接口发起请求&#xff0c;有些请求需要附带一些参数&#xff0c;比如说分页查询&#xff0c;就需要带上pageNum(当前页)和pageSize(页面大小)等参数 有两种方式可以传递这样的参数 query类型&#xff0c;参数通过…

iPhone 15分辨率,屏幕尺寸,PPI 详细数据对比 iPhone 15 Plus、iPhone 15 Pro、iPhone 15 Pro Max

史上最全iPhone 机型分辨率&#xff0c;屏幕尺寸&#xff0c;PPI详细数据&#xff01;已更新到iPhone 15系列&#xff01; 点击放大查看高清图 &#xff01;

MDK自动生成带校验带SVN版本号的升级文件

MDK自动生成带校验带SVN版本号的升级文件 获取SVN版本信息 确保SVN安装了命令行工具&#xff0c;默认安装时不会安装命令行工具 编写一个模板头文件 svn_version.temp.h, 版本号格式为 1_0_0_SVN版本号 #ifndef __SVN_VERSION_H #define __SVN_VERSION_H#define SVN_REVISIO…

web前端面试-- js深拷贝的一些bug,特殊对象属性(RegExp,Date,Error,Symbol,Function)处理,循环引用weekmap处理

本人是一个web前端开发工程师&#xff0c;主要是vue框架&#xff0c;整理了一些面试题&#xff0c;今后也会一直更新&#xff0c;有好题目的同学欢迎评论区分享 ;-&#xff09; web面试题专栏&#xff1a;点击此处 文章目录 深拷贝和浅拷贝的区别浅拷贝示例深拷贝示例 特殊对象…

ODrive移植keil(五)—— 开环控制和电流变换

目录 一、开环控制1.1、控制原理1.2、硬件接线1.3、代码说明1.4、程序演示1.5、程序架构的体现 二、电流变换2.1、理论说明2.2、代码说明 ODrive、VESC和SimpleFOC 教程链接汇总&#xff1a;请点击 一、开环控制 在SimpleFOC系列中有开环控制的教程&#xff0c;SimpleFOC移植S…

【C进阶】内存函数

strcpy拷贝的仅仅是字符串&#xff0c;但是内存中的数据不仅仅是字符&#xff0c;所以就有了memcpy函数 1. memcpy void *memcpy &#xff08;void * destination &#xff0c;const void * source , size_t num) 函数memcpy从source的位置开始向后拷贝num个字节的数据到desti…

基于nodejs+vue驾校预约管理系统

通过科技手段提高自身的优势&#xff1b;对于驾校预约管理系统当然也不能排除在外&#xff0c;随着网络技术的不断成熟&#xff0c;带动了驾校预约管理系统&#xff0c; 随着科学技术的飞速发展&#xff0c;各行各业都在努力与现代先进技术接轨&#xff0c;驾校预约管理系统&am…

1.go web之gin框架

Gin框架 一、准备 1.下载依赖 go get -u github.com/gin-gonic/gin2.引入依赖 import "github.com/gin-gonic/gin"3. &#xff08;可选&#xff09;如果使用诸如 http.StatusOK 之类的常量&#xff0c;则需要引入 net/http 包 import "net/http"二、基…