深度学习框架:Pytorch与Keras的区别与使用方法

  

☁️主页 Nowl

🔥专栏《机器学习实战》 《机器学习》

📑君子坐而论道,少年起而行之 

文章目录

Pytorch与Keras介绍

Pytorch

模型定义

模型编译

模型训练

输入格式

完整代码

Keras

模型定义

模型编译

模型训练

输入格式

完整代码

区别与使用场景

结语


Pytorch与Keras介绍

pytorch和keras都是一种深度学习框架,使我们能很便捷地搭建各种神经网络,但它们在使用上有一些区别,也各自有其特性,我们一起来看看吧

Pytorch

模型定义

我们以最简单的网络定义来学习pytorch的基本使用方法,我们接下来要定义一个神经网络,包括一个输入层,一个隐藏层,一个输出层,这些层都是线性的,给隐藏层添加一个激活函数Relu,给输出层添加一个Sigmoid函数

import torch
import torch.nn as nnclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(1, 32)self.relu = nn.ReLU()self.fc2 = nn.Linear(32, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.Sigmoid(x)return x

模型编译

我们在之前的机器学习文章中反复提到过,模型的训练是怎么进行的呢,要有一个损失函数与优化方法,我们接下来看看在pytorch中怎么定义这些

import torch.optim as optim# 实例化模型对象
model = SimpleNet()
# 定义损失函数
criterion = nn.MSELoss()# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

我们上面创建的神经网络是一个类,所以我们实例化一个对象model,然后定义损失函数为mse,优化器为随机梯度下降并设置学习率

模型训练

# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1))  # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1))  # 100个样本,每个样本有1个目标值# 训练模型
epochs = 100for epoch in range(epochs):# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target_data)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()

以上步骤是先创建了一些随机样本,作为模型的训练集,然后定义训练轮次为100次,然后前向传播数据集,计算损失,再优化,如此反复

输入格式

关于输入格式是很多人在实战中容易出现问题的,对于pytorch创建的神经网络,我们的输入内容是一个torch张量,怎么创建呢

data = torch.Tensor([[1], [2], [3]])

很简单对吧,上面这个例子创建了一个torch张量,有三组数据,每组数据有1个特征

我们可以把这个数据输入到训练好的模型中,得到输出结果,如果输出不是torch张量,代码就会报错

完整代码

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的神经网络模型
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(1, 32)self.relu = nn.ReLU()self.fc2 = nn.Linear(32, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.sigmoid(x)return xmodel = SimpleNet()
criterion = nn.MSELoss()# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1))  # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1))  # 100个样本,每个样本有1个目标值# 训练模型
epochs = 100for epoch in range(epochs):# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target_data)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()data = torch.Tensor([[1], [2], [3]])
prediction = model(data)print(prediction)

可以看到模型输出了三个预测值

注意,这个任务本身没有意义,因为我们的训练集是随机生成的,这里主要学习框架的使用方法

Keras

我们在这里把和上面相同的神经网络结构使用keras框架实现一遍

模型定义

from keras.models import Sequential
from keras.layers import Densemodel = Sequential([Dense(32, input_dim=1, activation='relu'),Dense(1, activation='sigmoid')
])

注意这里也是一层输入层,一层隐藏层,一层输出层,和pytorch一样,输入层是隐式的,我们的输入数据就是输入层,上述代码定义了一个隐藏层,输入维度是1,输出维度是32,还定义了一个输出层,输入维度是32,输出维度是1,和pytorch环节的模型结构是一样的 

模型编译

那么在Keras中模型又是怎么编译的呢

model.compile(loss='mse', optimizer='sgd')

非常简单,只需要这一行代码 ,设置损失函数为mse,优化器为随机梯度下降

模型训练

模型的训练也非常简单

# 训练模型
model.fit(input_data, target_data, epochs=100)

 因为我们已经编译好了损失函数和优化器,在fit里只需要输入数据,输出数据和训练轮次这些参数就可以训练了

输入格式

对于Keras模型的输入,我们要把它转化为numpy数组,不然会报错

data = np.array([[1], [2], [3]])

完整代码

from keras.models import Sequential
from keras.layers import Dense
import numpy as np# 定义模型
model = Sequential([Dense(32, input_dim=1, activation='relu'),Dense(1, activation='sigmoid')
])# 创建随机输入数据和目标数据
input_data = np.random.randn(100, 1)  # 100个样本,每个样本有10个特征
target_data = np.random.randn(100, 1)  # 100个样本,每个样本有5个目标值# 编译模型
model.compile(loss='mse', optimizer='sgd')
# 训练模型
model.fit(input_data, target_data, epochs=10)data = np.array([[1], [2], [3]])prediction = model(data)
print(prediction)

可以看到,同样的任务,Keras的代码量小很多

区别与使用场景

Keras代码量少,使用便捷,适用于快速实验和快速神经网络设计

而pytorch由于结构是由类定义的,可以更加灵活地组建神经网络层,这对于要求细节的任务更有利,同时,pytorch还采用动态计算图,使得模型的结构可以在运行时根据输入数据动态调整,但这个特点我还没有接触到,之后可能会详细讲解

结语

Keras和Pytorch都各有各的优点,请读者根据需求选择,同时有些深度学习教程偏向于使用某一种框架,最好都学习一点,以适应不同的场景

 

感谢阅读,觉得有用的话就订阅下本专栏吧 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/183150.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

渗透测试考核(靶机1)

信息收集 主机发现 nbtscan -r 172.16.17.0/24 发现在局域网内,有两台主机名字比较可疑,177和134,猜测其为目标主机,其余的应该是局域网内的其他用户,因为其主机名字比较显眼,有姓名的拼音和笔记本电脑的…

【Python】SqlmapAPI调用实现自动化SQL注入安全检测

文章目录 简单使用优化 应用案例:前期通过信息收集拿到大量的URL地址,这个时候可以配置sqlmapAP接口进行批量的SQL注入检测 (SRC挖掘) 查看sqlmapapi使用方法 python sqlmapapi.py -h启动sqlmapapi 的web服务: 任务流…

【论文笔记】SDCL: Self-Distillation Contrastive Learning for Chinese Spell Checking

文章目录 论文信息Abstract1. Introduction2. Methodology2.1 The Main Model2.2 Contrastive Loss2.3 Implementation Details(Hyperparameters) 3. Experiments代码实现个人总结值得借鉴的地方 论文信息 论文地址:https://arxiv.org/pdf/2210.17168.pdf Abstrac…

idea doc 注释 插件及使用

开启rendered view https://blog.csdn.net/Leiyi_Ann/article/details/124145492 生成doc https://blog.csdn.net/qq_42581682/article/details/105018239 把注释加到类名旁边插件 https://blog.csdn.net/qq_30231473/article/details/128825306

聚类分析例题 (多元统计分析期末复习)

例一 动态聚类,K-means法,随机选取凝聚点(题目直接给出) 已知5个样品的观测值为:1,4,5,7,11。试用K均值法分为两类(凝聚点分别取1,4与1,11) 解&…

找不到 sun.misc.BASE64Decoder ,sun.misc.BASE64Encoder 类

找不到 sun.misc.BASE64Decoder ,sun.misc.BASE64Encoder 类 1. 现象 idea 引用报错 找不到对应的包 import sun.misc.BASE64Decoder; import sun.misc.BASE64Encoder;2. 原因 因为sun.misc.BASE64Decoder和sun.misc.BASE64Encoder是Java的内部API,通…

VR虚拟教育展厅,为教学领域开启创新之路

线上虚拟展厅是一项全新的展示技术,可以为参展者带来不一样的观展体验。传统的实体展览存在着空间限制、时间限制以及高昂的成本,因此对于教育领域来说,线上虚拟教育展厅的出现,可以对传统教育方式带来改革,凭借强大的…

ORA-00837: Specified value of MEMORY_TARGET greater than MEMORY_MAX_TARGET

有个11g rac环境,停电维护后,orcl1正常启动了,orcl2启动报错如下 SQL*Plus: Release 11.2.0.4.0 Production on Wed Nov 29 14:04:21 2023 Copyright (c) 1982, 2013, Oracle. All rights reserved. Connected to an idle instance. SYS…

从0开始学习JavaScript--JavaScript 模板字符串的全面应用

JavaScript 模板字符串是 ES6 引入的一项强大特性,它提供了一种更优雅、更灵活的字符串拼接方式。在本文中,将深入探讨模板字符串的基本语法、高级用法以及在实际项目中的广泛应用,通过丰富的示例代码带你领略模板字符串的魅力。 模板字符串…

亚马逊云科技基于 Polygon 推出首款 Amazon Managed Blockchain Access,助 Web3 开发人员降低区块链节点运行成本

2023 年 11 月 26 日,亚马逊 (Amazon) 旗下 Amazon Web Services(Amazon)在其官方博客上宣布,Amazon Managed Blockchain (AMB) Access 已支持 Polygon Proof-of-Stake(POS) 网络,并将满足各种场景的需求,包…

删除list中除最后一个之外所有的数据

1.你可以新建一个list List<Integer> listnew ArrayList<>();int i0;while (i<100){list.add(i);}List<Integer> subList list.subList(list.size()-1, list.size());System.out.println("原list大小--"list.size());System.out.println("…

群晖安装portainer

一、下载镜像 打开【Container Manager】 ,搜索portainer&#xff0c;双击【6053537/portainer-ce】下载汉化版本 二、创建映射文件夹 打开【File Station】&#xff0c;在docker目录下创建【portainer】文件夹 三、开启SSH 群晖 - 【控制面板】-【终端机和SNMP】 勾选【启动…

第二十章 多线程总结

继承Thread 类 Thread 类时 java.lang 包中的一个类&#xff0c;从类中实例化的对象代表线程&#xff0c;程序员启动一个新线程需要建立 Thread 实例。 Thread 对象需要一个任务来执行&#xff0c;任务是指线程在启动时执行的工作&#xff0c;start() 方法启动线程&…

五、初识FreeRTOS之FreeRTOS的任务创建和删除

本节主要学习以下内容&#xff1a; 1&#xff0c;任务创建和删除的API函数&#xff08;熟悉&#xff09; 2&#xff0c;任务创建和删除&#xff08;动态方法&#xff09;&#xff08;掌握&#xff09; 3&#xff0c;任务创建和删除&#xff08;静态方法&#xff09;&#xf…

mongodb基本操作命令

mongodb快速搭建及使用 1.mongodb安装1.1 docker安装启动mongodb 2.mongo shell常用命令2.1 插入文档2.1.1 插入单个文档2.1.2 插入多个文档2.1.3 用脚本批量插入 2.2 查询文档 前言&#xff1a;本篇默认你是对nongodb的基础概念有了了解&#xff0c;操作是非常基础的。但是与关…

微信小程序——给按钮添加点击音效

今天来讲解一下如何给微信小程序的按钮添加点击音效 注意&#xff1a;这里的按钮不一定只是 <button>&#xff0c;也可以是一张图片&#xff0c;其实只是添加一个监听点击事件的函数而已 首先来看下按钮的定义 <button bind:tap"onInput" >点我有音效&…

C++面向对象复习笔记暨备忘录

C指针 指针作为形参 交换两个实际参数的值 #include <iostream> #include<cassert> using namespace std;int swap(int *x, int* y) {int a;a *x;*x *y;*y a;return 0; } int main() {int a 1;int b 2;swap(&a, &b);cout << a << &quo…

【开源视频联动物联网平台】为什么需要物联网网关?

在一些物联网项目中&#xff0c;物联网网关这一产品经常被涉及。那么&#xff0c;物联网网关究竟有何作用&#xff1f;具备哪些功能&#xff1f;同时&#xff0c;我们也发现有些物联网设备并不需要网关。那么&#xff0c;究竟在何时需要物联网网关呢&#xff1f; 物联网的架构…

LaTeX插入裁剪后的pdf图像

画图 VSCode Draw.io Integration插件 有数学公式的打开下面的选项&#xff1a; 导出 File -> Export -> .svg导出成svg格式的文件。然后用浏览器打开svg文件后CtrlP选择另存为PDF&#xff0c;将图片存成pdf格式。 裁剪 只要安装了TeXLive&#xff0c;就只需要在图…

LVS-NAT实验

实验前准备&#xff1a; LVS负载调度器&#xff1a;ens33&#xff1a;192.168.20.11 ens34&#xff1a;192.168.188.3 Web1节点服务器1&#xff1a;192.168.20.12 Web2节点服务器2&#xff1a;192.168.20.13 NFS服务器&#xff1a;192.168.20.14 客户端&#xff08;win11…