VGG网络的代码实现

VGG网络的程序实现完全根据配置表来实现。

全连接层之前的部分属于特征提取部分,后三部分全连接层用来分类。

1、模型

import torch.nn as nn
import torch# official pretrain weights
#预训练的权重下载地址
model_urls = {'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth','vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth','vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth','vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}#进行分类的代码
class VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=False):super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential(nn.Linear(512*7*7, 4096),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, num_classes))if init_weights:self._initialize_weights()def forward(self, x):# N x 3 x 224 x 224x = self.features(x)# N x 512 x 7 x 7x = torch.flatten(x, start_dim=1)# N x 512*7*7x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)# nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)#进行特征提取部分的代码
def make_features(cfg: list):layers = []in_channels = 3for v in cfg:if v == "M":layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)layers += [conv2d, nn.ReLU(True)]in_channels = vreturn nn.Sequential(*layers)cfgs = {'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}def vgg(model_name="vgg16", **kwargs):assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)cfg = cfgs[model_name]model = VGG(make_features(cfg), **kwargs)return model

定义了VGG11、VGG13、VGG16和VGG19。调用的时候只需要输入模型名字就可以。比如model=vgg("vgg16")

2、预处理

    data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]),"val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])}

3、数据集及可视化

数据集使用的是眼睛疾病的数据集,分别包括四类:cataract(白内障)、diabetic_retinopathy(糖尿病性视网膜病变)、glaucoma(青光眼)、normal(正常)。

可视化:

代码:

    fig = plt.figure()for i in range(4):plt.subplot(1,4,i+1)# plt.tight_layout()# plt.imshow(test_image[i][0],cmap='CMRmap', interpolation='none')plt.imshow(test_image[i][0])# plt.title("Ground Truth: {}".format(test_label[i].item()))plt.xticks([])plt.yticks([])plt.show()

输出:

4、加载数据:

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))

5、加载模型

model_name = "vgg16"
net = vgg(model_name=model_name, num_classes=5, init_weights=True)
net.to(device)

6、损失函数:

loss_function = nn.CrossEntropyLoss()

7、优化器:

optimizer = optim.Adam(net.parameters(), lr=0.0001)

8、迁移学习

 #迁移学习model_weight_path = "./vgg16-pre.pth"assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)net.load_state_dict(torch.load(model_weight_path, map_location='cpu'),False)for param in net.parameters():param.requires_grad = Falsen_inputs=net.classifier[6].in_featureslast_layer=nn.Linear(n_inputs,4)net.classifier[6]=last_layer

9、训练:

    epochs = 30best_acc = 0.0save_path = './{}Net.pth'.format(model_name)train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')

结果:

对vgg11、vgg13、vgg16、vgg19分别进行测试。

vgg模型训练的时间都比较长,从损失值看vgg19效果好一些。从精度上看,vgg13、vgg11、vgg19都有不错的精确率。

完整代码:

import os
import sys
import jsonimport torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdmfrom model import vggdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()model_name = "vgg16"net = vgg(model_name=model_name, num_classes=5, init_weights=True)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0001)epochs = 30best_acc = 0.0save_path = './{}Net.pth'.format(model_name)train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

参考资料:

在Pytorch中使用VGG16进行迁移学习-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/747170.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电子科技大学链时代工作室招新题C语言部分---题号E

