机器学习 - 神经网络中的训练模型

接着上一篇机器学习-创建一个PyTorch classification model做进一步陈述。

训练模型的步骤:

  1. Forward pass: The model goes through all of the training data once, performing its forward() function calculations (model(x_train))
  2. Calculate the loss: 使用 loss = loss_fn(y_pred, y_train)
  3. Zero gradients: optimizer.zero_grad()
  4. Perform backpropagation on the loss: Computes the gradient of the loss with respect for every model parameter to be updated (each parameter with requires_grad=True). This is known as backpropagation, hence “backwards” (loss.backward())
  5. Step the optimizer (gradient descent): Update the parameters with requires_grad = True with respect to the loss gradients in order to improve them (optimizer.step())

# View the first 5 outputs of the forward pass on the test data
y_logits = model_0(X_test.to("cpu"))[:5]
print(y_logits)# Use sigmoid on model logits 
y_pred_probs = torch.sigmoid(y_logits)
print(y_pred_probs)# Find the predicted labels (round the prediction probabilities)
y_preds = torch.round(y_pred_probs)# In full 
y_pred_labels = torch.round(torch.sigmoid(model_0(X_test))[:5])# Check for equality 
print(torch.eq(y_preds.squeeze(), y_pred_labels.squeeze()))print(y_preds.squeeze())# 结果如下
tensor([[ 0.3798],[ 0.2257],[ 0.4383],[ 0.3647],[-0.1101]], grad_fn=<SliceBackward0>)
tensor([[0.5938],[0.5562],[0.6078],[0.5902],[0.4725]], grad_fn=<SigmoidBackward0>)
tensor([True, True, True, True, True])
tensor([1., 1., 1., 1., 0.], grad_fn=<SqueezeBackward0>)

The use of the sigmoid activation function is often only for binary classification logits.
The use of the sigmoid activation function is not required when passing the model’s raw outputs to the nn.BCEWithLogitsLoss (the “logits” in logits loss is because it works on the model’s raw logits output), this is because it has a sigmoid function built-in.


创建 training 和 testing loop

# 创建一个 loss function
loss_fn = nn.BCEWithLogitsLoss()def accuracy_fn(y_true, y_pred):correct = torch.eq(y_true, y_pred).sum().item()acc = (correct / len(y_pred)) * 100return acc# Build a train and test loop torch.manual_seed(42)# Set the number of epochs
epochs = 100# Put data into device
X_train, y_train = X_train.to("cpu"), y_train.to("cpu")
X_test, y_test = X_test.to("cpu"), y_test.to("cpu")# Build training and evaluation loop
for epoch in range(epochs):### Trainingmodel_0.train()# 1. Forward pass (model outputs raw logits)y_logits = model_0(X_train).squeeze()y_pred = torch.round(torch.sigmoid(y_logits)) # turn logits -> pred probs -> pred labls# 2. Calculate loss/accuracyloss = loss_fn(y_logits,y_train)acc = accuracy_fn(y_true = y_train,y_pred = y_pred)# 3. Optimizer zero grad optimizer.zero_grad()# 4. Loss backwardsloss.backward()# 5. Optimizer step optimizer.step() ### Testing model_0.eval()with torch.inference_mode():# 1. Forward passtest_logits = model_0(X_test).squeeze()test_pred = torch.round(torch.sigmoid(test_logits))# 2. Caculate loss/accuracytest_loss = loss_fn(test_logits,y_test)test_acc = accuracy_fn(y_true = y_test,y_pred = test_pred)if epoch % 10 == 0:print(f"Epoch: {epoch} | Loss: {loss:.5f}, Accuracy: {acc:.2f}% | Test loss: {test_loss:.5f}, Test acc: {test_acc:.2f}%")# 输出结果
Epoch: 0 | Loss: 0.70758, Accuracy: 50.25% | Test loss: 0.70294, Test acc: 56.00%
Epoch: 10 | Loss: 0.70192, Accuracy: 50.25% | Test loss: 0.69895, Test acc: 52.50%
Epoch: 20 | Loss: 0.69892, Accuracy: 50.00% | Test loss: 0.69713, Test acc: 50.00%
Epoch: 30 | Loss: 0.69716, Accuracy: 49.75% | Test loss: 0.69626, Test acc: 51.50%
Epoch: 40 | Loss: 0.69603, Accuracy: 49.75% | Test loss: 0.69582, Test acc: 51.50%
Epoch: 50 | Loss: 0.69527, Accuracy: 49.75% | Test loss: 0.69561, Test acc: 51.00%
Epoch: 60 | Loss: 0.69474, Accuracy: 49.25% | Test loss: 0.69551, Test acc: 52.50%
Epoch: 70 | Loss: 0.69435, Accuracy: 49.00% | Test loss: 0.69547, Test acc: 51.00%
Epoch: 80 | Loss: 0.69406, Accuracy: 49.75% | Test loss: 0.69545, Test acc: 51.00%
Epoch: 90 | Loss: 0.69384, Accuracy: 49.25% | Test loss: 0.69545, Test acc: 51.50%

