读取和存储数据
我们可以使用pt文件存储Tensor数据:
import torch
from torch import nnx = torch.ones(3)
torch.save(x, 'x.pt')
这样我们就将数据存储在名为x.pt的文件中了
我们可以从文件中将该数据读入内存:
x2 = torch.load('x.pt')
print(x2)
还可以存储Tensor列表到文件中,并读取:
y = torch.zeros(4)
torch.save([x, y], "xy.pt")
xy_list = torch.load("xy.pt")
print(xy_list)
不仅如此,还可以存储一个键值为Tensor变量的字典:
torch.save({'x':x, 'y':y}, "xy_dict")
xy_dict = torch.load("xy_dict")
print(xy_dict)
对模型参数进行读写:
对于Module类的对象,我们可以使用model.parameters()函数来访问模型的参数。而state_dict函数将会返回一个模型的参数名称到参数Tensor对象的一个字典对象。
class my_module(mm.Module):def __init__(self):super(my_module, self)self.hidden = nn.Linear(3, 2)self.action = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):middle = self.action(self.hidden(x))return self.output(middle) net = my_module()
net.state_dict()
输出:
OrderedDict([('hidden.weight', tensor([[ 0.2448, 0.1856, -0.5678],[ 0.2030, -0.2073, -0.0104]])),('hidden.bias', tensor([-0.3117, -0.4232])),('output.weight', tensor([[-0.4556, 0.4084]])),('output.bias', tensor([-0.3573]))])
但是,只有具有可变参数(可学习参数)的网络层才会在state_dict中,
同样的,优化器(optim)也有一个state_dict,这个函数返回一个字典,该字典包含优化器的状态以及其超参数信息:
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()
输出:
{'param_groups': [{'dampening': 0,'lr': 0.001,'momentum': 0.9,'nesterov': False,'params': [4736167728, 4736166648, 4736167368, 4736165352],'weight_decay': 0}],'state': {}}
那么就可以通过保存模型的state_dict来保存模型:
torch.save(net.state_dict(), PATH)model = my_module(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
还可以直接保存整个模型:
torch.save(model, PATH)
model = torch.load(PATH)