PyTorch檔案生成機制中的FileManager.write_with_template
- 前言
- FileManager.write_with_template調用
- gen_pyi
- gen_nn_functional
- write_sharded
- FileManager.write_with_template實現
- torchgen/utils.py
- FileManager.write_with_template
- FileManager.substitute_with_template
- _read_template
- torchgen/code_template.py
- CodeTemplate
- CodeTemplate.from_file
- CodeTemplate.\__init__
- substitute
前言
PyTorch中有些檔案是在編譯過程中跑腳本生成的,如.pyi
檔是由.pyi.in
檔生成,torch/csrc/autograd/generated
目錄下的.cpp
檔則是由tools/autograd/templates
下的template .cpp
檔生成的。
它們底層都是調用FileManager.write_with_template
函數,其功能是對原檔案中的特定字串依照callback function所指示的方式做替換,進而生成對應的.pyi
或.cpp
檔。
本文會先查看FileManager.write_with_template
函數是如何被調用的,再細看它的實現。
FileManager.write_with_template調用
gen_pyi
tools/pyi/gen_pyi.py
fm.write_with_template("torch/_C/__init__.pyi","torch/_C/__init__.pyi.in",lambda: {"generated_comment": "@" + "generated from torch/_C/__init__.pyi.in",**env,},)fm.write_with_template("torch/_C/_VariableFunctions.pyi","torch/_C/_VariableFunctions.pyi.in",lambda: {"generated_comment": "@"+ "generated from torch/_C/_VariableFunctions.pyi.in",**env,},)fm.write_with_template("torch/_VF.pyi","torch/_C/_VariableFunctions.pyi.in",lambda: {"generated_comment": "@"+ "generated from torch/_C/_VariableFunctions.pyi.in",**env,},)fm.write_with_template("torch/return_types.pyi","torch/_C/return_types.pyi.in",lambda: {"generated_comment": "@" + "generated from torch/_C/return_types.pyi",**env,},)gen_nn_functional(fm)
此處的四個fm.write_with_template
會由torch/_C
資料夾下的四個.pyi.in
檔生成torch/_C
資料夾下的__init__.pyi
, _VariableFunctions.pyi
和torch
資料夾下的_VF.pyi
, return_types.pyi
。
gen_nn_functional
tools/pyi/gen_pyi.py
def gen_nn_functional(fm: FileManager) -> None:# ...fm.write_with_template("torch/nn/functional.pyi","torch/nn/functional.pyi.in",lambda: {"imported_hints": import_code,"dispatched_hints": dispatch_code,},)# ...fm.write_with_template("torch/_C/_nn.pyi","torch/_C/_nn.pyi.in",lambda: {"imported_hints": import_code,"dispatched_hints": dispatch_code,},)
此處的兩個fm.write_with_template
會由torch/nn/functional.pyi.in
及torch/_C/_nn.pyi.in
生成torch/nn/functional.pyi
和torch/_C/_nn.pyi.in
。
write_sharded
torchgen/utils.py
def write_sharded(self,filename: str,items: Iterable[T],*,key_fn: Callable[[T], str],env_callable: Callable[[T], Dict[str, List[str]]],num_shards: int,base_env: Optional[Dict[str, Any]] = None,sharded_keys: Set[str],) -> None:#...for shard in all_shards:shard_id = shard["shard_id"]self.write_with_template(f"{base_filename}{shard_id}{extension}", filename, lambda: shard)#...
其中的all_shards
為:
[{'shard_id': 'Everything'}, {'shard_id': '_0'}, {'shard_id': '_1'}, {'shard_id': '_2'}]
所以這裡的write_with_template
會由filename
即python_torch_functions.cpp
生成python_torch_functionsEverything.cpp
, python_torch_functions_0.cpp
, python_torch_functions_1.cpp
和python_torch_functions_2.cpp
四個檔案。
注意到上面三個例子中,write_with_template
的第三個參數(env_callable
)都是一個呼叫後會返回dict
的lambda函數。
FileManager.write_with_template實現
torchgen/utils.py
FileManager.write_with_template
write_with_template
除了self
以外有三個參數:
filename
:生成的.pyi
的檔名或.cpp
的檔名template_fn
:作為輸入的.pyi.in
的檔名或template.cpp
的檔名env_callable
:在做替換時會用到的callback function
def write_with_template(self,filename: str,template_fn: str,env_callable: Callable[[], Union[str, Dict[str, Any]]],) -> None:filename = "{}/{}".format(self.install_dir, filename)assert filename not in self.filenames, "duplicate file write {filename}"self.filenames.add(filename)if not self.dry_run:substitute_out = self.substitute_with_template(template_fn=template_fn,env_callable=env_callable,)self._write_if_changed(filename=filename, contents=substitute_out)
可以看到這段代碼最核心的內容就是調用substitute_with_template
生成substitute_out
。
之後再將替換後的結果,也就是substitute_out
寫入filename
(.pyi
檔)這個檔案中。
注:在做類型檢查時,callback function是由typing.Callable表示的,詳見Python typing函式庫和torch.types。
FileManager.substitute_with_template
torchgen/utils.py
除self
外有兩個參數:
template_fn
:作為輸入的.pyi.in
的檔名或template.cpp
的檔名env_callable
:在做替換時會用到的callback function
# Read from template file and replace pattern with callable (type could be dict or str).def substitute_with_template(self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]) -> str:template_path = os.path.join(self.template_dir, template_fn)env = env_callable()if isinstance(env, dict):# TODO: Update the comment reference to the correct locationif "generated_comment" not in env:comment = "@" + "generated by torchgen/gen.py"comment += " from {}".format(os.path.basename(template_path))env["generated_comment"] = commenttemplate = _read_template(template_path)return template.substitute(env)elif isinstance(env, str):return envelse:assert_never(env)
env_callable
是一個呼叫後會返回dict
的lambda函數,所以會進入isinstance(env, dict)
這個分支,先由_read_template
讀入template檔案(.pyi.in
檔或template .cpp
檔)後調用template.substitute
。
_read_template
torchgen/utils.py
參數template_fn
為pyi
或template cpp
的檔名。
@functools.lru_cache(maxsize=None)
def _read_template(template_fn: str) -> CodeTemplate:return CodeTemplate.from_file(template_fn)
讀入template_fn
,生成CodeTemplate
物件並回傳。
torchgen/code_template.py
CodeTemplate
torchgen/code_template.py
先來看看CodeTemplate
類別的作用。
# match $identifier or ${identifier} and replace with value in env
# If this identifier is at the beginning of whitespace on a line
# and its value is a list then it is treated as
# block substitution by indenting to that depth and putting each element
# of the list on its own line
# if the identifier is on a line starting with non-whitespace and a list
# then it is comma separated ${,foo} will insert a comma before the list
# if this list is not empty and ${foo,} will insert one after.class CodeTemplate:substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"substitution = re.compile(substitution_str, re.MULTILINE)pattern: strfilename: str# ...
注釋裡說明了CodeTemplate
的功用是把模板中${identifier}
字樣替換成env
中對應的value。
在torch/_C/_VariableFunctions.pyi.in
中就有以下字樣:
# ${generated_comment}
# ...
${function_hints}${all_directive}
在python_torch_functions.cpp
中則有以下字樣:
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
$ops_headers
#endif// ...
// generated forward declarations start here${py_forwards}// ...
static PyMethodDef torch_functions_shard[] = {${py_method_defs}
};// ...
// generated methods start here${py_methods}
CodeTemplate.from_file
torchgen/code_template.py
class CodeTemplate:# ...@staticmethoddef from_file(filename: str) -> "CodeTemplate":with open(filename, "r") as f:return CodeTemplate(f.read(), filename)# ...
調用CodeTemplate
的建構子,傳入filename
的內容及名稱。
CodeTemplate._init_
filename
:作為輸入的.pyi.in
的檔名或template.cpp
的檔名pattern
:在CodeTemplate.from_file
中是以CodeTemplate(f.read(), filename)
調用CodeTemplate
建構子,所以pattern
成員變數會被設為從filename
檔案裡讀出來的東西
class CodeTemplate:# ...def __init__(self, pattern: str, filename: str = "") -> None:self.pattern = patternself.filename = filename# ...
substitute
torchgen/code_template.py
回顧torchgen/utils.py
的substitute_with_template
中的:
template = _read_template(template_path)
生成了CodeTemplate
物件template
後繼續調用:
return template.substitute(env)
其功能是做一些正則替換:
class CodeTemplate:# ...def substitute(self, env: Optional[Mapping[str, object]] = None, **kwargs: object) -> str:if env is None:env = {}def lookup(v: str) -> object:assert env is not Nonereturn kwargs[v] if v in kwargs else env[v]def indent_lines(indent: str, v: Sequence[object]) -> str:return "".join([indent + l + "\n" for e in v for l in str(e).splitlines()]).rstrip()def replace(match: Match[str]) -> str:indent = match.group(1)key = match.group(2)comma_before = ""comma_after = ""if key[0] == "{":key = key[1:-1]if key[0] == ",":comma_before = ", "key = key[1:]if key[-1] == ",":comma_after = ", "key = key[:-1]v = lookup(key)if indent is not None:if not isinstance(v, list):v = [v]return indent_lines(indent, v)elif isinstance(v, list):middle = ", ".join([str(x) for x in v])if len(v) == 0:return middlereturn comma_before + middle + comma_afterelse:return str(v)return self.substitution.sub(replace, self.pattern)
函數最後的self.substitution.sub(replace, self.pattern)
中的self.substitution
是CodeTemplate
的成員:
substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"substitution = re.compile(substitution_str, re.MULTILINE)
re.compile
後得到的substitution
是一個re.Pattern
物件。
先來看看re.Pattern.sub
是什麼,參考Passing a function to re.sub in Python及Python: re.compile and re.sub中給出的例子:
import re
substitution = re.compile(r'\d')
number_mapping = {'1': 'one', '2': 'two', '3': 'three'}
s = "1 testing 2 3"
substitution.sub(lambda x: number_mapping[x.group()], s) # 'one testing two three'
re.Pattern.sub
的第一個參數是做替換的函數,第二個參數則是欲處理的字串,它會尋找特定樣式的字串(此處是r'\d'
),對它們做替換後回傳。
所以self.substitution.sub(replace, self.pattern)
這句是在self.pattern
(也就是pyi.in
或template cpp
檔中的內容)中尋找substitution_str
樣式的字串,並用replace
這個函數所指定的方式做替換。
得到替換後的結果後,回到substitute_with_template
函數:
return template.substitute(env)
那裡繼續將結果回傳,來到write_with_template
函數:
substitute_out = self.substitute_with_template(template_fn=template_fn,env_callable=env_callable,)self._write_if_changed(filename=filename, contents=substitute_out)
在那裡會把替換結果substitute_out
寫入filename
,也就是生成的.pyi
的檔名或.cpp
的檔名。
來看看torch/_C/_VariableFunctions.pyi
中的${generated_comment}
。
回顧gen_pyi
函數中呼叫write_with_template
時,與env
一同傳入了generated_comment
的key value pair:
fm.write_with_template("torch/_C/_VariableFunctions.pyi","torch/_C/_VariableFunctions.pyi.in",lambda: {"generated_comment": "@"+ "generated from torch/_C/_VariableFunctions.pyi.in",**env,},)
所以到了substitute
函數,env
參數便是一個包含generated_comment
的key value pair的字典。
# ${generated_comment}
在做替換後,會變成生成的torch/_C/_VariableFunctions.pyi
檔案中的第一行:
# @generated from torch/_C/_VariableFunctions.pyi.in