1.用参数字典 model.state_dict()更新最优参数
best_state_dict = model.state_dict() # 训练前
best_state_dict = model.state_dict() # 训练时更新最优state_dict
完整代码:
# 初始化一个变量来保存最优的state_dictbest_state_dict = model.state_dict()for epoch in range(epochs):model.train()# 训练集上训练模型权重for data, targets in tqdm.tqdm(train_dataloader):# 把数据加载到GPU上data = data.to(devices[0])targets = targets.to(devices[0])# 前向传播preds = model(data)loss = criterion(preds, targets)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 测试集上评估模型性能model.eval()num_correct = 0num_samples = 0with torch.no_grad():for x, y in tqdm.tqdm(test_dataloader):x = x.to(devices[0])y = y.to(devices[0])preds = model(x)predictions = preds.max(1).indices # 返回每一行的最大值和该最大值在该行的列索引num_correct += (predictions == y).sum()num_samples += predictions.size(0)acc = (num_correct / num_samples).item()if acc > best_acc:best_acc = accbest_epoch = epoch+1# 保存模型最优准确率的参数best_state_dict = model.state_dict() # 更新最优state_dictmodel.train()# 训练结束保存torch.save(best_state_dict, f"weights/{model_name}_{epochs}_{best_acc}.pth")
2.训练过程中保存最优参数
if acc > best_acc:best_acc = accbest_epoch = epoch+1torch.save(best_state_dict, f"weights/{model_name}_{epochs}_{best_acc}.pth")
3.对模型深拷贝方法保存最优模型
深拷贝方法介绍
copy模块可以用来创建一个对象的深拷贝。这意味着复制后的模型和原始模型是完全独立的,包括它们的参数。
import torch
import copy
import torch.nn as nn # 假设我们有一个模型实例
original_model = nn.Sequential( nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 2)
) # 复制模型
model_copy = copy.deepcopy(original_model)
深拷贝方法保存最优模型
best_model = copy.deepcopy(model.state_dict()) # 训练前
best_model = copy.deepcopy(model.state_dict()) # 训练时更新最优state_dict
代码案例:
def fit_zsl(self):best_acc = 0mean_loss = 0last_loss_epoch = 1e8# 定义best_modelbest_model = copy.deepcopy(self.model.state_dict())for epoch in range(self.nepoch):for i in range(0, self.ntrain, self.batch_size):self.model.zero_grad()batch_input, batch_label = self.next_batch(self.batch_size)self.input.copy_(batch_input)self.label.copy_(batch_label)inputv = Variable(self.input)labelv = Variable(self.label)output = self.model(inputv)loss = self.criterion(output, labelv)mean_loss += loss.item()loss.backward()self.optimizer.step()acc = self.val(self.test_unseen_feature,self.test_unseen_label,self.unseenclasses,)if acc > best_acc:best_acc = acc# 更新best_modelbest_model = copy.deepcopy(self.model.state_dict())#训练完毕本地保存torch.save(best_model.state_dict(), f"weights/{self.nepoch}_{best_acc}.pth")return best_acc, best_model