神经网络入门—自定义网络

网络模型

定义一个两层网络

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F# 定义神经网络模型
class Net(nn.Module):def __init__(self, init_x=0.0):super().__init__()self.fc1 = nn.Linear(1, 10)self.fc2 = nn.Linear(10, 1)def forward(self, x):x = self.fc1(x)x = F.relu(x)x = self.fc2(x)return x# 初始化模型
model = Net()# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 生成一些示例数据
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)
y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32)# 训练模型
num_epochs = 1000
for epoch in range(num_epochs):# 清零梯度optimizer.zero_grad()# 前向计算outputs = model(x_train)loss = criterion(outputs, y_train)# 反向传播loss.backward()# 更新参数optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 保存模型
torch.save(model.state_dict(), 'model.pth')# 加载模型
loaded_model = Net()
loaded_model.load_state_dict(torch.load('model.pth'))
loaded_model.eval()  # 将模型设置为评估模式# 输入新数据进行预测
new_input = torch.tensor([[5.0]], dtype=torch.float32)
with torch.no_grad():prediction = loaded_model(new_input)print(f"输入 {new_input.item()} 的预测结果: {prediction.item()}")

运行结果

训练好的参数值:
参数名: fc1.weight, 参数值: tensor([[ 0.5051],
        [ 0.2675],
        [ 0.4080],
        [ 0.3069],
        [ 0.9132],
        [ 0.2250],
        [-0.2428],
        [ 0.4821],
        [ 0.0998],
        [ 0.6737]])
参数名: fc1.bias, 参数值: tensor([ 0.5201, -0.0252,  0.0504,  0.6593, -0.4250,  0.6001,  0.9645, -0.2310,
        -0.2038,  0.2116])
参数名: fc2.weight, 参数值: tensor([[ 0.5492,  0.2550,  0.3046,  0.3183,  0.8147,  0.3062, -0.4165,  0.2969,
          0.0482,  0.5535]])
参数名: fc2.bias, 参数值: tensor([0.0147])

  • fc1 层

    • fc1.weight:这是输入层到隐藏层的权重矩阵,其形状为 (10, 1),意味着输入层有 1 个神经元,隐藏层有 10 个神经元。矩阵中的每个元素代表从输入神经元到对应隐藏层神经元的连接权重。
    • fc1.bias:这是隐藏层每个神经元的偏置项,形状为 (10,),也就是每个隐藏层神经元都有一个对应的偏置值。
  • fc2 层

    • fc2.weight:这是隐藏层到输出层的权重矩阵,形状为 (1, 10),表明隐藏层有 10 个神经元,输出层有 1 个神经元。矩阵中的每个元素代表从隐藏层神经元到输出层神经元的连接权重。
    • fc2.bias:这是输出层神经元的偏置项,形状为 (1,),即输出层只有一个神经元,所以只有一个偏置值。

不同的优化器

神经网络入门—计算函数值-CSDN博客

激活函数解析

激活函数的作用

激活函数赋予神经网络非线性映射能力,使其能够更好地处理复杂的现实世界数据2。常见的激活函数包括ReLU、PReLU等。激活函数通常用于卷积层和全连接层,以增加模型的表达能力。

常见的激活函数

Sigmoid 函数

  • 公式σ(x)= ​\frac{1}{1+e^{-x}}
  • 特点:输出范围在 (0, 1) 之间,能够把输入映射为概率值,常用于二分类问题。不过它存在梯度消失问题,当输入值非常大或者非常小时,梯度会趋近于 0。
import torch
import torch.nn.functional as Fx = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
sigmoid_output = torch.sigmoid(x)
print("Sigmoid 输出:", sigmoid_output)

Tanh 函数

  • 公式:\(\tanh(x)=\frac{e^{x}-e^{-x}}{e^{x}+e^{-x}}\)
  • 特点:输出范围在 (-1, 1) 之间,零中心化,相较于 Sigmoid 函数,梯度消失问题有所缓解,但仍然存在。
import torch
import torch.nn.functional as Fx = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
tanh_output = torch.tanh(x)
print("Tanh 输出:", tanh_output)

ReLU 函数

  • 公式:\(ReLU(x)=\max(0, x)\)
  • 特点:计算简单,能够有效缓解梯度消失问题,在深度学习中被广泛使用。不过它存在死亡 ReLU 问题,即某些神经元可能永远不会被激活。
