VGG网络的代码实现

VGG网络的程序实现完全根据配置表来实现。

全连接层之前的部分属于特征提取部分,后三部分全连接层用来分类。

1、模型

import torch.nn as nn
import torch# official pretrain weights
#预训练的权重下载地址
model_urls = {'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth','vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth','vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth','vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}#进行分类的代码
class VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=False):super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential(nn.Linear(512*7*7, 4096),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, num_classes))if init_weights:self._initialize_weights()def forward(self, x):# N x 3 x 224 x 224x = self.features(x)# N x 512 x 7 x 7x = torch.flatten(x, start_dim=1)# N x 512*7*7x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)# nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)#进行特征提取部分的代码
def make_features(cfg: list):layers = []in_channels = 3for v in cfg:if v == "M":layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)layers += [conv2d, nn.ReLU(True)]in_channels = vreturn nn.Sequential(*layers)cfgs = {'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}def vgg(model_name="vgg16", **kwargs):assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)cfg = cfgs[model_name]model = VGG(make_features(cfg), **kwargs)return model

定义了VGG11、VGG13、VGG16和VGG19。调用的时候只需要输入模型名字就可以。比如model=vgg("vgg16")

2、预处理

    data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]),"val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])}

3、数据集及可视化

数据集使用的是眼睛疾病的数据集,分别包括四类:cataract(白内障)、diabetic_retinopathy(糖尿病性视网膜病变)、glaucoma(青光眼)、normal(正常)。

可视化:

代码:

    fig = plt.figure()for i in range(4):plt.subplot(1,4,i+1)# plt.tight_layout()# plt.imshow(test_image[i][0],cmap='CMRmap', interpolation='none')plt.imshow(test_image[i][0])# plt.title("Ground Truth: {}".format(test_label[i].item()))plt.xticks([])plt.yticks([])plt.show()

输出:

4、加载数据:

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))

5、加载模型

model_name = "vgg16"
net = vgg(model_name=model_name, num_classes=5, init_weights=True)
net.to(device)

6、损失函数:

loss_function = nn.CrossEntropyLoss()

7、优化器:

optimizer = optim.Adam(net.parameters(), lr=0.0001)

8、迁移学习

 #迁移学习model_weight_path = "./vgg16-pre.pth"assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)net.load_state_dict(torch.load(model_weight_path, map_location='cpu'),False)for param in net.parameters():param.requires_grad = Falsen_inputs=net.classifier[6].in_featureslast_layer=nn.Linear(n_inputs,4)net.classifier[6]=last_layer

9、训练:

    epochs = 30best_acc = 0.0save_path = './{}Net.pth'.format(model_name)train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')

结果:

对vgg11、vgg13、vgg16、vgg19分别进行测试。

vgg模型训练的时间都比较长,从损失值看vgg19效果好一些。从精度上看,vgg13、vgg11、vgg19都有不错的精确率。

完整代码:

import os
import sys
import jsonimport torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdmfrom model import vggdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()model_name = "vgg16"net = vgg(model_name=model_name, num_classes=5, init_weights=True)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0001)epochs = 30best_acc = 0.0save_path = './{}Net.pth'.format(model_name)train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

参考资料:

在Pytorch中使用VGG16进行迁移学习-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/747170.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电子科技大学链时代工作室招新题C语言部分---题号E

