实现第一个神经网络

PyTorch 包含创建和实现神经网络的特殊功能。在本节实验中,将创建一个简单的神经网络,其中一个隐藏层开发一个输出单元。

通过以下步骤使用 PyTorch 实现第一个神经网络。

第1步

首先,需要使用以下命令导入 PyTorch 库。

In [1]:

import torch

import torch.nn as nn

第2步

定义所有层和批量大小以开始执行神经网络,如下所示。

In [2]:

# Defining input size, hidden layer size, output size and batch size respectively

n_in, n_h, n_out, batch_size = 10, 5, 1, 10

第3步

由于神经网络包含输入数据的组合以获取相应的输出数据,将遵循以下相同的程序。

In [3]:

# Create dummy input and target tensors (data)

x = torch.randn(batch_size, n_in)

y = torch.tensor([[1.0], [0.0], [0.0],

[1.0], [1.0], [1.0], [0.0], [0.0], [1.0], [1.0]])

第4步

借助内置函数创建顺序模型。使用下面的代码行,创建一个顺序模型。

In [4]:

# Create a model

model = nn.Sequential(nn.Linear(n_in, n_h),

    nn.ReLU(),

    nn.Linear(n_h, n_out),

    nn.Sigmoid())

第5步

在梯度下降优化器的帮助下构造损失函数,如下所示。

In [5]:

#Construct the loss function

criterion = torch.nn.MSELoss()

# Construct the optimizer (Stochastic Gradient Descent in this case)

optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

第6步

使用给定的代码行使用迭代循环实现梯度下降模型。

In [6]:

# Gradient Descent

for epoch in range(50):

    # Forward pass: Compute predicted y by passing x to the model

    y_pred = model(x)

    # Compute and print loss

    loss = criterion(y_pred, y)

    print('epoch: ', epoch,' loss: ', loss.item())

    # Zero gradients, perform a backward pass, and update the weights.

    optimizer.zero_grad()

    # perform a backward pass (backpropagation)

    loss.backward()

    # Update the parameters

    optimizer.step()

epoch:  0  loss:  0.26868727803230286

epoch:  1  loss:  0.2684982419013977

epoch:  2  loss:  0.2683093845844269

epoch:  3  loss:  0.2681207060813904

epoch:  4  loss:  0.267932265996933

epoch:  5  loss:  0.2677440345287323

epoch:  6  loss:  0.26755601167678833

epoch:  7  loss:  0.2673681080341339

epoch:  8  loss:  0.267180472612381

epoch:  9  loss:  0.2669930160045624

epoch:  10  loss:  0.2668057382106781

epoch:  11  loss:  0.2666186988353729

epoch:  12  loss:  0.2664318084716797

epoch:  13  loss:  0.2662450969219208

epoch:  14  loss:  0.2660585939884186

epoch:  15  loss:  0.2658722996711731

epoch:  16  loss:  0.2656860947608948

epoch:  17  loss:  0.26550015807151794

epoch:  18  loss:  0.26531440019607544

epoch:  19  loss:  0.2651287913322449

epoch:  20  loss:  0.264943391084671

epoch:  21  loss:  0.2647581398487091

epoch:  22  loss:  0.26457303762435913

epoch:  23  loss:  0.26438820362091064

epoch:  24  loss:  0.2642034590244293

epoch:  25  loss:  0.2640189528465271

epoch:  26  loss:  0.26383453607559204

epoch:  27  loss:  0.2636503577232361

epoch:  28  loss:  0.2634662687778473

epoch:  29  loss:  0.26329419016838074

epoch:  30  loss:  0.2631300091743469

epoch:  31  loss:  0.2629658579826355

epoch:  32  loss:  0.2628018260002136

epoch:  33  loss:  0.2626378536224365

epoch:  34  loss:  0.262474000453949

epoch:  35  loss:  0.2623102366924286

epoch:  36  loss:  0.26214656233787537

epoch:  37  loss:  0.2619829773902893

epoch:  38  loss:  0.2618195414543152

epoch:  39  loss:  0.26165610551834106

