目前的状态:人有点晕。好多细节的东西都不知道怎么来的。大方向有所把握:
1. 准备数据集:看起来很简单,其实不然。如何把文件读取进来,变成pytorch所需要的数据类型。
图片:你就需要ToTensor,Normalize转换为需要的数据类型
文字:对init,getitem,len进行重写
准备dataset,构建data_loader并返回
2. 构建模型:重写init和forward方法。在forward里对每一层进行处理。包括矩阵变换,激活函数等去得到输出
3. 训练:基本就是循环里面梯度归零,调用,loss,反向传播,更新
data_loader = get_dataloader()for idx,(input,traget) in enumerate(data_loader):optimizer.zero_grad() # 梯度归零output = model(input) # 调用模型得到预测值loss = F.nll_loss(output,traget) # 得到损失loss.backward() # 反向传播optimizer.step() # 梯度更新
4. 测试:pass