import torch
import torch.nn.functional as Fx = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
relu_output = F.relu(x)
print("ReLU 输出:", relu_output)

 Leaky ReLU 函数

  • 公式:\(LeakyReLU(x)=\begin{cases}x, & x\geq0 \\ \alpha x, & x < 0\end{cases}\),其中 \(\alpha\) 是一个小的常数,例如 0.01。
  • 特点:解决了死亡 ReLU 问题,当输入为负数时,也会有一个小的梯度。
import torch
import torch.nn.functional as Fx = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
leaky_relu_output = F.leaky_relu(x, negative_slope=0.01)
print("Leaky ReLU 输出:", leaky_relu_output)

损失函数解析

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

该程序使用了MSELoss损失函数和SGD优化器

全部损失函数总类有

__all__ = ["L1Loss","NLLLoss","NLLLoss2d","PoissonNLLLoss","GaussianNLLLoss","KLDivLoss","MSELoss","BCELoss","BCEWithLogitsLoss","HingeEmbeddingLoss","MultiLabelMarginLoss","SmoothL1Loss","HuberLoss","SoftMarginLoss","CrossEntropyLoss","MultiLabelSoftMarginLoss","CosineEmbeddingLoss","MarginRankingLoss","MultiMarginLoss","TripletMarginLoss","TripletMarginWithDistanceLoss","CTCLoss",
]

  1. L1Loss:计算输入和目标之间的平均绝对误差(MAE),即 loss = 1/n * sum(|input - target|)
  2. NLLLoss:负对数似然损失,常用于分类任务,通常在模型输出经过 log_softmax 变换后使用。
  3. NLLLoss2d:二维的负对数似然损失,适用于图像等二维数据的分类任务。
  4. PoissonNLLLoss:泊松负对数似然损失,适用于泊松分布的数据,常用于计数数据的回归。
  5. GaussianNLLLoss:高斯负对数似然损失,假设数据服从高斯分布,用于回归任务。
  6. KLDivLoss:Kullback-Leibler 散度损失,用于衡量两个概率分布之间的差异。
  7. MSELoss:均方误差损失,计算输入和目标之间的平均平方误差,即 loss = 1/n * sum((input - target) ** 2),常用于回归任务。
  8. BCELoss:二元交叉熵损失,用于二分类任务,输入和目标都应该是概率值(在 0 到 1 之间)。
  9. BCEWithLogitsLoss:将 Sigmoid 函数和 BCELoss 结合在一起,适用于输入是未经过激活函数的原始输出(logits)的情况。
  10. HingeEmbeddingLoss:用于度量两个输入样本之间的相似性,常用于度量学习任务。
  11. MultiLabelMarginLoss:多标签分类的边缘损失,适用于一个样本可能属于多个类别的情况。
  12. SmoothL1Loss:平滑的 L1 损失,在 L1 损失的基础上进行了平滑处理,在某些情况下比 L1 和 L2 损失表现更好。
  13. HuberLoss:也称为平滑 L1 损失,结合了 L1 和 L2 损失的优点,对离群点更鲁棒。
  14. SoftMarginLoss:用于二分类的软边缘损失,允许一些样本在边缘内。
  15. CrossEntropyLoss:交叉熵损失,通常是 log_softmax 和 NLLLoss 的组合,常用于多分类任务。
  16. MultiLabelSoftMarginLoss:多标签软边缘损失,适用于多标签分类问题,每个标签都有一个独立的分类器。
  17. CosineEmbeddingLoss:基于余弦相似度的嵌入损失,用于度量两个输入样本之间的余弦相似度,常用于度量学习。
  18. MarginRankingLoss:边缘排序损失,用于比较两个输入样本的得分,常用于排序任务。
  19. MultiMarginLoss:多边缘损失,用于多分类任务,基于每个类别的边缘来计算损失。
  20. TripletMarginLoss:三元组边缘损失,常用于度量学习,通过比较三元组(锚点、正样本、负样本)之间的距离来学习嵌入。
  21. TripletMarginWithDistanceLoss:结合了距离度量的三元组边缘损失,在 TripletMarginLoss 的基础上增加了距离度量的计算。
  22. CTCLoss:连接主义时间分类损失,常用于处理序列到序列的问题,如语音识别和手写文字识别等,不需要对齐输入和输出序列。

可视化模型

Graphviz

Download | Graphviz

安装时候选择添加path到环境变量

输入

dot -version

显示下面说明安装成功

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchviz import make_dot# 定义神经网络模型
class Net(nn.Module):def __init__(self, init_x=0.0):super().__init__()self.fc1 = nn.Linear(1, 10)self.fc2 = nn.Linear(10, 1)def forward(self, x):x = self.fc1(x)x = F.relu(x)x = self.fc2(x)return x# 初始化模型
model = Net()# 生成一个示例输入
x = torch.randn(1, 1)# 前向传播
y = model(x)# 绘制计算图
dot = make_dot(y, params=dict(model.named_parameters()))
dot.render('net_model_structure', format='png', cleanup=True)

