【Python】MacBook M系列芯片Anaconda下载Pytorch,并开发一个简单的数字识别代码(附带踩坑记录)

文章目录

  • 配置镜像源
  • 下载Pytorch
  • 验证
  • 使用Pytorch进行数字识别

配置镜像源

Anaconda下载完毕之后,有两种方式下载pytorch,一种是用页面可视化的方式去下载,另一种方式就是直接用命令行工具去下载。
在这里插入图片描述
但是由于默认的Anaconda走的是外网,所以下载很慢,我们得首先配置镜像源,这里推荐用清华的,之前用中科大的出问题了,换成清华马上就好了。。。

打开Termial或者iTerm2
输入如下命令

conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r
conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/pro
conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2

然后输入如下命令查看是否ok了

conda config --show channels

在输入如下命令

conda config --set show_channel_urls yes

这个时候你的配置基本就完成了,接下来你就可以开始下载了

下载Pytorch

pytorch官网
进入到官网,然后基于你的机器配置选择命令
在这里插入图片描述
然后将命令放入到命令行中进行运行。
特别注意!!!
这里一定要把梯子等工具都关掉,不然会出现HTTP相关的异常。
可以考虑使用如下命令处理一下

conda config --set ssl_verify false

如果踩坑了,从如下几个地方思考:

  1. 镜像源问题,换镜像源
  2. ssl验证关闭,使用上面的命令
  3. 别开梯子!!!!!!!

验证

使用如下命令就可以查看是否安装成功了

conda list | grep pytorch

在这里插入图片描述

使用Pytorch进行数字识别

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from PIL import Image# 定义神经网络模型
class Net(torch.nn.Module):def __init__(self):super().__init__()self.fc1 = torch.nn.Linear(28*28, 64)  # 第一个全连接层,将输入从784维映射到64维self.fc2 = torch.nn.Linear(64, 64)     # 第二个全连接层,将输入从64维映射到64维self.fc3 = torch.nn.Linear(64, 64)     # 第三个全连接层,将输入从64维映射到64维self.fc4 = torch.nn.Linear(64, 10)     # 第四个全连接层,将输入从64维映射到10维(对应10个类别)def forward(self, x):x = torch.nn.functional.relu(self.fc1(x))  # 应用ReLU激活函数x = torch.nn.functional.relu(self.fc2(x))  # 应用ReLU激活函数x = torch.nn.functional.relu(self.fc3(x))  # 应用ReLU激活函数x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)  # 应用log_softmax激活函数return x# 定义数据加载函数
def get_data_loader(is_train):to_tensor = transforms.Compose([transforms.ToTensor()])  # 定义数据转换data_set = MNIST("", is_train, transform=to_tensor, download=True)  # 加载MNIST数据集return DataLoader(data_set, batch_size=15, shuffle=True)  # 创建数据加载器# 定义模型评估函数
def evaluate(test_data, net):n_correct = 0n_total = 0with torch.no_grad():  # 禁用梯度计算for (x, y) in test_data:outputs = net.forward(x.view(-1, 28*28))  # 前向传播计算输出for i, output in enumerate(outputs):if torch.argmax(output) == y[i]:  # 比较预测结果与真实标签n_correct += 1n_total += 1return n_correct / n_total  # 返回准确率# 定义模型保存函数
def save_model(net, path="mnist_model.pth"):torch.save(net.state_dict(), path)  # 保存模型权重到文件# 定义模型加载函数
def load_model(net, path="mnist_model.pth"):net.load_state_dict(torch.load(path))  # 从文件加载模型权重# 定义图像预测函数
def predict_image(image, net):net.eval()  # 设置为评估模式with torch.no_grad():  # 禁用梯度计算output = net(image.view(-1, 28*28))  # 前向传播计算输出predicted = torch.argmax(output, dim=1)  # 获取预测结果return predicted.item()  # 返回预测类别# 定义图像加载函数
def load_image(image_path):image = Image.open(image_path).convert('L')  # 打开图像并转换为灰度图transform = transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor()])  # 定义图像转换image = transform(image)  # 应用转换return image  # 返回处理后的图像def main():train_data = get_data_loader(is_train=True)  # 加载训练数据test_data = get_data_loader(is_train=False)  # 加载测试数据net = Net()  # 初始化神经网络模型# 训练模型optimizer = torch.optim.Adam(net.parameters(), lr=0.001)  # 定义Adam优化器for epoch in range(2):  # 训练2个epochfor (x, y) in train_data:net.zero_grad()  # 清零梯度output = net.forward(x.view(-1, 28*28))  # 前向传播计算输出loss = torch.nn.functional.nll_loss(output, y)  # 计算损失loss.backward()  # 反向传播计算梯度optimizer.step()  # 更新模型参数print("epoch", epoch, "accuracy:", evaluate(test_data, net))  # 打印每个epoch后的准确率# 保存模型save_model(net)# 加载模型net = Net()  # 初始化新的神经网络模型load_model(net)  # 加载已保存的模型权重print("Loaded model accuracy:", evaluate(test_data, net))  # 打印加载模型后的准确率# 使用模型预测新图像image_path = "path_to_your_image.png"  # 替换为你要预测的图像路径image = load_image(image_path)  # 加载并预处理图像prediction = predict_image(image, net)  # 使用模型进行预测print(f"Predicted digit: {prediction}")  # 打印预测结果if __name__ == "__main__":main()  # 运行main函数

