问题描述
使用如下的自定义的多层嵌套网络进行训练:
class FC1_bot(nn.Module):def __init__(self):super(FC1_bot, self).__init__()self.embeddings = nn.Sequential(nn.Linear(10, 10))def forward(self, x):emb = self.embeddings(x)return embclass FC1_top(nn.Module):def __init__(self):super(FC1_top, self).__init__()self.prediction = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(10, 10))def forward(self, x):logit = self.prediction(x)return logitclass FC1(nn.Module):def __init__(self, num):super(FC1, self).__init__()self.num = numself.bot = []for _ in range(num):self.bot.append(FC1_bot())self.top = FC1_top()self.softmax = nn.Softmax(dim=1)def forward(self, x):x = list(x)emb = []for i in range(self.num):emb.append(self.bot[i](x[i]))agg_emb = self._aggregate(emb)logit = self.top(agg_emb)pred = self.softmax(logit)return emb, preddef _aggregate(self, x):# Note: x is a list of tensors.return torch.cat(x, dim=1)
训练的代码如下:
num = 4
model = FC1(num)
optimizer_entire = torch.optim.SGD(model.parameters(), lr=0.01)def train(self):# train entire modelself.model.train()for epoch in range(self.args.epochs):pred = self.model(data)loss = torch.nn.CrossEntropyLoss(pred, labels)# zero grad for all optimizersoptimizer_entire.zero_grad()loss.backward()# update parameters for all optimizersoptimizer_entire.step()
解决办法
需要给所有用到的模型参数都设置optimizer,否则只有top部分的参数在训练,底层的会得到gradient,但parameter不会更新。
num = 4
model = FC1(num)
optimizer_entire = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer_top = torch.optim.SGD(model.top.parameters(), lr=0.01)
optimizer_bot = []
for i in range(num):optimizer_passive.append(torch.optim.SGD(model.passive[i].parameters(), lr=0.01))def train(self):# train entire modelself.model.train()self.model.top.train()for i in range(self.args.num):self.model.bot[i].train()for epoch in range(self.args.epochs):pred = self.model(data)loss = torch.nn.CrossEntropyLoss(pred, labels)# zero grad for all optimizersoptimizer_entire.zero_grad()optimizer_top.zero_grad()for i in range(num):optimizer_bot[i].zero_grad()loss.backward()# update parameters for all optimizersoptimizer_entire.step()optimizer_top.step()for i in range(num):optimizer_bot[i].step()