deeplearning4j训练推理案例2023——手写数字识别

文章目录

  • 1.minist数据集
  • 2.依赖包
  • 3.手写数字训练与推理
  • 4. 扩展阅读deeplearning4j自带学习案例项目deeplearning4j-examples

1.minist数据集

下载链接 6W训练集,1W测试集

2.依赖包

主要是deeplearning4j、javacv的一些包,案例打出的jar包1.3G,pom来自github deeplearning子项目deeplearning4j-examples 的dl4j-examples模块

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"><modelVersion>4.0.0</modelVersion><parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.7.9</version><relativePath/></parent><groupId>com.example</groupId><artifactId>demo</artifactId><version>0.0.1-SNAPSHOT</version><name>demo</name><description>demo</description><properties><dl4j-master.version>1.0.0-M2.1</dl4j-master.version><nd4j.backend>nd4j-native</nd4j.backend><java.version>17</java.version><maven-compiler-plugin.version>3.8.1</maven-compiler-plugin.version><maven.minimum.version>3.3.1</maven.minimum.version><exec-maven-plugin.version>1.4.0</exec-maven-plugin.version><maven-shade-plugin.version>2.4.3</maven-shade-plugin.version><jcommon.version>1.0.23</jcommon.version><jfreechart.version>1.0.13</jfreechart.version><logback.version>1.1.7</logback.version><project.build.sourceEncoding>UTF-8</project.build.sourceEncoding><junit.version>5.8.0-M1</junit.version><javacv.version>1.5.9</javacv.version></properties><dependencyManagement><dependencies><dependency><groupId>org.bytedeco</groupId><artifactId>javacv-platform</artifactId><version>${javacv.version}</version></dependency></dependencies></dependencyManagement><dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter</artifactId></dependency><dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId><scope>test</scope></dependency><dependency><groupId>org.nd4j</groupId><artifactId>${nd4j.backend}</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.datavec</groupId><artifactId>datavec-api</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.datavec</groupId><artifactId>datavec-data-image</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.datavec</groupId><artifactId>datavec-local</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-datasets</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-core</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>resources</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-ui</artifactId><version>${dl4j-master.version}</version></dependency><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-zoo</artifactId><version>${dl4j-master.version}</version></dependency><!-- ParallelWrapper & ParallelInference live here --><dependency><groupId>org.deeplearning4j</groupId><artifactId>deeplearning4j-parallel-wrapper</artifactId><version>${dl4j-master.version}</version></dependency><!-- Used in the feedforward/classification/MLP* and feedforward/regression/RegressionMathFunctions example --><dependency><groupId>jfree</groupId><artifactId>jfreechart</artifactId><version>${jfreechart.version}</version></dependency><dependency><groupId>org.jfree</groupId><artifactId>jcommon</artifactId><version>${jcommon.version}</version></dependency><!-- Used for downloading data in some of the examples --><dependency><groupId>org.apache.httpcomponents</groupId><artifactId>httpclient</artifactId><version>4.3.5</version></dependency><dependency><groupId>ch.qos.logback</groupId><artifactId>logback-classic</artifactId><version>${logback.version}</version></dependency><dependency><groupId>org.bytedeco</groupId><artifactId>javacv-platform</artifactId></dependency><dependency><groupId>org.nd4j</groupId><artifactId>nd4j-api</artifactId><version>1.0.0-M2.1</version></dependency></dependencies><build><plugins><plugin><groupId>org.springframework.boot</groupId><artifactId>spring-boot-maven-plugin</artifactId></plugin><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-compiler-plugin</artifactId><configuration><source>17</source><target>17</target></configuration></plugin></plugins></build></project>

3.手写数字训练与推理

1个epoch训练耗时100s,准确率达97%,详见代码注释,框架的api做得还比较好用

