pytorch升级打怪(一)

基础介绍

  • 学习基础知识
    • 机器学习的基本流程
  • 快速入门
    • 一个简单的目标分类任务
    • 执行过程

学习基础知识

机器学习的基本流程

  • 数据处理
  • 创建模型
  • 优化模型参数
  • 保存训练的模型

快速入门

一个简单的目标分类任务

识别衣服的类型

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# Define model
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsdef train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")if __name__ == "__main__":# 处理数据# Download training data from open datasets.training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),)# Download test data from open datasets.test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor(),)batch_size = 64# Create data loaders.train_dataloader = DataLoader(training_data, batch_size=batch_size)test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break# 创建模型# Get cpu, gpu or mps device for training.device = ("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available()else "cpu")print(f"Using {device} device")model = NeuralNetwork().to(device)print(model)# 优化模型参数# 损失函数loss_fn = nn.CrossEntropyLoss()# 优化器optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)epochs = 20for t in range(epochs):print(f"Epoch {t + 1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)print("Done!")# 保存模型torch.save(model.state_dict(), "model.pth")print("Saved PyTorch Model State to model.pth")# 加载模型model = NeuralNetwork().to(device)model.load_state_dict(torch.load("model.pth"))# 使用训练的模型进行预测"""“t恤/顶”,“裤子”,“套衫”,“衣服”,“外套”,“凉鞋”,“衬衫”,“运动鞋”,“包”,“踝靴”,"""classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",]model.eval()x, y = test_data[0][0], test_data[0][1]with torch.no_grad():x = x.to(device)pred = model(x)predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')

执行过程


/Users/futuredeng/anaconda3/envs/pyspide6_study/bin/python -X pycache_prefix=/Users/futuredeng/Library/Caches/JetBrains/PyCharm2024.1/cpython-cache /Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py --multiprocess --qt-support=auto --client 127.0.0.1 --port 52646 --file /Users/futuredeng/PycharmProjects/pyspide6_study/s_torch/demo.py 
已连接到 pydev 调试器(内部版本号 241.14494.19)Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
Using mps device
NeuralNetwork((flatten): Flatten(start_dim=1, end_dim=-1)(linear_relu_stack): Sequential((0): Linear(in_features=784, out_features=512, bias=True)(1): ReLU()(2): Linear(in_features=512, out_features=512, bias=True)(3): ReLU()(4): Linear(in_features=512, out_features=10, bias=True))
)
Epoch 1
-------------------------------
loss: 2.298847  [   64/60000]
loss: 2.291248  [ 6464/60000]
loss: 2.278691  [12864/60000]
loss: 2.270169  [19264/60000]
loss: 2.247777  [25664/60000]
loss: 2.226532  [32064/60000]
loss: 2.221170  [38464/60000]
loss: 2.191688  [44864/60000]
loss: 2.186391  [51264/60000]
loss: 2.159593  [57664/60000]
Test Error: Accuracy: 48.9%, Avg loss: 2.151264 Epoch 2
-------------------------------
loss: 2.160394  [   64/60000]
loss: 2.149280  [ 6464/60000]
loss: 2.097811  [12864/60000]
loss: 2.111865  [19264/60000]
loss: 2.052902  [25664/60000]
loss: 2.002435  [32064/60000]
loss: 2.016076  [38464/60000]
loss: 1.941067  [44864/60000]
loss: 1.946122  [51264/60000]
loss: 1.868463  [57664/60000]
Test Error: Accuracy: 58.9%, Avg loss: 1.870417 Epoch 3
-------------------------------
loss: 1.907855  [   64/60000]
loss: 1.873430  [ 6464/60000]
loss: 1.759730  [12864/60000]
loss: 1.795776  [19264/60000]
loss: 1.683292  [25664/60000]
loss: 1.641434  [32064/60000]
loss: 1.654433  [38464/60000]
loss: 1.561658  [44864/60000]
loss: 1.587600  [51264/60000]
loss: 1.478837  [57664/60000]
Test Error: Accuracy: 60.4%, Avg loss: 1.501103 Epoch 4
-------------------------------
loss: 1.573581  [   64/60000]
loss: 1.536037  [ 6464/60000]
loss: 1.387229  [12864/60000]
loss: 1.457505  [19264/60000]
loss: 1.340816  [25664/60000]
loss: 1.339017  [32064/60000]
loss: 1.352573  [38464/60000]
loss: 1.279530  [44864/60000]
loss: 1.314921  [51264/60000]
loss: 1.217413  [57664/60000]
Test Error: Accuracy: 63.3%, Avg loss: 1.242960 Epoch 5
-------------------------------
loss: 1.320792  [   64/60000]
loss: 1.301409  [ 6464/60000]
loss: 1.135017  [12864/60000]
loss: 1.243455  [19264/60000]
loss: 1.120873  [25664/60000]
loss: 1.144230  [32064/60000]
loss: 1.168045  [38464/60000]
loss: 1.104519  [44864/60000]
loss: 1.145055  [51264/60000]
loss: 1.060252  [57664/60000]
Test Error: Accuracy: 64.9%, Avg loss: 1.082238 Done!
Saved PyTorch Model State to model.pth
Predicted: "Ankle boot", Actual: "Ankle boot"进程已结束,退出代码为 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/734465.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

