cifar10数据集测试有多少张图_pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)...

首先这是VGG的结构图,VGG11则是红色框里的结构,共分五个block,如红框中的VGG11第一个block就是一个conv3-64卷积层:

一,写VGG代码时,首先定义一个 vgg_block(n,in,out)方法,用来构建VGG中每个block中的卷积核和池化层:

n是这个block中卷积层的数目,in是输入的通道数,out是输出的通道数

有了block以后,我们还需要一个方法把形成的block叠在一起,我们定义这个方法叫vgg_stack:

def vgg_stack(num_convs, channels): # vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))

net = []

for n, c in zip(num_convs, channels):

in_c = c[0]

out_c = c[1]

net.append(vgg_block(n, in_c, out_c))

return nn.Sequential(*net)

右边的注释

vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))

里,(1, 1, 2, 2, 2)表示五个block里,各自的卷积层数目,((3, 64), (64, 128), (128, 256), (256, 512), (512, 512))表示每个block中的卷积层的类型,如(3,64)表示这个卷积层输入通道数是3,输出通道数是64。vgg_stack方法返回的就是完整的vgg11模型了。

接着定义一个vgg类,包含vgg_stack方法:

#vgg类

class vgg(nn.Module):

def __init__(self):

super(vgg, self).__init__()

self.feature = vgg_net

self.fc = nn.Sequential(

nn.Linear(512, 100),

nn.ReLU(True),

nn.Linear(100, 10)

)

def forward(self, x):

x = self.feature(x)

x = x.view(x.shape[0], -1)

x = self.fc(x)

return x

最后:

net = vgg() #就能获取到vgg网络

那么构建vgg网络完整的pytorch代码是:

def vgg_block(num_convs, in_channels, out_channels):

net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(True)]

for i in range(num_convs - 1): # 定义后面的许多层

net.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))

net.append(nn.ReLU(True))

net.append(nn.MaxPool2d(2, 2)) # 定义池化层

return nn.Sequential(*net)

# 下面我们定义一个函数对这个 vgg block 进行堆叠

def vgg_stack(num_convs, channels): # vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))

net = []

for n, c in zip(num_convs, channels):

in_c = c[0]

out_c = c[1]

net.append(vgg_block(n, in_c, out_c))

return nn.Sequential(*net)

#确定vgg的类型,是vgg11 还是vgg16还是vgg19

vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))

#vgg类

class vgg(nn.Module):

def __init__(self):

super(vgg, self).__init__()

self.feature = vgg_net

self.fc = nn.Sequential(

nn.Linear(512, 100),

nn.ReLU(True),

nn.Linear(100, 10)

)

def forward(self, x):

x = self.feature(x)

x = x.view(x.shape[0], -1)

x = self.fc(x)

return x

#获取vgg网络

net = vgg()

基于VGG11的cifar10训练代码:

import sys

import numpy as np

import torch

from torch import nn

from torch.autograd import Variable

from torchvision.datasets import CIFAR10

import torchvision.transforms as transforms

def vgg_block(num_convs, in_channels, out_channels):

net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(True)]

for i in range(num_convs - 1): # 定义后面的许多层

net.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))

net.append(nn.ReLU(True))

net.append(nn.MaxPool2d(2, 2)) # 定义池化层

return nn.Sequential(*net)

# 下面我们定义一个函数对这个 vgg block 进行堆叠

def vgg_stack(num_convs, channels): # vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))

net = []

for n, c in zip(num_convs, channels):

in_c = c[0]

out_c = c[1]

net.append(vgg_block(n, in_c, out_c))

return nn.Sequential(*net)

#vgg类

class vgg(nn.Module):

def __init__(self):

super(vgg, self).__init__()

self.feature = vgg_net

self.fc = nn.Sequential(

nn.Linear(512, 100),

nn.ReLU(True),

nn.Linear(100, 10)

)

def forward(self, x):

x = self.feature(x)

x = x.view(x.shape[0], -1)

x = self.fc(x)

return x

# 然后我们可以训练我们的模型看看在 cifar10 上的效果

def data_tf(x):

x = np.array(x, dtype='float32') / 255

x = (x - 0.5) / 0.5

x = x.transpose((2, 0, 1)) ## 将 channel 放到第一维,只是 pytorch 要求的输入方式

x = torch.from_numpy(x)

return x

transform = transforms.Compose([transforms.ToTensor(),

transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),

])

def get_acc(output, label):

total = output.shape[0]

_, pred_label = output.max(1)

num_correct = (pred_label == label).sum().item()

return num_correct / total

def train(net, train_data, valid_data, num_epochs, optimizer, criterion):

if torch.cuda.is_available():

net = net.cuda()

for epoch in range(num_epochs):

train_loss = 0

train_acc = 0

net = net.train()

for im, label in train_data:

if torch.cuda.is_available():

im = Variable(im.cuda())

label = Variable(label.cuda())

else:

im = Variable(im)

label = Variable(label)

# forward

output = net(im)

loss = criterion(output, label)

# forward

optimizer.zero_grad()

loss.backward()

optimizer.step()

train_loss += loss.item()

