使用Caffe进行手写数字识别执行流程解析

之前在 http://blog.csdn.net/fengbingchun/article/details/50987185 中仿照Caffe中的examples实现对手写数字进行识别,这里详细介绍下其执行流程并精简了实现代码,使用Caffe对MNIST数据集进行train的文章可以参考  http://blog.csdn.net/fengbingchun/article/details/68065338 :

1.   先注册所有层,执行layer_factory.hpp中类LayerRegisterer的构造函数,类LayerRegistry的AddCreator和Registry静态函数;关于Caffe中Layer的注册可以参考: http://blog.csdn.net/fengbingchun/article/details/54310956 

2.  指定执行mode是采用CPU还是GPU;

3.   指定需要的.prototxt和.caffemodel文件:注意此处的.prototxt文件(lenet_train_test_.prototxt)与train时.prototxt文件(lenet_train_test.prototxt)在内容上的差异。.caffemodel文件即是train后最终生成的二进制文件lenet_iter_10000.caffemodel,里面存放着所有层的权值和偏置。lenet_train_test_.prototxt文件内容如下:

name: "LeNet" # net名
layer { # memory required: (784+1)*4=3140
  name: "data" # layer名字
  type: "MemoryData" # layer类型,Data enters Caffe through data layers,read data directly from memory
  top: "data" # top名字, shape: 1 1 28 28 (784)
  top: "label"  # top名字, shape: 1 (1) #感觉并无实质作用,仅用于增加一个top blob,不可去掉
  memory_data_param { # 内存数据参数
    batch_size: 1 # 指定待识别图像一次的数量
    channels: 1 # 指定待识别图像的通道数
    height: 28 # 指定待识别图像的高度
    width: 28 # 指定待识别图像的宽度
  }
  transform_param { # 图像预处理参数
    scale: 0.00390625 # 对图像像素值进行scale操作,范围[0, 1)
  }
}
layer { # memory required: 11520*4=46080
  name: "conv1" # layer名字
  type: "Convolution" # layer类型,卷积层
  bottom: "data" # bottom名字
  top: "conv1" # top名字, shape: 1 20 24 24 (11520)
  param { # Specifies training parameters
    lr_mult: 1 # The multiplier on the global learning rate
  }
  param { # Specifies training parameters
    lr_mult: 2 # The multiplier on the global learning rate
  }
  convolution_param { # 卷积参数
    num_output: 20 # 输出特征图(feature map)数量
    kernel_size: 5 # 卷积核大小(卷积核其实就是权值)
    stride: 1 # 滑动步长
    weight_filler { # The filler for the weight
      type: "xavier" # 权值使用xavier滤波
    }
    bias_filler { # The filler for the bias
      type: "constant" # 偏置使用常量滤波
    }
  }
}
layer { # memory required: 2880*4=11520
  name: "pool1" # layer名字
  type: "Pooling" # layer类型,Pooling层
  bottom: "conv1" # bottom名字
  top: "pool1" # top名字, shape: 1 20 12 12 (2880)
  pooling_param { # pooling parameter,pooling层参数
    pool: MAX # pooling方法:最大值采样
    kernel_size: 2 # 滤波器大小
    stride: 2 # 滑动步长
  }
}
layer { # memory required: 3200*4=12800
  name: "conv2" # layer名字
  type: "Convolution" # layer类型,卷积层
  bottom: "pool1" # bottom名字
  top: "conv2" # top名字, shape: 1 50 8 8 (3200)
  param { # Specifies training parameters
    lr_mult: 1 # The multiplier on the global learning rate
  }
  param { # Specifies training parameters
    lr_mult: 2 # The multiplier on the global learning rate
  }
  convolution_param { # 卷积参数
    num_output: 50 # 输出特征图(feature map)数量
    kernel_size: 5 # 卷积核大小(卷积核其实就是权值)
    stride: 1 # 滑动步长
    weight_filler { # The filler for the weight
      type: "xavier" # 权值使用xavier滤波
    }
    bias_filler { # The filler for the bias
      type: "constant" # 偏置使用常量滤波
    }
  }
}
layer { # memory required: 800*4=3200
  name: "pool2" # layer名字
  type: "Pooling" # layer类型,Pooling层
  bottom: "conv2" # bottom名字
  top: "pool2" # top名字, shape: 1 50 4 4 (800)
  pooling_param { # pooling parameter,pooling层参数
    pool: MAX # pooling方法:最大值采样
    kernel_size: 2 # 滤波器大小
    stride: 2 # 滑动步长
  }
}
layer { # memory required: 500*4=2000
  name: "ip1" # layer名字
  type: "InnerProduct" # layer类型,全连接层
  bottom: "pool2" # bottom名字
  top: "ip1" # top名字, shape: 1 500 (500)
  param { # Specifies training parameters
    lr_mult: 1 # The multiplier on the global learning rate
  }
  param { # Specifies training parameters
    lr_mult: 2 # The multiplier on the global learning rate
  }
  inner_product_param { # 全连接层参数
    num_output: 500 # 输出特征图(feature map)数量
    weight_filler { # The filler for the weight
      type: "xavier" # 权值使用xavier滤波
    }
    bias_filler { # The filler for the bias
      type: "constant" # 偏置使用常量滤波
    }
  }
}
# ReLU: Given an input value x, The ReLU layer computes the output as x if x > 0 and 
# negative_slope * x if x <= 0. When the negative slope parameter is not set,
# it is equivalent to the standard ReLU function of taking max(x, 0).
# It also supports in-place computation, meaning that the bottom and
# the top blob could be the same to preserve memory consumption
layer { # memory required: 500*4=2000
  name: "relu1" # layer名字
  type: "ReLU" # layer类型
  bottom: "ip1" # bottom名字
  top: "ip1" # top名字 (in-place), shape: 1 500 (500)
}
layer { # memory required: 10*4=40
  name: "ip2" # layer名字
  type: "InnerProduct" # layer类型,全连接层
  bottom: "ip1" # bottom名字
  top: "ip2" # top名字, shape: 1 10 (10)
  param { # Specifies training parameters
    lr_mult: 1 # The multiplier on the global learning rate
  }
  param { # Specifies training parameters
    lr_mult: 2 # The multiplier on the global learning rate
  }
  inner_product_param {
    num_output: 10 # 输出特征图(feature map)数量
    weight_filler { # The filler for the weight
      type: "xavier" # 权值使用xavier滤波
    }
    bias_filler { # The filler for the bias
      type: "constant" # 偏置使用常量滤波
    }
  }
}
layer { # memory required: 10*4=40
  name: "prob" # layer名字
  type: "Softmax" # layer类型
  bottom: "ip2" # bottom名字
  top: "prob" # top名字, shape: 1 10 (10)
}
# 占用总内存大小为:3140+46080+11520+12800+3200+2000+2000+40+40=80820
lenet_train_test_.prototxt可视化结果( http://ethereon.github.io/netscope/quickstart.html )如下图:

train时lenet_train_test.prototxt与识别时用到的lenet_train_test_.prototxt差异:

(1)、数据层:训练时用Data,是以lmdb数据存储方式载入网络的,而识别时用MemoryData方式直接从内存载入网络;

(2)、Accuracy层:仅训练时用到,用以计算test集的准确率;

(3)、输出层Softmax/SoftmaxWithLoss层:训练时用SoftmaxWithLoss,输出loss值,识别时用Softmax输出10类数字的概率值。

4.   创建Net对象并初始化,有两种方法:一个是通过传入string类型(.prototxt文件)参数创建,一个是通过传入NetParameter参数;

5.   调用Net的CopyTrainedLayersFrom函数加载在train时生成的二进制文件.caffemodel即lenet_iter_10000.caffemodel,有两种方法,一个是通过传入string类型(.caffemodel文件)参数,一个是通过传入NetParameter参数;

6.   获取Net相关参数在后面识别时需要用到:

(1)、通过调用Net的blob_by_name函数获得待识别图像所要求的通道数、宽、高;

(2)、通过调用Net的output_blobs函数获得输出blob的数目及大小,注:这里输出2个blob,第一个是label,count为1,第二个是prob,count为10,即表示数字识别结果的概率值。

7.   开始进行手写数字识别:

(1)、通过opencv的imread函数读入图像;

(2)、根据从Net中获得的需要输入图像的要求对图像进行颜色空间转换和缩放;

(3)、因为MNIST train时,图像为前景为白色,背景为黑色,而现在输入图像为前景为黑色,背景为白色,因此需要对图像进行取反操作;

(4)、将图像数据传入Net,有两种方法:一种是通过MemoryDataLayer类的Reset函数,一种是通过MemoryDataLayer类的AddMatVector函数传入Mat参数;

(5)、调用Net的ForwardPrefilled函数进行前向计算;

(6)、输出识别结果,注,前向计算完返回的Blob有两个,第二个Blob中的数据才是最终的识别结果的概率值,其中最大值的索引即是识别结果。

8.   通过lenet_train_test_.prototxt文件分析各层的权值、偏置和神经元数量,共9层:

(1)、data数据层:无权值和偏置,神经元数量为1*1*28*28+1=785;

(2)、conv1卷积层:卷积窗大小为5*5,输出特征图数量为20,卷积窗种类为20,输出特征图大小为24*24,可训练参数(权值+阈值(偏置))为 20*1*5*5+20=520,神经元数量为1*20*24*24=11520;

(3)、pool1降采样层:滤波窗大小为2*2,输出特征图数量为20,滤波窗种类为20,输出特征图大小为12*12,可训练参数(权值+偏置)为1*20+20=40,神经元数量为1*20*12*12=2880;

(4)、conv2卷积层:卷积窗大小为5*5,输出特征图数量为50,卷积窗种类为50*20,输出特征图大小为8*8,可训练参数(权值+偏置)为50*20*5*5+50=25050,神经元数量为1*50*8*8=3200;

(5)、pool2降采样层:滤波窗大小为2*2,输出特征图数量为50,滤波窗种类为50,输出特征图大小为4*4,可训练参数(权值+偏置)为1*50+50=100,神经元数量为1*50*4*4=800;

(6)、ip1全连接层:滤波窗大小为1*1,输出特征图数量为500,滤波窗种类为500*800,输出特征图大小为1*1,可训练参数(权值+偏置)为500*800*1*1+500=400500,神经元数量为1*500*1*1=500;

(7)、relu1层:in-placeip1;

(8)、ip2全连接层:滤波窗大小为1*1,输出特征图数量为10,滤波窗种类为10*500,输出特征图大小为1*1,可训练参数(权值+偏置)为10*500*1*1+10=5010,神经元数量为1*10*1*1=10;

(9)、prob输出层:神经元数量为1*10*1*1+1=11。

精简后的手写数字识别测试代码如下:

int mnist_predict()
{caffe::Caffe::set_mode(caffe::Caffe::CPU);const std::string param_file{ "E:/GitCode/Caffe_Test/test_data/model/mnist/lenet_train_test_.prototxt" };const std::string trained_filename{ "E:/GitCode/Caffe_Test/test_data/model/mnist/lenet_iter_10000.caffemodel" };const std::string image_path{ "E:/GitCode/Caffe_Test/test_data/images/" };// 有两种方法可以实例化net// 1. 通过传入参数类型为std::stringcaffe::Net<float> caffe_net(param_file, caffe::TEST);caffe_net.CopyTrainedLayersFrom(trained_filename);// 2. 通过传入参数类型为caffe::NetParameter//caffe::NetParameter net_param1, net_param2;//caffe::ReadNetParamsFromTextFileOrDie(param_file, &net_param1);//net_param1.mutable_state()->set_phase(caffe::TEST);//caffe::Net<float> caffe_net(net_param1);//caffe::ReadNetParamsFromBinaryFileOrDie(trained_filename, &net_param2);//caffe_net.CopyTrainedLayersFrom(net_param2);int num_inputs = caffe_net.input_blobs().size(); // 0 ??const boost::shared_ptr<caffe::Blob<float> > blob_by_name = caffe_net.blob_by_name("data");int image_channel = blob_by_name->channels();int image_height = blob_by_name->height();int image_width = blob_by_name->width();int num_outputs = caffe_net.num_outputs();const std::vector<caffe::Blob<float>*> output_blobs = caffe_net.output_blobs();int require_blob_index{ -1 };const int digit_category_num{ 10 };for (int i = 0; i < output_blobs.size(); ++i) {if (output_blobs[i]->count() == digit_category_num)require_blob_index = i;}if (require_blob_index == -1) {fprintf(stderr, "ouput blob don't match\n");return -1;}std::vector<int> target{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };std::vector<int> result;for (auto num : target) {std::string str = std::to_string(num);str += ".png";str = image_path + str;cv::Mat mat = cv::imread(str.c_str(), 1);if (!mat.data) {fprintf(stderr, "load image error: %s\n", str.c_str());return -1;}if (image_channel == 1)cv::cvtColor(mat, mat, CV_BGR2GRAY);else if (image_channel == 4)cv::cvtColor(mat, mat, CV_BGR2BGRA);cv::resize(mat, mat, cv::Size(image_width, image_height));cv::bitwise_not(mat, mat);// 将图像数据载入Net网络,有2种方法boost::shared_ptr<caffe::MemoryDataLayer<float> > memory_data_layer =boost::static_pointer_cast<caffe::MemoryDataLayer<float>>(caffe_net.layer_by_name("data"));// 1. 通过MemoryDataLayer类的Reset函数mat.convertTo(mat, CV_32FC1, 0.00390625);float dummy_label[1] {0};memory_data_layer->Reset((float*)(mat.data), dummy_label, 1);// 2. 通过MemoryDataLayer类的AddMatVector函数//std::vector<cv::Mat> patches{mat}; // set the patch for testing//std::vector<int> labels(patches.size());//memory_data_layer->AddMatVector(patches, labels); // push vector<Mat> to data layerfloat loss{ 0.0 };const std::vector<caffe::Blob<float>*>& results = caffe_net.ForwardPrefilled(&loss); // Net forwardconst float* output = results[require_blob_index]->cpu_data();float tmp{ -1 };int pos{ -1 };fprintf(stderr, "actual digit is: %d\n", target[num]);for (int j = 0; j < 10; j++) {printf("Probability to be Number %d is: %.3f\n", j, output[j]);if (tmp < output[j]) {pos = j;tmp = output[j];}}result.push_back(pos);}for (auto i = 0; i < 10; i++)fprintf(stderr, "actual digit is: %d, result digit is: %d\n", target[i], result[i]);fprintf(stderr, "predict finish\n");return 0;
}

 

测试结果如下:

GitHub:https://github.com/fengbingchun/Caffe_Test
--------------------- 
作者:fengbingchun 
来源:CSDN 
原文:https://blog.csdn.net/fengbingchun/article/details/69001433 
版权声明:本文为博主原创文章,转载请附上博文链接!

--------------------- 
作者:fengbingchun 
来源:CSDN 
原文:https://blog.csdn.net/fengbingchun/article/details/69001433 
版权声明:本文为博主原创文章,转载请附上博文链接!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/458064.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

obs可以装手机吗?_原神PC和手机数据互通吗 PC和手机可以一起玩吗

在原神中&#xff0c;很多玩家都在PC端创建了角色&#xff0c;那么疑问来了&#xff0c;PC端与手机端的账号会是互通的吗&#xff1f;下面小编就为大家带来原神PC和手机数据互通吗的相关内容&#xff0c;一起来看看吧&#xff01;更多攻略&#xff1a;原神攻略大全PC和手机数据…

三维点云目标提取总结(续)

三维点云目标提取&#xff08;续&#xff09; 3.三维点云目标提取 3.1一般流程 先根据个人认识总结一下目标提取的一般性步骤&#xff1a; 如上所示&#xff0c;三维点云的目标提取关键性的两步即为&#xff1a;特征提取与选择、分类&#xff0c;是不是整个方法流程与图像中的目…

安卓高手之路之java层Binder

很多人一提到Binder就说代理模式&#xff0c;人云亦云的多&#xff0c;能理解精髓的少。 本篇文章就从设计角度分析一下java层BInder的设计目标&#xff0c;以及设计思路&#xff0c;设计缺陷&#xff0c;从而驾驭它。 对于【邦德儿】的理解, 从通信的角度来看&#xff0c;就是…

ftp改为sftp_浅谈 FTP、FTPS 与 SFTP

二狗子最近搭建了一个图片分享网站&#xff0c;每天都有好多人在他的网站上传许多照片&#xff0c;这些照片还会通过内部的逻辑同步到又拍云存储中&#xff0c;非常方便。但不久后问题就来了&#xff0c;由于刚开始的用户照片管理规划没有做好&#xff0c;随着用户上传的图片越…

如何解决秒杀的性能问题和超卖的讨论

2019独角兽企业重金招聘Python工程师标准>>> 最近业务试水电商&#xff0c;接了一个秒杀的活。之前经常看到淘宝的同行们讨论秒杀&#xff0c;讨论电商&#xff0c;这次终于轮到我们自己理论结合实际一次了。 ps&#xff1a;进入正文前先说一点个人感受&#xff0c;…

C# 从Excel中读取时间数据

之前写到从Excel中读取时间数据 //读取Excel数据Excel.Application xapp new Excel.Application();string filepath txt_Excel.Text;Excel.Workbook xbook xapp.Workbooks._Open(filepath, Missing.Value, Missing.Value,Missing.Value, Missing.Value, Missing.Value, Miss…

grid autosport额外内容下载慢_清理大王app下载-清理大王v1.0安卓下载

清理大王&#xff0c;下面由小编给大家介绍一下这款软件&#xff0c;该软件是一款非常不错的手机清理服务应用软件&#xff0c;清理大王app为用户提供了手机垃圾清理&#xff0c;内存加速&#xff0c;优化手机&#xff0c;解决手机卡顿的情况。感兴趣的朋友欢迎使用微侠下载&am…

怎么看cudnn的版本好_祖坟风水怎么看,好祖坟有什么征兆?

人们之所以看重祖坟的风水&#xff0c;是因为祖坟的风水与后代子孙的运势密切相关&#xff0c;可以说祖坟的风水好不好关系着子孙后代的运势顺不顺&#xff0c;因此对于祖坟的风水好坏人们是非常在意的&#xff0c;那么祖坟风水怎么看,好祖坟有什么征兆呢&#xff1f;下面是小编…

Spark 宽依赖和窄依赖

2019独角兽企业重金招聘Python工程师标准>>> 我们知道RDD就是一个不可变的带分区的记录集合&#xff0c;Spark提供了RDD上的两类操作&#xff0c;转换和动作。转换是用来定义一个新的RDD&#xff0c;包括map, flatMap, filter, union, sample, join, groupByKey, co…

smart gesture安装失败_WinCC flexible SMART V3 SP2安装步骤以及常见错误解决方法

1安装配置1. win7和win10系统都可以装2. 运行内存至少要2G。3. 硬盘储存空间至少要3G。2安装注意事项1.安装本软件之前必须要关闭所有杀毒软件(例如360安全卫士/360杀毒/电脑管家)等。2.其它西门子软件不要使用或者打开。3.安装之前确保硬盘空间充足。3下载地址https://bbs.jcp…

启动页面和各设备的宽高比及像素

2019独角兽企业重金招聘Python工程师标准>>> iOS7只能用LaunchImage来布置启动画面&#xff0c;只能用图片。iOS8以后支持LaunchScreen.xib来布置&#xff0c;可以自己添加控件。iOS8以及以后的用LaunchScreen来配置启动页。iOS8以后的会走这个设置&#xff0c;而io…

cc压力测试_中小型网站如何防范CC攻击?

大公司就不说了&#xff0c;付费CDN&#xff0c;防火墙&#xff0c;WAF&#xff0c;大流量&#xff0c;一般也会配置专门的安全问题响应团队。今天侧重讨论一下中小型网站如何&#xff08;优雅&#xff09;防范CC攻击。中小站点安全问题通病&#xff1a;对安全问题不重视&#…

泛型复习

回顾泛型类 泛型类&#xff1a;具有一个或多个泛型变量的类被称之为泛型类1、class A<T>{} 2、在创建泛型实例时&#xff0c;需要为其类型变量赋值A<String> anew A<String>(); *如果创建实例时&#xff0c;不给类型变量赋值&#xff0c;那么会有一个警告&am…

.net core EPPlus npoi_2020 ASP.NET界面开发:DevExpress v20.1支持.NET Core设计时

DevExpress ASP.NET Web Forms Controls拥有针对Web表单(包括报表)的110种UI控件&#xff0c;DevExpress ASP.NET MVC Extensions是服务器端MVC扩展或客户端控件&#xff0c;由轻量级JavaScript小部件提供支持的70个高性能DevExpress ASP.NET Core Controls&#xff0c;包含功能…

mac电脑如何与手机同步复制粘贴_如何将电脑里的文件同步到手机里?

由于PDF的特殊性&#xff0c;一般很少有适用于手机编辑的软件&#xff0c;所以我们都习惯于使用电脑来修改PDF文档后&#xff0c;再发送到手机微信发送给其他人&#xff0c;那么如何快速将电脑里的PDF文件同步到手机里面呢&#xff1f;可能很多人会想到使用各种云盘&#xff0c…

走进缓存的世界(一) - 开篇

系列文章 走进缓存的世界&#xff08;一&#xff09; - 开篇走进缓存的世界&#xff08;二&#xff09; - 缓存设计走进缓存的世界&#xff08;三&#xff09; - Memcache概述 对于程序员来说多多少少都懂一点算法&#xff0c;算法是什么&#xff1f;算法是“时间”与“空间”的…

an 转换器_400V耐压场效应管替代IRF730B型号参数,使用在DC-DC电源转换器。_场效应管吧...

DC-DC电源转化器的应用场景逐渐广泛&#xff0c;那么适用于DC-DC电源模块的场效应管需求也随之越来越高&#xff0c;这时候电源转化器厂的电子工程师就要留意了&#xff0c;国内是否有优质的场效应管能替代IRF730B型号呢&#xff0c;其实是有的&#xff0c;FHP840其实是可以跟I…

spring MVC中页面添加锚点

2019独角兽企业重金招聘Python工程师标准>>> 需要添加锚点的代码&#xff1a; <li><a href"main/index#page1">推荐车型</a></li> <li><a href"main/index#page2">热门车型</a></li>需要跳转的…

steam一键授权工具_半个东的时间让你省了一个亿 Steam免费游戏一键领取

总所周知Steam上有很多的免费游戏&#xff0c;但是有哪些游戏是免费的呢&#xff1f;这个一时半会儿也总结不出来&#xff0c;而且还得需要大量的时间添加到自己的游戏库&#xff0c;今天这个教程就教大家如何一键添加Steam上大量免费游戏&#xff0c;需要的小伙伴赶紧收藏哦。…

YModem协议

源&#xff1a;YModem协议 YModem协议&#xff1a; YModem协议是由XModem协议演变而来的&#xff0c;每包数据可以达到1024字节&#xff0c;是一个非常高效的文件传输协议。 下面先看下YModem协议传输的完整的握手过程&#xff1a;先看下图 SENDER:发送方。 RECEIVER:接收方。 …