机器学习-pytorch1(持续更新)

上一节我们学习了机器学习的线性模型和非线性模型的机器学习基础知识,这一节主要将公式变为代码

代码编写网站:https://colab.research.google.com/drive

学习课程链接:ML 2022 Spring

1、Load Data(读取数据)

这需要用到pytorch里面的两个函数Dataset和Dataloader

torch.utils.data.Dataset
torch.utils.data.DataLoader

Dataset:是用来存储数据样本和期望值

dataset = MyDataset(file)

Dataloader:批量对数据进行分组,启用多处理

dataloader = DataLoader(dataset, batch_size, shuffle=True)

// 其中对于shuffle的取值,True表示训练,false表示测试

关于Dataset和Dataloader的关系如下:

 

ML 2022 Spring为图片来源

我们读取完数据,是不是想知道我们的数据长什么样子呢?(我们称数据为Tensors)

首先,它可能是一个一维数据,比如一个音频、一个温度

其次,还可能是一个二维数据,比如一张二值图像

最后,还可能是一个三维数据,比如一个彩色的图像

又有问题了,我们怎么通过编程得到我们图像的大小?

可以使用pytorch里面的shape()函数

我们怎么通过编程创造我们的数据呢?

eg:
x = torch.tensor([[1,-1],[-1,1]])
x = torch.from_numpy(np.array([[1,-1],[-1,1]]))
全0或全1数据
x = torch.zeros([2,2])    # 2*2的全0数据
x = torch.ones([1,2,5])    # 1*2*5的全1数据

 其次,还支持矩阵的运算

Addition:z = x + y
Subtraction:z = x - y
Power:y = x.pow(2)
Summation:y = x.sum()
Mean:y = x.mean()
维度转换:x = x.transpose(dim0,dim1)
消除维度:x = x.squeeze(dim)
增加维度:x = x.unsqueeze(dim)
组合:w = torch.cat([x,y,z],dim=1)

拥有不同的数据类型:

使用.to()可以切换到不同的设备:

CPU: x = x.to('cpu')
GPU: x = x.to('cuda')

 这里就又涉及到如何检查你的GPU了?可以使用以下语句检查你的计算机是否有GPU:

torch.cuda.is_available()

如何计算梯度?

 // 注意矩阵一定要使用小数点

2、Define Neural Network(训练和测试神经网络)

torch.nn.Module

线性: 

 非线性:

Sigmoid Activation:nn.Sigmoid()

ReLU Activation:nn.ReLU()

下面我根据所学的知识构建我自己的神经网络:

3、Loss Function(损失函数) 

x = torch.nn.MSELoss    # 对于回归任务
x = torch.nn.CrossEntropyLoss etc.    # 对于分类任务
loss = x(model_output,expected_value)

4、Optimization Algorithm(优化)

torch.optim

这是基于梯度的优化算法,不断调整参数,减少误差

比如:随机梯度下降(SGD)

torch.optim.SGD(model.parameters(), lr, momentum = 0)

* 调用optimizer.zero_grad()重置模型参数的梯度。

*调用loss.backward()反向传播预测loss的梯度。

*调用optimizer.step()调整模型参数。 

5、Entire Procedure(整个程序)