epoch:  40  loss:  0.2614927887916565

epoch:  41  loss:  0.2613295316696167

epoch:  42  loss:  0.26116639375686646

epoch:  43  loss:  0.2610033452510834

epoch:  44  loss:  0.26084035634994507

epoch:  45  loss:  0.2606774866580963

epoch:  46  loss:  0.26051464676856995

epoch:  47  loss:  0.2603519558906555

epoch:  48  loss:  0.2601892650127411

epoch:  49  loss:  0.2600266933441162

5. 神经网络到功能块

训练深度学习算法涉及以下步骤:

  • 构建数据管道
  • 构建网络架构
  • 使用损失函数评估架构
  • 使用优化算法优化网络架构权重

训练特定的深度学习算法是将神经网络转换为功能块的确切要求,如下所示。

关于上图,任何深度学习算法都涉及获取输入数据,构建各自的架构,其中包括嵌入其中的一堆层。

观察上图,使用损失函数对神经网络权重的优化进行评估。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/40722.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决mysql数据库连接报错:Authentication plugin ‘caching_sha2_password‘ cannot be loaded

解决mysql数据库连接报错:Authentication plugin ‘caching_sha2_password’ cannot be loaded OperationalError: (2059, “Authentication plugin ‘caching_sha2_password’ cannot be loaded: /usr/lib/mysql/plugin/caching_sha2_password.so: cannot open sha…

启动Nuxt-hub-starter: Failed to initialize wrangler bindings proxy write EOF

重新安装 node.js 这样做可以确保下载到了适合的 Windows 框架、Chocolatey(一款Windows包管理工具)、Python 等资源。 这个错误与Node版本、pnpm/yarn 的版本无关! Node.js — Download Node.js (nodejs.org)

Selenium 监视数据收发

实际上,在我提供的示例中,确实使用了浏览器实例。webdriver.Chrome()这行代码正是创建了一个Chrome浏览器的WebDriver实例。Selenium Wire扩展了标准的Selenium WebDriver,允许你通过这个浏览器实例来监听网络请求。 当你运行类似这样的代码…

汉光联创HGLM2200N黑白激光多功能一体机加粉及常见问题处理

基本参数: 机器型号:HGLM2200N 产品名称:A4黑白激光多功能一体机 基础功能:打印、扫描、复印 打印速度:22页/分钟 纸张输入容量:150-249页 单面支持纸张尺寸:A4、A5、A6 产品尺寸&#x…

MySQL数据恢复(适用于误删后马上发现)

首先解释一下标题,之所以适用于误删后马上发现是因为太久了之后时间和当时操作的数据表可能会记不清楚,不是因为日志丢失 1.首先确保自己的数据库开启了binlog(我的是默认开启的我没有配置过) 根据这篇博客查看自己的配置和自己…

MS32008N低压 5V 多通道电机驱动器

MS32008N 是一款多通道电机驱动芯片,其中包 含两路步进电机驱动,一路直流电机驱动;每个步 进电机驱动通道的最大工作电流 1.0A ;支持两相四 线与四相五线步进电机。 芯片采用可选的 I 2 C 或 SPI 串行总线控制模式&…

安装 Mamba、Conv1d 时报错 “bare_metal_version“

报错详情1(pip install mamba/causal_conv1d): Preparing metadata (setup.py) ... errorerror: subprocess-exited-with-error python setup.py egg_info did not run successfully.│ exit code: 1╰─> [13 lines of output]/tmp/pip-…

鸿蒙开发HarmonyOS NEXT (三) 熟悉ArkTs

一、自定义组件 1、自定义组件 自定义组件,最基础的结构如下: Component struct Header {build() {} } 提取头部标题部分的代码,写成自定义组件。 1、新建ArkTs文件,把Header内容写好。 2、在需要用到的地方,导入…

centos7 安装mysql8.0.34

在 CentOS 7 上安装 MySQL 8.0.34 的步骤如下: 1. 卸载 MariaDB(如果已安装) CentOS 7 默认使用 MariaDB 作为数据库管理系统,因此在安装 MySQL 之前需要卸载 MariaDB。 rpm -qa | grep mariadb rpm -e --nodeps mariadb-libs-5…

Linux 摄像头编号固化

一、前言 在工业领域,一台设备会有很多个摄像头,可以使用命令:ll /dev/video* 进行查看; 在代码中,如果需要使用摄像头,那么都是需要具体到哪个摄像头编号的,例如 open("/dev/video4"…

[Day 24] 區塊鏈與人工智能的聯動應用:理論、技術與實踐

AI在自動駕駛中的應用 1. 簡介 自動駕駛技術是現代交通領域的一個革命性進展。通過結合人工智能(AI)、機器學習(ML)、深度學習(DL)和傳感器技術,自動駕駛汽車可以在無人干預的情況下安全駕駛。…

线段树求区间最值问题

引言 今天主要还是练了两道题,是有关线段树如何去求一个区间内的最值问题的,我们可以用线段树来解决。 对应一个无法改变顺序的数组,我们想要去求一个区间内的最值,假设有n个结点,m次询问,暴力的解决办法…

51、基于主成分分析和聚类分析的基因表达分析(matlab)

1、主成分分析和聚类分析简介 主成分分析(Principal Component Analysis, PCA)和聚类分析(Cluster Analysis)是两种常用的数据分析方法,用于降维和数据分类。 1)主成分分析(PCA) 主成分分析是一种常用的多元统计数据分析方法,旨在通过找到数据中最重要的变量(主成…

股票分析-20240628

今日关注: 20240626 六日涨幅最大: ------1--------300386--------- 飞天诚信 五日涨幅最大: ------1--------300386--------- 飞天诚信 四日涨幅最大: ------1--------300386--------- 飞天诚信 三日涨幅最大: ------1--------300386--------- 飞天诚信 二日涨幅最…

基于go-gmsm静态库编写的SM2椭圆曲线公钥密码算法PHP扩展 相较于openssl-ext-sm2编译更方便 增加了密文指定排序、识别ans1编码等功能

go-ext-sm2 介绍 基于go-gmsm静态库编写的SM2椭圆曲线公钥密码算法PHP扩展 相较于openssl-ext-sm2编译更方便 增加了密文指定排序、识别ans1编码等功能 特性:非对称加密 git地址:https://gitee.com/state-secret-series/go-ext-sm2.git 软件架构 zend 常规PHP扩展结构 …

vue-org-tree搜索到对应项高亮展开

效果图&#xff1a; 代码&#xff1a; <template><div class"AllTree"><el-form :inline"true" :model"formInline" class"demo-form-inline"><el-form-item><el-input v-model"formInline.user&quo…

c++ using namespace std的作用及注意事项

在C中&#xff0c;using namespace std; 是一个常见的指令&#xff0c;它用于简化标准库&#xff08;Standard Library&#xff09;中类和函数的引用。下面我将详细解释这个指令的作用和使用时的注意事项。 作用 在c/c标准库中&#xff0c;许多类和函数的定义都在std(standar…

【Git】远程仓库操作

创建远程仓库 在官网进行注册登录&#xff1a;Gitee或Github 进入后点击新建仓库&#xff0c;默认选项创建即可 **仓库创建完成后可以看到SSH的仓库地址&#xff1a;gitgitee.com:username/test.git**或gitgithub.com:Toukensan/test.git 配置SSH公钥 在本地通过命令行创建…

js学习--制作猜数字

猜数字制作 <!DOCTYPE html> <html><head><meta charset"utf-8"><title></title></head><body><script>function fun() {alert("1-100猜数字");let num Math.floor(Math.random() * 100) 1;for …

MDB-RS232 控制自动售货机MDB年龄验证设备

(以下是与台湾ICT的DCM5年龄验证设备测试数据) &#xff08;如果需要使用年龄验证设备&#xff0c;一定要记得购买MDB-RS232的PRO版本&#xff0c;也就是专业版&#xff09; 指令 HEX 代码 描述 RESET(复位) 68H 复位设备 SETUP(配置) 69H 读取年龄验证设备配置 Expa…