问题:
for执行次数不跟据输入而改变。
解决方案:
torch.jit.script
例如:
class LoopAdd(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):h = xfor i in range(x.size(0)):h = h + 1return h
input_1 = torch.ones(3, 16)
model = LoopAdd()
traced_model = torch.jit.trace(model, (input_1, ))
print(traced_model.graph)
graph(%self : __torch__.LoopAdd,%x : Float(3, 16, strides=[16, 1], requires_grad=0, device=cpu)):%7 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # /home/mark.yj/GPT-SoVITS/b.py:8:0%8 : int = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:8:0%h.1 : Float(3, 16, strides=[16, 1], requires_grad=0, device=cpu) = aten::add(%x, %7, %8) # /home/mark.yj/GPT-SoVITS/b.py:8:0%10 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # /home/mark.yj/GPT-SoVITS/b.py:8:0%11 : int = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:8:0%h : Float(3, 16, strides=[16, 1], requires_grad=0, device=cpu) = aten::add(%h.1, %10, %11) # /home/mark.yj/GPT-SoVITS/b.py:8:0%13 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # /home/mark.yj/GPT-SoVITS/b.py:8:0%14 : int = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:8:0%15 : Float(3, 16, strides=[16, 1], requires_grad=0, device=cpu) = aten::add(%h, %13, %14) # /home/mark.yj/GPT-SoVITS/b.py:8:0return (%15)
改造后:
class LoopAdd(torch.jit.ScriptModule):def __init__(self):super().__init__()@torch.jit.script_methoddef forward(self, x):h = xfor i in range(x.size(0)):h = h + 1return h
input_1 = torch.ones(3, 16)
model = LoopAdd()
traced_model = torch.jit.trace(model, (input_1, ))
print(traced_model.graph)
graph(%self : __torch__.LoopAdd,%x.1 : Tensor):%8 : bool = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:18:8%4 : int = prim::Constant[value=0]() # /home/mark.yj/GPT-SoVITS/b.py:18:30%11 : int = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:19:20%5 : int = aten::size(%x.1, %4) # /home/mark.yj/GPT-SoVITS/b.py:18:23%h : Tensor = prim::Loop(%5, %8, %x.1) # /home/mark.yj/GPT-SoVITS/b.py:18:8block0(%i : int, %h.9 : Tensor):%h.3 : Tensor = aten::add(%h.9, %11, %11) # /home/mark.yj/GPT-SoVITS/b.py:19:16-> (%8, %h.3)return (%h)
可以看到 prim::Loop
,说明不再是固定参数的静态图了。