import torch.utils.data as data
dataset = data.Dataset(file)              # 读取数据
tr_set = DataLoader(dataset,batch_size,shuffle=True)  # 对数据集进行分组
model = MyModel().to(device)              # 建立我的模型并且选择我的设备(cpu or gpu)
criterion = nn.MSELoss()                # 建立损失函数
optimizer = torch.optim.SGD(model.parameters(),0.1)   # 建立优化
# 训练
for epoch in range(n_epochs):             # 迭代数据model.train()                    # 训练模型for x, y in tr_set:               # 迭代数据集optimizer.zero_grad()              # 设置梯度为0x, y = x.to(device),y.to(device)       # 将数据移动到设备pred = model(x)                # 计算输出loss = criterion(pred,y)            # 计算损失函数loss.backward()                 # 计算反向梯度optimizer.model()                # 优化模型
# 验证
model.eval()                      # 将模型设置为评估模式
total_loss = 0          
for x,y in dv_set:                  # 对数据集进行迭代x,y = x.to(device),y.to(device)          # 将数据移动到涉笔with torch.no_grad():                # 不可迭代的计算pred = model(x)                # 计算输出loss = criterion(pred,y)           # 计算损失函数total_loss += loss.cpu().item()*len(x)      # 累加损失误差avg_loss = total_loss / len(dv_set.dataset)   # 计算平均损失
# 测试
model.eval()                       # 将模型设置为评估模式
preds = []
for x in dv_set:                   # 对数据集进行迭代x = x.to(device)                  # 将数据移动到涉笔with torch.no_grad():                # 不可迭代的计算pred = model(x)                # 计算输出preds.append(pred.cpu())             # 收集预测

// model.eval()  :更改模型的行为

//  with torch.no_grad() :防止对验证/测试数据进行意外训练

当我们训练完模型,也完成了测试,为了不使模型丢失,我们需要保存模型,pytorch也为我们提供了保存模型的方法。

保存模型:torch.save(model.state_dict(),path)

下次我们使用已经训练完成的模型,或者想继续训练,我们需要读取模型。

读取模型:ckpt = torch.load(path)     model.load_state_dict(ckpt)

// 这只是我根据所听的课自己写的笔记,如果有什么错误欢迎指正!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/733987.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学习Java的第六天

一、变量 1、变量的定义 在程序执行的过程中变量的值会发生变化,直白来说就是用来存储可变化的数据。从本质上讲,变量其实指的是内存中的一小块存储空间,空间位置是确定的,但是里面放置的值不确定。比如屋子里有多个鞋柜&#x…

2024年目标检测研究进展

YOLOv9 图片来源网络 YOLO相关的研究:https://blog.csdn.net/yunxinan/article/details/103431338

网络安全: Kali Linux 进行 SSH 渗透与防御