1. 题目 这道题大概的意思是说,一座城市中被埋了许多雷(用一个只含0和1的字符串表示城市,1代表有雷,0代表无雷)。 你作为一个排雷兵,需要花最少的钱引爆所有的雷来使城市中不再有雷(太逆天了&a…

信创产品操作系统加固配置解决方案

1 口令策略配置方面,未设置口令生存周期(高)、未设置口令更改最小间隔天数(高)、未设置口令最小长度(高)、未设置口令复杂度策略(高)、未设置密码重复使用次数限制(中)。 1.1 口令策略未设置口令生存周期(高) 解决方法: 全局设置: 在/etc/login.defs中添加:PAS…

软件工程师,是时候了解下Rust编程语言了

背景 2024年年初,美国政府发布了一份网络安全报告,呼吁软件开发人员停止使用容易出现内存安全漏洞的编程语言,比如:C和C,转而使用内存安全的编程语言。这份报告由美国网络空间总监办公室 (ONCD) 发布,旨在落…

计算机行业在数字经济时代的角色和地位以及如何通过数字化转型提升行业竞争力

在数字经济时代,计算机行业扮演着至关重要的角色,并且拥有极高的地位。计算机行业是数字经济的基础设施和核心驱动力之一,为其他各个行业提供了关键的技术和解决方案。计算机行业通过数字化转型可以进一步提升自身的竞争力和地位。 首先&…

测试用例的设计(1)

目录 1. 测试用例的基本要素 2.测试用例的设计方法 2.1.基于需求设计 2.2根据功能需求测试 2.3非功能测试 3. 具体的设计方法 3.1等价类法 3.2边界值法 3.3判定表 1. 测试用例的基本要素 测试用例是为了实施测试而面向测试的系统提供的一组集合,这组集合包含:测试环境,…

netstat命令——查看网络状态统计信息

netstat是network statistics的缩写,其功能是显示各种网络相关统计信息,例如网络连接状态、路由表信息、接口状态、NAT、多播成员等。通用于Linux和Windows。 netstat命令的语法格式如下: netstat 选项 常用选项如下: 选项 …

Java Web开发从0到1

文章目录 总纲第1章 Java Web应用开发概述1.1 程序开发体系结构1.1.1 C/S体系结构介绍1.1.2 B/S体系结构介绍1.1.3 两种体系结构的比较1.2 Web应用程序的工作原理1.3 Web应用技术1.3.1 客服端应用技术1.3.2 服务端应用技术1.4 Java Web应用的开发环境变量1.5 Tomcat的安装与配置…

2024年3月2日~2024年3月15日周报

文章目录 一、前言二、 D 2 UNet \textrm{D}^{2}\textrm{UNet} D2UNet 阅读情况2.1 体系结构2.2 损失函数 三、遇到的问题及解决四、小结 一、前言 在上上周寻找改进网络框架与超参数的灵感,并跑代码查看了效果。 最近两周,继续修改网络框架结构&#xf…

【Unity】Tag、Layer、LayerMask

文章目录 层(Layer)什么是LayerLayer的应用场景Layer层的配置(Tags & Layers)Layer的数据结构LayerMaskLayer的选中和忽略Layer的管理(架构思路)层碰撞矩阵设置(Layer Collision Matrix&…

SpringBoot(拦截器+文件上传)

文章目录 1.拦截器1.基本介绍2.应用实例1.去掉Thymeleaf案例中使用session进行权限验证的部分2.编写自定义拦截器 LoginInterceptor.java 实现HandlerInterceptor接口的三个方法3.注册拦截器1.第一种方式 配置类直接实现WebMvcConfigurer接口,重写addInterceptors方…

C++语言学习(一)—— 认识C++语言

目录 一、C语言 二、C与C语言的区别 2.1 预处理器 2.2 标准库 2.3 类型 2.4 函数重载 2.5 内存管理 2.6 输入输出函数 2.7 关键字 三、C的基本结构 一、C语言 C语言是一种高级编程语言,由Bjarne Stroustrup在20世纪80年代初设计和开发。它是C语言的扩展&a…

综合小区管理系统|基于Springboot的综合小区管理系统设计与实现(源码+数据库+文档)

综合小区管理系统目录 目录 基于Springboot的综合小区管理系统设计与实现 一、前言 二、系统设计 三、系统功能设计 1、出入管理 2、报修管理 3、车位管理 4、公告管理 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#…

git 将某些代码一个分支转移到另一个分支

https://blog.csdn.net/huyongfu2004/article/details/122239102 假设A分支已经有修改的代码 提交过了 但是没有合并到master; 想要吧A分支这次的修改单独提交 已经commit 查看提交的commit号 git log切到新的分支,后将该commit id pick到正确分支 g…

day06、07-MySQL

文章目录 一、MySQL概述1.1 安装1.2 数据模型1.3 SQL简介1.3.1 SQL通用语法1.3.2 分类 二. 数据库设计-DDL2.1 项目开发流程2.2 数据库操作2.2.1 查询数据库2.2.2 创建数据库2.2.3 使用数据库2.2.4 删除数据库 2.3 图形化工具2.3.1 介绍2.3.2 安装2.3.3 使用2.2.3.1 连接数据库…

【English Learning】Day13

2024/03/14 和小录打卡的第13天 目录 Words & phrases Words & phrases incrredibly incredibly busy 超级忙merely not merely 不仅仅tragedy a terible tregedy 可怕的悲剧a personal tragedy 个人遭遇strive strive to be best 努力做最好的strive for peace 为和平…

【项目管理】进度管理

一、前言 小型项目中,定义活动、排列活动顺序、估算活动持续时间及制定进度模型形成进度计划等过程的联系非常密切,可以视为一个过程,可以由一个人在较短时间内完成。项目管理团队编制进度计划的一般步骤为:首先选择进度计划方法…

Qt+FFmpeg+opengl从零制作视频播放器-7.OpenGL播放视频

在上一节Qt+FFmpeg+opengl从零制作视频播放器-6.视频解码中,我们学到了如何将视频数据解码成YUV原始数据,并且保存到本地,最后使用工具来播放YUV文件。 本节使用QOpenGLWidget来渲染解码后的YUV视频数据。 首先简单介绍QOpenGLWidget的使用。 QOpenGLWidget类是用于渲染O…

HTML—标签的分类,span和div标签,不同的标签之间类型转换

标签的分类: ①块级标签:无论内容多少,会充满整个行。大小可自定义 例:p,h1,ul,ol,hr 等 ②行级标签:自身的大小就是标签的大小,不会占一整行。大小不可调 例…

密码保护小贴士:如何应对常见的网络钓鱼攻击?

网络钓鱼攻击是一种常见的网络欺诈手段,针对个人隐私和财产安全构成威胁。以下是一些密码保护的小贴士,帮助您应对常见的网络钓鱼攻击: 1.谨慎点击链接:收到来历不明的邮件、短信或社交媒体消息时,不要轻易点击其中的…

Python 基础语法:基本数据类型(字典)

为什么这个基本的数据类型被称作字典呢?这个是因为字典这种基本数据类型的一些行为和我们日常的查字典过程非常相似。 通过汉语字典查找汉字,首先需要确定这个汉字的首字母,然后再通过这个首字母找到我们所想要的汉字。这个过程其实就代表了…