这里写目录标题 pytorch保存模型的两种方式 开始:定义一个简单的模型 方式一:保存和加载模型的状态字典 方式二:2保存和加载整个模型
pytorch保存模型的两种方式
开始:定义一个简单的模型
import torch
import torch. nn as nn
import torch. optim as optim
class SimpleModel ( nn. Module) : def __init__ ( self) : super ( SimpleModel, self) . __init__( ) self. fc = nn. Linear( 10 , 1 ) def forward ( self, x) : return self. fc( x)
model = SimpleModel( )
criterion = nn. MSELoss( )
optimizer = optim. SGD( model. parameters( ) , lr= 0.001 )
for epoch in range ( 10 ) : inputs = torch. randn( 1 , 10 ) target = torch. randn( 1 , 1 ) output = model( inputs) loss = criterion( output, target) optimizer. zero_grad( ) loss. backward( ) optimizer. step( ) print ( f'Epoch [ { epoch+ 1 } /10], Loss: { loss. item( ) : .4f } ' )
方式一:保存和加载模型的状态字典
torch. save( model. state_dict( ) , 'model_state_dict1.pth' )
model = SimpleModel( )
model. load_state_dict( torch. load( 'model_state_dict1.pth' ) )
model. eval ( )
方式二:2保存和加载整个模型
torch. save( model, 'entire_model2.pth' )
model = torch. load( 'entire_model2.pth' )
model. eval ( )