目录 一、实验 1.环境 2.nmap扫描目标主机 3.Kali Linux 进行 SSH 渗透 3.Kali Linux 进行 SSH 防御 二、问题 1.SSH有哪些安全配置 一、实验 1.环境 (1)主机 表1 主机 系统版本IP备注Kali Linux2022.4 192.168.204.154(动态&…

【软考】单元测试

目录 1. 概念2. 测试内容2.1 说明2.2 模块接口2.3 局部数据结构2.4 重要的执行路径 3. 测试过程2.1 说明2.2 单元测试环境图2.3 驱动模块2.4 桩模块 4. 模块接口测试与局部数据结构测试的区别 1. 概念 1.单元测试也称为模块测试,在模块编写完成且无编译错误后就可以…

16.Git从入门到进阶

一.Git 初识 1. 概念: 一个免费开源,分布式的代码版本控制系统,帮助开发团队维护代码 2. 作用: 记录代码内容,切换代码版本,多人开发时高效合并代码内容 3. 如何学: 个人本机使用&#xf…

数据库中 SQL Hint 是什么?

前言 最近在调研业界其他数据库中 SQL Hint 功能的设计和实现,整体上对 Oracle、Mysql、Postgresql、 Apache Calcite 中的 SQL Hint 的设计和功能都进行了解,这里整理一篇文章来对其进行梳理,一是帮助自己未来回顾,加深自己的思…

Python之Web开发中级教程----搭建Git环境三

Python之Web开发中级教程----搭建Git环境三 多人分布式使用仓库操作实例 场景:开发者A,开发者B在同一个项目协同开发,修改同一个代码文件。开发者A在Win10下,开发者B在Ubuntu下。 1、开发者A修改提交代码 从GitHub: Let’s bu…

44岁「台偶一哥」成现实版「王子变青蛙」,育一子一女成人生赢家

电影《周处除三害》近日热度极高,男主角阮经天被大赞演技出色,最让人意想不到,因为该片在内地票房报捷,很多人走去恭喜另一位台湾男艺人明道,皆因二人出道时外貌神似,至今仍有不少人将两人搞混。 多年过去&…

11.Node.js入门

一.什么是 Node.js Node.js 是一个独立的 JavaScript 运行环境,能独立执行 JS 代码,因为这个特点,它可以用来编写服务器后端的应用程序 Node.js 作用除了编写后端应用程序,也可以对前端代码进行压缩,转译,…

Linux最小系统安装无法查看IP地址

1,出现原因 服务器重启完成之后,我们可以通过linux的指令 ip addr 来查询Linux系统的IP地址,具体信息如下: 从图中我们可以看到,并没有获取到linux系统的IP地址,这是为什么呢?这是由于启动服务器时未加载网…

《探索虚拟与现实的边界:VR与AR谁更能引领未来?》

引言 在当今数字时代,虚拟现实(VR)和增强现实(AR)技术正以惊人的速度发展,并逐渐渗透到我们的日常生活中。它们正在重新定义人与技术、人与环境之间的关系,同时也为各行各业带来了全新的可能性。然而,究竟是VR还是AR更有潜力改变未来?本文将围绕这一问题展开深入探讨。…

Unity ShaderGraph实现地面积水效果

先看看效果 右侧参数,能够控制水高,波纹的速度等,但是这个效果需要修改高度图和凹凸图,毕竟有些模型并不是平面,对于具有斜面的模型就需要修改贴图。 ShaderGraph如下

基于pytorch的视觉变换器-Vision Transformer(ViT)的介绍与应用

近年来,计算机视觉领域因变换器模型的出现而发生了革命性变化。最初为自然语言处理任务设计的变换器,在捕捉视觉数据的空间依赖性方面也显示出了惊人的能力。视觉变换器(Vision Transformer,简称ViT)就是这种变革的一个…

第一代高通S7和S7 Pro音频平台:超旗舰性能,全面革新音频体验

以下文章来源于高通中国 如今,音频内容与形式日渐丰富,可满足人们放松心情、提升自我、获取资讯等需求。得益于手机、手表、耳机、车载音箱等智能设备的广泛应用,音频内容可以更快速触达用户。从《音频产品使用现状调研报告2023》中发现&…

幕译--本地字幕生成与翻译--Whisper客户端

幕译–本地字幕生成与翻译 本地离线的字幕生成与翻译,支持GPU加速。可免费试用,无次数限制 基于Whisper,希望做最好的Whisper客户端 功能介绍 本地离线,不用担心隐私问题支持GPU加速支持多种模型支持(中文、英语、日…

连接时序分类 Connectionist Temporal Classification (CTC)

CTC全称Connectionist temporal classification,是一种常用在语音识别、文本识别等领域的算法,用来解决输入和输出序列长度不一、无法对齐的问题。在CRNN中,它实际上就是模型对应的损失函数(CTC loss)。 一、背景 字母和语音的对齐(align)非…

【数据通信】数据通信基础知识---信号

1. 信息、数据、信号 信息是人们通过施加于数据的一些规定而赋予数据的特定含义(ISO定义)通信就是在信源和信宿之间传递信息。 信息和消息的关系:消息中包含信息,消息不等于信息。 消息所包含信息的多少,与在收到消息…

transformer--使用transformer构建语言模型

什么是语言模型? 以一个符合语言规律的序列为输入,模型将利用序列间关系等特征,输出一个在所有词汇上的概率分布.这样的模型称为语言模型. # 语言模型的训练语料一般来自于文章,对应的源文本和目标文本形如: src1"Ican do",tgt1…

Revit-二开之不同个立面/剖面上点的处理-(8)

由上图我们可以知道,在不同的立面坐标系是不同的。在很多业务逻辑处理的时候,需要对不同的立面进行处理,在此封装了一个方法,便于处理不同立面上点的计算。 viewSection 立面或者剖面 point 立面或者剖面上的点 horizontalOffset 点在屏幕中水平方向上的偏移量 verticalOf…

Android14之解决报错:No module named sepolgen(一百九十二)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…