train_acc += get_acc(output, label)

if valid_data is not None:

valid_loss = 0

valid_acc = 0

net = net.eval()

for im, label in valid_data:

if torch.cuda.is_available():

with torch.no_grad():

im = Variable(im.cuda())

label = Variable(label.cuda())

else:

with torch.no_grad():

im = Variable(im)

label = Variable(label)

output = net(im)

loss = criterion(output, label)

valid_loss += loss.item()

valid_acc += get_acc(output, label)

epoch_str = (

"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "

% (epoch, train_loss / len(train_data),

train_acc / len(train_data), valid_loss / len(valid_data),

valid_acc / len(valid_data)))

else:

epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %

(epoch, train_loss / len(train_data),

train_acc / len(train_data)))

# prev_time = cur_time

print(epoch_str)

if __name__ == '__main__':

# 作为实例,我们定义一个稍微简单一点的 vgg11 结构,其中有 8 个卷积层

vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))

print(vgg_net)

train_set = CIFAR10('./data', train=True, transform=transform, download=True)

train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)

test_set = CIFAR10('./data', train=False, transform=transform, download=True)

test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = vgg()

optimizer = torch.optim.SGD(net.parameters(), lr=1e-1)

criterion = nn.CrossEntropyLoss() #损失函数为交叉熵

train(net, train_data, test_data, 50, optimizer, criterion)

torch.save(net, 'vgg_model.pth')

结束后,会出现一个模型文件vgg_model.pth

二,然后网上找张图片,把图片缩成32x32,放到预测代码中,即可有预测结果出现,预测代码如下:

import torch

import cv2

import torch.nn.functional as F

from vgg2 import vgg ##重要,虽然显示灰色(即在次代码中没用到),但若没有引入这个模型代码,加载模型时会找不到模型

from torch.autograd import Variable

from torchvision import datasets, transforms

import numpy as np