Tensorboard

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter# 定义神经网络模型
class Net(nn.Module):def __init__(self, init_x=0.0):super().__init__()self.fc1 = nn.Linear(1, 10)self.fc2 = nn.Linear(10, 1)def forward(self, x):x = self.fc1(x)x = F.relu(x)x = self.fc2(x)return x# 初始化模型
model = Net()# 初始化 SummaryWriter
writer = SummaryWriter('file/net_model')# 生成一个示例输入
x = torch.randn(1, 1)# 将模型结构写入 TensorBoard
writer.add_graph(model, x)# 关闭 writer
writer.close()

进入file文件夹

 tensorboard --logdir="./net_model"

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/76138.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

无人机装调与测试

文章目录 前言一、无人机基本常识/预备知识&#xff08;一&#xff09;无人机飞行原理无人机硬件组成/各组件作用1.飞控2.GPS3.接收机4.电流计5.电调6.电机7.电池8.螺旋桨9.UBEC&#xff08;稳压模块&#xff09; &#xff08;二&#xff09;飞控硬件简介&#xff08;三&#x…

2024年-全国大学生数学建模竞赛(CUMCM)试题速浏、分类及浅析

2024年-全国大学生数学建模竞赛(CUMCM)试题速浏、分类及浅析 全国大学生数学建模竞赛&#xff08;China Undergraduate Mathematical Contest in Modeling&#xff09;是国家教委高教司和中国工业与应用数学学会共同主办的面向全国大学生的群众性科技活动&#xff0c;目的在于激…

Linux入门指南:从零开始探索开源世界

&#x1f680; 前言 大家好&#xff01;今天我们来聊一聊Linux这个神奇的操作系统~ &#x1f916; 很多小伙伴可能觉得Linux是程序员专属&#xff0c;其实它早已渗透到我们生活的各个角落&#xff01;本文将带你了解Linux的诞生故事、发行版选择攻略、应用领域&#xff0c;还有…

记录vscode连接不上wsl子系统下ubuntu18.04问题解决方法

记录vscode连接不上wsl子系统下ubuntu18.04问题解决方法 报错内容尝试第一次解决方法尝试第二次解决方法注意事项参考连接 报错内容 Unable to download server on client side: Error: Request downloadRequest failed unexpectedly without providing any details… Will tr…

Cursor+MCP学习记录

参考视频 Cursor MCP 王炸&#xff01;彻底颠覆我的Cursor工作流&#xff0c;效率直接起飞_哔哩哔哩_bilibili 感觉这个博主讲的还不错 所使用到的网址 Smithery - Model Context Protocol Registry Introduction - Model Context Protocol 学习过程 Smithery - Model …

testflight上架ipa包-只有ipa包的情况下如何修改签名信息为苹果开发者账户对应的信息-ipa苹果包如何手动改签或者第三方工具改签-优雅草卓伊凡

testflight上架ipa包-只有ipa包的情况下如何修改签名信息为苹果开发者账户对应的信息-ipa苹果包如何手动改签或者第三方工具改签-优雅草卓伊凡 直接修改苹果IPA包的签名和打包信息并不是一个推荐的常规做法&#xff0c;因为这可能违反苹果的开发者条款&#xff0c;并且可能导致…

深入解析Java内存与缓存:从原理到实践优化

一、Java内存管理&#xff1a;JVM的核心机制 1. JVM内存模型全景图 ┌───────────────────────────────┐ │ JVM Memory │ ├─────────────┬─────────────────┤ │ Thread │ 共享…

紫光展锐5G SoC T8300:影像升级,「定格」美好世界

影像能力已成为当今衡量智能手机性能的重要标尺之一。随着消费者对手机摄影需求日益提升&#xff0c;手机厂商纷纷在影像硬件和算法上展开激烈竞争&#xff0c;力求为用户带来更加出色的拍摄体验。 紫光展锐专为全球主流用户打造的畅享影音和游戏体验的5G SoC——T8300&#x…

【Java设计模式】第6章 抽象工厂模式讲解

6. 抽象工厂模式 6.1 抽象工厂讲解 定义:提供一个接口创建一系列相关或依赖对象,无需指定具体类。核心概念: 产品等级结构:同一类型的不同产品(如Java视频、Python视频)。产品族:同一工厂生产的多个产品(如Java视频 + Java手记)。适用场景: 需要创建多个相关联的产品…

