深度学习框架:Pytorch与Keras的区别与使用方法

  

☁️主页 Nowl

🔥专栏《机器学习实战》 《机器学习》

📑君子坐而论道,少年起而行之 

文章目录

Pytorch与Keras介绍

Pytorch

模型定义

模型编译

模型训练

输入格式

完整代码

Keras

模型定义

模型编译

模型训练

输入格式

完整代码

区别与使用场景

结语


Pytorch与Keras介绍

pytorch和keras都是一种深度学习框架,使我们能很便捷地搭建各种神经网络,但它们在使用上有一些区别,也各自有其特性,我们一起来看看吧

Pytorch

模型定义

我们以最简单的网络定义来学习pytorch的基本使用方法,我们接下来要定义一个神经网络,包括一个输入层,一个隐藏层,一个输出层,这些层都是线性的,给隐藏层添加一个激活函数Relu,给输出层添加一个Sigmoid函数

import torch
import torch.nn as nnclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(1, 32)self.relu = nn.ReLU()self.fc2 = nn.Linear(32, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.Sigmoid(x)return x

模型编译

我们在之前的机器学习文章中反复提到过,模型的训练是怎么进行的呢,要有一个损失函数与优化方法,我们接下来看看在pytorch中怎么定义这些

import torch.optim as optim# 实例化模型对象
model = SimpleNet()
# 定义损失函数
criterion = nn.MSELoss()# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

我们上面创建的神经网络是一个类,所以我们实例化一个对象model,然后定义损失函数为mse,优化器为随机梯度下降并设置学习率

模型训练

# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1))  # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1))  # 100个样本,每个样本有1个目标值# 训练模型
epochs = 100for epoch in range(epochs):# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target_data)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()

以上步骤是先创建了一些随机样本,作为模型的训练集,然后定义训练轮次为100次,然后前向传播数据集,计算损失,再优化,如此反复

输入格式

关于输入格式是很多人在实战中容易出现问题的,对于pytorch创建的神经网络,我们的输入内容是一个torch张量,怎么创建呢

data = torch.Tensor([[1], [2], [3]])

很简单对吧,上面这个例子创建了一个torch张量,有三组数据,每组数据有1个特征

我们可以把这个数据输入到训练好的模型中,得到输出结果,如果输出不是torch张量,代码就会报错

完整代码

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的神经网络模型
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(1, 32)self.relu = nn.ReLU()self.fc2 = nn.Linear(32, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.sigmoid(x)return xmodel = SimpleNet()
criterion = nn.MSELoss()# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1))  # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1))  # 100个样本,每个样本有1个目标值# 训练模型
epochs = 100for epoch in range(epochs):# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target_data)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()data = torch.Tensor([[1], [2], [3]])
prediction = model(data)print(prediction)

可以看到模型输出了三个预测值

注意,这个任务本身没有意义,因为我们的训练集是随机生成的,这里主要学习框架的使用方法

Keras

我们在这里把和上面相同的神经网络结构使用keras框架实现一遍

模型定义

from keras.models import Sequential
from keras.layers import Densemodel = Sequential([Dense(32, input_dim=1, activation='relu'),Dense(1, activation='sigmoid')
])

注意这里也是一层输入层,一层隐藏层,一层输出层,和pytorch一样,输入层是隐式的,我们的输入数据就是输入层,上述代码定义了一个隐藏层,输入维度是1,输出维度是32,还定义了一个输出层,输入维度是32,输出维度是1,和pytorch环节的模型结构是一样的 

模型编译

那么在Keras中模型又是怎么编译的呢

model.compile(loss='mse', optimizer='sgd')

非常简单,只需要这一行代码 ,设置损失函数为mse,优化器为随机梯度下降

模型训练

模型的训练也非常简单

# 训练模型
model.fit(input_data, target_data, epochs=100)

 因为我们已经编译好了损失函数和优化器,在fit里只需要输入数据,输出数据和训练轮次这些参数就可以训练了

输入格式

对于Keras模型的输入,我们要把它转化为numpy数组,不然会报错

data = np.array([[1], [2], [3]])

完整代码

from keras.models import Sequential
from keras.layers import Dense
import numpy as np# 定义模型
model = Sequential([Dense(32, input_dim=1, activation='relu'),Dense(1, activation='sigmoid')
])# 创建随机输入数据和目标数据
input_data = np.random.randn(100, 1)  # 100个样本,每个样本有10个特征
target_data = np.random.randn(100, 1)  # 100个样本,每个样本有5个目标值# 编译模型
model.compile(loss='mse', optimizer='sgd')
# 训练模型
model.fit(input_data, target_data, epochs=10)data = np.array([[1], [2], [3]])prediction = model(data)
print(prediction)

可以看到,同样的任务,Keras的代码量小很多

区别与使用场景

Keras代码量少,使用便捷,适用于快速实验和快速神经网络设计

而pytorch由于结构是由类定义的,可以更加灵活地组建神经网络层,这对于要求细节的任务更有利,同时,pytorch还采用动态计算图,使得模型的结构可以在运行时根据输入数据动态调整,但这个特点我还没有接触到,之后可能会详细讲解

结语

Keras和Pytorch都各有各的优点,请读者根据需求选择,同时有些深度学习教程偏向于使用某一种框架,最好都学习一点,以适应不同的场景

 

感谢阅读,觉得有用的话就订阅下本专栏吧 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/183150.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

渗透测试考核(靶机1)

信息收集 主机发现 nbtscan -r 172.16.17.0/24 发现在局域网内,有两台主机名字比较可疑,177和134,猜测其为目标主机,其余的应该是局域网内的其他用户,因为其主机名字比较显眼,有姓名的拼音和笔记本电脑的…

【Python】SqlmapAPI调用实现自动化SQL注入安全检测

文章目录 简单使用优化 应用案例:前期通过信息收集拿到大量的URL地址,这个时候可以配置sqlmapAP接口进行批量的SQL注入检测 (SRC挖掘) 查看sqlmapapi使用方法 python sqlmapapi.py -h启动sqlmapapi 的web服务: 任务流…

【论文笔记】SDCL: Self-Distillation Contrastive Learning for Chinese Spell Checking

文章目录 论文信息Abstract1. Introduction2. Methodology2.1 The Main Model2.2 Contrastive Loss2.3 Implementation Details(Hyperparameters) 3. Experiments代码实现个人总结值得借鉴的地方 论文信息 论文地址:https://arxiv.org/pdf/2210.17168.pdf Abstrac…

游戏APP接入哪些广告类型

当谈到游戏应用程序(APP)接入广告时,选择适合用户体验和盈利的广告类型至关重要。游戏开发者通常考虑以下几种广告类型: admaoyan猫眼聚合 横幅广告: 这些广告以横幅形式显示在游戏界面的顶部或底部。它们不会打断游戏…

idea doc 注释 插件及使用

开启rendered view https://blog.csdn.net/Leiyi_Ann/article/details/124145492 生成doc https://blog.csdn.net/qq_42581682/article/details/105018239 把注释加到类名旁边插件 https://blog.csdn.net/qq_30231473/article/details/128825306

解决QT信号在信号和槽连接前发出而导致槽函数未调用问题

1.使用QMetaObject::invokeMethod 当使用 QMetaObject::invokeMethod 将函数放入事件队列时,该函数会在适当时机被执行,然后被从事件队列中移除。 "适当时机" 指的是函数被安排在事件队列中,等待事件循环处理时机。这个时机取决于…

聚类分析例题 (多元统计分析期末复习)

例一 动态聚类,K-means法,随机选取凝聚点(题目直接给出) 已知5个样品的观测值为:1,4,5,7,11。试用K均值法分为两类(凝聚点分别取1,4与1,11) 解&…

找不到 sun.misc.BASE64Decoder ,sun.misc.BASE64Encoder 类

找不到 sun.misc.BASE64Decoder ,sun.misc.BASE64Encoder 类 1. 现象 idea 引用报错 找不到对应的包 import sun.misc.BASE64Decoder; import sun.misc.BASE64Encoder;2. 原因 因为sun.misc.BASE64Decoder和sun.misc.BASE64Encoder是Java的内部API,通…

oracle java.sql.SQLException: Invalid column type: 1111

1.遇到的问题 org.mybatis.spring.MyBatisSystemException: nested exception is org.apache.ibatis.type.TypeException: Could not set parameters for mapping: ParameterMapping{propertyuuid, modeIN, javaTypeclass java.lang.String, jdbcTypenull, numericScalenull, r…

VR虚拟教育展厅,为教学领域开启创新之路

线上虚拟展厅是一项全新的展示技术,可以为参展者带来不一样的观展体验。传统的实体展览存在着空间限制、时间限制以及高昂的成本,因此对于教育领域来说,线上虚拟教育展厅的出现,可以对传统教育方式带来改革,凭借强大的…

An illegal reflective access operation has occurred问题记录

报错 2023-11-30T01:08:18.7440800 [ERROR] [system.err] WARNING: An illegal reflective access operation has occurred 2023-11-30T01:08:18.7450800 [ERROR] [system.err] WARNING: Illegal reflective access by com.intellij.ui.JreHiDpiUtil to method sun.java2d.Sun…

ORA-00837: Specified value of MEMORY_TARGET greater than MEMORY_MAX_TARGET

有个11g rac环境,停电维护后,orcl1正常启动了,orcl2启动报错如下 SQL*Plus: Release 11.2.0.4.0 Production on Wed Nov 29 14:04:21 2023 Copyright (c) 1982, 2013, Oracle. All rights reserved. Connected to an idle instance. SYS…

1091 Acute Stroke (三维搜索)

题目可能看起来很难的样子,但是看懂了其实挺简单的。(众所周知,pat考察英文水平) 题目意思大概是:给你一个L*M*N的01长方体,求全为1的连通块的总体积大小。(连通块体积大于T才计算在内&#xf…

从0开始学习JavaScript--JavaScript 模板字符串的全面应用

JavaScript 模板字符串是 ES6 引入的一项强大特性,它提供了一种更优雅、更灵活的字符串拼接方式。在本文中,将深入探讨模板字符串的基本语法、高级用法以及在实际项目中的广泛应用,通过丰富的示例代码带你领略模板字符串的魅力。 模板字符串…

亚马逊云科技基于 Polygon 推出首款 Amazon Managed Blockchain Access,助 Web3 开发人员降低区块链节点运行成本

2023 年 11 月 26 日,亚马逊 (Amazon) 旗下 Amazon Web Services(Amazon)在其官方博客上宣布,Amazon Managed Blockchain (AMB) Access 已支持 Polygon Proof-of-Stake(POS) 网络,并将满足各种场景的需求,包…

vueRouter常用属性

vueRouter常用属性 basemodehashhistoryhistory模式下可能会遇到的问题及解决方案 routesprops配置(最佳方案) scrollBehavior base 基本的路由请求的路径 如果整个单页应用服务在 /app/ 下,然后 base 就应该设为 “/app/”,所有的请求都会在url之后加上/app/ new …

删除list中除最后一个之外所有的数据

1.你可以新建一个list List<Integer> listnew ArrayList<>();int i0;while (i<100){list.add(i);}List<Integer> subList list.subList(list.size()-1, list.size());System.out.println("原list大小--"list.size());System.out.println("…

el-table根据返回数据回显选择复选框

接口给你返回一个集合&#xff0c;然后如果这个集合里面的status2&#xff0c;就把这一行的复选框给选中 注意&#xff1a; 绑定的ref :row-key"getRowKeys" this.$refs.multiTableInst.toggleRowSelection(this.list[i], true); <el-table :data"list"…

群晖安装portainer

一、下载镜像 打开【Container Manager】 ,搜索portainer&#xff0c;双击【6053537/portainer-ce】下载汉化版本 二、创建映射文件夹 打开【File Station】&#xff0c;在docker目录下创建【portainer】文件夹 三、开启SSH 群晖 - 【控制面板】-【终端机和SNMP】 勾选【启动…

创建conan包-不同/相同repo中的配方和来源

创建conan包-不同/相同repo中的配方和来源 1 Recipe and Sources in a Different Repo1.1 source()的方法1.2 使用scm 属性 2 Recipe and Sources in the Same Repo2.1 Exporting the Sources with the Recipe: exports_sources2.2 Capturing the Remote and Commit: scm 本文是…