Spring Boot集成tensorflow实现图片检测服务

1.什么是tensorflow?

TensorFlow名字的由来就是张量(Tensor)在计算图(Computational Graph)里的流动(Flow),如图。它的基础就是前面介绍的基于计算图的自动微分,除了自动帮你求梯度之外,它也提供了各种常见的操作(op,也就是计算图的节点),常见的损失函数,优化算法。

tensorflow

  • TensorFlow 是一个开放源代码软件库,用于进行高性能数值计算。借助其灵活的架构,用户可以轻松地将计算工作部署到多种平台(CPU、GPU、TPU)和设备(桌面设备、服务器集群、移动设备、边缘设备等)。https://www.tensorflow.org/tutorials/?hl=zh-cnwww.tensorflow.org/tutorials/?hl=zh-cn(opens new window)
  • TensorFlow 是一个用于研究和生产的开放源代码机器学习库。TensorFlow 提供了各种 API,可供初学者和专家在桌面、移动、网络和云端环境下进行开发。
  • TensorFlow是采用数据流图(data flow graphs)来计算,所以首先我们得创建一个数据流流图,然后再将我们的数据(数据以张量(tensor)的形式存在)放在数据流图中计算. 节点(Nodes)在图中表示数学操作,图中的边(edges)则表示在节点间相互联系的多维数据数组, 即张量(tensor)。训练模型时tensor会不断的从数据流图中的一个节点flow到另一节点, 这就是TensorFlow名字的由来。 张量(Tensor):张量有多种. 零阶张量为 纯量或标量 (scalar) 也就是一个数值. 比如 [1],一阶张量为 向量 (vector), 比如 一维的 [1, 2, 3],二阶张量为 矩阵 (matrix), 比如 二维的 [[1, 2, 3],[4, 5, 6],[7, 8, 9]],以此类推, 还有 三阶 三维的 … 张量从流图的一端流动到另一端的计算过程。它生动形象地描述了复杂数据结构在人工神经网中的流动、传输、分析和处理模式。

在机器学习中,数值通常由4种类型构成: (1)标量(scalar):即一个数值,它是计算的最小单元,如“1”或“3.2”等。 (2)向量(vector):由一些标量构成的一维数组,如[1, 3.2, 4.6]等。 (3)矩阵(matrix):是由标量构成的二维数组。 (4)张量(tensor):由多维(通常)数组构成的数据集合,可理解为高维矩阵。

tensorflow的基本概念

  • 图:描述了计算过程,Tensorflow用图来表示计算过程
  • 张量:Tensorflow 使用tensor表示数据,每一个tensor是一个多维化的数组
  • 操作:图中的节点为op,一个op获得/输入0个或者多个Tensor,执行并计算,产生0个或多个Tensor
  • 会话:session tensorflow的运行需要再绘话里面运行

tensorflow写代码流程

  • 定义变量占位符
  • 根据数学原理写方程
  • 定义损失函数cost
  • 定义优化梯度下降 GradientDescentOptimizer
  • session 进行训练,for循环
  • 保存saver

2.环境准备

整合步骤

  1. 模型构建:首先,我们需要在TensorFlow中定义并训练深度学习模型。这可能涉及选择合适的网络结构、优化器和损失函数等。
  2. 训练数据准备:接下来,我们需要准备用于训练和验证模型的数据。这可能包括数据清洗、标注和预处理等步骤。
  3. REST API设计:为了与TensorFlow模型进行交互,我们需要在SpringBoot中创建一个REST API。这可以使用SpringBoot的内置功能来实现,例如使用Spring MVC或Spring WebFlux。
  4. 模型部署:在模型训练完成后,我们需要将其部署到SpringBoot应用中。为此,我们可以使用TensorFlow的Java API将模型导出为ONNX或SavedModel格式,然后在SpringBoot应用中加载并使用。

在整合过程中,有几个关键点需要注意。首先,防火墙设置可能会影响TensorFlow训练过程中的网络通信。确保你的防火墙允许TensorFlow访问其所需的网络资源,以免出现训练中断或模型性能下降的问题。其次,要关注版本兼容性。SpringBoot和TensorFlow都有各自的版本更新周期,确保在整合时使用兼容的版本可以避免很多不必要的麻烦。

模型下载

