1.准备好网络模型代码
import torch
import torch. nn as nn
import torch. optim as optim
class BP_36 ( nn. Module) : def __init__ ( self) : super ( BP_36, self) . __init__( ) self. fc1 = nn. Linear( 2 , 36 ) self. fc2 = nn. Linear( 36 , 25 ) def forward ( self, x) : x = torch. relu( self. fc1( x) ) x = self. fc2( x) return x
class BP_64 ( nn. Module) : def __init__ ( self) : super ( BP_64, self) . __init__( ) self. fc1 = nn. Linear( 2 , 64 ) self. fc2 = nn. Linear( 64 , 25 ) def forward ( self, x) : x = torch. relu( self. fc1( x) ) x = self. fc2( x) return x
class Bi_LSTM ( nn. Module) : def __init__ ( self) : super ( Bi_LSTM, self) . __init__( ) self. lstm = nn. LSTM( input_size= 2 , hidden_size= 36 , bidirectional= True , batch_first= True ) self. fc1 = nn. Linear( 72 , 25 ) def forward ( self, x) : x, _ = self. lstm( x) x = self. fc1( x) return x
class Bi_GRU ( nn. Module) : def __init__ ( self) : super ( Bi_GRU, self) . __init__( ) self. gru = nn. GRU( input_size= 2 , hidden_size= 36 , bidirectional= True , batch_first= True ) self. fc1 = nn. Linear( 72 , 25 ) def forward ( self, x) : x, _ = self. gru( x) x = self. fc1( x) return x
2.运行计算参数量和复杂度的脚本
import torch
from net import Bi_GRUfrom ptflops import get_model_complexity_info
device = torch. device( "cuda:0" if torch. cuda. is_available( ) else "cpu" )
model_transformer = Bi_GRU( )
model_transformer. to( device)
flops_transformer, params_transformer = get_model_complexity_info( model_transformer, ( 256 , 2 ) , as_strings= True , print_per_layer_stat= False )
print ( '模型参数量:' + params_transformer)
print ( '模型计算复杂度:' + flops_transformer)