红包题第一弹

下载附件,发现有86个压缩包 现每个压缩包里面都有图片,010打开图片末尾都有base64部分,并且每个压缩包里面图片末尾的base64长度一样,刚好每一张的base64长度为100。猜测需要拼接起来然后解码 写个python脚本 import os import …

sql server 数据删除操作

删除 DELETE FROM table_name WHERE id < 10000; DELETE FROM [allmedia_restore] WHERE [SYS_DOCUMENTID] < 10000; 清空表 truncate table tableA truncate table [allmedia_restore].[dbo].[DOM_4_DOCLIB] 倒序排序 order by [字段] desc SELECT TOP 1000 [SYS_…

阿里云服务器9元1个月优惠价格表

阿里云服务器9元1个月优惠价格表&#xff0c;用不上9元&#xff0c;又降价了&#xff0c;只要5元。阿里云服务器一个月多少钱&#xff1f;最便宜5元1个月。阿里云轻量应用服务器2核2G3M配置61元一年&#xff0c;折合5元一个月&#xff0c;2核4G服务器30元3个月&#xff0c;2核2…

详解Mysql中redo log、undo log、bin log

目录 1 redo log&#xff08;重做日志&#xff09;2 undo log&#xff08;回滚日志&#xff09;3 Binlog&#xff08;二进制日志&#xff09;4 两阶段提交4.1 执行过程4.2 系统崩溃后重启如何刷新数据4.3 redo log 和 bin log区别 MySQL是一个关系型数据库管理系统&#xff0c;…

对象注入的几种方式

⭐ 作者&#xff1a;小胡_不糊涂 &#x1f331; 作者主页&#xff1a;小胡_不糊涂的个人主页 &#x1f4c0; 收录专栏&#xff1a;JavaEE &#x1f496; 持续更文&#xff0c;关注博主少走弯路&#xff0c;谢谢大家支持 &#x1f496; 注入对象 1. 属性注入2. 构造方法注入3. S…

微信小程序uniapp+django+python的酒店民宿预订系统ea9i3

Android的民宿预订系统设计的目的是为用户提供民宿客房、公告信息等方面的平台。 与PC端应用程序相比&#xff0c;Android的民宿预订系统的设计主要面向于民宿&#xff0c;旨在为管理员和用户、商家提供一个Android的民宿预订系统。用户可以通过Android及时查看民宿客房等。 An…

随机选择器

说明&#xff1a; 在阅读本公司源码时发现了一段实现随机选择器的代码&#xff0c;感觉不错&#xff0c;现分享出来。 public class RandomSelector {private final NavigableMap<Integer, Object> map new TreeMap<>();private Integer total 0;Random random …

滑动窗算一下rms

clear clc close all fs20; width16; height16; t(1/fs:1/fs:200); signalsin(2*pi*0.1)rand(length(t),1)3/100*t; figure(1) set(gcf,units,centimeters,Position,[1,2height,width,height]) plot(t,signal) smoothed_avg_values smooth(signal, 20); % 这里的10是…

013 Linux_互斥

前言 本文将会向你介绍互斥的概念&#xff0c;如何加锁与解锁&#xff0c;互斥锁的底层原理是什么 线程ID及其地址空间布局 每个线程拥有独立的线程上下文&#xff1a;一个唯一的整数线程ID, 独立的栈和栈指针&#xff0c;程序计数器&#xff0c;通用的寄存器和条件码。 和其…

【C++】深度解剖多态

> 作者简介&#xff1a;დ旧言~&#xff0c;目前大二&#xff0c;现在学习Java&#xff0c;c&#xff0c;c&#xff0c;Python等 > 座右铭&#xff1a;松树千年终是朽&#xff0c;槿花一日自为荣。 > 目标&#xff1a;了解什么是多态&#xff0c;熟练掌握多态的定义&a…

【SpringCloud】微服务重点解析

微服务重点解析 1. Spring Cloud 组件有哪些&#xff1f; 2. 服务注册和发现是什么意思&#xff1f;Spring Cloud 如何实现服务注册和发现的&#xff1f; 如果写过微服务项目&#xff0c;可以说做过的哪个微服务项目&#xff0c;使用了哪个注册中心&#xff0c;常见的有 eurek…

图片在div完全显示

效果图&#xff1a; html代码&#xff1a; <div class"container" style" display: flex;width: 550px;height: 180px;"><div class"box" style" color: red; background-color:blue; width: 50%;"></div><div …

python实现回溯算法

什么是回溯算法&#xff1f; 回溯算法是一种经典的解决组合优化问题、搜索问题以及求解决策问题的算法。它通过不断地尝试各种可能的候选解&#xff0c;并在尝试过程中搜索问题的解空间&#xff0c;直到找到问题的解或者确定问题无解为止。回溯算法常用于解决诸如排列、组合、…

30m二级分类土地利用数据Arcgis预处理及获取

本篇以武汉市为例&#xff0c;主要介绍将土地利用数据转换成武汉市内各区土地利用详情的过程以及分区统计每个区内各地类面积情况&#xff0c;后面还有制作过程中遇到的面积制表后数据过小的解决方法以及一些相关的知识点&#xff1a; 示例数据下载链接&#xff1a;数据下载链…

2024年阿里云服务器新用户购买一个月多少钱?

阿里云服务器一个月多少钱&#xff1f;最便宜5元1个月。阿里云轻量应用服务器2核2G3M配置61元一年&#xff0c;折合5元一个月&#xff0c;2核4G服务器30元3个月&#xff0c;2核2G3M带宽服务器99元12个月&#xff0c;轻量应用服务器2核4G4M带宽165元12个月&#xff0c;4核16G服务…

UnicodeDecodeError: ‘gbk‘和Error: Command ‘pip install ‘pycocotools>=2.0

今天重新弄YOLOv5的时候发现不能用了&#xff0c;刚开始给我报这个错误 subprocess.CalledProcessError: Command ‘pip install ‘pycocotools&#xff1e;2.0‘‘ returned non-zero exit statu 说这个包安装不了 根据他的指令pip install ‘pycocotools&#xff1e;2.0这个根…

哥德巴赫猜想

七十年代末八十年代初&#xff0c;哥德巴赫猜想在中国风靡一时&#xff0c;来源于徐迟的一篇同名报告文学。我还是小孩子&#xff0c;记得大人们叽里咕噜疯传。 “哇&#xff0c;不得了。陈景润证明了1&#xff0b;2&#xff1d;3&#xff0c;离1&#xff0b;1&#xff1d;2就…

misc30

rar解压得到 发现只有中间的图片可以分析&#xff0c;另外两个都有密码 那就先分析星空&#xff0c;属性里面发现 使用该密码可以解压doc文本&#xff0c;发现doc隐写 使用此密码&#xff08;Hello friend!)解压图片,得到一个二维码 扫码得到flag flag{welcome_to_ctfshow}

【Web】浅聊Java反序列化之Rome——关于其他利用链

目录 前言 JdbcRowSetImpl利用链 BasicDataSource利用链 Hashtable利用链 BadAttributeValueExpException利用链 HotSwappableTargetSource利用链 前文&#xff1a;【Web】浅聊Java反序列化之Rome——EqualsBean&ObjectBean-CSDN博客 前言 Rome中ToStringBean的利用…

(001)UV 的使用以及导出

文章目录 UV窗口导出模型的主要事项导出时材质的兼容问题unity贴图导出导出FBX附录 UV窗口 1.uv主要的工作区域&#xff1a; 2.在做 uv 和贴图之前&#xff0c;最好先应用下物体的缩放、旋转。 导出模型的主要事项 1.将原点设置到物体模型的底部&#xff1a; 2.应用修改器的…