模型构建和模型训练这块设计到python代码,这里跳过,感兴趣的可以下载源代码自己训练模型,咱们直接下载训练好的模型

  • https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz

下载好了,解压放在/resources/inception_v3目录下

3.代码工程

实验目的

实现图片检测

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><parent><artifactId>springboot-demo</artifactId><groupId>com.et</groupId><version>1.0-SNAPSHOT</version></parent><modelVersion>4.0.0</modelVersion><artifactId>Tensorflow</artifactId><properties><maven.compiler.source>11</maven.compiler.source><maven.compiler.target>11</maven.compiler.target></properties><dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-autoconfigure</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId><scope>test</scope></dependency><dependency><groupId>org.tensorflow</groupId><artifactId>tensorflow-core-platform</artifactId><version>0.5.0</version></dependency><dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId></dependency><dependency><groupId>jmimemagic</groupId><artifactId>jmimemagic</artifactId><version>0.1.2</version></dependency><dependency><groupId>jakarta.platform</groupId><artifactId>jakarta.jakartaee-api</artifactId><version>9.0.0</version></dependency><dependency><groupId>commons-io</groupId><artifactId>commons-io</artifactId><version>2.16.1</version></dependency><dependency><groupId>org.springframework.restdocs</groupId><artifactId>spring-restdocs-mockmvc</artifactId><scope>test</scope></dependency></dependencies>
</project>

controller

package com.et.tf.api;import java.io.IOException;import com.et.tf.service.ClassifyImageService;
import net.sf.jmimemagic.Magic;
import net.sf.jmimemagic.MagicMatch;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;@RestController
@RequestMapping("/api")
public class AppController {@AutowiredClassifyImageService classifyImageService;@PostMapping(value = "/classify")@CrossOrigin(origins = "*")public ClassifyImageService.LabelWithProbability classifyImage(@RequestParam MultipartFile file) throws IOException {checkImageContents(file);return classifyImageService.classifyImage(file.getBytes());}@RequestMapping(value = "/")public String index() {return "index";}private void checkImageContents(MultipartFile file) {MagicMatch match;try {match = Magic.getMagicMatch(file.getBytes());} catch (Exception e) {throw new RuntimeException(e);}String mimeType = match.getMimeType();if (!mimeType.startsWith("image")) {throw new IllegalArgumentException("Not an image type: " + mimeType);}}}

service

