>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**
任务:
●阅读ResNeXt论文,了解作者的构建思路
●对比我们之前介绍的ResNet50V2、DenseNet算法
●使用ResNeXt-50算法完成猴痘病识别
🏡 我的环境:
- 语言环境:Python3.8
- 编译器:Jupyter Notebook
- 深度学习环境:Pytorch
-
- torch==2.3.1+cu118
-
- torchvision==0.18.1+cu118
本文完全根据 第J6周:ResNeXt-50实战解析(TensorFlow版)中的内容转换为pytorch版本,所以前述性的内容不在一一重复,仅就pytorch版本中的内容进行叙述。
一、 前期准备
1. 设置GPU
如果设备上支持GPU就使用GPU,否则使用CPU
import warnings
warnings.filterwarnings("ignore")import torch
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
运行结果:
device(type='cuda')
2. 导入数据
同时查看数据集中图片的数量
import pathlibdata_dir=r'D:\THE MNIST DATABASE\P4-data'
data_dir=pathlib.Path(data_dir)image_count=len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
运行结果:
图片总数为: 2142
3. 查看数据集分类
data_paths=list(data_dir.glob('*'))
classNames=[str(path).split("\\")[3] for path in data_paths]
classNames
运行结果:
['Monkeypox', 'Others']
4. 随机查看图片
随机抽取数据集中的20张图片进行查看
import PIL,random
import matplotlib.pyplot as plt
from PIL import Imageplt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号data_paths2=list(data_dir.glob('*/*'))
plt.figure(figsize=(10,4))
for i in range(10):plt.subplot(2,5,i+1)plt.axis("off")image=random.choice(data_paths2) #随机选择一个图片plt.title(image.parts[-2]) #通过glob对象取出他的文件夹名称,即分类名plt.imshow(Image.open(str(image))) #显示图片
运行结果:
5. 图片预处理
import torchvision.transforms as transforms
from torchvision import transforms,datasetstrain_transforms=transforms.Compose([transforms.Resize([224,224]), #将图片统一尺寸transforms.RandomHorizontalFlip(), #将图片随机水平翻转transforms.RandomRotation(0.2), #将图片按照0.2的弧度值随机翻转transforms.ToTensor(), #将图片转换为tensortransforms.Normalize( #标准化处理-->转换为正态分布,使模型更容易收敛mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])total_data=datasets.ImageFolder(r'D:\THE MNIST DATABASE\P4-data',transform=train_transforms
)
total_data
运行结果:
Dataset ImageFolderNumber of datapoints: 2142Root location: D:\THE MNIST DATABASE\P4-dataStandardTransform
Transform: Compose(Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)RandomHorizontalFlip(p=0.5)RandomRotation(degrees=[-0.2, 0.2], interpolation=nearest, expand=False, fill=0)ToTensor()Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
将数据集分类情况进行映射输出:
total_data.class_to_idx
运行结果:
{'Monkeypox': 0, 'Others': 1}
6. 划分数据集
train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_sizetrain_dataset,test_dataset=torch.utils.data.random_split(total_data,[train_size,test_size]
)
train_dataset,test_dataset
运行结果:
(<torch.utils.data.dataset.Subset at 0x207565a54d0>,<torch.utils.data.dataset.Subset at 0x2075514cf90>)
查看训练集和测试集的数据数量:
train_size,test_size
运行结果:
(1713, 429)
7. 加载数据集
batch_size=16
train_dl=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=1
)
test_dl=torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1
)
查看测试集的情况:
for x,y in train_dl:print("Shape of x [N,C,H,W]:",x.shape)print("Shape of y:",y.shape,y.dtype)break
运行结果:
Shape of x [N,C,H,W]: torch.Size([16, 3, 224, 224])
Shape of y: torch.Size([16]) torch.int64
二、搭建模型
1. 创建卷积块
import torch.nn as nn
import torch.nn.functional as F
class BN_Conv2d(nn.Module):"""BN_CONV_RELU"""def __init__(self,in_channels,out_channels,kernel_size,stride,padding,dilation=1,groups=1,bias=False):super(BN_Conv2d,self).__init__()self.seq=nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias),nn.BatchNorm2d(out_channels))def forward(self,x):return F.relu(self.seq(x))
2. 创建block
class ResNeXt_Block(nn.Module):"""ResNeXt block with group convolutions"""def __init__(self,in_channnls,cardinality,group_depth,stride):super(ResNeXt_Block,self).__init__()self.group_channels=cardinality*group_depthself.conv1=BN_Conv2d(in_channnls,self.group_channels,1,stride=1,padding=0)self.conv2=BN_Conv2d(self.group_channels,self.group_channels,3,stride=stride,padding=1,groups=cardinality)self.conv3=nn.Conv2d(self.group_channels,self.group_channels*2,1,stride=1,padding=0)self.bn=nn.BatchNorm2d(self.group_channels*2)self.short_cut=nn.Sequential(nn.Conv2d(in_channnls,self.group_channels*2,1,stride,0,bias=False),nn.BatchNorm2d(self.group_channels*2))def forward(self,x):out=self.conv1(x)out=self.conv2(out)out=self.bn(self.conv3(out))out+=self.short_cut(x)return F.relu(out)
3. 搭建ResNeXt 模型
class ResNeXt(nn.Module):"""ResNeXt builder"""def __init__(self,layers:object,cardinality,group_depth,num_classes):super(ResNeXt,self).__init__()self.cardinality=cardinalityself.channels=64self.conv1=BN_Conv2d(3,self.channels,7,stride=2,padding=3)d1=group_depthself.conv2=self.__make_layers(d1,layers[0],stride=1)d2=d1*2self.conv3=self.__make_layers(d2,layers[1],stride=2)d3=d2*2self.conv4=self.__make_layers(d3,layers[2],stride=2)d4=d3*2self.conv5=self.__make_layers(d4,layers[3],stride=2)self.fc=nn.Linear(self.channels,num_classes) #224*224 input sizedef __make_layers(self,d,blocks,stride):strides=[stride]+[1]*(blocks-1)layers=[]for stride in strides:layers.append(ResNeXt_Block(self.channels,self.cardinality,d,stride))self.channels=self.cardinality*d*2return nn.Sequential(*layers)def forward(self,x):out=self.conv1(x)out=F.max_pool2d(out,3,2,1)out=self.conv2(out)out=self.conv3(out)out=self.conv4(out)out=self.conv5(out)out=F.avg_pool2d(out,7)out=out.view(out.size(0),-1)out=F.softmax(self.fc(out),dim=1)return out
4. 查看 ResNeXt-50 模型的参数
model=ResNeXt([3,4,6,3],32,4,4)
model.to(device)#统计模型参数量以及其他指标
import torchsummary as summary
summary.summary(model,(3,224,224))
运行结果:
----------------------------------------------------------------Layer (type) Output Shape Param #
================================================================Conv2d-1 [-1, 64, 112, 112] 9,408BatchNorm2d-2 [-1, 64, 112, 112] 128BN_Conv2d-3 [-1, 64, 112, 112] 0Conv2d-4 [-1, 128, 56, 56] 8,192BatchNorm2d-5 [-1, 128, 56, 56] 256BN_Conv2d-6 [-1, 128, 56, 56] 0Conv2d-7 [-1, 128, 56, 56] 4,608BatchNorm2d-8 [-1, 128, 56, 56] 256BN_Conv2d-9 [-1, 128, 56, 56] 0Conv2d-10 [-1, 256, 56, 56] 33,024BatchNorm2d-11 [-1, 256, 56, 56] 512Conv2d-12 [-1, 256, 56, 56] 16,384BatchNorm2d-13 [-1, 256, 56, 56] 512ResNeXt_Block-14 [-1, 256, 56, 56] 0Conv2d-15 [-1, 128, 56, 56] 32,768BatchNorm2d-16 [-1, 128, 56, 56] 256BN_Conv2d-17 [-1, 128, 56, 56] 0Conv2d-18 [-1, 128, 56, 56] 4,608BatchNorm2d-19 [-1, 128, 56, 56] 256BN_Conv2d-20 [-1, 128, 56, 56] 0Conv2d-21 [-1, 256, 56, 56] 33,024BatchNorm2d-22 [-1, 256, 56, 56] 512Conv2d-23 [-1, 256, 56, 56] 65,536BatchNorm2d-24 [-1, 256, 56, 56] 512ResNeXt_Block-25 [-1, 256, 56, 56] 0Conv2d-26 [-1, 128, 56, 56] 32,768BatchNorm2d-27 [-1, 128, 56, 56] 256BN_Conv2d-28 [-1, 128, 56, 56] 0Conv2d-29 [-1, 128, 56, 56] 4,608BatchNorm2d-30 [-1, 128, 56, 56] 256BN_Conv2d-31 [-1, 128, 56, 56] 0Conv2d-32 [-1, 256, 56, 56] 33,024BatchNorm2d-33 [-1, 256, 56, 56] 512Conv2d-34 [-1, 256, 56, 56] 65,536BatchNorm2d-35 [-1, 256, 56, 56] 512ResNeXt_Block-36 [-1, 256, 56, 56] 0Conv2d-37 [-1, 256, 56, 56] 65,536BatchNorm2d-38 [-1, 256, 56, 56] 512BN_Conv2d-39 [-1, 256, 56, 56] 0Conv2d-40 [-1, 256, 28, 28] 18,432BatchNorm2d-41 [-1, 256, 28, 28] 512BN_Conv2d-42 [-1, 256, 28, 28] 0Conv2d-43 [-1, 512, 28, 28] 131,584BatchNorm2d-44 [-1, 512, 28, 28] 1,024Conv2d-45 [-1, 512, 28, 28] 131,072BatchNorm2d-46 [-1, 512, 28, 28] 1,024ResNeXt_Block-47 [-1, 512, 28, 28] 0Conv2d-48 [-1, 256, 28, 28] 131,072BatchNorm2d-49 [-1, 256, 28, 28] 512BN_Conv2d-50 [-1, 256, 28, 28] 0Conv2d-51 [-1, 256, 28, 28] 18,432BatchNorm2d-52 [-1, 256, 28, 28] 512BN_Conv2d-53 [-1, 256, 28, 28] 0Conv2d-54 [-1, 512, 28, 28] 131,584BatchNorm2d-55 [-1, 512, 28, 28] 1,024Conv2d-56 [-1, 512, 28, 28] 262,144BatchNorm2d-57 [-1, 512, 28, 28] 1,024ResNeXt_Block-58 [-1, 512, 28, 28] 0Conv2d-59 [-1, 256, 28, 28] 131,072BatchNorm2d-60 [-1, 256, 28, 28] 512BN_Conv2d-61 [-1, 256, 28, 28] 0Conv2d-62 [-1, 256, 28, 28] 18,432BatchNorm2d-63 [-1, 256, 28, 28] 512BN_Conv2d-64 [-1, 256, 28, 28] 0Conv2d-65 [-1, 512, 28, 28] 131,584BatchNorm2d-66 [-1, 512, 28, 28] 1,024Conv2d-67 [-1, 512, 28, 28] 262,144BatchNorm2d-68 [-1, 512, 28, 28] 1,024ResNeXt_Block-69 [-1, 512, 28, 28] 0Conv2d-70 [-1, 256, 28, 28] 131,072BatchNorm2d-71 [-1, 256, 28, 28] 512BN_Conv2d-72 [-1, 256, 28, 28] 0Conv2d-73 [-1, 256, 28, 28] 18,432BatchNorm2d-74 [-1, 256, 28, 28] 512BN_Conv2d-75 [-1, 256, 28, 28] 0Conv2d-76 [-1, 512, 28, 28] 131,584BatchNorm2d-77 [-1, 512, 28, 28] 1,024Conv2d-78 [-1, 512, 28, 28] 262,144BatchNorm2d-79 [-1, 512, 28, 28] 1,024ResNeXt_Block-80 [-1, 512, 28, 28] 0Conv2d-81 [-1, 512, 28, 28] 262,144BatchNorm2d-82 [-1, 512, 28, 28] 1,024BN_Conv2d-83 [-1, 512, 28, 28] 0Conv2d-84 [-1, 512, 14, 14] 73,728BatchNorm2d-85 [-1, 512, 14, 14] 1,024BN_Conv2d-86 [-1, 512, 14, 14] 0Conv2d-87 [-1, 1024, 14, 14] 525,312BatchNorm2d-88 [-1, 1024, 14, 14] 2,048Conv2d-89 [-1, 1024, 14, 14] 524,288BatchNorm2d-90 [-1, 1024, 14, 14] 2,048ResNeXt_Block-91 [-1, 1024, 14, 14] 0Conv2d-92 [-1, 512, 14, 14] 524,288BatchNorm2d-93 [-1, 512, 14, 14] 1,024BN_Conv2d-94 [-1, 512, 14, 14] 0Conv2d-95 [-1, 512, 14, 14] 73,728BatchNorm2d-96 [-1, 512, 14, 14] 1,024BN_Conv2d-97 [-1, 512, 14, 14] 0Conv2d-98 [-1, 1024, 14, 14] 525,312BatchNorm2d-99 [-1, 1024, 14, 14] 2,048Conv2d-100 [-1, 1024, 14, 14] 1,048,576BatchNorm2d-101 [-1, 1024, 14, 14] 2,048ResNeXt_Block-102 [-1, 1024, 14, 14] 0Conv2d-103 [-1, 512, 14, 14] 524,288BatchNorm2d-104 [-1, 512, 14, 14] 1,024BN_Conv2d-105 [-1, 512, 14, 14] 0Conv2d-106 [-1, 512, 14, 14] 73,728BatchNorm2d-107 [-1, 512, 14, 14] 1,024BN_Conv2d-108 [-1, 512, 14, 14] 0Conv2d-109 [-1, 1024, 14, 14] 525,312BatchNorm2d-110 [-1, 1024, 14, 14] 2,048Conv2d-111 [-1, 1024, 14, 14] 1,048,576BatchNorm2d-112 [-1, 1024, 14, 14] 2,048ResNeXt_Block-113 [-1, 1024, 14, 14] 0Conv2d-114 [-1, 512, 14, 14] 524,288BatchNorm2d-115 [-1, 512, 14, 14] 1,024BN_Conv2d-116 [-1, 512, 14, 14] 0Conv2d-117 [-1, 512, 14, 14] 73,728BatchNorm2d-118 [-1, 512, 14, 14] 1,024BN_Conv2d-119 [-1, 512, 14, 14] 0Conv2d-120 [-1, 1024, 14, 14] 525,312BatchNorm2d-121 [-1, 1024, 14, 14] 2,048Conv2d-122 [-1, 1024, 14, 14] 1,048,576BatchNorm2d-123 [-1, 1024, 14, 14] 2,048ResNeXt_Block-124 [-1, 1024, 14, 14] 0Conv2d-125 [-1, 512, 14, 14] 524,288BatchNorm2d-126 [-1, 512, 14, 14] 1,024BN_Conv2d-127 [-1, 512, 14, 14] 0Conv2d-128 [-1, 512, 14, 14] 73,728BatchNorm2d-129 [-1, 512, 14, 14] 1,024BN_Conv2d-130 [-1, 512, 14, 14] 0Conv2d-131 [-1, 1024, 14, 14] 525,312BatchNorm2d-132 [-1, 1024, 14, 14] 2,048Conv2d-133 [-1, 1024, 14, 14] 1,048,576BatchNorm2d-134 [-1, 1024, 14, 14] 2,048ResNeXt_Block-135 [-1, 1024, 14, 14] 0Conv2d-136 [-1, 512, 14, 14] 524,288BatchNorm2d-137 [-1, 512, 14, 14] 1,024BN_Conv2d-138 [-1, 512, 14, 14] 0Conv2d-139 [-1, 512, 14, 14] 73,728BatchNorm2d-140 [-1, 512, 14, 14] 1,024BN_Conv2d-141 [-1, 512, 14, 14] 0Conv2d-142 [-1, 1024, 14, 14] 525,312BatchNorm2d-143 [-1, 1024, 14, 14] 2,048Conv2d-144 [-1, 1024, 14, 14] 1,048,576BatchNorm2d-145 [-1, 1024, 14, 14] 2,048ResNeXt_Block-146 [-1, 1024, 14, 14] 0Conv2d-147 [-1, 1024, 14, 14] 1,048,576BatchNorm2d-148 [-1, 1024, 14, 14] 2,048BN_Conv2d-149 [-1, 1024, 14, 14] 0Conv2d-150 [-1, 1024, 7, 7] 294,912BatchNorm2d-151 [-1, 1024, 7, 7] 2,048BN_Conv2d-152 [-1, 1024, 7, 7] 0Conv2d-153 [-1, 2048, 7, 7] 2,099,200BatchNorm2d-154 [-1, 2048, 7, 7] 4,096Conv2d-155 [-1, 2048, 7, 7] 2,097,152BatchNorm2d-156 [-1, 2048, 7, 7] 4,096ResNeXt_Block-157 [-1, 2048, 7, 7] 0Conv2d-158 [-1, 1024, 7, 7] 2,097,152BatchNorm2d-159 [-1, 1024, 7, 7] 2,048BN_Conv2d-160 [-1, 1024, 7, 7] 0Conv2d-161 [-1, 1024, 7, 7] 294,912BatchNorm2d-162 [-1, 1024, 7, 7] 2,048BN_Conv2d-163 [-1, 1024, 7, 7] 0Conv2d-164 [-1, 2048, 7, 7] 2,099,200BatchNorm2d-165 [-1, 2048, 7, 7] 4,096Conv2d-166 [-1, 2048, 7, 7] 4,194,304BatchNorm2d-167 [-1, 2048, 7, 7] 4,096ResNeXt_Block-168 [-1, 2048, 7, 7] 0Conv2d-169 [-1, 1024, 7, 7] 2,097,152BatchNorm2d-170 [-1, 1024, 7, 7] 2,048BN_Conv2d-171 [-1, 1024, 7, 7] 0Conv2d-172 [-1, 1024, 7, 7] 294,912BatchNorm2d-173 [-1, 1024, 7, 7] 2,048BN_Conv2d-174 [-1, 1024, 7, 7] 0Conv2d-175 [-1, 2048, 7, 7] 2,099,200BatchNorm2d-176 [-1, 2048, 7, 7] 4,096Conv2d-177 [-1, 2048, 7, 7] 4,194,304BatchNorm2d-178 [-1, 2048, 7, 7] 4,096ResNeXt_Block-179 [-1, 2048, 7, 7] 0Linear-180 [-1, 4] 8,196
================================================================
Total params: 37,574,724
Trainable params: 37,574,724
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 379.37
Params size (MB): 143.34
Estimated Total Size (MB): 523.28
----------------------------------------------------------------
三、 训练模型
1. 编写训练函数
#训练循环
def train(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset) #训练集的大小num_batches=len(dataloader) #批次数目,(size/batch_size,向上取整)train_loss,train_acc=0,0 #初始化训练损失和正确率for x,y in dataloader: #获取图片及其标签x,y=x.to(device),y.to(device)#计算预测误差pred=model(x) #网络输出loss=loss_fn(pred,y) #计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失#反向传播optimizer.zero_grad() #grad属性归零loss.backward() #反向传播optimizer.step() #每一步自动更新#记录acc与losstrain_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()train_loss+=loss.item()train_acc/=sizetrain_loss/=num_batchesreturn train_acc,train_loss
2. 编写测试函数
测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
def test(dataloader,model,loss_fn):size=len(dataloader.dataset) #测试集的大小num_batches=len(dataloader) #批次数目test_loss,test_acc=0,0#当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs,target in dataloader:imgs,target=imgs.to(device),target.to(device)#计算losstarget_pred=model(imgs)loss=loss_fn(target_pred,target)test_loss+=loss.item()test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()test_acc/=sizetest_loss/=num_batchesreturn test_acc,test_loss
3. 正式训练
import copy
optimizer=torch.optim.Adam(model.parameters(),lr=1e-4) #创建优化器,并设置学习率
loss_fn=nn.CrossEntropyLoss() #创建损失函数 epochs=100train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]best_acc=0 #设置一个最佳准确率,作为最佳模型的判别指标for epoch in range(epochs):model.train()epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)model.eval()epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)#保存最佳模型到J6_modelif epoch_test_acc>best_acc:best_acc=epoch_test_accJ6_model=copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)#获取当前学习率lr=optimizer.state_dict()['param_groups'][0]['lr']template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))#保存最佳模型到文件中
PATH=r'D:\THE MNIST DATABASE\J-series\J6_model.pth'
torch.save(model.state_dict(),PATH)print('Done')
运行结果:
Epoch: 1,Train_acc:57.0%,Train_loss:1.159,Test_acc:59.0%,Test_loss:1.152,Lr:1.00E-04
Epoch: 2,Train_acc:59.7%,Train_loss:1.133,Test_acc:64.6%,Test_loss:1.089,Lr:1.00E-04
Epoch: 3,Train_acc:64.0%,Train_loss:1.097,Test_acc:62.0%,Test_loss:1.117,Lr:1.00E-04
Epoch: 4,Train_acc:63.9%,Train_loss:1.095,Test_acc:63.9%,Test_loss:1.096,Lr:1.00E-04
Epoch: 5,Train_acc:64.2%,Train_loss:1.100,Test_acc:68.1%,Test_loss:1.067,Lr:1.00E-04
Epoch: 6,Train_acc:64.6%,Train_loss:1.094,Test_acc:61.5%,Test_loss:1.132,Lr:1.00E-04
Epoch: 7,Train_acc:65.5%,Train_loss:1.077,Test_acc:70.4%,Test_loss:1.032,Lr:1.00E-04
Epoch: 8,Train_acc:65.2%,Train_loss:1.088,Test_acc:66.4%,Test_loss:1.072,Lr:1.00E-04
Epoch: 9,Train_acc:67.2%,Train_loss:1.064,Test_acc:74.4%,Test_loss:1.008,Lr:1.00E-04
Epoch:10,Train_acc:66.1%,Train_loss:1.080,Test_acc:68.5%,Test_loss:1.052,Lr:1.00E-04
Epoch:11,Train_acc:65.6%,Train_loss:1.078,Test_acc:69.9%,Test_loss:1.040,Lr:1.00E-04
Epoch:12,Train_acc:66.9%,Train_loss:1.062,Test_acc:76.2%,Test_loss:0.982,Lr:1.00E-04
Epoch:13,Train_acc:65.9%,Train_loss:1.077,Test_acc:74.1%,Test_loss:1.002,Lr:1.00E-04
Epoch:14,Train_acc:65.5%,Train_loss:1.084,Test_acc:59.4%,Test_loss:1.144,Lr:1.00E-04
Epoch:15,Train_acc:62.5%,Train_loss:1.113,Test_acc:56.9%,Test_loss:1.171,Lr:1.00E-04
Epoch:16,Train_acc:66.5%,Train_loss:1.069,Test_acc:67.4%,Test_loss:1.065,Lr:1.00E-04
Epoch:17,Train_acc:68.0%,Train_loss:1.054,Test_acc:73.9%,Test_loss:1.005,Lr:1.00E-04
Epoch:18,Train_acc:67.5%,Train_loss:1.052,Test_acc:73.9%,Test_loss:0.989,Lr:1.00E-04
Epoch:19,Train_acc:68.6%,Train_loss:1.048,Test_acc:67.8%,Test_loss:1.049,Lr:1.00E-04
Epoch:20,Train_acc:70.0%,Train_loss:1.035,Test_acc:70.2%,Test_loss:1.033,Lr:1.00E-04
Epoch:21,Train_acc:70.6%,Train_loss:1.040,Test_acc:62.9%,Test_loss:1.107,Lr:1.00E-04
Epoch:22,Train_acc:71.0%,Train_loss:1.023,Test_acc:71.3%,Test_loss:1.036,Lr:1.00E-04
Epoch:23,Train_acc:72.5%,Train_loss:1.014,Test_acc:76.0%,Test_loss:0.981,Lr:1.00E-04
Epoch:24,Train_acc:70.9%,Train_loss:1.035,Test_acc:75.3%,Test_loss:0.993,Lr:1.00E-04
Epoch:25,Train_acc:72.5%,Train_loss:1.012,Test_acc:76.7%,Test_loss:0.974,Lr:1.00E-04
Epoch:26,Train_acc:70.8%,Train_loss:1.028,Test_acc:72.7%,Test_loss:1.004,Lr:1.00E-04
Epoch:27,Train_acc:72.7%,Train_loss:1.009,Test_acc:73.2%,Test_loss:1.011,Lr:1.00E-04
Epoch:28,Train_acc:73.8%,Train_loss:1.006,Test_acc:75.3%,Test_loss:0.991,Lr:1.00E-04
Epoch:29,Train_acc:74.5%,Train_loss:0.992,Test_acc:74.6%,Test_loss:0.986,Lr:1.00E-04
Epoch:30,Train_acc:73.3%,Train_loss:1.005,Test_acc:73.2%,Test_loss:1.004,Lr:1.00E-04
Epoch:31,Train_acc:75.7%,Train_loss:0.993,Test_acc:77.4%,Test_loss:0.968,Lr:1.00E-04
Epoch:32,Train_acc:74.6%,Train_loss:0.989,Test_acc:72.3%,Test_loss:1.016,Lr:1.00E-04
Epoch:33,Train_acc:76.6%,Train_loss:0.973,Test_acc:70.2%,Test_loss:1.042,Lr:1.00E-04
Epoch:34,Train_acc:75.2%,Train_loss:0.982,Test_acc:74.6%,Test_loss:0.992,Lr:1.00E-04
Epoch:35,Train_acc:71.5%,Train_loss:1.018,Test_acc:77.6%,Test_loss:0.977,Lr:1.00E-04
Epoch:36,Train_acc:74.4%,Train_loss:1.006,Test_acc:76.7%,Test_loss:0.973,Lr:1.00E-04
Epoch:37,Train_acc:72.0%,Train_loss:1.012,Test_acc:76.9%,Test_loss:0.978,Lr:1.00E-04
Epoch:38,Train_acc:71.5%,Train_loss:1.030,Test_acc:72.7%,Test_loss:1.017,Lr:1.00E-04
Epoch:39,Train_acc:75.1%,Train_loss:0.987,Test_acc:76.5%,Test_loss:0.979,Lr:1.00E-04
Epoch:40,Train_acc:75.4%,Train_loss:0.989,Test_acc:75.8%,Test_loss:0.979,Lr:1.00E-04
Epoch:41,Train_acc:78.1%,Train_loss:0.968,Test_acc:77.9%,Test_loss:0.963,Lr:1.00E-04
Epoch:42,Train_acc:77.2%,Train_loss:0.977,Test_acc:74.4%,Test_loss:0.987,Lr:1.00E-04
Epoch:43,Train_acc:77.9%,Train_loss:0.968,Test_acc:73.7%,Test_loss:0.994,Lr:1.00E-04
Epoch:44,Train_acc:79.1%,Train_loss:0.954,Test_acc:78.8%,Test_loss:0.953,Lr:1.00E-04
Epoch:45,Train_acc:79.6%,Train_loss:0.950,Test_acc:79.3%,Test_loss:0.949,Lr:1.00E-04
Epoch:46,Train_acc:80.2%,Train_loss:0.938,Test_acc:79.0%,Test_loss:0.948,Lr:1.00E-04
Epoch:47,Train_acc:80.6%,Train_loss:0.943,Test_acc:78.3%,Test_loss:0.962,Lr:1.00E-04
Epoch:48,Train_acc:75.9%,Train_loss:0.982,Test_acc:73.0%,Test_loss:1.013,Lr:1.00E-04
Epoch:49,Train_acc:77.3%,Train_loss:0.966,Test_acc:76.2%,Test_loss:0.977,Lr:1.00E-04
Epoch:50,Train_acc:79.9%,Train_loss:0.947,Test_acc:74.4%,Test_loss:0.991,Lr:1.00E-04
Epoch:51,Train_acc:80.4%,Train_loss:0.944,Test_acc:75.1%,Test_loss:0.986,Lr:1.00E-04
Epoch:52,Train_acc:79.2%,Train_loss:0.953,Test_acc:77.2%,Test_loss:0.970,Lr:1.00E-04
Epoch:53,Train_acc:80.0%,Train_loss:0.939,Test_acc:78.8%,Test_loss:0.951,Lr:1.00E-04
Epoch:54,Train_acc:79.0%,Train_loss:0.954,Test_acc:80.2%,Test_loss:0.944,Lr:1.00E-04
Epoch:55,Train_acc:82.7%,Train_loss:0.923,Test_acc:79.0%,Test_loss:0.945,Lr:1.00E-04
Epoch:56,Train_acc:81.9%,Train_loss:0.926,Test_acc:80.0%,Test_loss:0.939,Lr:1.00E-04
Epoch:57,Train_acc:82.8%,Train_loss:0.915,Test_acc:76.2%,Test_loss:0.973,Lr:1.00E-04
Epoch:58,Train_acc:81.7%,Train_loss:0.926,Test_acc:82.8%,Test_loss:0.918,Lr:1.00E-04
Epoch:59,Train_acc:83.2%,Train_loss:0.918,Test_acc:81.4%,Test_loss:0.931,Lr:1.00E-04
Epoch:60,Train_acc:82.5%,Train_loss:0.916,Test_acc:81.4%,Test_loss:0.926,Lr:1.00E-04
Epoch:61,Train_acc:79.6%,Train_loss:0.950,Test_acc:78.8%,Test_loss:0.946,Lr:1.00E-04
Epoch:62,Train_acc:83.4%,Train_loss:0.914,Test_acc:80.2%,Test_loss:0.940,Lr:1.00E-04
Epoch:63,Train_acc:86.0%,Train_loss:0.893,Test_acc:80.2%,Test_loss:0.940,Lr:1.00E-04
Epoch:64,Train_acc:84.1%,Train_loss:0.899,Test_acc:80.9%,Test_loss:0.921,Lr:1.00E-04
Epoch:65,Train_acc:84.2%,Train_loss:0.905,Test_acc:82.1%,Test_loss:0.917,Lr:1.00E-04
Epoch:66,Train_acc:85.5%,Train_loss:0.894,Test_acc:80.9%,Test_loss:0.934,Lr:1.00E-04
Epoch:67,Train_acc:83.7%,Train_loss:0.913,Test_acc:80.0%,Test_loss:0.942,Lr:1.00E-04
Epoch:68,Train_acc:83.4%,Train_loss:0.907,Test_acc:81.8%,Test_loss:0.913,Lr:1.00E-04
Epoch:69,Train_acc:85.2%,Train_loss:0.892,Test_acc:81.8%,Test_loss:0.926,Lr:1.00E-04
Epoch:70,Train_acc:86.1%,Train_loss:0.884,Test_acc:82.1%,Test_loss:0.928,Lr:1.00E-04
Epoch:71,Train_acc:82.5%,Train_loss:0.918,Test_acc:81.4%,Test_loss:0.929,Lr:1.00E-04
Epoch:72,Train_acc:85.9%,Train_loss:0.892,Test_acc:81.6%,Test_loss:0.920,Lr:1.00E-04
Epoch:73,Train_acc:85.2%,Train_loss:0.893,Test_acc:79.3%,Test_loss:0.944,Lr:1.00E-04
Epoch:74,Train_acc:87.2%,Train_loss:0.875,Test_acc:85.8%,Test_loss:0.884,Lr:1.00E-04
Epoch:75,Train_acc:86.7%,Train_loss:0.876,Test_acc:84.8%,Test_loss:0.893,Lr:1.00E-04
Epoch:76,Train_acc:86.5%,Train_loss:0.875,Test_acc:83.4%,Test_loss:0.903,Lr:1.00E-04
Epoch:77,Train_acc:87.0%,Train_loss:0.878,Test_acc:85.8%,Test_loss:0.884,Lr:1.00E-04
Epoch:78,Train_acc:88.3%,Train_loss:0.861,Test_acc:86.0%,Test_loss:0.888,Lr:1.00E-04
Epoch:79,Train_acc:87.2%,Train_loss:0.869,Test_acc:86.0%,Test_loss:0.883,Lr:1.00E-04
Epoch:80,Train_acc:87.1%,Train_loss:0.877,Test_acc:85.8%,Test_loss:0.886,Lr:1.00E-04
Epoch:81,Train_acc:88.4%,Train_loss:0.859,Test_acc:82.8%,Test_loss:0.913,Lr:1.00E-04
Epoch:82,Train_acc:88.9%,Train_loss:0.851,Test_acc:85.8%,Test_loss:0.878,Lr:1.00E-04
Epoch:83,Train_acc:88.4%,Train_loss:0.859,Test_acc:84.8%,Test_loss:0.893,Lr:1.00E-04
Epoch:84,Train_acc:89.0%,Train_loss:0.860,Test_acc:84.1%,Test_loss:0.900,Lr:1.00E-04
Epoch:85,Train_acc:89.9%,Train_loss:0.850,Test_acc:84.1%,Test_loss:0.899,Lr:1.00E-04
Epoch:86,Train_acc:89.5%,Train_loss:0.850,Test_acc:83.0%,Test_loss:0.913,Lr:1.00E-04
Epoch:87,Train_acc:88.7%,Train_loss:0.854,Test_acc:86.0%,Test_loss:0.885,Lr:1.00E-04
Epoch:88,Train_acc:91.2%,Train_loss:0.837,Test_acc:80.9%,Test_loss:0.928,Lr:1.00E-04
Epoch:89,Train_acc:91.7%,Train_loss:0.831,Test_acc:86.0%,Test_loss:0.883,Lr:1.00E-04
Epoch:90,Train_acc:87.4%,Train_loss:0.863,Test_acc:84.1%,Test_loss:0.900,Lr:1.00E-04
Epoch:91,Train_acc:90.1%,Train_loss:0.851,Test_acc:86.2%,Test_loss:0.878,Lr:1.00E-04
Epoch:92,Train_acc:88.3%,Train_loss:0.855,Test_acc:86.7%,Test_loss:0.871,Lr:1.00E-04
Epoch:93,Train_acc:90.5%,Train_loss:0.844,Test_acc:85.8%,Test_loss:0.884,Lr:1.00E-04
Epoch:94,Train_acc:92.4%,Train_loss:0.821,Test_acc:85.3%,Test_loss:0.881,Lr:1.00E-04
Epoch:95,Train_acc:91.4%,Train_loss:0.835,Test_acc:86.2%,Test_loss:0.878,Lr:1.00E-04
Epoch:96,Train_acc:92.2%,Train_loss:0.829,Test_acc:82.3%,Test_loss:0.917,Lr:1.00E-04
Epoch:97,Train_acc:90.0%,Train_loss:0.848,Test_acc:83.2%,Test_loss:0.913,Lr:1.00E-04
Epoch:98,Train_acc:90.8%,Train_loss:0.836,Test_acc:87.9%,Test_loss:0.868,Lr:1.00E-04
Epoch:99,Train_acc:89.6%,Train_loss:0.848,Test_acc:83.7%,Test_loss:0.908,Lr:1.00E-04
Epoch:100,Train_acc:91.0%,Train_loss:0.832,Test_acc:86.2%,Test_loss:0.881,Lr:1.00E-04
Done
四、 结果可视化
1. Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif']=['SimHei'] #正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #正常显示负号
plt.rcParams['figure.dpi']=300 #分辨率epochs_range=range(epochs)
plt.figure(figsize=(12,3))plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
运行结果:
2. 指定图片进行预测
from PIL import Imageclasses=list(total_data.class_to_idx)def predict_one_image(image_path,model,transform,classes):test_img=Image.open(image_path).convert('RGB')plt.imshow(test_img) #展示预测的图片test_img=transform(test_img)img=test_img.to(device).unsqueeze(0)model.eval()output=model(img)_,pred=torch.max(output,1)pred_class=classes[pred]print(f'预测结果是:{pred_class}')
预测图片:
#预测训练集中的某张照片
predict_one_image(image_path=r'D:\THE MNIST DATABASE\P4-data\Others\NM01_01_00.jpg',model=model,transform=train_transforms,classes=classes)
运行结果:
预测结果是:Others
3. 模型评估
J6_model.eval()
epoch_test_acc,epoch_test_loss=test(test_dl,J6_model,loss_fn)
epoch_test_acc,epoch_test_loss
五、心得体会
在pytorch环境下手动搭建了ResNeXt-50模型,深刻理解了该模型的构造原理,对该模型有了更深层次的感悟。但模型训练结果没有达到最为理想的状态,时间原因不再做调整,在今后的测试中可以尝试调整学习率等查看结果是否有较好变化。