机器学习-pytorch1(持续更新)

上一节我们学习了机器学习的线性模型和非线性模型的机器学习基础知识,这一节主要将公式变为代码

代码编写网站:https://colab.research.google.com/drive

学习课程链接:ML 2022 Spring

1、Load Data(读取数据)

这需要用到pytorch里面的两个函数Dataset和Dataloader

torch.utils.data.Dataset
torch.utils.data.DataLoader

Dataset:是用来存储数据样本和期望值

dataset = MyDataset(file)

Dataloader:批量对数据进行分组,启用多处理

dataloader = DataLoader(dataset, batch_size, shuffle=True)

// 其中对于shuffle的取值,True表示训练,false表示测试

关于Dataset和Dataloader的关系如下:

 

ML 2022 Spring为图片来源

我们读取完数据,是不是想知道我们的数据长什么样子呢?(我们称数据为Tensors)

首先,它可能是一个一维数据,比如一个音频、一个温度

其次,还可能是一个二维数据,比如一张二值图像

最后,还可能是一个三维数据,比如一个彩色的图像

又有问题了,我们怎么通过编程得到我们图像的大小?

可以使用pytorch里面的shape()函数

我们怎么通过编程创造我们的数据呢?

eg:
x = torch.tensor([[1,-1],[-1,1]])
x = torch.from_numpy(np.array([[1,-1],[-1,1]]))
全0或全1数据
x = torch.zeros([2,2])    # 2*2的全0数据
x = torch.ones([1,2,5])    # 1*2*5的全1数据

 其次,还支持矩阵的运算

Addition:z = x + y
Subtraction:z = x - y
Power:y = x.pow(2)
Summation:y = x.sum()
Mean:y = x.mean()
维度转换:x = x.transpose(dim0,dim1)
消除维度:x = x.squeeze(dim)
增加维度:x = x.unsqueeze(dim)
组合:w = torch.cat([x,y,z],dim=1)

拥有不同的数据类型:

使用.to()可以切换到不同的设备:

CPU: x = x.to('cpu')
GPU: x = x.to('cuda')

 这里就又涉及到如何检查你的GPU了?可以使用以下语句检查你的计算机是否有GPU:

torch.cuda.is_available()

如何计算梯度?

 // 注意矩阵一定要使用小数点

2、Define Neural Network(训练和测试神经网络)

torch.nn.Module

线性: 

 非线性:

Sigmoid Activation:nn.Sigmoid()

ReLU Activation:nn.ReLU()

下面我根据所学的知识构建我自己的神经网络:

3、Loss Function(损失函数) 

x = torch.nn.MSELoss    # 对于回归任务
x = torch.nn.CrossEntropyLoss etc.    # 对于分类任务
loss = x(model_output,expected_value)

4、Optimization Algorithm(优化)

torch.optim

这是基于梯度的优化算法,不断调整参数,减少误差

比如:随机梯度下降(SGD)

torch.optim.SGD(model.parameters(), lr, momentum = 0)

* 调用optimizer.zero_grad()重置模型参数的梯度。

*调用loss.backward()反向传播预测loss的梯度。

*调用optimizer.step()调整模型参数。 

5、Entire Procedure(整个程序)

import torch.utils.data as data
dataset = data.Dataset(file)              # 读取数据
tr_set = DataLoader(dataset,batch_size,shuffle=True)  # 对数据集进行分组
model = MyModel().to(device)              # 建立我的模型并且选择我的设备(cpu or gpu)
criterion = nn.MSELoss()                # 建立损失函数
optimizer = torch.optim.SGD(model.parameters(),0.1)   # 建立优化
# 训练
for epoch in range(n_epochs):             # 迭代数据model.train()                    # 训练模型for x, y in tr_set:               # 迭代数据集optimizer.zero_grad()              # 设置梯度为0x, y = x.to(device),y.to(device)       # 将数据移动到设备pred = model(x)                # 计算输出loss = criterion(pred,y)            # 计算损失函数loss.backward()                 # 计算反向梯度optimizer.model()                # 优化模型
# 验证
model.eval()                      # 将模型设置为评估模式
total_loss = 0          
for x,y in dv_set:                  # 对数据集进行迭代x,y = x.to(device),y.to(device)          # 将数据移动到涉笔with torch.no_grad():                # 不可迭代的计算pred = model(x)                # 计算输出loss = criterion(pred,y)           # 计算损失函数total_loss += loss.cpu().item()*len(x)      # 累加损失误差avg_loss = total_loss / len(dv_set.dataset)   # 计算平均损失
# 测试
model.eval()                       # 将模型设置为评估模式
preds = []
for x in dv_set:                   # 对数据集进行迭代x = x.to(device)                  # 将数据移动到涉笔with torch.no_grad():                # 不可迭代的计算pred = model(x)                # 计算输出preds.append(pred.cpu())             # 收集预测

