java(kotlin) ai框架djl

DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:

MXNet:由Apache软件基金会支持的开源深度学习框架。
PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。
TensorFlow:由Google开发的另一个流行的开源机器学习框架。
DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。

maven

 <!--        djl--><dependency><groupId>ai.djl</groupId><artifactId>api</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-engine</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl.pytorch</groupId><artifactId>pytorch-model-zoo</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl</groupId><artifactId>basicdataset</artifactId><version>0.28.0</version></dependency><dependency><groupId>ai.djl</groupId><artifactId>model-zoo</artifactId><version>0.28.0</version></dependency><!--        /djl-->

Java DJL 架构图

┌──────────────────────────────┐
│          ModelZoo            │
├──────────────────────────────┤
│            Model             │
└───────────────┬──────────────┘│┌─────────▼─────────┐│       Engine      │└───────┬─┬─────────┘│ │┌───────▼─▼─────────┐│     NDManager     │└───────┬─┬─────────┘│ │┌─────────▼─▼───────────┐│    Dataset └─────────┬─────────────┘│┌─────────▼─────────────┐│  Trainer / Predictor  │└───────────────────────┘

主要组件详细描述

1. ModelZoo 和 Model
2. Dataset
  • 常见的数据集类型:

    1. RandomAccessDataset:
      • RandomAccessDataset 是一种基本的数据集接口,适用于数据可以随机访问的情况,如数组或列表。
      • 它支持批处理(batching)、数据切片(slicing)等操作,适合大多数监督学习任务。
    2. IterableDataset:
      • IterableDataset 适用于数据不能随机访问的情况,如流数据或实时生成的数据。
      • 它通过迭代器(iterator)提供数据,适用于需要动态生成或处理的数据源。
    3. RecordDataset:
      • RecordDataset 是基于记录文件(record file)的数据集格式,常用于大规模数据处理。
      • 它可以高效地加载和处理数据记录,适用于分布式训练和大数据集的处理。

    DJL 的数据集组件提供的功能包括:

    1. 数据加载和预处理:
      • 支持从多种数据源加载数据,如本地文件、远程服务器、数据库等。
      • 提供数据预处理功能,如归一化、数据增强、特征提取等。
    2. 批处理(Batching):
      • 支持将数据分成小批次进行处理,适用于大规模数据集的训练。
      • 提供灵活的批处理策略,可根据需要进行自定义。
    3. 数据变换(Transformations):
      • 提供多种数据变换功能,如图像变换、文本处理、数值处理等。
      • 支持链式调用,将多个变换操作组合在一起,形成数据处理管道。
    4. 数据加载器(DataLoader):
      • DataLoader 负责将数据集打包成批次,并在训练过程中按需提供数据。
      • 支持多线程数据加载,提高数据处理效率。
  • Dataset:定义数据集的抽象类,用户可以继承该类来实现自定义的数据集。

    • import ai.djl.Model;
      import ai.djl.ModelException;
      import ai.djl.inference.Predictor;
      import ai.djl.modality.Classifications;
      import ai.djl.modality.cv.Image;
      import ai.djl.modality.cv.ImageFactory;
      import ai.djl.repository.zoo.Criteria;
      import ai.djl.repository.zoo.ModelZoo;
      import ai.djl.translate.TranslateException;import java.io.IOException;
      import java.nio.file.Paths;public class DjlExample {public static void main(String[] args) throws IOException, ModelException, TranslateException {// 加载模型Criteria<Image, Classifications> criteria = Criteria.builder().optEngine("TensorFlow") // 选择引擎.setTypes(Image.class, Classifications.class).optModelPath(Paths.get("path/to/model")).build();try (Model model = ModelZoo.loadModel(criteria);Predictor<Image, Classifications> predictor = model.newPredictor()) {// 加载图像Image img = ImageFactory.getInstance().fromFile(Paths.get("path/to/image.jpg"));// 进行推理Classifications result = predictor.predict(img);System.out.println(result);}}
      }
    • import ai.djl.Application;
      import ai.djl.Model;
      import ai.djl.basicdataset.cv.classification.FashionMnist;
      import ai.djl.engine.Engine;
      import ai.djl.metric.Metrics;
      import ai.djl.ndarray.NDArray;
      import ai.djl.ndarray.NDManager;
      import ai.djl.training.DefaultTrainingConfig;
      import ai.djl.training.EasyTrain;
      import ai.djl.training.Trainer;
      import ai.djl.training.dataset.Batch;
      import ai.djl.training.dataset.Dataset;
      import ai.djl.training.listener.TrainingListener;
      import ai.djl.training.loss.Loss;
      import ai.djl.training.optimizer.Optimizer;
      import ai.djl.training.tracker.Tracker;
      import ai.djl.translate.TranslateException;
      import ai.djl.util.Pair;import java.io.IOException;public class DJLDatasetExample {public static void main(String[] args) throws IOException, TranslateException {NDManager manager = NDManager.newBaseManager();FashionMnist fashionMnist = FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true) // 32 is the batch size.optLimit(Long.MAX_VALUE) // Use this to limit the number of samples.build();fashionMnist.prepare();Model model = Model.newInstance("fashion-mnist-model");TrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).optOptimizer(Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build()).addTrainingListeners(TrainingListener.Defaults.logging());try (Trainer trainer = model.newTrainer(config)) {trainer.initialize(new long[]{1, 28, 28}); // Example shape for image dataMetrics metrics = new Metrics();trainer.setMetrics(metrics);for (Batch batch : trainer.iterateDataset(fashionMnist)) {EasyTrain.trainBatch(trainer, batch);trainer.step();batch.close();}trainer.notifyListeners(listener -> listener.onTrainingEnd(trainer));}}
      }

3. Engine 和 NDManager
  • Engine:DJL支持多个深度学习引擎,如MXNet、PyTorch、ONNX、TensorFlow,Engine接口提供统一的抽象,方便切换底层引擎。

  • NDManager:管理NDArray,用于处理多维数组,封装了底层的数组操作。

    Using DJL Engine
    
    import ai.djl.Model
    import ai.djl.ModelException
    import ai.djl.ndarray.NDArray
    import ai.djl.ndarray.NDList
    import ai.djl.ndarray.types.Shape
    import ai.djl.translate.Batchifier
    import ai.djl.translate.TranslateException
    import ai.djl.translate.Translator
    import ai.djl.translate.TranslatorContext
    import java.io.IOException
    import java.nio.file.Pathsobject DJLEngineExample {@Throws(ModelException::class, TranslateException::class, IOException::class)@JvmStaticfun main(args: Array<String>) {// Initialize the modelval model = Model.newInstance("model-name", "ai.djl.pytorch") // Assuming "model-name" is valid and using PyTorch engine// Load a pre-trained modelmodel.load(Paths.get("path/to/your/model")) // Ensure the path is correct// Define a translator for data preprocessing and postprocessingval translator: Translator<Array<Float>, Float> = object : Translator<Array<Float>, Float> {override fun processInput(ctx: TranslatorContext, input: Array<Float>): NDList {val manager = ctx.ndManagerval array: NDArray = manager.create(input.toFloatArray()).reshape(Shape(1, input.size.toLong())) // Reshape might be necessaryreturn NDList(array)}override fun processOutput(ctx: TranslatorContext, list: NDList): Float {// Assuming the output is a single scalar valuereturn list[0].getFloat() // Use getFloat() to get the scalar value}override fun getBatchifier(): Batchifier? {return null // Or implement batching if needed}}model.newPredictor(translator).use { predictor ->val input = arrayOf(1.0f, 2.0f, 3.0f) // Input should match the model's expected input shapeval output = predictor.predict(input)println("Prediction: $output")}}
    }
    Overview of NDManager
    Key Features of NDManager:
    1. Memory Management: Automates the process of memory allocation and deallocation for NDArrays.
    2. Resource Scope: NDArrays created by an NDManager are tied to the lifecycle of that manager. When the manager is closed, all associated NDArrays are also released.
    3. Hierarchical Structure: NDManagers can create child managers, which can further manage their own NDArrays. This is useful for managing resources in complex workflows.
    Using NDManager
    
    import ai.djl.ndarray.NDManagerobject NDManagerExample {@JvmStaticfun main(args: Array<String>) {NDManager.newBaseManager().use { manager ->val array = manager.create(floatArrayOf(1.0f, 2.0f, 3.0f))println("Array: $array")// Perform operationsval result = array.add(2.0f)println("Result: $result")}// No need to explicitly free the memory, it's handled by the NDManager}
    }
    
4. Trainer 和 Predictor
  • Trainer 类

    提供训练模型的接口,包含优化器、损失函数和训练循环等功能。用于训练深度学习模型。它封装了训练过程中的一些常见操作,如前向传播、反向传播和参数更新。

    主要功能包括:

    • 模型的训练和验证
    • 管理优化器和损失函数
    • 提供易于使用的训练循环
    代码演示

    以下是使用 DJL 的 Trainer 类训练一个简单神经网络的示例代码:

    
    import ai.djl.Model
    import ai.djl.basicdataset.cv.classification.FashionMnist
    import ai.djl.basicmodelzoo.basic.Mlp
    import ai.djl.ndarray.types.Shape
    import ai.djl.training.DefaultTrainingConfig
    import ai.djl.training.TrainingConfig
    import ai.djl.training.dataset.Dataset
    import ai.djl.training.dataset.RandomAccessDataset
    import ai.djl.training.listener.LoggingTrainingListener
    import ai.djl.training.listener.TrainingListener
    import ai.djl.training.loss.Loss
    import ai.djl.training.optimizer.Optimizer
    import ai.djl.training.tracker.FixedPerVarTracker
    import ai.djl.training.util.ProgressBar
    import ai.djl.translate.TranslateException
    import java.io.IOException
    import java.nio.file.Pathsobject DjlTrainerDemo {@Throws(IOException::class, TranslateException::class)@JvmStaticfun main(args: Array<String>) {// Load datasetval trainDataset: RandomAccessDataset =FashionMnist.builder().optUsage(Dataset.Usage.TRAIN).setSampling(32, true).build()trainDataset.prepare(ProgressBar())// Define modelval model = Model.newInstance("mlp")model.block = Mlp(28 * 28, 10, intArrayOf(128, 64))// Define training configurationval config: TrainingConfig = DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()).optOptimizer(Optimizer.sgd().setLearningRateTracker(FixedPerVarTracker.builder().setDefaultValue(0.01f).build()).build()).addTrainingListeners(LoggingTrainingListener())model.newTrainer(config).use { trainer ->trainer.initialize(Shape(1, (28 * 28).toLong()))for (epoch in 0..9) {for (batch in trainer.iterateDataset(trainDataset)) {trainer.step()batch.close()}trainer.notifyListeners { listener: TrainingListener ->listener.onEpoch(trainer)}}model.save(Paths.get("model"), "mlp")}}
    }
    Predictor 类

    用于模型推理,接收输入数据并返回预测结果。用于对训练好的模型进行推理。它提供了一个简单的接口,用于将输入数据传递给模型并获取预测结果。

    主要功能包括:

    • 加载模型进行推理
    • 处理输入和输出数据的转换
    代码演示
    
    import ai.djl.Model
    import ai.djl.modality.Classifications
    import ai.djl.ndarray.NDArray
    import ai.djl.ndarray.NDList
    import ai.djl.ndarray.NDManager
    import ai.djl.ndarray.types.Shape
    import ai.djl.translate.Batchifier
    import ai.djl.translate.TranslateException
    import ai.djl.translate.Translator
    import ai.djl.translate.TranslatorContext
    import java.io.IOException
    import java.nio.file.Pathsobject DjlPredictorDemo {@Throws(IOException::class, TranslateException::class)@JvmStaticfun main(args: Array<String>) {// Load modelval model = Model.newInstance("mlp")model.load(Paths.get("model"), "mlp")// Define Translatorval translator: Translator<NDArray, Classifications> = object : Translator<NDArray, Classifications> {override fun processInput(ctx: TranslatorContext, input: NDArray): NDList {return NDList(input.reshape(Shape(1, (28 * 28).toLong())))}override fun processOutput(ctx: TranslatorContext, list: NDList): Classifications {// Assuming the output NDArray is the first element in NDListval probabilities = list.singletonOrThrow()return Classifications(listOf("Label1", "Label2"), probabilities) // Example labels}override fun getBatchifier(): Batchifier {return Batchifier.STACK}}model.newPredictor(translator).use { predictor ->val manager = NDManager.newBaseManager()val array = manager.ones(Shape(1, (28 * 28).toLong()))val classifications = predictor.predict(array)println(classifications)}}
    }

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/28476.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python进阶:从函数到文件的编程艺术!!!

第二章&#xff1a;Python进阶 模块概述 函数是一段可重复使用的代码块&#xff0c;它接受输入参数并返回一个结果。函数可以用于执行特定的任务、计算结果、修改数据等&#xff0c;使得代码更具模块化和可重用性。 模块是一组相关函数、类和变量的集合&#xff0c;它们被封…

第 2 章:Spring Framework 中的 IoC 容器

控制反转&#xff08;Inversion of Control&#xff0c;IoC&#xff09;与 面向切面编程&#xff08;Aspect Oriented Programming&#xff0c;AOP&#xff09;是 Spring Framework 中最重要的两个概念&#xff0c;本章会着重介绍前者&#xff0c;内容包括 IoC 容器以及容器中 …

Yum安装LAMP

查看当前80端口是否被占用 ss -tulanp | grep 80查询httpd是否在yum源中 yum info httpd安装httpd yum -y install httpd启动httpd服务&#xff0c;设置开机自启 systemctl enable httpd --now systemctl start httpd查看当前进程 ps aux | grep httpd查看当前IP&#xff…

【机器学习】大模型环境下的应用:计算机视觉的探索与实践

引言 随着数据量的爆炸性增长和计算能力的提升&#xff0c;机器学习&#xff08;Machine Learning&#xff0c;ML&#xff09;在计算机视觉&#xff08;Computer Vision&#xff0c;CV&#xff09;领域的应用日益广泛。特别是大模型&#xff08;Large Models&#xff09;如深度…

【Qt 6.3 基础教程 03】第一个Qt应用:Hello World

文章目录 前言创建项目编写代码main.cppmainwindow.cpp 编译和运行结果和调试扩展你的应用总结 前言 Qt编程之旅的第一个里程碑通常是构建一个简单的"Hello World"应用程序。在这个教程中&#xff0c;我们将指导你如何创建一个基本的Qt应用程序&#xff0c;它将显示…

自动化技术如何影响企业数据分析的发展

当今时代&#xff0c;企业普遍面临着转型的压力&#xff0c;这些挑战主要源于在科技和市场的双重压力下如何实现增长。当前&#xff0c;企业发展的趋势是紧追自动化的浪潮&#xff0c;并通过优化预算管理流程&#xff0c;推进系统和数据分析的现代化。在这一过程中&#xff0c;…

LoRA用于高效微调的基本原理

Using LoRA for efficient fine-tuning: Fundamental principles — ROCm Blogs (amd.com) 大型语言模型的低秩适配&#xff08;LoRA&#xff09;用于解决微调大型语言模型&#xff08;LLMs&#xff09;的挑战。GPT和Llama等拥有数十亿参数的模型&#xff0c;特定任务或领域的微…

怎样搭建serveru ftp个人服务器

首先说说什么是ftp&#xff1f; FTP协议是专门针对在两个系统之间传输大的文件这种应用开发出来的&#xff0c;它是TCP/IP协议的一部分。FTP的意思就是文件传输协议&#xff0c;用来管理TCP/IP网络上大型文件的快速传输。FTP早也是在Unix上开发出来的&#xff0c;并且很长一段…

Vue54-浏览器的本地存储webStorage

一、本地存储localStorage的作用 二、本地存储的代码实现 2-1、存储数据 注意&#xff1a; localStorage是window上的函数&#xff0c;所以&#xff0c;可以把window.localStorage直接写成localStorage&#xff08;直接调用&#xff01;&#xff09; 默认调了p.toString()方…

curl命令行发送post/get请求

文章目录 curl概述post请求get请求 curl概述 curl 是一个命令行实用程序&#xff0c;允许用户创建网络请求curl 在Windows、 Linux 和 Mac 上皆可使用 post请求 一个简单的 POST 请求 -X&#xff1a;指定与远程服务器通信时将使用哪种 HTTP 请求方法 curl -X POST http://ex…

中小企业使用CRM系统的优势有哪些

中小企业如何在竞争激烈的市场中脱颖而出&#xff1f;除了优秀的产品和服务&#xff0c;一个高效的管理工具也是必不可少的。而客户关系管理&#xff08;CRM&#xff09;系统正是这样一个能帮助企业提升客户体验、优化内部管理流程的重要工具。接下来&#xff0c;让我们一起探讨…

【Android面试八股文】Java的泛型中super 和 extends 有什么区别?

文章目录 Java的泛型中super 和 extends 有什么区别?这道题想考察什么?考察的知识点考生应该如何回答一、 extends二、super三、PECS原则3.1 解释 PECS 原则3.2 PECS原则的总结3.3 PECS原则的应用场景Java的泛型中super 和 extends 有什么区别? 这道题想考察什么? 掌握PE…

主流框架选择:React、Angular、Vue的详细比较

目前前端小伙伴经常使用三种广泛使用的开发框架&#xff1a;React、Angular、Vue - 来设计网站 Reactjs&#xff1a;效率和多功能性而闻名 Angularjs&#xff1a;创建复杂的应用程序提供了完整的解决方案&#xff0c;紧凑且易于使用的框架 Vuejs&#xff1a;注重灵活性和可重用…

Prometheus之图形化界面grafana与服务发现

前言 上一篇文章中我们介绍了Prometheus的组件&#xff0c;监控作用&#xff0c;部署方式&#xff0c;以及如何通过在客户机安装exporter再添加监控项的操作。 但是不免会发现原生的Prometheus的图像化界面对于监控数据并不能其他很好的展示效果。所以本次我们将介绍一…

Cookie-SameSite属性 前端请求不带cookie的问题解决方案

最近遇到了前端请求后端不带cookie的问题&#xff0c; 请求时header里面就是没有cookie 查看响应应该是这个问题 SameSite是一个cookie属性&#xff0c;用于控制浏览器是否在跨站点请求中发送cookie。它有三个可能的值&#xff1a; 1. Strict&#xff08;严格模式&#xff09…

ubuntu安装和应用以及要点难点

Ubuntu是一个基于Linux的免费开源操作系统,它以桌面应用为主,但同样适用于服务器和其他特定用途。以下是关于Ubuntu的详细介绍: 起源与名称: Ubuntu的名称来源于非洲南部祖鲁语或豪萨语的“ubuntu”一词,意思是“人性”、“我的存在是因为大家的存在”,体现了非洲传统的…

Python中的自定义异常类与异常处理机制深度解析

Python中的自定义异常类与异常处理机制深度解析 在Python编程中&#xff0c;异常处理是一种重要的编程范式&#xff0c;它允许我们在程序运行时检测并处理错误。Python内置了一些常见的异常类&#xff0c;但有时候我们可能需要定义自己的异常类&#xff0c;以更精确地描述和处…

2024华为OD机试真题-出租车计费 、靠谱的车-(C++/Python)-C卷D卷-100分

2024华为OD机试题库-(C卷+D卷)-(JAVA、Python、C++) 题目描述: 程序员小明打了一辆出租车去上班。出于职业敏感,他注意到这辆出租车的计费表有点问题,总是偏大。 出租车司机解释说他不喜欢数字4,所以改装了计费表,任何数字位置遇到数字4就直接跳过,其余功能都正常。 比如…

超硬核五千字!彻底讲明白JavaScript中的异步和同步,以及JavaScript代码执行顺序

同步操作和异步操作是编程中处理任务的两种不同方式&#xff0c;它们主要区别在于控制流和对程序执行的影响。不知道大家是怎么理解JavaScript中的同步和异步的&#xff1f;JavaScript的代码执行顺序是怎么样&#xff1f;下面这段代码是同步还是异步的&#xff1f; console.log…

浙大版PTA Python程序设计 题目与知识点整理(综合版)

目录 第一章 一、高级语言程序的执行方式 二、变量赋值与内存地址 三、字符编码 3.1 Unicode 3.2 ASCII&#xff08;American Standard Code for Information Interchange&#xff09; 四、编程语言分类按照编程范式分类 4.1 面向过程语言 4.2 面向对象语言 五、原码…