package ai;import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.common.io.Assert;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;import java.io.File;
import java.util.Random;@Slf4j
public class LeNetMNISTReLu {private static final String DATASET_PATH_BASE = "D:\\";public static void main(String[] args) throws Exception {int height = 28;int width = 28;// 黑白图片通道只有一个int channels = 1;// 0-9十种数字int outputNum = 10;int batchSize = 64;// 这里一个epoch耗时约100s,3次准确率99%int nEpochs = 1;Assert.isTrue(new File(DATASET_PATH_BASE + "/mnist_png").exists(), "请下载压缩包并解压到" + DATASET_PATH_BASE);// 该label生成器会将数据所在父目录名作为label,要求目录名必须为数值,这里mnist数据集正好是放在0-9文件夹的ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();// 归一化(0-1)DataNormalization normalization = new ImagePreProcessingScaler();Random random = new Random(12345);log.info("训练集6W张...");File trainData = new File(DATASET_PATH_BASE + "/mnist_png/training");FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, random);ImageRecordReader trainRecordReader = new ImageRecordReader(height, width, channels, labelMaker);trainRecordReader.initialize(trainSplit);DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, batchSize, 1, outputNum);normalization.fit(trainIter);trainIter.setPreProcessor(normalization); // 先像素归一化log.info("验证集1W张...");File validateData = new File(DATASET_PATH_BASE + "/mnist_png/testing");FileSplit validateSplit = new FileSplit(validateData, NativeImageLoader.ALLOWED_FORMATS, random);ImageRecordReader validateRecordReader = new ImageRecordReader(height, width, channels, labelMaker);validateRecordReader.initialize(validateSplit);DataSetIterator validateIter = new RecordReaderDataSetIterator(validateRecordReader, batchSize, 1, outputNum);validateIter.setPreProcessor(normalization);// 训练集6W数据 每次迭代batchSize=64,故这里大概有1000次迭代// 学习率,每200个迭代更新一次学习率(步长),先大一点,还可以每个Epoch更新一次学习率MapSchedule mapSchedule = new MapSchedule.Builder(ScheduleType.ITERATION).add(0, 0.06).add(200, 0.05).add(600, 0.028).add(800, 0.006).add(1000, 0.001).build();// 超参MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(1).l2(0.0005).updater(new Nesterovs(mapSchedule))//.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) //该优化器导致长时间无法拟合.weightInit(WeightInit.XAVIER).list().layer(new ConvolutionLayer.Builder(5, 5).nIn(channels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()).layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()).layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image.build();// 神经网络对象构建MultiLayerNetwork net = new MultiLayerNetwork(conf);net.init();// 训练监控,每次迭代打印损失函数值net.setListeners(new ScoreIterationListener(10));// WEB UI监控训练过程//UIServer uiServer = UIServer.getInstance();//FileStatsStorage statsStorage = new FileStatsStorage(new File("D:\\ai-webui.dat"));//uiServer.attach(statsStorage);//net.setListeners(new StatsListener(statsStorage));log.info("网络参数个数{}", net.numParams());long startTime = System.currentTimeMillis();// 训练epochs轮for (int i = 0; i < nEpochs; i++) {log.info("Epoch=" + i);net.fit(trainIter);Evaluation eval = net.evaluate(validateIter);log.info(eval.stats());trainIter.reset();validateIter.reset();}log.info("训练耗时{}毫秒", System.currentTimeMillis() - startTime);// 保存模型File ministModelPath = new File(DATASET_PATH_BASE + "/ministModel.zip");ModelSerializer.writeModel(net, ministModelPath, true);// 推理逻辑:加载网络(模型)——>加载测试图片——>预测MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(new File(DATASET_PATH_BASE + "/ministModel.zip"));NativeImageLoader imageLoader = new NativeImageLoader(height, width, channels);FileUtils.listFiles(new File("D:\\mnist_png\\testing"), null, true).parallelStream().forEach(file -> {try {INDArray matrix = imageLoader.asMatrix(file);INDArray output = network.output(matrix);// 取最可能的预测结果int predictedValue = Nd4j.argMax(output, 1).getInt(0);// 数字图片按数值放在每个文件夹的,故图片所在文件夹名即为真实值String realValue = file.getParentFile().getName();log.info("真实值:{},预测值:{}", realValue, predictedValue);Assert.isTrue(predictedValue == Integer.parseInt(realValue), file.getAbsolutePath() + "预测错误");} catch (Exception e) {log.warn(e.getMessage(), e);}});}
}

4. 扩展阅读deeplearning4j自带学习案例项目deeplearning4j-examples

deeplearning4j-examples 参考其readme文档

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/118625.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

超级强大!送你几款Linux 下终极SSH客户端

更多IT技术&#xff0c;请关注微信公众号:“运维之美” 超级强大&#xff01;送你几款Linux 下终极SSH客户端 1.MobaXterm2.Xshell3.SecureCRT4.PuTTY5.FinalShell6.Termius7.WindTerm 安全外壳协议&#xff08;Secure Shell&#xff0c;简称 SSH&#xff09;是一种网络连接协议…

【Gensim概念】02/3 NLP玩转 word2vec

第二部分 句法 六、句法模型&#xff08;类对象和参数&#xff09; 6.1 数据集的句子查看 classgensim.models.word2vec.BrownCorpus(dirname) Bases: object 迭代句子 Brown corpus (part of NLTK data). 6.2 数据集的句子和gram classgensim.models.word2vec.Heapitem(c…

【Docker】Docker数据的存储

默认情况下&#xff0c;在运行中的容器里创建的文件&#xff0c;被保存在一个可写的容器层里&#xff0c;如果容器被删除了&#xff0c;则对应的数据也随之删除了。 这个可写的容器层是和特定的容器绑定的&#xff0c;也就是这些数据无法方便的和其它容器共享。 Docker主要提…

智能井盖监测系统功能,万宾科技传感器效果

智能井盖传感器的出现是高科技产品的更新换代&#xff0c;同时也是智慧城市建设中的需求。在智慧城市建设过程之中&#xff0c;高科技产品的应用数不胜数&#xff0c;智能井盖传感器的出现&#xff0c;解决了城市道路安全保护着城市地下生命线&#xff0c;改善着传统井盖带来的…

责任链模式应用案例

前几天系统商品折扣功能优化&#xff0c;同事采用了责任链模式重构了代码&#xff0c;现整理如下。 一、概念 责任链模式是为请求创建一个处理者对象的链条&#xff0c;所有处理者&#xff08;除最末端&#xff09;都含有下一个对象的引用从而形成一条处理链&#xff0c;该模…

10月最新H5自适应樱花导航网站源码SEO增强版

10月最新H5自适应樱花导航网源码SEO增强版。非常强大的导航网站亮点就是对SEO优化比较好。 开发时PHP版本&#xff1a;7.3开发时MySQL版本&#xff1a;5.7.26 懂前端和PHP技术想更改前端页面的可以看&#xff1a;网站的前端页面不好看&#xff0c;你可以查看index目录&#x…

二、W5100S/W5500+RP2040树莓派Pico<DHCP>

文章目录 1 前言2 简介2 .1 什么是DHCP&#xff1f;2.2 为什么要使用DHCP&#xff1f;2.3 DHCP工作原理2.4 DHCP应用场景 3 WIZnet以太网芯片4 DHCP网络设置示例概述以及使用4.1 流程图4.2 准备工作核心4.3 连接方式4.4 主要代码概述4.5 结果演示 5 注意事项6 相关链接 1 前言 …

vue项目中将html转为pdf并下载

个人项目地址&#xff1a; SubTopH前端开发个人站 &#xff08;自己开发的前端功能和UI组件&#xff0c;一些有趣的小功能&#xff0c;感兴趣的伙伴可以访问&#xff0c;欢迎提出更好的想法&#xff0c;私信沟通&#xff0c;网站属于静态页面&#xff09; SubTopH前端开发个人…

C/C++不及格学生 2020年9月电子学会青少年软件编程(C/C++)等级考试一级真题答案解析

目录 C/C不及格学生 一、题目要求 1、编程实现 2、输入输出 二、算法分析 三、程序编写 四、程序说明 五、运行结果 六、考点分析 C/C不及格学生 2020年9月 C/C编程等级考试一级编程题 一、题目要求 1、编程实现 给出一名学生的语文和数学成绩&#xff0c;判断他是…

web3之链上情报平台Arkham

文章目录 web3之链上情报平台Arkham什么是Arkham链上情报交易所 Arkham Intel Exchange相较于传统情报交易方式,Arkham Intel Exchange下优势 web3之链上情报平台Arkham 什么是Arkham 官网&#xff1a;https://zh.arkhamintelligence.com/ 官方&#xff1a;https://platform.…

如何在 Chrome 中设置HTTP服务器?

首先&#xff0c;定义问题&#xff1a;在 Chrome 浏览器中设置HTTP服务器主要涉及到修改网络设置&#xff0c;使用HTTP服务器可以帮助用户访问网络内容&#xff0c;提高网络速度或者保护隐私。 亲身经验&#xff1a;我曾在使用 Chrome 浏览器时&#xff0c;为了访问一些受限的网…

使用Docker快速搭建服务器环境

简介 这篇文章也是方便自己记录搭建流程&#xff0c;服务器的购买啥的就不说了&#xff0c;最终目标就是在一个空白的Linux系统上&#xff0c;使用docker运行MySQL、TomcatJava、Nginx、Redis 的单机环境&#xff0c;以后方便自己快速的部署服务器。 安装Docker 首先需要安装…

python网络爬虫(二)基本库的使用urllib/requests

使用urllib 了解一下 urllib 库&#xff0c;它是 Python 内置的 HTTP 请求库&#xff0c;也就是说不需要额外安装即可使用。它包含如下 4 个模块。 request&#xff1a;它是最基本的 HTTP 请求模块&#xff0c;可以用来模拟发送请求。就像在浏览器里输入网址然后回车一样&…

06 MIT线性代数-列空间和零空间 Column space Nullspace

1. Vector space Vector space requirements vw and c v are in the space, all combs c v d w are in the space 但是“子空间”和“子集”的概念有区别&#xff0c;所有元素都在原空间之内就可称之为子集&#xff0c;但是要满足对线性运算封闭的子集才能成为子空间 中 2 …

【OpenCV实现图像阈值处理】

文章目录 概要简单阈值调整自适应阈值调整大津(Otsus)阈值法Otsus 二值化是如何工作的 概要 OpenCV库中的图像处理技术&#xff0c;主要分为几何变换、图像阈值调整和平滑处理三个部分。 在几何变换方面&#xff0c;OpenCV提供了cv.warpAffine和cv.warpPerspective函数&#…

(链表) 25. K 个一组翻转链表 ——【Leetcode每日一题】

❓ 25. K 个一组翻转链表 难度&#xff1a;困难 给你链表的头节点 head &#xff0c;每 k 个节点一组进行翻转&#xff0c;请你返回修改后的链表。 k 是一个正整数&#xff0c;它的值小于或等于链表的长度。如果节点总数不是 k 的整数倍&#xff0c;那么请将最后剩余的节点保…

Kotlin基础——函数、变量、字符串模板、类

函数、变量、字符串模板、类 函数变量字符串模板类 函数 函数组成为 fun 函数名(参数名: 参数类型, …): 返回值{} fun max(a: Int, b: Int): Int {return if (a > b) a else b }上面称为代码块函数体&#xff0c;当函数体由单个表达式构成时&#xff0c;可简化为表达式函…

FreeRTOS 计数型信号量 详解

目录 什么是计数型信号量&#xff1f; 计数型信号量相关 API 函数 1. 创建计数型信号量 2. 释放二值信号量 3. 获取二值信号量 计数型信号量实操 什么是计数型信号量&#xff1f; 计数型信号量相当于队列长度大于1 的队列&#xff0c;因此计数型信号量能够容纳多个资源&a…

Azure - 机器学习:创建机器学习所需资源,配置工作区

目录 一、Azure机器学习工作区与计算实例简要介绍工作区计算实例 二、创建工作区1. 登录到 Azure 机器学习工作室2. 选择“创建工作区”3. 提供以下信息来配置新工作区&#xff1a;4. 选择“创建”以创建工作区 三、创建计算实例四、工作室实战4.1 工作室快速导览4.2 从示例笔记…

css 雷达扫描图

html 代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>css 雷达扫描</title><style>* {margin: 0;padding: 0;}body {background: #000000;height: 100vh;display: flex;align-items…