看到这了,给个赞呗~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/773029.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

format(C++20)

1. std::format format_01.cpp // g format_01.cpp -stdc20 #include <iostream> #include <string> #include <format>void test_01() {// 使用字符串填充std::cout << std::format("Hello {}!\n", "World"); // Hello World!…

Open CASCADE学习|适配器

OpenCascade适配器在OpenCASCADE软件框架中起着至关重要的作用。它提供了一种方便的方式&#xff0c;用于在OpenCASCADE模型和其他软件之间进行数据交换和转换&#xff0c;从而使得OpenCASCADE更加灵活和实用。具体来说&#xff0c;适配器类在OpenCASCADE中实现了适配器模式&am…

[webpack-cli] Invalid options object 报错

[webpack-cli] Invalid options object. Dev Server has been initialized using an options object that does not match the API schema devServer: {contentBase: ./src, // 告诉服务器从哪里提供内容&#xff0c;默认情况下&#xff0c;它会使用当前工作目录作为根目录c…

深度学习pytorch——减少过拟合的几种方法(持续更新)

1、增加数据集 2、正则化(Regularization) 正则化&#xff1a;得到一个更加简单的模型的方法。 以一个多项式为例&#xff1a; 随着最高次的增加&#xff0c;会得到一个更加复杂模型&#xff0c;模型越复杂就会更好的拟合输入数据的模型&#xff08;图-1&#xff09;&#…

Oracle中实现根据条件对数据的增删改操作——Merge Into

一、需求描述 在我们进行项目开发的过程中&#xff0c;会遇到这样的场景&#xff0c;需要根据某个条件对数据进行增、删、改的操作&#xff1b;遇到这种情况我们有2种方法进行解决&#xff1a; 方法一&#xff1a;①查询指定条件&#xff1b;②根据查询出的指定条件结果在执行…

阿里云国际DDoS高防的定制场景策略

DDoS高防的定制场景策略允许您在特定的业务突增时段&#xff08;例如新业务上线、双11大促销等&#xff09;选择应用独立于通用防护策略的定制防护策略模板&#xff0c;保证适应业务需求的防护效果。您可以根据需要设置定制场景策略。 背景信息 定制场景策略提供基于业务场景…

【图论 | 数据结构】用链式前向星存图(保姆级教程,详细图解+完整代码)

一、概述 链式前向星是一种用于存储图的数据结构,特别适合于存储稀疏图,它可以有效地存储图的边和节点信息,以及边的权重。 它的主要思想是将每个节点的所有出边存储在一起,通过数组的方式连接(类似静态数组实现链表)。这种方法的优点是存储空间小,查询速度快,尤其适…

金融投贷通--功能测试分析与设计

金融投贷通功能测试分析与设计 测试点分析借款业务测试点投资业务测试点 测试用例借款业务测试用例投资业务测试用例 缺陷面试题 测试报告 测试点分析 借款业务测试点 投资业务测试点 测试用例 借款业务测试用例 借款成功&#xff08;主业务&#xff09;、借款成功&#xff…

FFMPEG AVFrame AVPacket内存管理相关API说明

AVFrame和AVPacket是ffmpeg中保存音视频数据的结构体&#xff0c;AVFrame保存未压缩的原始音视频数据&#xff0c;AVPacket保存编码后的音视频数据&#xff0c;AVFrame和AVPacket都是使用引用计数进行的内存管理。 一、AVFrame 内存分配&#xff1a; 视频&#xff1a; AVFra…

iOS——【CGD】

GCD 什么是GCD GCD指的是Grand Central Dispatch&#xff0c;它是苹果公司开发的一套多线程编程技术。GCD提供了一种简单而有效的方式来管理应用程序中的并发任务。它通过将任务提交到适当的队列&#xff08;串行队列或并发队列&#xff09;来管理并发执行的任务&#xff0c;…

WebAR开发简介

WebAR 开发使企业能够以独特且高度有趣的方式向客户和员工提供信息。 它提供增强现实 (AR) 内容&#xff0c;人们在智能手机上将其视为视觉叠加。 然而&#xff0c;WebAR 可在手机的普通网络浏览器上运行&#xff0c;无需下载任何应用程序。 WebAR 的多种用途包括帮助零售和在…

跟张良均老师学大数据人工智能——数据挖掘集训营开营

集训营特色&#xff1a; 知识点深入浅出&#xff0c;实现以学促用 以业务内容为主线&#xff0c;数据挖掘技能嵌入 多行业项目实战&#xff0c;全面提升职业素养 全程线上辅导&#xff0c;助力熟练掌握技能 惊喜优惠&#xff1a; 限时“六折”&#xff01; 师傅带练 方向…

docker centos7离线安装ElasticSearch单机版

目录 1.下载ES并解压2.新建elasticsearch用户3.修改ES配置文件4.启动ES服务5.设置开机启动 本文以 elasticsearch-7.8.1为例。 1.下载ES并解压 cd /root/install wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.8.1-linux-x86_64.tar.gz tar -z…

Qt|读写数据库管理图片资源工具

文章目录 创建项目设置UI布局控制数据库类实现界面实现类主函数 功能&#xff1a;用来管理数据库中图像资源 开发环境&#xff1a;windows10VS2017Qt5.14.2开发 创建项目 首先创建Qt Widgets Application next->next->finish就创建好了 设置UI布局 打开已经创建好的U…

笔记本如何调节亮度?笔记本亮度调节方法

对于经常长时间面对笔记本电脑的小伙伴们来说&#xff0c;屏幕亮度过暗或者过亮&#xff0c;都会对眼睛造成伤害。那么&#xff0c;我们如何调节笔记本亮度至适中呢?下面为大家介绍3种简单的调节屏幕亮度的方法&#xff0c;一起来看看吧! 笔记本亮度调节方法一&#xff1a; 1、…

Amuse:.NET application for stable diffusion

目录 Welcome to Amuse! Features Why Choose Amuse? Key Highlights Paint To Image Text To Image Image To Image Image Inpaint Model Manager Hardware Requirements Compute Requirements Memory Requirements System Requirements Realtime Requirements…

Electron 入门 - 创建应用的全流程 - npm 踩坑版

说明 本文记录一下&#xff0c;使用Electron创建一个简单的客户端应用的全流程。 在官方文档的基础上&#xff0c;针对依赖安装过程中出现的异常&#xff0c;进行了补充&#xff0c;确保可以正常的创建应用。 创建步骤 0、校验node版本 官方文档建议使用 最新版本的 NodeJS …

Codigger用户篇:安全、稳定、高效的运行环境(一)

在当今数字化时代&#xff0c;个人数据的安全与隐私保护显得尤为重要。为了满足用户对数据信息的安全需求&#xff0c;我们推出Codigger分布式操作系统&#xff0c;它提供了一个运行私有应用程序的平台&#xff0c;旨在为用户提供一个安全、稳定、高效的私人应用运行环境。Codi…

html 元素宽度自适应 占据剩余宽度

弹性盒实现 父元素设置display: flex; 需要自适应宽度的子元素设置flex: 1; <html lang"en"> <head><style>*{margin: 0;padding: 0;}.main{display: flex;}.box1,.box2{width: 100px;height: 200px;}.box1{background: rgb(134 187 233);}.box2…

【javaWeb 第五篇】后端-Http协议学习

HTTP协议 HTTP概述HTTP-请求数据格式HTTP响应格式HTTP-协议解析 HTTP概述 Hyper Text Transfer Protocol,超文本传输协议&#xff0c;规定了浏览器和服务器之间的数据传输规则 简述概念就是&#xff0c;浏览器需要向服务器发送请求&#xff0c;想要得到服务器中的数据&#xff…