// model.eval()  :更改模型的行为

//  with torch.no_grad() :防止对验证/测试数据进行意外训练

当我们训练完模型,也完成了测试,为了不使模型丢失,我们需要保存模型,pytorch也为我们提供了保存模型的方法。

保存模型:torch.save(model.state_dict(),path)

下次我们使用已经训练完成的模型,或者想继续训练,我们需要读取模型。

读取模型:ckpt = torch.load(path)     model.load_state_dict(ckpt)

// 这只是我根据所听的课自己写的笔记,如果有什么错误欢迎指正!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/733987.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

比较字符串

给定程序函数fun的功能是&#xff1a;比较两个字符串&#xff0c;将长的那个字符串的首地址作为函数值返回。 #include <stdio.h> /**********found**********/ char* fun(char *s, char *t) {int s10, t10;char *ss, *tt;sss;ttt; while(*ss){s1;/**********found****…

学习Java的第六天

一、变量 1、变量的定义 在程序执行的过程中变量的值会发生变化&#xff0c;直白来说就是用来存储可变化的数据。从本质上讲&#xff0c;变量其实指的是内存中的一小块存储空间&#xff0c;空间位置是确定的&#xff0c;但是里面放置的值不确定。比如屋子里有多个鞋柜&#x…

启发式算法:禁忌搜索 Tabu Search

文章目录 Tabu searchTabu search算法算例-旅行商TSP问题Tabu search 从一个初始可行解出发,选择一系列的特定搜索方向(移动)作为试探,选择实现让特定的目标函数值变化最多的移动。 为了避免陷入局部最优解,TS搜索中采用了一种灵活的“记忆”技术,对已经进行的优化过程进…

2024年目标检测研究进展

YOLOv9 图片来源网络 YOLO相关的研究&#xff1a;https://blog.csdn.net/yunxinan/article/details/103431338

网络安全: Kali Linux 进行 SSH 渗透与防御

目录 一、实验 1.环境 2.nmap扫描目标主机 3.Kali Linux 进行 SSH 渗透 3.Kali Linux 进行 SSH 防御 二、问题 1.SSH有哪些安全配置 一、实验 1.环境 &#xff08;1&#xff09;主机 表1 主机 系统版本IP备注Kali Linux2022.4 192.168.204.154&#xff08;动态&…

【软考】单元测试

目录 1. 概念2. 测试内容2.1 说明2.2 模块接口2.3 局部数据结构2.4 重要的执行路径 3. 测试过程2.1 说明2.2 单元测试环境图2.3 驱动模块2.4 桩模块 4. 模块接口测试与局部数据结构测试的区别 1. 概念 1.单元测试也称为模块测试&#xff0c;在模块编写完成且无编译错误后就可以…

使用React Context和Hooks在React Native中共享蓝牙数据

使用React Context和Hooks在React Native中共享蓝牙数据 **背景****实现步骤****步骤 1: 创建并导出bleContext****步骤 2: 在App.tsx中使用bleContext.Provider提供数据****步骤 3: 在父组件和子组件中访问共享的数据** **结论** 在开发React Native应用时&#xff0c;跨组件共…

16.Git从入门到进阶

一.Git 初识 1. 概念&#xff1a; 一个免费开源&#xff0c;分布式的代码版本控制系统&#xff0c;帮助开发团队维护代码 2. 作用&#xff1a; 记录代码内容&#xff0c;切换代码版本&#xff0c;多人开发时高效合并代码内容 3. 如何学&#xff1a; 个人本机使用&#xf…

SQL中的不加锁查询 with(nolock)

WITH(NOLOCK) 是一种 SQL Server 中的表提示&#xff08;table hint&#xff09;&#xff0c;可以用来告诉数据库引擎在查询数据时不要加锁&#xff0c;以避免因为锁等待导致查询性能下降。 当多个事务同时访问同一张表时&#xff0c;数据库引擎会对表进行锁定&#xff0c;以确…

数据库中 SQL Hint 是什么?