package com.et.tf.service;import jakarta.annotation.PreDestroy;
import java.util.Arrays;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.op.OpScope;
import org.tensorflow.op.Scope;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;//Inspired from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
@Service
@Slf4j
public class ClassifyImageService {private final Session session;private final List<String> labels;private final String outputLayer;private final int W;private final int H;private final float mean;private final float scale;public ClassifyImageService(Graph inceptionGraph, List<String> labels, @Value("${tf.outputLayer}") String outputLayer,@Value("${tf.image.width}") int imageW, @Value("${tf.image.height}") int imageH,@Value("${tf.image.mean}") float mean, @Value("${tf.image.scale}") float scale) {this.labels = labels;this.outputLayer = outputLayer;this.H = imageH;this.W = imageW;this.mean = mean;this.scale = scale;this.session = new Session(inceptionGraph);}public LabelWithProbability classifyImage(byte[] imageBytes) {long start = System.currentTimeMillis();try (Tensor image = normalizedImageToTensor(imageBytes)) {float[] labelProbabilities = classifyImageProbabilities(image);int bestLabelIdx = maxIndex(labelProbabilities);LabelWithProbability labelWithProbability =new LabelWithProbability(labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f, System.currentTimeMillis() - start);log.debug(String.format("Image classification [%s %.2f%%] took %d ms",labelWithProbability.getLabel(),labelWithProbability.getProbability(),labelWithProbability.getElapsed()));return labelWithProbability;}}private float[] classifyImageProbabilities(Tensor image) {try (Tensor result = session.runner().feed("input", image).fetch(outputLayer).run().get(0)) {final Shape resultShape = result.shape();final long[] rShape = resultShape.asArray();if (resultShape.numDimensions() != 2 || rShape[0] != 1) {throw new RuntimeException(String.format("Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",Arrays.toString(rShape)));}int nlabels = (int) rShape[1];FloatDataBuffer resultFloatBuffer = result.asRawTensor().data().asFloats();float[] dst = new float[nlabels];resultFloatBuffer.read(dst);return dst;}}private int maxIndex(float[] probabilities) {int best = 0;for (int i = 1; i < probabilities.length; ++i) {if (probabilities[i] > probabilities[best]) {best = i;}}return best;}private Tensor normalizedImageToTensor(byte[] imageBytes) {try (Graph g = new Graph();TInt32 batchTensor = TInt32.scalarOf(0);TInt32 sizeTensor = TInt32.vectorOf(H, W);TFloat32 meanTensor = TFloat32.scalarOf(mean);TFloat32 scaleTensor = TFloat32.scalarOf(scale);) {GraphBuilder b = new GraphBuilder(g);//Tutorial python here: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/label_image// Some constants specific to the pre-trained model at:// https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz//// - The model was trained with images scaled to 299x299 pixels.// - The colors, represented as R, G, B in 1-byte each were converted to//   float using (value - Mean)/Scale.// Since the graph is being constructed once per execution here, we can use a constant for the// input image. If the graph were to be re-used for multiple input images, a placeholder would// have been more appropriate.final Output input = b.constant("input", TString.tensorOfBytes(NdArrays.scalarOfObject(imageBytes)));final Output output =b.div(b.sub(b.resizeBilinear(b.expandDims(b.cast(b.decodeJpeg(input, 3), DataType.DT_FLOAT),b.constant("make_batch", batchTensor)),b.constant("size", sizeTensor)),b.constant("mean", meanTensor)),b.constant("scale", scaleTensor));try (Session s = new Session(g)) {return s.runner().fetch(output.op().name()).run().get(0);}}}static class GraphBuilder {final Scope scope;GraphBuilder(Graph g) {this.g = g;this.scope = new OpScope(g);}Output div(Output x, Output y) {return binaryOp("Div", x, y);}Output sub(Output x, Output y) {return binaryOp("Sub", x, y);}Output resizeBilinear(Output images, Output size) {return binaryOp("ResizeBilinear", images, size);}Output expandDims(Output input, Output dim) {return binaryOp("ExpandDims", input, dim);}Output cast(Output value, DataType dtype) {return g.opBuilder("Cast", "Cast", scope).addInput(value).setAttr("DstT", dtype).build().output(0);}Output decodeJpeg(Output contents, long channels) {return g.opBuilder("DecodeJpeg", "DecodeJpeg", scope).addInput(contents).setAttr("channels", channels).build().output(0);}Output<? extends TType> constant(String name, Tensor t) {return g.opBuilder("Const", name, scope).setAttr("dtype", t.dataType()).setAttr("value", t).build().output(0);}private Output binaryOp(String type, Output in1, Output in2) {return g.opBuilder(type, type, scope).addInput(in1).addInput(in2).build().output(0);}private final Graph g;}@PreDestroypublic void close() {session.close();}@Data@NoArgsConstructor@AllArgsConstructorpublic static class LabelWithProbability {private String label;private float probability;private long elapsed;}
}

application.yaml

tf:frozenModelPath: inception-v3/inception_v3_2016_08_28_frozen.pblabelsPath: inception-v3/imagenet_slim_labels.txtoutputLayer: InceptionV3/Predictions/Reshape_1image:width: 299height: 299mean: 0scale: 255logging.level.net.sf.jmimemagic: WARN
spring:servlet:multipart:max-file-size: 5MB

Application.java

package com.et.tf;import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.tensorflow.Graph;
import org.tensorflow.proto.framework.GraphDef;@SpringBootApplication
@Slf4j
public class Application {public static void main(String[] args) {SpringApplication.run(Application.class, args);}@Beanpublic Graph tfModelGraph(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) throws IOException {Resource graphResource = getResource(tfFrozenModelPath);Graph graph = new Graph();graph.importGraphDef(GraphDef.parseFrom(graphResource.getInputStream()));log.info("Loaded Tensorflow model");return graph;}private Resource getResource(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) {Resource graphResource = new FileSystemResource(tfFrozenModelPath);if (!graphResource.exists()) {graphResource = new ClassPathResource(tfFrozenModelPath);}if (!graphResource.exists()) {throw new IllegalArgumentException(String.format("File %s does not exist", tfFrozenModelPath));}return graphResource;}@Beanpublic List<String> tfModelLabels(@Value("${tf.labelsPath}") String labelsPath) throws IOException {Resource labelsRes = getResource(labelsPath);log.info("Loaded model labels");return IOUtils.readLines(labelsRes.getInputStream(), StandardCharsets.UTF_8).stream().map(label -> label.substring(label.contains(":") ? label.indexOf(":") + 1 : 0)).collect(Collectors.toList());}
}

以上只是一些关键代码,所有代码请参见下面代码仓库

代码仓库

