pytorch中的面向对象编程方法

一、__xxx__形式的魔法方法

我们可以经常在python代码片段中看到类的定义,其中第一个被定义的方法往往是__init__,如下所示:

class Accumulator:  """在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]

我们知道__init__显然是类的构造函数,但为什么要在前后都加上双下划线呢?

原来这是python设计的一种特殊方法,它别称为魔法方法,它可以被运算符隐式调用,下面给出示例:

import torch
class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]test=Accumulator(3)
test.add(0,1,2)
print(test.__getitem__(2)) #打印2.0
print(test[2])             #打印2.0

可以发现,所谓魔法方法,实质上实现的就是c++中运算符重载的功能。 它让test.__getitem__(2)和test[2]这两种语法都能调用该方法!

对于更多的python的魔法方法,可以看看下面这一篇文章:

Python 中的 `__xxx__` 特殊方法:介绍与使用-CSDN博客

其中,与pytorch紧密相关的一个方法比较重要,这便是__call__方法,它能通过以下方式被直接调用:

test=Accumulator(3)
test.__call__(param)
test(param)
'''两者等价'''

示例如下:

import torch
class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __call__(self, idx):return self.data[idx]test=Accumulator(3)
test.add(0,1,2)
print(test.__call__(2))    #打印2.0
print(test(2))             #打印2.0

二、nn.Module类的使用以及其重要的两个类方法

神经网络较为复杂,可以被分为块和层。用类去分别定义神经网络的一个层、一个块会更为清晰。而nn.Module便是pytorch设计者为块类和层类设计的父类,如nn.Linear、nn.Flatten和nn.ReLU这样描述层的类,便是继承自nn.Module类。我们也可以继承nn.Module类,编写我们所需要的层和类。

nn.Module类有一个非常非常重要的特征,那便是其__call__函数里调用了forward方法。所以,在继承nn.Module类时,需要为子类写好forward方法,这样才能按照pytorch的习惯发挥子类作为神经网络中层和类的功能!

例如,我现在写一个类用来描述多层感知机这个块,那么,我需要写好forward和__init__两个方法

import torch
from torch import nn
from torch.nn import functional as funcclass MulLayerPerceptron(nn.Module):def __init__(self):super().__init__()self.hidden=nn.Linear(20,256)self.output=nn.Linear(256,10)def forward(self,X):return self.output(func.relu(self.hidden(X)))net=MulLayerPerceptron()
X=torch.arange(20.0)
print(net(X))

这样,在实例化这个类的对象net后,就可以直接使用net(X)这个函数调用forward方法!

三、Sequential子类

torch框架本省就提供了一个nn.Module的子类——Sequential,这是一个包含多个层的类,可以按顺序执行一系列函数

如上的多层感知机,我就可以这样定义:

net=nn.Sequential(nn.Linear(20,256),nn.Linear(256,10))
X=torch.arange(20.0)
print(net(X))

事实上,其源码的形式如下所示:

class MySequential(nn.Module):def __init__(self,*args):super().__init__()self.layers=[]for layer in args:self.layers.append(layer)def forward(self,X):for layer in self.layers:X=layer(X)return Xnet=MySequential(nn.Linear(20,256),nn.Linear(256,10))
X=torch.arange(20.0)
print(net(X))

当然,实际上Sequential类会更加复杂。 

使用对net使用add_module方法可以为为Sequential块类添加层,其中第一个参数是层的名称,第二个参数是层的类型。

net=nn.Sequential(nn.Linear(20,256))
net.add_module("mylayer",nn.Linear(256,10))

直接打印net可以查看整个Sequential块的组成情况:

print(net)

Sequential(
  (0): Linear(in_features=20, out_features=256, bias=True)
  (mylayer): Linear(in_features=256, out_features=10, bias=True)
)

可以看出,Sequential的层默认按照数字命名,我们添加的层是自主命名的!

四、参数访问

1、简单结构

对于Sequential类,可以使用类似数组下标访问的方法来访问其每一层:

net=nn.Sequential(nn.Linear(20,256),nn.Linear(256,10))
X=torch.arange(20.0)
print(net[0])
print(net[1])

Linear(in_features=20, out_features=256, bias=True)
Linear(in_features=256, out_features=10, bias=True)

可以用state_dict方法看看net[X]中的各个参数:

net=nn.Sequential(nn.Linear(20,256),nn.Linear(256,10))
print(net[0].state_dict())

一般,nn.Linear结构中的参数包括weight和bias,即权重和偏置,可以用.运算符访问:

net=nn.Sequential(nn.Linear(20,256),nn.Linear(256,10))
print(net[0].weight)
print(net[0].bias)

 其中,weight和bias都分别包含tensor数组(可用data属性访问)和控制是否求梯度的bool变量requires_grad,可以如下进行访问和设置

net=nn.Sequential(nn.Linear(20,256),nn.Linear(256,10))
print(net[0].weight.data[0])
net[0].weight.requires_grad=True

事实上,net[X]返回了nn库自定义的Parameter类型。 

2、嵌套结构
block=nn.Sequential(nn.Linear(20,256),nn.Linear(256,10))
net=nn.Sequential(block,nn.Linear(10,5))
X=torch.arange(20.0)
print(net(X))
print(net)
print(net[0][0])

 运行结果:

tensor([ 1.9557, -2.0318, -1.2498, -0.1433, -0.3641], grad_fn=<ViewBackward0>)
Sequential(
  (0): Sequential(
    (0): Linear(in_features=20, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=10, bias=True)
  )
  (1): Linear(in_features=10, out_features=5, bias=True)
)
Linear(in_features=20, out_features=256, bias=True)

如上所示,Sequential类可以进行嵌套。嵌套后,我们使用类似多维数组的方式访问

五、参数初始化

对于如nn.Linear这样的层,可以使用nn.init下带有的方法将参数初始化。

def init_normal(m):if type(m)==nn.Linear:nn.init.normal_(m.weight,mean=0.0,std=0.01)nn.init.zeros_(m.bias)

如代码所示,nn.Linear作为参数m传入,那么调用nn.init.normal_可以使weight向量得以初始化。第一个参数使所需要初始化的向量,mean参数规定平均值,std参数规定方差。

nn.init.zeros_方法用于将参数置于0,nn.init.constant_方法可以将参数置为某个常数。

nn.init.zeros_(m.weight)
nn.init.constant_(m.weight,1.0) #将weight向量全置为1

而nn.init.uniform_方法,可以使参数均匀分布。其第一个参数还是需要初始化的tensor向量,a代表均匀分布的下界(默认为0.0),b代表上界(默认为1.0)

nn.init.uniform_(m.weight,a=0.0,b=2.0)

 

对块net使用apply方法,可以让自定义初始化函数在所有层上作用一遍。

block=nn.Sequential(nn.Linear(20,256),nn.Linear(256,10))
net=nn.Sequential(block,nn.Linear(10,5))
X=torch.arange(20.0)def init_normal(m):if type(m)==nn.Linear:print("init")nn.init.normal_(m.weight,mean=0.0,std=0.01)nn.init.zeros_(m.bias)net.apply(init_normal)

init

init

init

注意,apply非常智能。纵使Sequential中出现嵌套,也可以层层访问,把所有的nn.Linear结构都进行初始化。

这是因为apply函数会把net的所有结构都访问一遍,打印m,就可以知晓其运作的规律:

block=nn.Sequential(nn.Linear(20,256),nn.Linear(256,10))
net=nn.Sequential(block,nn.Linear(10,5))
X=torch.arange(20.0)def init_normal(m):print(m)if type(m)==nn.Linear:nn.init.normal_(m.weight,mean=0.0,std=0.01)nn.init.zeros_(m.bias)net.apply(init_normal)

Linear(in_features=20, out_features=256, bias=True)
Linear(in_features=256, out_features=10, bias=True)
Sequential(
  (0): Linear(in_features=20, out_features=256, bias=True)
  (1): Linear(in_features=256, out_features=10, bias=True)
)
Linear(in_features=10, out_features=5, bias=True)
Sequential(
  (0): Sequential(
    (0): Linear(in_features=20, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=10, bias=True)
  )
  (1): Linear(in_features=10, out_features=5, bias=True)
)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/874813.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Android】ListView和RecyclerView知识总结

文章目录 ListView步骤适配器AdpterArrayAdapterSimpleAdapterBaseAdpter效率问题 RecyclerView具体实现不同布局形式的设置横向滚动瀑布流网格 点击事件 ListView ListView 是 Android 中的一种视图组件&#xff0c;用于显示可滚动的垂直列表。每个列表项都是一个视图对象&…

【JavaScript】前端路由

前端路由是指在前端⻚⾯内部实现⻚⾯之间的跳转&#xff0c;⽽不是像传统的⽹⻚跳转那样在后端进⾏⻚⾯跳转&#xff0c;从后端获取 html 页面。前端路由使⽤浏览器的 history 接⼝&#xff0c;通过改变浏览器的 URL&#xff0c;来更新⻚⾯的视图。 前端路由适合⽤于单⻚⾯应⽤…

Python教程(一):环境搭建及PyCharm安装

目录 引言1. Python简介1.1 编译型语言 VS 解释型语言 2. Python的独特之处3. Python应用全览4. Python版本及区别5. 环境搭建5.1 安装Python&#xff1a; 6. 开发工具&#xff08;IDE&#xff09;6.1 PyCharm安装教程6.2 永久使用教程 7. 编写第一个Hello World结语 引言 在当…

每日一题 LeetCode03 无重复字符的最长字串

1.题目描述 给定一个字符串 s &#xff0c;请你找出其中不含有重复字符的最长字串的长度。 2 思路 可以用两个指针, 滑动窗口的思想来做这道题,即定义两个指针.一个left和一个right 并且用一个set容器,一个length , 一个maxlength来记录, 让right往右走,并且用一个set容器来…

探索Prompt的世界

在人工智能&#xff08;AI&#xff09;和自然语言处理&#xff08;NLP&#xff09;的飞速发展中&#xff0c;prompt技术作为一种与语言模型交互的重要方式&#xff0c;正逐步占据中心舞台。为了对prompt这一概念进行全面介绍&#xff0c;我们将从其发展历史、运行原理、调试方式…

如何避免蓝屏?轻量部署,安全和业务连续性才能两不误

自19日起&#xff0c;因CrowdStrike软件更新的错误配置而导致的“微软全球蓝屏”&#xff0c;影响依然在持续。这场被称为“史上最大规模的IT故障”&#xff0c;由于所涉全球企业太多&#xff0c;专家估计“蓝屏”电脑全部恢复正常仍需时日。 尽管 CEO 乔治 库尔茨&#xff08…

2024年自动驾驶SLAM面试题及答案(更新中)

自动驾驶中的SLAM&#xff08;Simultaneous Localization and Mapping&#xff0c;即同步定位与地图构建&#xff09;是关键技术&#xff0c;它能够让车辆在未知环境中进行自主定位和地图建构。秋招来临之际&#xff0c;相信大家都已经在忙碌的准备当中了&#xff0c;尤其是应届…

Oracle星型查询转换解析

目录 一、星型查询转换原理二、配置星型查询转换三、性能考虑四、案例1、数据模型2、创建表和数据3、创建位图索引4、查询优化前5、查询优化后6、检查执行计划 Oracle的星型查询转换&#xff08;Star Query Transformation&#xff09;是Oracle数据库优化器的一个重要特性&…

Go语言入门之错误处理

Go语言入门之错误处理 错误处理是开发中必不可少的一个部分&#xff0c;go中的错误一般有两种&#xff0c;一种为error&#xff0c;一种为panic go语言通常返回一个错误值&#xff0c;然后检查错误值是否为nil&#xff0c;以此判断函数是否执行 1.Error Go使用error接口来表示一…

鸿蒙OpenHarmony Native API【drawing_pen.h】 头文件

drawing_pen.h Overview Related Modules: [Drawing] Description: 文件中定义了与画笔相关的功能函数 Since: 8 Version: 1.0 Summary Enumerations Enumeration NameDescription[OH_Drawing_PenLineCapStyle] { [LINE_FLAT_CAP], [LINE_SQUARE_CAP], [LINE_ROUND_…

Exchange Server 中 Exchange 虚拟目录的默认设置

Exchange Server 2016 和 Exchange Server 2019 在服务器安装过程中自动配置多个 Internet Information Services (IIS) 虚拟目录。 以下部分中的表显示了邮箱服务器上客户端访问 (前端) 服务的设置&#xff0c;以及默认的 IIS 身份验证和安全套接字层 (SSL) 设置。 有时为了调…

聚焦智慧出行,TDengine 与路特斯科技再度携手

在全球汽车行业向电动化和智能化转型的过程中&#xff0c;智能驾驶技术正迅速成为行业的焦点。随着消费者对出行效率、安全性和便利性的需求不断提升&#xff0c;汽车制造商们需要在全球范围内实现低延迟、高质量的数据传输和处理&#xff0c;以提升用户体验。在此背景下&#…

从零开始:神经网络(1)——什么是人工神经网络

声明&#xff1a;本文章是根据网上资料&#xff0c;加上自己整理和理解而成&#xff0c;仅为记录自己学习的点点滴滴。可能有错误&#xff0c;欢迎大家指正。 人工神经网络&#xff08;Artificial Neural Network&#xff0c;简称ANN&#xff09;是一种模仿生物神经网络结构和功…

Android SurfaceFlinger——GraphicBuffer初始化(二十九)

在 SurfaceFlinger 中,GraphicBuffer 是一个关键的数据结构,用于封装和管理图形数据的内存缓冲区。它不仅在 SurfaceFlinger 内部使用,也被其他组件如 GPU 驱动、摄像头服务、视频解码器等广泛利用,以实现高效的数据交换和图形渲染。 一、概述 GraphicBuffer 对象封装了一…

从dev分支合并到master分支

git命令从dev分支合并到master分支 1、拉取dev分支的代码 git checkout dev //切换成本地分支 git pull origin dev //拉取远程开发分支 git add . //暂存到本地仓库 git commit -m //增加备注信息 git push origin dev //推送到远程仓库 git checkout master // 切换到maste…

《500 Lines or Less》(5)异步爬虫

https://aosabook.org/en/500L/a-web-crawler-with-asyncio-coroutines.html ——A. Jesse Jiryu Davis and Guido van Rossum 介绍 网络程序消耗的不是计算资源&#xff0c;而是打开许多缓慢的连接&#xff0c;解决此问题的现代方法是异步IO。 本章介绍一个简单的网络爬虫&a…

STM32F0-标准库时钟配置指南

启动 从startup_stm32f0xx.s内的开头的Description可以看到 ;* Description : STM32F051 devices vector table for EWARM toolchain. ;* This module performs: ;* - Set the initial SP ;* - Set t…

使用sqlalchemy查询mysql的JSON字段

使用sqlalchemy查询mysql的JSON字段 在使用SQLAlchemy查询MySQL的JSON字段时,你可以按照以下步骤操作: 假设你有一个包含JSON字段的表格 假设你有一个名为 items 的表格,其中有一个名为 data 的JSON字段。我们来查询这个字段。 1. 定义模型类 首先,你需要定义一个与表…

【Leetcode】十八、动态规划:不同路径 + 最大正方形

文章目录 1、动态规划2、leetcode509&#xff1a;斐波那契数列3、leetcode62&#xff1a;不同路径4、leetcode121&#xff1a;买卖股票的最佳时机5、leetcode70&#xff1a;爬楼梯6、leetcode279&#xff1a;完全平方数7、leetcode221&#xff1a;最大正方形 1、动态规划 只能…

【Java语法基础】4.字符串

4.字符串 字符char无需多言&#xff0c;单引号。 String类 基本操作 String类的访问不能通过数组访问&#xff0c;只能通过API&#xff0c;并且只能访问不能修改&#xff0c;如&#xff1a; String a "hello"; for(int i 0; i < a.length(); i ) {//注意&…