Python typing函式庫和torch.types
- 前言
- typing
- Sequence vs Iterable
- Callable
- Union
- Optional
- Functions
- Callable
- Iterator/generator
- 位置參數 & 關鍵字參數
- Classes
- self
- 自定義類別
- ClassVar
- \_\_setattr\_\_ 與 \__getattr\_\_
- torch.types
- builtins
- 參數前的*
前言
在PyTorch的torch/_C/_VariableFunctions.pyi
中有如下代碼:
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
當中的Sequence
, Iterable
, Optional
, Union
以及_int
, _bool
都是什麼意思呢?可以從torch/_C/_VariableFunctions.pyi.in
中一窺端倪:
from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided, inf
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, Literal, TypeVarfrom torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout, SymInt, Device
所以Sequence
, Iterable
, Optional
, Union
等是從一個叫做typing
的庫中導入的。typing是Python的標準庫之一,作用是提供對類型提示的運行時支持。
_int
, _bool
等則是PyTorch中自行定義的類型。
typing
Sequence vs Iterable
根據Type hints cheat sheet - Standard “duck types”,Sequence
代表的是支持__len__
及__getitem__
方法的序列類型,例如list, tuple和str。dict和set則不屬於此類型。
# Use Iterable for generic iterables (anything usable in "for"),
# and Sequence where a sequence (supporting "len" and "__getitem__") is
# required
根據Python Iterable vs Sequence:
Iterable
代表的是支持__iter__
或__getitem__
的類型,如range
和reversed
。
r = range(4)
r.__getitem__(0) # 0
r.__iter__() # <range_iterator object at 0x0000015AE7945D30>
l = [1, 2, 3]
rv = reversed(l)
rv.__iter__() # <list_reverseiterator object at 0x0000015AE7980E20>
rv.__getitem__() # 不支援__getitem__方法,但因為支持__iter__所以依然可以歸類為Iterable
# Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# AttributeError: 'list_reverseiterator' object has no attribute '__getitem__'
因為Sequence
也具有__iter__
和__getitem__
,所以根據定義,所有的Sequence
都是Iterable
。
l = []
l.__iter__ # <method-wrapper '__iter__' of list object at 0x7f15bb50b5c0>
l.__getitem__ # <built-in method __getitem__ of list object at 0x7f15bb50b5c0>
Callable
typing - Callable
Callable
Frameworks expecting callback functions of specific signatures might be type hinted using Callable[[Arg1Type, Arg2Type], ReturnType].
文檔寫得很淺顯易懂,不過有一點要注意的是入參型別要用[]
括起來。
Type hints cheat sheet - Functions中給出了例子:
# This is how you annotate a callable (function) value
x: Callable[[int, float], float] = f
如果先不看類型提示的代碼,這句其實就是x = f
,把x
這個變數設定為f
這個函數。當中的Callable[[int, float], float]
說明了f
是一個接受int
, float
,輸出float
的函數。
Union
typing - Union
typing.Union
Union type; Union[X, Y] is equivalent to X | Y and means either X or Y.To define a union, use e.g. Union[int, str] or the shorthand int | str. Using that shorthand is recommended.
Union[X, Y]
表示型別可以是X
或Y
,從Python 3.10以後,可以使用X | Y
這種更簡潔的寫法。
Type hints cheat sheet - Useful built-in types中給出的例子:
# On Python 3.10+, use the | operator when something could be one of a few types
x: list[int | str] = [3, 5, "test", "fun"] # Python 3.10+
# On earlier versions, use Union
x: list[Union[int, str]] = [3, 5, "test", "fun"]
Optional
typing - Optional
Optional type.Optional[X] is equivalent to X | None (or Union[X, None]).
Optional[X]
表示該變數可以是X
型別或是None
型別。
Type hints cheat sheet - Useful built-in types中給出了一個很好的例子:
# Use Optional[X] for a value that could be None
# Optional[X] is the same as X | None or Union[X, None]
x: Optional[str] = "something" if some_condition() else None
這裡x
根據some_condition()
的回傳值有可能是一個字串或是None,所以此處選用Optional[str]
的類型提示。
Functions
mypy - Functions
指定參數和回傳值型別:
from typing import Callable, Iterator, Union, Optional# This is how you annotate a function definition
def stringify(num: int) -> str:return str(num)
多個參數:
# And here's how you specify multiple arguments
def plus(num1: int, num2: int) -> int:return num1 + num2
無回傳值的函數以None
為回傳型別,並且參數的預設值應寫在參數型別後面:
# If a function does not return a value, use None as the return type
# Default value for an argument goes after the type annotation
def show(value: str, excitement: int = 10) -> None:print(value + "!" * excitement)
可以接受任意型別參數的函數則不必指定參數型別:
# Note that arguments without a type are dynamically typed (treated as Any)
# and that functions without any annotations not checked
def untyped(x):x.anything() + 1 + "string" # no errors
Callable
將Callable
當作參數的函數:
# This is how you annotate a callable (function) value
x: Callable[[int, float], float] = f
def register(callback: Callable[[str], int]) -> None: ...
Iterator/generator
generator函數相當於一個Iterator
:
# A generator function that yields ints is secretly just a function that
# returns an iterator of ints, so that's how we annotate it
def gen(n: int) -> Iterator[int]:i = 0while i < n:yield ii += 1
將function annotation分成多行:
# You can of course split a function annotation over multiple lines
def send_email(address: Union[str, list[str]],sender: str,cc: Optional[list[str]],bcc: Optional[list[str]],subject: str = '',body: Optional[list[str]] = None) -> bool:...
位置參數 & 關鍵字參數
# Mypy understands positional-only and keyword-only arguments
# Positional-only arguments can also be marked by using a name starting with
# two underscores
def quux(x: int, /, *, y: int) -> None:passquux(3, y=5) # Ok
quux(3, 5) # error: Too many positional arguments for "quux"
quux(x=3, y=5) # error: Unexpected keyword argument "x" for "quux"
注意到此處參數列表中有/
和*
兩個符號,參考What Are Python Asterisk and Slash Special Parameters For?:
Left side | Divider | Right side |
---|---|---|
Positional-only arguments | / | Positional or keyword arguments |
Positional or keyword arguments | * | Keyword-only arguments |
Python的參數分為三種:位置參數,關鍵字參數及可變參數(可以透過位置或關鍵字的方式傳遞)。
/
符號的左邊必須是位置參數,*
符號的右邊則必須是關鍵字參數。
所以上例中x
必須以位置參數的方式傳遞,y
必須以關鍵字參數的方式傳遞。
一次指定多個參數的型別:
# This says each positional arg and each keyword arg is a "str"
def call(self, *args: str, **kwargs: str) -> str:reveal_type(args) # Revealed type is "tuple[str, ...]"reveal_type(kwargs) # Revealed type is "dict[str, str]"request = make_request(*args, **kwargs)return self.do_api_query(request)
Classes
mypy - Classes
self
class BankAccount:# The "__init__" method doesn't return anything, so it gets return# type "None" just like any other method that doesn't return anythingdef __init__(self, account_name: str, initial_balance: int = 0) -> None:# mypy will infer the correct types for these instance variables# based on the types of the parameters.self.account_name = account_nameself.balance = initial_balance# For instance methods, omit type for "self"def deposit(self, amount: int) -> None:self.balance += amountdef withdraw(self, amount: int) -> None:self.balance -= amount
成員函數self
參數的型別不需指定。
自定義類別
可以將變數型別指定為自定義的類別:
# User-defined classes are valid as types in annotations
account: BankAccount = BankAccount("Alice", 400)
def transfer(src: BankAccount, dst: BankAccount, amount: int) -> None:src.withdraw(amount)dst.deposit(amount)
# Functions that accept BankAccount also accept any subclass of BankAccount!
class AuditedBankAccount(BankAccount):# You can optionally declare instance variables in the class bodyaudit_log: list[str]def __init__(self, account_name: str, initial_balance: int = 0) -> None:super().__init__(account_name, initial_balance)self.audit_log: list[str] = []def deposit(self, amount: int) -> None:self.audit_log.append(f"Deposited {amount}")self.balance += amountdef withdraw(self, amount: int) -> None:self.audit_log.append(f"Withdrew {amount}")self.balance -= amountaudited = AuditedBankAccount("Bob", 300)
transfer(audited, account, 100) # type checks!
transfer
函數的第一個參數型別應為BankAccount
,而AuditedBankAccount
是BankAccount
的子類別,所以在做類型檢查時不會出錯。
ClassVar
Python中類別的變數有類別變數別實例變數兩種。如果想要將成員變數標記為類別變數,可以用ClassVar[type]
。
# You can use the ClassVar annotation to declare a class variable
class Car:seats: ClassVar[int] = 4passengers: ClassVar[list[str]]
__setattr__ 與 __getattr__
# If you want dynamic attributes on your class, have it
# override "__setattr__" or "__getattr__"
class A:# This will allow assignment to any A.x, if x is the same type as "value"# (use "value: Any" to allow arbitrary types)def __setattr__(self, name: str, value: int) -> None: ...# This will allow access to any A.x, if x is compatible with the return typedef __getattr__(self, name: str) -> int: ...a.foo = 42 # Works
a.bar = 'Ex-parrot' # Fails type checking
__setattr__
函數可以為類別新增實體變數。
torch.types
PyTorch中自定義的類型。
torch/types.py
import torch
from typing import Any, List, Sequence, Tuple, Unionimport builtins# Convenience aliases for common composite types that we need
# to talk about in PyTorch_TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]# In some cases, these basic types are shadowed by corresponding
# top-level values. The underscore variants let us refer to these
# types. See https://github.com/python/mypy/issues/4146 for why these
# workarounds is necessary
_int = builtins.int
_float = builtins.float
_bool = builtins.bool_dtype = torch.dtype
_device = torch.device
_qscheme = torch.qscheme
_size = Union[torch.Size, List[_int], Tuple[_int, ...]]
_layout = torch.layout
_dispatchkey = Union[str, torch._C.DispatchKey]class SymInt:pass# Meta-type for "numeric" things; matches our docs
Number = Union[builtins.int, builtins.float, builtins.bool]# Meta-type for "device-like" things. Not to be confused with 'device' (a
# literal device object). This nomenclature is consistent with PythonArgParser.
# None means use the default device (typically CPU)
Device = Union[_device, str, _int, None]# Storage protocol implemented by ${Type}StorageBase classesclass Storage(object):_cdata: intdevice: torch.devicedtype: torch.dtype_torch_load_uninitialized: booldef __deepcopy__(self, memo) -> 'Storage':...def _new_shared(self, int) -> 'Storage':...def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool, element_size: int) -> None:...def element_size(self) -> int:...def is_shared(self) -> bool:...def share_memory_(self) -> 'Storage':...def nbytes(self) -> int:...def cpu(self) -> 'Storage':...def data_ptr(self) -> int:...def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage':...def _new_with_file(self, f: Any, element_size: int) -> 'Storage':......
torch.types
中的_int
, _float
, _bool
就是Python內建的builtins.int
, builtins.float
, builtins.bool
。
PyTorch中定義的Number
則是_int
, _float
, _bool
中的其中一個。
builtins
builtins — Built-in objects
This module provides direct access to all ‘built-in’ identifiers of Python; for example, builtins.open is the full name for the built-in function open().
可以透過builtins
這個模組存取Python內建的identifier,例如Python中的open()
函數可以使用builtins.open
來存取。
參數前的*
參考What does the Star operator mean in Python?
Single asterisk as used in function declaration allows variable number of arguments passed from calling environment. Inside the function it behaves as a tuple.
在函數參數前加上*
表示可以接受任意個參數,在函數內部,該參數會被當成一個tuple。
def function(*arg):print (type(arg))for i in arg:print (i)
function(1,2,3)
# <class 'tuple'>
# 1
# 2
# 3