  • GitHub - Harries/springboot-demo: a simple springboot demo with some components for example: redis,solr,rockmq and so on.

4.测试

启动 Spring Boot应用程序

测试图片分类

访问http://127.0.0.1:8080/,上传一张图片,点击分类

 

5.引用

  • https://www.tensorflow.org/
  • Spring Boot集成tensorflow实现图片检测服务 | Harries Blog™

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/31070.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯 经典算法题 Fizz Buzz 经典问题

题目&#xff1a; 给定一个整数 N&#xff0c;从 1 到 N 按照下面的规则返回每个数&#xff1a; 如果这个数被 3 整除&#xff0c;返回 Fizz。如果这个数被 5 整除&#xff0c;返回 Buzz如果这个数能同时被 3 和 5 整除&#xff0c;返回 FizzBuzz。如果这个数既不能被 3 也不…

python爬虫之selenium自动化操作

python爬虫之selenium自动化操作 需求&#xff1a;操作淘宝去掉弹窗广告搜索物品后进入百度回退又前进 selenium模块的基本使用 问题&#xff1a;selenium模块和爬虫之间具有怎样的关联? 1、便捷的获取网站中动态加载的数据 2、便捷实现模拟登录 什么是selenium模块&#x…

maven-jar-plugin maven打包插件笔记

文章目录 配置示例 其他官网文档问题maven打包插件是如何和打包动作关联在一起的?配置文件中 goal是必须的吗? maven自定义插件内容很多&#xff0c;也不易理解&#xff0c;这里把maven打包插件单拿出来&#xff0c;作为入口试着理解下。 配置示例 <plugin><groupI…

ubuntu22.04禁止自动休眠的几种方式

在Ubuntu 20.04中&#xff0c;您可以通过以下几种方式禁用自动休眠功能&#xff1a; 使用systemd设置: sudo systemctl mask sleep.target suspend.target hibernate.target hybrid-sleep.target 修改/etc/systemd/logind.conf文件: sudo nano /etc/systemd/logind.conf 找…

番外篇 | 基于YOLOv5-RCS的明火烟雾检测 | 源于RCS-YOLO

前言:Hello大家好,我是小哥谈。RCS-YOLO是一种目标检测算法,它是基于YOLOv3算法的改进版本。通过查看RCS-YOLO的整体架构可知,其中包括RCS-OSA模块。RCS-OSA模块在模型中用于堆叠RCS模块,以确保特征的复用并加强不同层之间的信息流动。本文就给大家详细介绍如何将RCS-YOLO…

<Linux> 基础IO

文章目录 基础IO文件描述符重定向重定向本质重定向系统调用 基础IO 文件描述符 系统底层提供打开文件(open)&#xff0c;读(read)&#xff0c;写(write)&#xff0c;关闭文件(close)的系统调用&#xff0c;如果想详细了解可以复制以下命令仔细阅读使用方法&#xff0c;这里不…

C++程序编译 错误提示和评测状态

编译常见错误提示 1.[Error] expected ; before cout。在cout前面&#xff0c;缺少一个分号。 2.[Error] arr was not declared in this scope。未定义变量名arr。 3.[Error] ld returned 1 exit status。重复运行错误(上一个运行的程序&#xff0c;输入窗口没有关掉)。 或者…

如何解决windows自动更新,释放C盘更新内存

第一步&#xff1a;首先关闭windows自动更新组件 没有更新windows需求&#xff0c;为了防止windows自动更新&#xff0c;挤占C盘空间&#xff0c;所以我们要采取停止Windows Update服务。按下WinR打开运行对话框&#xff0c;输入services.msc&#xff0c; 然后按Enter。在服务…

24上软考成绩预计6月底公布?附查分指南

最近&#xff0c;很多小伙伴都在问上半年成绩什么时候出来&#xff1f;每天学习群变成了祈祷群&#xff0c;都在祈祷45,45,45。按照上一次的成绩发布时间&#xff0c;从考试结束到成绩发布&#xff0c;间隔了32天。这次是不是会更快&#xff1f; 一般阅卷只要7-10天&#xff0c…

vb.net教程

下载地址&#xff1a;https://download.csdn.net/download/wgxds/89462547

内核模块的各种概念及示例

基本概念 (1)模块本身不被编译入内核映像&#xff0c;从而控制了内核镜像的大小。模块一旦insmod&#xff0c;它就和内核中的其他部分完全一样 (2)内核中已加载模块的信息也存在于/sys/module目录下&#xff1b;内核中将包含/sys/module/test_mod目录 (3)modprobe在加载某模…

系统架构设计师 - 数据库系统(1)

数据库系统 数据库系统数据库模式 ★分布式数据库 ★★★数据库设计阶段 ★★ER模型 ★关系模型 ★ ★结构约束条件完整性约束 关系代数 ★ ★ ★ ★概述自然连接 大家好呀&#xff01;我是小笙&#xff0c;本章我主要分享系统架构设计师 - 数据库系统(1)知识&#xff0c;希望内…

2024-06-20力扣每日一题

链接&#xff1a; 2748. 美丽下标对的数目 **废话&#xff1a;**彩笔做题家回归&#xff0c;要开始找工作噜 题意 在数组里&#xff0c;按i<j规则取两个数字nums[i]和nums[j]&#xff0c;只要nums[i]的第一位数字和nums[j]的最后一位数字互质&#xff0c;则结果加一 解…

RX8025/INS5T8025实时时钟-国产兼容RS4TC8025

该模块是一个符合I2C总线接口的实时时钟&#xff0c;包括一个32.768 kHz的DTCXO。 除了提供日历&#xff08;年、月、日、日、时、分、秒&#xff09;功能和时钟计数器功能外&#xff0c;该模块还提供了大量其他功能&#xff0c;包括报警功能、唤醒定时器功能、时间更新中断功能…

访问控制列表(Access Control Lists,ACL)与哈希查找的爱恨情怨

访问控制列表&#xff08;Access Control Lists&#xff0c;ACL&#xff09;与哈希查找 什么是访问控制列表ACL&#xff1f;直接说ACL是干啥的ACL概念为什么需要ACLACL类型ACL匹配机制使用例子 哈希查找什么是哈希查找&#xff1f;哈希查找的基本原理哈希查找的步骤 哈希查找在…

H3C防火墙抓包(命令行)

命令行 请按照如下步骤收集下设备的debug信息 1&#xff09; 创建一个空ACL 3XXX&#xff0c;写上两条明细rule&#xff0c;分别对应来回流量的源目地址 [FW]acl advanced 3XXX [FW-acl-ipv4-adv-3XXX]rule permit ip source x.x.x.x 0 destination y.y.y.y 0 [FW-acl…

js如何实现开屏弹窗

开屏弹窗是什么&#xff0c;其实就是第一次登录后进入页面给你的一种公告提示&#xff0c;此后再回到当前这个页面时弹窗是不会再出现的。也就是说这个弹窗只会出现一次。 <!DOCTYPE html> <html><head><meta charset"utf-8"><title>…

【绝对有用】C++ vector const函数和右值移动

std::vector 是 C 标准库中的动态数组&#xff0c;提供了许多方便的函数来操作数组。以下是 std::vector 的常用函数及其使用方法&#xff1a; 构造函数 vector()&#xff1a;默认构造函数&#xff0c;创建一个空的 vector。vector(size_t n)&#xff1a;创建一个包含 n 个默…

索引和深分页优化案例

一、初始状态没加索引 总数据100w左右 浅分页 查询10条需要1.5s左右 select * from timer_task where app hzhXtimer order by run_timer limit 0,10深分页查询10条需要1.7s左右 select * from timer_task where app hzhXtimer order by run_timer limit 100000,10看执…

视频采集概念

视频采集通常指的是将视频信号从视频源&#xff08;如摄像头、视频播放器等&#xff09;捕获并转换为数字格式&#xff0c;以便于计算机处理和存储。 步骤&#xff1a; 视频信号捕获&#xff1a;通过摄像头、网络摄像头、视频采集卡等设备将视频信号捕获。 信号转换&#xff…