第一次运行的时候,会加载数字识别模型到本地,第二次运行的时候,你就可以把训练过程的代码都注释掉了,直接使用这个最终的模型
在这里插入图片描述
第二次运行
你的模型就是这个pth文件
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/865030.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

杂学可以查看各个网络学习

WZMIAOMIAO (WuZhe) GitHub这个是一个各个深度学习的合集,有代码

主干网络篇 | YOLOv8改进之引入YOLOv10的主干网络 | 全网最新改进

前言:Hello大家好,我是小哥谈。YOLOv10是由清华大学研究人员利用Ultralytics Python软件包开发的,它通过改进模型架构并消除非极大值抑制(NMS)提供了一种新颖的实时目标检测方法。这些优化使得模型在保持先进性能的同时,降低了计算需求。与以往的YOLO版本不同,YOLOv10的…

突发!Runway的Gen-3向所有人开放,媲美Sora!

7月2日凌晨,著名生成式AI平台Runway在官网宣布,其文生视频模型Gen-3 Alpha向所有用户开放使用。 上周日Runway只向部分用户提供了Gen-3的使用权限,「AIGC开放社区」也为大家解读了10个非常有代表性的视频案例。(点击查看&#xf…

快速上手文心一言指令:解锁AI交互新体验

文心一言,作为百度研发的预训练语言模型,以其强大的语言理解和生成能力,为用户提供了丰富的交互体验。通过一系列精心设计的指令,用户可以轻松地与模型进行对话,获取信息、解决问题、启发灵感。本文将详细介绍如何快速…

晚上睡觉要不要关路由器?一语中的

前言 前几天小白去了一个朋友家,有朋友说:路由器不关机的话会影响睡眠吗? 这个影响睡眠嘛,确实是会的。毕竟一时冲浪一时爽,一直冲浪一直爽……刷剧刷抖音刷到根本停不下来,肯定影响睡眠。 所以晚上睡觉要…

昇思MindSpore学习笔记2-04 LLM原理和实践--文本解码原理--以MindNLP为例

摘要: 介绍了昇思MindSpore AI框架采用贪心搜索、集束搜索计算高概率词生成文本的方法、步骤,并为解决重复等问题所作的多种尝试。 这一节完全看不懂,猜测是如何用一定范围的词造句。 一、概念 自回归语言模型 文本序列概率分布 分解为每…

XML Schema 实例

XML Schema 实例 XML Schema 是一种用于定义 XML 文档结构和内容的语言。它提供了一种强大的方式来描述 XML 文档中允许的元素、属性和数据类型。在本篇文章中,我们将通过一系列实例来探讨 XML Schema 的基本概念和高级特性。 XML Schema 基础 XML Schema 是基于 XML 的,这…

多模态融合 + 慢病精准预测

多模态融合 慢病精准预测 慢病预测算法拆解子解法1:多模态数据集成子解法2:实时数据处理与更新子解法3:采用大型语言多模态模型(LLMMs)进行深度学习分析 慢病预测更多模态 论文:https://arxiv.org/pdf/2406…

自动化测试用例设计-软件测试基本概念解析

软件测试基本概念解析 1. 引言:软件测试的重要性​ 在当今这个数字化时代,软件质量直接关系到企业的竞争力和用户满意度。一个小小的bug可能造成重大经济损失,甚至影响品牌形象。因此,软件测试成为了确保软件可靠性的关键环节&a…

发电机保护屏组成都有哪些,如何选择

发电机保护屏组成都有哪些,如何选择 发电机是电力系统中最常用的一种电力设备。例如水力发电机,柴油发电机,风力发电机,火力发电等等。发电机保护是保证发电机安全、稳定运行的重要手段之一。对于一些小型机组的发电机&#xff0c…

探囊取物之多形式注册页面(基于BootStrap4)

基于BootStrap4的注册页面,支持手机验证码注册、账号密码注册 低配置云服务器,首次加载速度较慢,请耐心等候;演练页面可点击查看源码 预览页面:http://www.daelui.com/#/tigerlair/saas/preview/ly4gax38ub9j 演练页…

RTSP协议在视频监控系统中的典型应用、以及视频监控设备的rtsp地址格式介绍

目录 一、协议概述 1、定义 2、提交者 3、位置 二、主要特点 1、实时性 2、可扩展性 3、控制功能 4、回放支持 5、网络适应性 三、RTSP的工作原理 1、会话准备 2、会话建立 3、媒体流控制 4、会话终止 5、媒体数据传输 四、协议功能 1、双向性 2、带外协议 …

趣玩双色球APP-PyQt5实现

开发环境及软件主要功能说明 开发环境 win10 Vscode Python10.5-64_bit 使用的python库 requests,bs4,pandas,PyQt5 主要功能说明: 数据库更新,保存,另存为功能过滤显示,根据期数,开奖日期,开间期号过…

NativeMemoryTracking查看java内存信息

默认该功能是禁用的,因为会损失5-10%的性能 开启命令 -XX:NativeMemoryTrackingdetail 打印命令 jcmd 45064 VM.native_memory summary scaleMB > NativeMemoryTracking.log 具体的日志信息 ➜ ~ ➜ ~ jcmd 45064 VM.native_memory summary scaleMB 45064…

AndroidStudio activity-1.8.0.aar依赖报错

在使用Androidstudio自帶的創建activity及配套 xml時,構建項目失敗,報錯内容: Null extracted folder for artifact: ResolvedArtifact(componentIdentifierandroidx.activity:activity:1.8.0, variantNamenull, artifactFileC:\Users\hhhh\.…

Golang 开发实战day15 - Input info

🏆个人专栏 🤺 leetcode 🧗 Leetcode Prime 🏇 Golang20天教程 🚴‍♂️ Java问题收集园地 🌴 成长感悟 欢迎大家观看,不执着于追求顶峰,只享受探索过程 Golang 开发实战day15 - 用户…

object对象类型截取实现数组的slice效果

slice是数组的方法,而对象(Object)和数组是两种不同的数据结构。对象没有索引(index)的概念。 对象的属性是通过键(key)来访问的,而这些键并不保证是整数或连续的。 1、获取对象的键…

AMEYA360:类比半导体推出36V超低输入偏置电流高性能通用运算放大器

在精密信号处理领域,每一次技术创新都意味着性能的飞跃与应用的拓展。上海类比半导体技术有限公司(以下简称“类比半导体”)凭借其在模拟及数模混合芯片设计领域的深厚积累,今日正式宣布推出其全新OPJ301x系列超低输入偏置电流高性能通用运算放大器。该系…

Canvas 指纹:它是什么以及如何绕过它

什么是 Canvas 指纹? 网络浏览器在执行其功能时会收集各种信息。当这些信息中的某些被用于识别网站用户时,这被称为浏览器指纹。 浏览器指纹包括以下有关浏览器的信息:设备型号、浏览器类型和版本、操作系统 (OS)、屏幕分辨率、时区、p0p 文…

AI大模型对话(上下文)缓存能力

互联网应用中,为了提高数据获取的即时性,产生了各种分布式缓存组件,比如Redis、Memcached等等。 大模型时代,除非是免费模型,否则每次对话都会花费金钱来进行对话,对话是不是也可以参照缓存的做法来提高命…