一、matplotlib库
在我们自己训练模型时,常常会使用matplotlib库来绘制oss和accuracy的曲线图,帮助我们分析模型的训练表现。
matplotlib库安装:pip install matplotlib
二、代码
import matplotlib.pyplot as plt
import torch
import torch.optim as optim # 导入优化器模块#------------------------------------------------------------------#
# 定义损失函数
#------------------------------------------------------------------#
def loss_fn(y_true, y_pred):return torch.mean((y_true - y_pred)**2)#------------------------------------------------------------------#
# 定义模型
#------------------------------------------------------------------#
model = torch.nn.Linear(10, 1)#------------------------------------------------------------------#
# 定义训练,验证数据
#------------------------------------------------------------------#
x_train = torch.randn(1000, 10)
y_train = torch.randn(1000, 1)
x_val = torch.randn(1000, 10)
y_val = torch.randn(1000, 1)#------------------------------------------------------------------#
# 定义优化器
#------------------------------------------------------------------#
optimizer = optim.Adam(model.parameters(), lr=0.001) # 使用 Adam 优化器,学习率为 0.001#------------------------------------------------------------------#
# 定义损失函数
#------------------------------------------------------------------#
train_loss_list = []
val_loss_list = []#------------------------------------------------------------------#
# 开始训练
#------------------------------------------------------------------#
for epoch in range(10000):# ------------------------------------------------------------------## 训练# ------------------------------------------------------------------## ------------------------------------------------------------------## 前向传播# ------------------------------------------------------------------#y_pred = model(x_train)# ------------------------------------------------------------------## 计算损失# ------------------------------------------------------------------#training_loss = loss_fn(y_train, y_pred)train_loss_list.append(training_loss.item())# ------------------------------------------------------------------## 反向传播# ------------------------------------------------------------------#training_loss.backward()# ------------------------------------------------------------------## 更新参数# ------------------------------------------------------------------#optimizer.step()# ------------------------------------------------------------------## 展示训练损失# ------------------------------------------------------------------#if epoch % 10 == 0:print(f"epoch {epoch}:training loss {training_loss.item()}")# ------------------------------------------------------------------## 验证# ------------------------------------------------------------------## ------------------------------------------------------------------## 前向传播# ------------------------------------------------------------------#y_pred = model(x_val)# ------------------------------------------------------------------## 计算损失# ------------------------------------------------------------------#val_loss = loss_fn(y_val, y_pred)val_loss_list.append(val_loss.item())# ------------------------------------------------------------------## 展示验证损失# ------------------------------------------------------------------#if epoch % 10 == 0:print(f"epoch {epoch}:validate loss {val_loss.item()}")# ------------------------------------------------------------------#
# 记录训练,验证损失
# ------------------------------------------------------------------#
plt.plot(train_loss_list,color="red",label="training_loss")
plt.plot(val_loss_list,color="blue",label="val_loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(loc='lower right')
plt.show()
运行结果
查看