classes = ('plane', 'car', 'bird', 'cat',

'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

if __name__ == '__main__':

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.load('vgg_model.pth') # 加载模型

model = model.to(device)

model.eval() # 把模型转为test模式

img = cv2.imread("horse.jpg") # 读取要预测的图片

trans = transforms.Compose(

[

transforms.ToTensor(),

transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

])

img = trans(img)

img = img.to(device)

img = img.unsqueeze(0) # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]

# 扩展后,为[1,1,28,28]

output = model(img)

prob = F.softmax(output,dim=1) #prob是10个分类的概率

print(prob)

value, predicted = torch.max(output.data, 1)

print(predicted.item())

print(value)

pred_class = classes[predicted.item()]

print(pred_class)

# prob = F.softmax(output, dim=1)

# prob = Variable(prob)

# prob = prob.cpu().numpy() # 用GPU的数据训练的模型保存的参数都是gpu形式的,要显示则先要转回cpu,再转回numpy模式

# print(prob) # prob是10个分类的概率

# pred = np.argmax(prob) # 选出概率最大的一个

# # print(pred)

# # print(pred.item())

# pred_class = classes[pred]

# print(pred_class)

缩成32x32的图片:

运行结果:

以上这篇pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/538278.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

npm ERR! Please try running this command again as root/Administrator.

win10操作系统下 webstrom的控制台使用 npm install angular-file-upload 安装组件,报错:npm ERR! Please try running this command again as root/Administrator. 解决方法: 开始按钮右键---- windows powershell(管理员&…

map flatmap mappartition flatMapToPair四种用法区别

原文链接:http://blog.csdn.net/u013086392/article/details/55666912 ----------------------------------------------------------------------------------- map: 我们可以看到数据的每一行在map之后产生了一个数组,那么rdd存储的是一个数组的集合…

eve可以在linux运行吗,ubuntu下为eve游戏搭载 wine环境

援引该地址的参考,本文仅做整理:http://bbs.eve-china.com/thread-626756-1-1.htmllinux的显卡是否驱动成功,依次键入如下命令察看:sudo apt-get install mesa-utils /*安装 mesa-utils 的指令*/glxinfo | grep r…

自动飞行控制系统_波音公司将重设计737MAX自动飞行控制系统!力求十月前复飞...

据西雅图时报8月1日报道,美国联邦航空管理局(FAA)在6月份对波音737 MAX飞行控制系统进行新的严格测试时,发现了一个潜在的缺陷,该缺陷促使波音公司对其基本的软件设计进行变革。波音公司如今正在改变737 MAX的自动飞行控制系统软件&#xff0…

每日一题——LeetCode141.环形链表

个人主页:白日依山璟 专栏:Java|数据结构与算法|每日一题 文章目录 1. 题目描述示例1:示例2:示例3:提示: 2. 思路3. 代码 1. 题目描述 给你一个链表的头节点 head ,判断链表中是否有环。 如果链表中有某…

Android O 获取APK文件权限 Demo案例

1. 通过 aapt 工具查看 APK权限 C:\Users\zh>adb pull /system/priv-app/Settings . /system/priv-app/Settings/: 3 files pulled. 10.8 MB/s (48840608 bytes in 4.325s)C:\Users\zh>aapt d permissions C:\Users\zh\Settings\Settings.apk package: com.android.sett…

VBoxManage命令更详尽版

原文链接:http://418684644-qq-com.iteye.com/blog/1451000 ------------------------------------- VBoxManage命令详解(一) 本人对vboxmange命令按我个人的理解作了解释,由于本人水平有限难免有错误的地方,希望大…

linux make命令实现,Linux make命令主要参数详解

-C dir或者 --directoryDIR在读取makefile文件前,先切换到“dir”目录下,即把dir作为当前目录。如果存在多个-C选项,make的最终当前目录是第一个目录的相对路径,如“make –C /home/leowang –C document”,等价于“ma…

行人属性数据集pa100k_基于InceptionV3的多数据集联合训练的行人外观属性识别方法与流程...

本发明涉及模式识别技术、智能监控技术等领域,具体的说,是基于Inception V3的多数据集联合训练的行人外观属性识别方法。背景技术:近年来,视频监控系统已经被广泛应用于安防领域。安防人员通过合理的摄像头布局,实现对…

VBoxManage获取虚拟机IP地址

在宿主机Linux上安装VirtualBox,然后VirtualBox上安装linux虚拟机,在Virtualbox非界面启动虚拟机时,ip地址无法查看。怎么办? 使用命令: VBoxManage guestproperty enumerate 虚拟机名 | grep "Net.*V4.*IP"…

springboot系列(十)springboot整合shiro实现登录认证

关于shiro的概念和知识本篇不做详细介绍,但是shiro的概念还是需要做做功课的要不无法理解它的运作原理就无法理解使用shiro; 本篇主要讲解如何使用shiro实现登录认证,下篇讲解使用shiro实现权限控制 要实现shiro和springboot的整合需要以下几…

recyclerview item动画_这可能是你见过的迄今为止最简单的RecyclerView Item加载动画...

如何实现RecyclerView Item动画? 这个问题想必有很多人都会讲,我可以用ItemAnimator实现啊,这是RecyclerView官方定义的接口,专门扩展Item动画的,那我为什么要寻求另外一种方法实现呢?因为最近反思了一个问…

群晖编译LCD4Linux,LCD4LINUX配置文件一些参数使用解释。

#LCD显示配置Display dpf {Driver DPF #LCD驱动类型Port usb0 #连接端口Font 6x8 #字体大小Foreground ffffff #字体…

VBoxManage: error: Nonexistent host networking interface, name 'vboxnet0' (VERR_INTERNAL_ERROR)

错误: VBoxManage: error: Nonexistent host networking interface, name vboxnet0 (VERR_INTERNAL_ERROR) 原因: 原来配置的网卡发生了变更,找不到了,启动失败。 解决方法: 第一步,命令: V…

捷信达温泉管理软件员工卡SQL查询

捷信达温泉管理软件员工卡SQL查询 select * from snkey where v_name2 like %员工% 网名:浩秦; 邮箱:root#landv.pw; 只要我能控制一個國家的貨幣發行,我不在乎誰制定法律。金錢一旦作響,壞話隨之戛然而止。

Linux 软件安装到 /usr,/usr/local/ 还是 /opt 目录?

Linux 的软件安装目录是也是有讲究的,理解这一点,在对系统管理是有益的 /usr:系统级的目录,可以理解为C:/Windows/,/usr/lib理解为C:/Windows/System32。 /usr/local:用户级的程序目录,可以理解…

winpe装双系统linux_使用syslinux在u盘安装pubbylinux和winpe双系统

使用syslinux在u盘安装pubbylinux和winpe双系统1,在u盘里安装winpe,请参见"比较简单的制作U盘winpe启动盘方法"比较简单的制作U盘winpe启动盘方法 收藏1,下载一个深度winpev3.iso2,用winrar或ultraISO解压深度winpev3.iso3,进入解压出来的文件夹下,找到se…

esp32 嵌入式linux,初体验乐鑫 ESP32 AT 指令-嵌入式系统-与非网

乐鑫 AT 固件初体验初步体验 AT 指令下 TCP 数传,为了验证 AT 命令解析器。前往乐鑫官网 下载最新版本 AT 固件和 AT 指令集手册。硬件准备本文使用乐鑫的 ESP-WROOM-32(ESP-WROOM-32 是 ESP32-WROOM-32 的曾用名)模块,4MB Flash,无 PSRAM。E…

主机ping不通Virtualbox里的虚拟机

在redhat上安装了VirtualBox,虚拟了三台Linux机器。 宿主机网卡更换过了。三台虚拟机无法启动了,搭建虚拟机的运维离职了。 VirtualBox的图形界面坏了,启动不了。只能用命令行,今天时间就花在命令行上了。 第一个问题是&#xf…

python后端开发靠谱吗_【后端开发】python有这么强大吗

因为Python是一种代表简单主义思想的语言。除此之外,Python所拥有的标准库更是金融、营销类人群选择它的理由。Python 易于学习可靠且高效(推荐学习:Python视频教程)好吧,相较于其它许多你可以拿来用的编程语言而言,它“更容易一些…