1. 题目 这道题大概的意思是说,一座城市中被埋了许多雷(用一个只含0和1的字符串表示城市,1代表有雷,0代表无雷)。 你作为一个排雷兵,需要花最少的钱引爆所有的雷来使城市中不再有雷(太逆天了&a…

测试用例的设计(1)

目录 1. 测试用例的基本要素 2.测试用例的设计方法 2.1.基于需求设计 2.2根据功能需求测试 2.3非功能测试 3. 具体的设计方法 3.1等价类法 3.2边界值法 3.3判定表 1. 测试用例的基本要素 测试用例是为了实施测试而面向测试的系统提供的一组集合,这组集合包含:测试环境,…

Java Web开发从0到1

文章目录 总纲第1章 Java Web应用开发概述1.1 程序开发体系结构1.1.1 C/S体系结构介绍1.1.2 B/S体系结构介绍1.1.3 两种体系结构的比较1.2 Web应用程序的工作原理1.3 Web应用技术1.3.1 客服端应用技术1.3.2 服务端应用技术1.4 Java Web应用的开发环境变量1.5 Tomcat的安装与配置…

2024年3月2日~2024年3月15日周报

文章目录 一、前言二、 D 2 UNet \textrm{D}^{2}\textrm{UNet} D2UNet 阅读情况2.1 体系结构2.2 损失函数 三、遇到的问题及解决四、小结 一、前言 在上上周寻找改进网络框架与超参数的灵感,并跑代码查看了效果。 最近两周,继续修改网络框架结构&#xf…

【Unity】Tag、Layer、LayerMask

文章目录 层(Layer)什么是LayerLayer的应用场景Layer层的配置(Tags & Layers)Layer的数据结构LayerMaskLayer的选中和忽略Layer的管理(架构思路)层碰撞矩阵设置(Layer Collision Matrix&…

SpringBoot(拦截器+文件上传)

文章目录 1.拦截器1.基本介绍2.应用实例1.去掉Thymeleaf案例中使用session进行权限验证的部分2.编写自定义拦截器 LoginInterceptor.java 实现HandlerInterceptor接口的三个方法3.注册拦截器1.第一种方式 配置类直接实现WebMvcConfigurer接口,重写addInterceptors方…

综合小区管理系统|基于Springboot的综合小区管理系统设计与实现(源码+数据库+文档)

综合小区管理系统目录 目录 基于Springboot的综合小区管理系统设计与实现 一、前言 二、系统设计 三、系统功能设计 1、出入管理 2、报修管理 3、车位管理 4、公告管理 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#…

day06、07-MySQL

文章目录 一、MySQL概述1.1 安装1.2 数据模型1.3 SQL简介1.3.1 SQL通用语法1.3.2 分类 二. 数据库设计-DDL2.1 项目开发流程2.2 数据库操作2.2.1 查询数据库2.2.2 创建数据库2.2.3 使用数据库2.2.4 删除数据库 2.3 图形化工具2.3.1 介绍2.3.2 安装2.3.3 使用2.2.3.1 连接数据库…

【项目管理】进度管理

一、前言 小型项目中,定义活动、排列活动顺序、估算活动持续时间及制定进度模型形成进度计划等过程的联系非常密切,可以视为一个过程,可以由一个人在较短时间内完成。项目管理团队编制进度计划的一般步骤为:首先选择进度计划方法…

HTML—标签的分类,span和div标签,不同的标签之间类型转换

标签的分类: ①块级标签:无论内容多少,会充满整个行。大小可自定义 例:p,h1,ul,ol,hr 等 ②行级标签:自身的大小就是标签的大小,不会占一整行。大小不可调 例…

密码保护小贴士:如何应对常见的网络钓鱼攻击?

网络钓鱼攻击是一种常见的网络欺诈手段,针对个人隐私和财产安全构成威胁。以下是一些密码保护的小贴士,帮助您应对常见的网络钓鱼攻击: 1.谨慎点击链接:收到来历不明的邮件、短信或社交媒体消息时,不要轻易点击其中的…

Python 基础语法:基本数据类型(字典)

为什么这个基本的数据类型被称作字典呢?这个是因为字典这种基本数据类型的一些行为和我们日常的查字典过程非常相似。 通过汉语字典查找汉字,首先需要确定这个汉字的首字母,然后再通过这个首字母找到我们所想要的汉字。这个过程其实就代表了…

YoloV8改进策略:下采样改进|HWD改进下采样

摘要 本文使用HWD改进下采样,在YoloV8的测试中实现涨点。 论文解读 在卷积神经网络(CNNs)中,极大池化或跨行卷积等下采样操作被广泛用于聚合局部特征、扩大感受野和最小化计算开销。然而,对于语义分割任务&#xff…

谷粒商城——分布式基础(全栈开发篇第一部分)

文章目录 一、服务治理网路数据支撑日志处理ELK应用监控集成工具开发工具 二、环境创建1、虚拟机创建2、虚拟机安装docker等1. 安装docker1. 配置阿里docker3.docker安装mysql错误 4、docker安装redis 3、软件1.Maven 阿里云镜像1.8jdk2、idea lombokmybatisX ,3、 …

熔断降级的方案实现

熔断降级的方案实现 Spring Cloud Netflix Hystrix 提供线程隔离、服务降级、请求缓存、请求合并等功能可与Spring Cloud其他组件无缝集成官方已宣布停止维护,推荐使用Resilience4j代替 Spring Cloud Resilience4j 轻量级服务熔断库 提供类似于Hystrix的功能 具有更…

C++手写链表、反转链表、删除链表节点、遍历、为链表增加迭代器

本篇博客介绍如何使用C实现链表,首先编写一个简单的链表,然后增加模板,再增加迭代器。 简单链表的实现 链表的结构如下: 首先需要定义链表的节点: struct ListNode {int data;ListNode* pNext;ListNode(int value …

【C++算法模板】图论-拓扑排序,超详细注释带例题

文章目录 0)概述1)Kahn算法1:数据结构2:建图3:Kanh算法 2)DFS染色1:数据结构2:建图3:DFS 3)算法对比【例题】洛谷 B3644 推荐视频链接:D01 拓扑排…

4核8g服务器能支持多少人访问?价格感人,不知道性能如何

腾讯云轻量4核8G12M服务器配置446元一年,646元12个月,腾讯云轻量应用服务器具有100%CPU性能,系统盘为180GB SSD盘,12M带宽下载速度1536KB/秒,月流量2000GB,折合每天66.6GB流量,超出月流量包的流…

Linux下Arthas(阿尔萨斯)的简单使用-接口调用慢排查

使用环境 k8s容器内运行了一个springboot服务,服务的启动方法是main()方法 下载并启动 arthas curl -O https://arthas.aliyun.com/arthas-boot.jar java -jar arthas-boot.jar选择应用 java 进程 就一个进程org.apache.catalina.startup.Bootstrap,输…