前言 最近在调研业界其他数据库中 SQL Hint 功能的设计和实现&#xff0c;整体上对 Oracle、Mysql、Postgresql、 Apache Calcite 中的 SQL Hint 的设计和功能都进行了解&#xff0c;这里整理一篇文章来对其进行梳理&#xff0c;一是帮助自己未来回顾&#xff0c;加深自己的思…

Python之Web开发中级教程----搭建Git环境三

Python之Web开发中级教程----搭建Git环境三 多人分布式使用仓库操作实例 场景&#xff1a;开发者A&#xff0c;开发者B在同一个项目协同开发&#xff0c;修改同一个代码文件。开发者A在Win10下&#xff0c;开发者B在Ubuntu下。 1、开发者A修改提交代码 从GitHub: Let’s bu…

44岁「台偶一哥」成现实版「王子变青蛙」,育一子一女成人生赢家

电影《周处除三害》近日热度极高&#xff0c;男主角阮经天被大赞演技出色&#xff0c;最让人意想不到&#xff0c;因为该片在内地票房报捷&#xff0c;很多人走去恭喜另一位台湾男艺人明道&#xff0c;皆因二人出道时外貌神似&#xff0c;至今仍有不少人将两人搞混。 多年过去&…

11.Node.js入门

一.什么是 Node.js Node.js 是一个独立的 JavaScript 运行环境&#xff0c;能独立执行 JS 代码&#xff0c;因为这个特点&#xff0c;它可以用来编写服务器后端的应用程序 Node.js 作用除了编写后端应用程序&#xff0c;也可以对前端代码进行压缩&#xff0c;转译&#xff0c;…

Linux最小系统安装无法查看IP地址

1&#xff0c;出现原因 服务器重启完成之后&#xff0c;我们可以通过linux的指令 ip addr 来查询Linux系统的IP地址&#xff0c;具体信息如下: 从图中我们可以看到&#xff0c;并没有获取到linux系统的IP地址&#xff0c;这是为什么呢&#xff1f;这是由于启动服务器时未加载网…

《探索虚拟与现实的边界:VR与AR谁更能引领未来?》

引言 在当今数字时代,虚拟现实(VR)和增强现实(AR)技术正以惊人的速度发展,并逐渐渗透到我们的日常生活中。它们正在重新定义人与技术、人与环境之间的关系,同时也为各行各业带来了全新的可能性。然而,究竟是VR还是AR更有潜力改变未来?本文将围绕这一问题展开深入探讨。…

【DevOps基础篇之k8s】如何应用Kubernetes中的Role Based Access Control(RBAC)

【DevOps基础篇之k8s】如何应用Kubernetes中的Role Based Access Control(RBAC) 目录 【DevOps基础篇之k8s】如何应用Kubernetes中的Role Based Access Control(RBAC)背景Kubernetes身份验证和授权基于角色的访问控制(RBAC)用户账户 vs. 服务账户角色 vs. 集群角色RoleBi…

ES6 中的 class 是什么?和函式构造函式差别是什么?

ES6 class JavaScript 在 ECMAScript 6 (ES6) 之前并没有 class 的语法,而是会透过函式构造函式建立物件,并再通过 new 关键字实例。 在 ES6 时引入了 class 的概念,JavaScript class 使用的语法 class 似于其他 OOP 程式语言中的 class,但 JavaScript 的 class 是一种语…

LeetCode 2710.移除字符串中的尾随零

给你一个用字符串表示的正整数 num &#xff0c;请你以字符串形式返回不含尾随零的整数 num 。 示例 1&#xff1a; 输入&#xff1a;num “51230100” 输出&#xff1a;“512301” 解释&#xff1a;整数 “51230100” 有 2 个尾随零&#xff0c;移除并返回整数 “512301” …

AI 改变生活

2024 年 AI 辅助研发趋势随着人工智能技术的持续发展与突破&#xff0c;2024年AI辅助研发正成为科技界和工业界瞩目的焦点。从医药研发到汽车设计&#xff0c;从软件开发到材料科学&#xff0c;AI正逐渐渗透到研发的各个环节&#xff0c;变革着传统的研发模式。在这一背景下&am…

Oracle数据库system表空间

导读 Oracle数据库中的System表空间是一个特殊的表空间&#xff0c;它存储了数据库的核心系统对象和元数据信息。System表空间对数据库的正常运行至关重要&#xff0c;因为它包含了诸如数据字典、系统表、视图等重要的数据库对象。在本文中&#xff0c;我们将深入探讨Oracle Sy…