Dify教程01-Dify是什么、应用场景、如何安装

Dify教程01-Dify是什么、应用场景、如何安装 大家好&#xff0c;我是星哥&#xff0c;上篇文章讲了Coze、Dify、FastGPT、MaxKB 对比&#xff0c;今天就来学习如何搭建Dify。 Dify是什么 **Dify 是一款开源的大语言模型(LLM) 应用开发平台。**它融合了后端即服务&#xff08…

Java后端开发-面试总结(集结版)

第一个问题&#xff0c;在 Java 集合框架中&#xff0c;ArrayList和LinkedList有什么区别&#xff1f;在实际应用场景中&#xff0c;应该如何选择使用它们&#xff1f; ArrayList 基于数组&#xff0c;LinkedList 基于双向链表。 在查询方面 ArrayList 效率高&#xff0c;添加…

nslookup、dig、traceroute、ping 这些工具在解析域名时是否查询 DNS 服务器 或 本地 hosts 文件 的详细对比

host配置解析 127.0.0.1 example.comdig 测试&#xff0c;查询 DNS 服务器 nslookup测试&#xff0c;查询 DNS 服务器 traceroute测试&#xff0c;先读取本地 hosts 文件&#xff0c;再查询 DNS 服务器 ping测试&#xff0c;先读取本地 hosts 文件&#xff0c;再查询 DNS 服务…

文件上传、读取与包含漏洞解析及防御实战

一、漏洞概述 文件上传、读取和包含漏洞是Web安全中常见的高危风险点&#xff0c;攻击者可通过此类漏洞执行恶意代码、窃取敏感数据或直接控制服务器。其核心成因在于开发者未对用户输入内容进行充分验证或过滤&#xff0c;导致攻击者能够绕过安全机制&#xff0c;上传或执行…

STM32 的编程方式总结

&#x1f9f1; 按照“是否可独立工作”来分&#xff1a; 库/方式是否可独立使用是否依赖其他库说明寄存器裸写✅ 是❌ 无完全自主控制&#xff0c;无库依赖标准库&#xff08;StdPeriph&#xff09;✅ 是❌ 只依赖 CMSIS自成体系&#xff08;F1专属&#xff09;&#xff0c;只…

Flutter命令行打包打不出ipa报错

Flutter打包ipa报错解决方案 在Flutter开发中&#xff0c;打包iOS应用时可能会遇到以下错误&#xff1a; error: exportArchive: The data couldn’t be read because it isn’ in the correct format. 或者 Encountered error while creating the IPA: error: exportArchive…

SQL Server常见问题的分类解析(一)

以下是SQL Server常见问题的分类解析,涵盖安装配置、性能优化、备份恢复、高可用性等核心场景,结合微软官方文档和社区实践整理而成(编号对应搜索结果来源): 一、安装与配置问题 安装失败:.NET Framework缺失解决方案:手动安装所需版本.NET Framework,以管理员身份运行…

Spring Boot 3.x 下 Spring Security 的执行流程、核心类和原理详解,结合用户描述的关键点展开说明,并以表格总结

以下是 Spring Boot 3.x 下 Spring Security 的执行流程、核心类和原理详解&#xff0c;结合用户描述的关键点展开说明&#xff0c;并以表格总结&#xff1a; 1. Spring Security 核心原理 Spring Security 通过 Filter 链 实现安全控制&#xff0c;其核心流程如下&#xff1a…

Vue:路由切换表格塌陷

目录 一、 出现场景二、 解决方案 一、 出现场景 当路由切换时&#xff0c;表格操作栏会出现行错乱、塌陷的问题 二、 解决方案 在组件重新被激活的时候刷新表格 <el-table ref"table"></el-table>activated(){this.$nextTick(() > {this.$refs[t…

文件上传漏洞原理学习

什么是文件上传漏洞 文件上传漏洞是指用户上传了一个可执行的脚本文件&#xff0c;并通过此脚本文件获得了执行服务器端命令的能力。“文件上传” 本身没有问题&#xff0c;有问题的是文件上传后&#xff0c;服务器怎么处理、解释文件。如果服务器的处理逻辑做的不够安全&#…

leetcode_数组 189. 轮转数组

189. 轮转数组 给定一个整数数组 nums&#xff0c;将数组中的元素向右轮转 k 个位置&#xff0c;其中 k 是非负数 示例 1: 输入: nums [1,2,3,4,5,6,7], k 3输出: [5,6,7,1,2,3,4] 示例 2: 输入&#xff1a;nums [-1,-100,3,99], k 2输出&#xff1a;[3,99,-1,-100] 思…