java版本lstm_LSTM java 实现

由于实验室事情缘故,需要将Python写的神经网络转成Java版本的,但是python中的numpy等啥包也不知道在Java里面对应的是什么工具,所以索性直接寻找一个现成可用的Java神经网络框架,于是就找到了JOONE,JOONE是一个神经网络的开源框架,使用的是BP算法进行迭代计算参数,使用起来比较方便也比较实用,下面介绍一下JOONE的一些使用方法。

JOONE需要使用一些外部的依赖包,这在官方网站上有,也可以在这里下载。将所需的包引入工程之后,就可以进行编码实现了。

首先看下完整的程序,这个是上面那个超链接给出的程序,应该是官方给出的一个示例吧,因为好多文章都用这个,这其实是神经网络训练一个异或计算器:

import org.joone.engine.*;

import org.joone.engine.learning.*;

import org.joone.io.*;

import org.joone.net.*;

/*

*

* JOONE实现

*

* */

public class XOR_using_NeuralNet implements NeuralNetListener

{

private NeuralNet nnet = null;

private MemoryInputSynapse inputSynapse, desiredOutputSynapse;

LinearLayer input;

SigmoidLayer hidden, output;

boolean singleThreadMode = true;

// XOR input

private double[][] inputArray = new double[][]

{

{ 0.0, 0.0 },

{ 0.0, 1.0 },

{ 1.0, 0.0 },

{ 1.0, 1.0 } };

// XOR desired output

private double[][] desiredOutputArray = new double[][]

{

{ 0.0 },

{ 1.0 },

{ 1.0 },

{ 0.0 } };

/**

* @param args

*            the command line arguments

*/

public static void main(String args[])

{

XOR_using_NeuralNet xor = new XOR_using_NeuralNet();

xor.initNeuralNet();

xor.train();

xor.interrogate();

}

/**

* Method declaration

*/

public void train()

{

// set the inputs

inputSynapse.setInputArray(inputArray);

inputSynapse.setAdvancedColumnSelector(" 1,2 ");

// set the desired outputs

desiredOutputSynapse.setInputArray(desiredOutputArray);

desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");

// get the monitor object to train or feed forward

Monitor monitor = nnet.getMonitor();

// set the monitor parameters

monitor.setLearningRate(0.8);

monitor.setMomentum(0.3);

monitor.setTrainingPatterns(inputArray.length);

monitor.setTotCicles(5000);

monitor.setLearning(true);

long initms = System.currentTimeMillis();

// Run the network in single-thread, synchronized mode

nnet.getMonitor().setSingleThreadMode(singleThreadMode);

nnet.go(true);

System.out.println(" Total time=  "

+ (System.currentTimeMillis() - initms) + "  ms ");

}

private void interrogate()

{

double[][] inputArray = new double[][]

{

{ 1.0, 1.0 } };

// set the inputs

inputSynapse.setInputArray(inputArray);

inputSynapse.setAdvancedColumnSelector(" 1,2 ");

Monitor monitor = nnet.getMonitor();

monitor.setTrainingPatterns(4);

monitor.setTotCicles(1);

monitor.setLearning(false);

MemoryOutputSynapse memOut = new MemoryOutputSynapse();

// set the output synapse to write the output of the net

if (nnet != null)

{

nnet.addOutputSynapse(memOut);

System.out.println(nnet.check());

nnet.getMonitor().setSingleThreadMode(singleThreadMode);

nnet.go();

for (int i = 0; i 

{

double[] pattern = memOut.getNextPattern();

System.out.println(" Output pattern # " + (i + 1) + " = "

+ pattern[0]);

}

System.out.println(" Interrogating Finished ");

}

}

/**

* Method declaration

*/

protected void initNeuralNet()

{

// First create the three layers

input = new LinearLayer();

hidden = new SigmoidLayer();

output = new SigmoidLayer();

// set the dimensions of the layers

input.setRows(2);

hidden.setRows(3);

output.setRows(1);

input.setLayerName(" L.input ");

hidden.setLayerName(" L.hidden ");

output.setLayerName(" L.output ");

// Now create the two Synapses

FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */

FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */

// Connect the input layer whit the hidden layer

input.addOutputSynapse(synapse_IH);

hidden.addInputSynapse(synapse_IH);

// Connect the hidden layer whit the output layer

hidden.addOutputSynapse(synapse_HO);

output.addInputSynapse(synapse_HO);

// the input to the neural net

inputSynapse = new MemoryInputSynapse();

input.addInputSynapse(inputSynapse);

// The Trainer and its desired output

desiredOutputSynapse = new MemoryInputSynapse();

TeachingSynapse trainer = new TeachingSynapse();

trainer.setDesired(desiredOutputSynapse);

// Now we add this structure to a NeuralNet object

nnet = new NeuralNet();

nnet.addLayer(input, NeuralNet.INPUT_LAYER);

nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);

nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);

nnet.setTeacher(trainer);

output.addOutputSynapse(trainer);

nnet.addNeuralNetListener(this);

}

public void cicleTerminated(NeuralNetEvent e)

{

}

public void errorChanged(NeuralNetEvent e)

{

Monitor mon = (Monitor) e.getSource();

if (mon.getCurrentCicle() % 100 == 0)

System.out.println(" Epoch:  "

+ (mon.getTotCicles() - mon.getCurrentCicle()) + "  RMSE: "

+ mon.getGlobalError());

}

public void netStarted(NeuralNetEvent e)

{

Monitor mon = (Monitor) e.getSource();

System.out.print(" Network started for  ");

if (mon.isLearning())

System.out.println(" training. ");

else

System.out.println(" interrogation. ");

}

public void netStopped(NeuralNetEvent e)

{

Monitor mon = (Monitor) e.getSource();

System.out.println(" Network stopped. Last RMSE= "

+ mon.getGlobalError());

}

public void netStoppedError(NeuralNetEvent e, String error)

{

System.out.println(" Network stopped due the following error:  "

+ error);

}

}

现在我会逐步解释上面的程序。

【1】 从main方法开始说起,首先第一步新建一个对象:

XOR_using_NeuralNet xor = new XOR_using_NeuralNet();

【2】然后初始化神经网络:

xor.initNeuralNet();

初始化神经网络的方法中:

// First create the three layers

input = new LinearLayer();

hidden = new SigmoidLayer();

output = new SigmoidLayer();

// set the dimensions of the layers

input.setRows(2);

hidden.setRows(3);

output.setRows(1);

input.setLayerName(" L.input ");

hidden.setLayerName(" L.hidden ");

output.setLayerName(" L.output ");

上面代码解释:

input=new LinearLayer()是新建一个输入层,因为神经网络的输入层并没有训练参数,所以使用的是线性层;

hidden = new SigmoidLayer();这里是新建一个隐含层,使用sigmoid函数作为激励函数,当然你也可以选择其他的激励函数,如softmax激励函数

output则是新建一个输出层

之后的三行代码是建立输入层、隐含层、输出层的神经元个数,这里表示输入层为2个神经元,隐含层是3个神经元,输出层是1个神经元

最后的三行代码是给每个输出层取一个名字。

// Now create the two Synapses

FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */

FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */

// Connect the input layer whit the hidden layer

input.addOutputSynapse(synapse_IH);

hidden.addInputSynapse(synapse_IH);

// Connect the hidden layer whit the output layer

hidden.addOutputSynapse(synapse_HO);

output.addInputSynapse(synapse_HO);

上面代码解释:

上面代码的主要作用是将三个层连接起来,synapse_IH用来连接输入层和隐含层,synapse_HO用来连接隐含层和输出层

// the input to the neural net

inputSynapse = new MemoryInputSynapse();

input.addInputSynapse(inputSynapse);

// The Trainer and its desired output

desiredOutputSynapse = new MemoryInputSynapse();

TeachingSynapse trainer = new TeachingSynapse();

trainer.setDesired(desiredOutputSynapse);

上面代码解释:

上面的代码是在训练的时候指定输入层的数据和目的输出的数据,

inputSynapse = new MemoryInputSynapse();这里指的是使用了从内存中输入数据的方法,指的是输入层输入数据,当然还有从文件输入的方法,这点在文章后面再谈。同理,desiredOutputSynapse = new MemoryInputSynapse();也是从内存中输入数据,指的是从输入层应该输出的数据

// Now we add this structure to a NeuralNet object

nnet = new NeuralNet();

nnet.addLayer(input, NeuralNet.INPUT_LAYER);

nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);

nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);

nnet.setTeacher(trainer);

output.addOutputSynapse(trainer);

nnet.addNeuralNetListener(this);

上面代码解释:

这段代码指的是将之前初始化的构件连接成一个神经网络,NeuralNet是JOONE提供的类,主要是连接各个神经层,最后一个nnet.addNeuralNetListener(this);这个作用是对神经网络的训练过程进行监听,因为这个类实现了NeuralNetListener这个接口,这个接口有一些方法,可以实现观察神经网络训练过程,有助于参数调整。

【3】然后我们来看一下train这个方法:

inputSynapse.setInputArray(inputArray);

inputSynapse.setAdvancedColumnSelector(" 1,2 ");

// set the desired outputs

desiredOutputSynapse.setInputArray(desiredOutputArray);

desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");

上面代码解释:

inputSynapse.setInputArray(inputArray);这个方法是初始化输入层数据,也就是指定输入层数据的内容,inputArray是程序中给定的二维数组,这也就是为什么之前初始化神经网络的时候使用的是MemoryInputSynapse,表示从内存中读取数据

inputSynapse.setAdvancedColumnSelector(" 1,2 ");这个表示的是输入层数据使用的是inputArray的前两列数据。

desiredOutputSynapse这个也同理

Monitor monitor = nnet.getMonitor();

// set the monitor parameters

monitor.setLearningRate(0.8);

monitor.setMomentum(0.3);

monitor.setTrainingPatterns(inputArray.length);

monitor.setTotCicles(5000);

monitor.setLearning(true);

上面代码解释:

这个monitor类也是JOONE框架提供的,主要是用来调节神经网络的参数,monitor.setLearningRate(0.8);是用来设置神经网络训练的步长参数,步长越大,神经网络梯度下降的速度越快,monitor.setTrainingPatterns(inputArray.length);这个是设置神经网络的输入层的训练数据大小size,这里使用的是数组的长度;monitor.setTotCicles(5000);这个指的是设置迭代数目;monitor.setLearning(true);这个true表示是在训练过程。

nnet.getMonitor().setSingleThreadMode(singleThreadMode);

nnet.go(true);

上面代码解释:

nnet.getMonitor().setSingleThreadMode(singleThreadMode);这个指的是是不是使用多线程,但是我不太清楚这里的多线程指的是什么意思

nnet.go(true)表示的是开始训练。

【4】最后来看一下interrogate方法

double[][] inputArray = new double[][]

{

{ 1.0, 1.0 } };

// set the inputs

inputSynapse.setInputArray(inputArray);

inputSynapse.setAdvancedColumnSelector(" 1,2 ");

Monitor monitor = nnet.getMonitor();

monitor.setTrainingPatterns(4);

monitor.setTotCicles(1);

monitor.setLearning(false);

MemoryOutputSynapse memOut = new MemoryOutputSynapse();

// set the output synapse to write the output of the net

if (nnet != null)

{

nnet.addOutputSynapse(memOut);

System.out.println(nnet.check());

nnet.getMonitor().setSingleThreadMode(singleThreadMode);

nnet.go();

for (int i = 0; i 

{

double[] pattern = memOut.getNextPattern();

System.out.println(" Output pattern # " + (i + 1) + " = "

+ pattern[0]);

}

System.out.println(" Interrogating Finished ");

}

这个方法相当于测试方法,这里的inputArray是测试数据, 注意这里需要设置monitor.setLearning(false);,因为这不是训练过程,并不需要学习,monitor.setTrainingPatterns(4);这个是指测试的数量,4表示有4个测试数据(虽然这里只有一个)。这里还给nnet添加了一个输出层数据对象,这个对象mmOut是初始测试结果,注意到之前我们初始化神经网络的时候并没有给输出层指定数据对象,因为那个时候我们在训练,而且指定了trainer作为目的输出。

接下来就是输出结果数据了,pattern的个数和输出层的神经元个数一样大,这里输出层神经元的个数是1,所以pattern大小为1.

【5】我们看一下测试结果:

Output pattern # 1 = 0.018303527517809233

表示输出结果为0.01,根据sigmoid函数特性,我们得到的输出是0,和预期结果一致。如果输出层神经元个数大于1,那么输出值将会有多个,因为输出层结果是0|1离散值,所以我们取输出最大的那个神经元的输出值取为1,其他为0

【6】最后我们来看一下神经网络训练过程中的一些监听函数:

cicleTerminated:每个循环结束后输出的信息

errorChanged:神经网络错误率变化时候输出的信息

netStarted:神经网络开始运行的时候输出的信息

netStopped:神经网络停止的时候输出的信息

【7】好了,JOONE基本上内容就是这些。还有一些额外东西需要说明:

1,从文件中读取数据构建神经网络

2.如何保存训练好的神经网络到文件夹中,只要测试的时候直接load到内存中就行,而不用每次都需要训练。

【8】先看第一个问题:

从文件中读取数据:

文件的格式:

0;0;0

1;0;1

1;1;0

0;1;1

中间使用分号隔开,使用方法如下,也就是把上文的MemoryInputSynapse换成FileInputSynapse即可。

fileInputSynapse = new FileInputSynapse();

input.addInputSynapse(fileInputSynapse);

fileDisireOutputSynapse = new FileInputSynapse();

TeachingSynapse trainer = new TeachingSynapse();

trainer.setDesired(fileDisireOutputSynapse);

我们看下文件是如何输出数据的:

private File inputFile = new File(Constants.TRAIN_WORD_VEC_PATH);

fileInputSynapse.setInputFile(inputFile);

fileInputSynapse.setFirstCol(2);//使用文件的第2列到第3列作为输出层输入

fileInputSynapse.setLastCol(3);

fileDisireOutputSynapse.setInputFile(inputFile);

fileDisireOutputSynapse.setFirstCol(1);//使用文件的第1列作为输出数据

fileDisireOutputSynapse.setLastCol(1);

其余的代码和上文的是一样的。

【9】然后看第二个问题:

如何保存神经网络

其实很简单,直接序列化nnet对象就行了,然后读取该对象就是java的反序列化,这个就不多做介绍了,比较简单。但是需要说明的是,保存神经网络的时机一定是在神经网络训练完毕后,可以使用下面代码:

public void netStopped(NeuralNetEvent e) {

Monitor mon = (Monitor) e.getSource();

try {

if (mon.isLearning()) {

saveModel(nnet); //序列化对象

}

} catch (IOException ee) {

// TODO Auto-generated catch block

ee.printStackTrace();

}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/245686.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

easy excel date 类型解析报错_ptarchiver原理解析

pt-archiver原理解析作为MySQL DBA,可以说应该没有不知道pt-archiver了,作为pt-toolkit套件中的重要成员,往往能够轻松帮助DBA解决数据归档的问题。例如线上一个流水表,业务仅仅只需要存放最近3个月的流水数据,三个月前…

python中np没有定义_python中的np.empty_python – np.empty,np.zeros和np.one

我很好奇它使用np.empty而不是np.zeros实际上有多大差异,还有关于np.ones的差异.我运行这个小脚本来测试每个创建一个大型数组所需的时间:import numpy as npfrom timeit import timeitN 10_000_000dtypes [np.int8, np.int16, np.int32, np.int64,np.uint8, np.u…

设计师电脑推荐笔记本_笔记本电脑选购推荐全攻略

笔记本电脑现如今在我们的生活中出镜率如此之高,不论是学生党查阅资料还是上班族的忙碌办公,抑或是电竞爱好者的游戏体验,都要依靠笔记本电脑来完成,但对于它的选购方法你了解多少?一千个人就有一千种不同的标准&#…

java s1_转!!Java 基础面试题的剖析: short s1=1;s1 = s1 +1 报错? s1+=1 呢

package common;public class ShortTypeTest {/** param args*/public static void main(String[] args) {// TODO Auto-generated method stubshort s1 1; s1 (short) (s1 1);//简单类型short s21; s2 1;//复合类型,复合赋值操作符,System.out.println(s1";"s2);…

python内循环只执行一次_如何1分钟内从3600只股种选出黑马股,仅分享这一次(太透彻了)...

“1234”操盘定理中,1是指如何持牛股,2是指观念要正确,3是指看破十个现象,4是指抓热点。具体讲解下选股步骤盘后我们怎么去快速选出超短线需要密切关注的票呢?其实选股方法少说也有几千种,我想没有人统计过&#xff0…

java 运行时类型_Java基础之RTTI 运行时类型识别

运行时类型识别(RTTI, Run-Time Type Identification)是Java中非常有用的机制,在Java运行时,RTTI维护类的相关信息。多态(polymorphism)是基于RTTI实现的。RTTI的功能主要是由Class类实现的。Class类Class类是"类的类"(class of classes)。如果…

指定的服务已经标记为删除_你的电话号码被标记过吗?你知道这件事情还能赚钱吗?...

今天闲来无事和大家唠唠赚钱的小副业。“电话标记”,我被这事困惑了很多年,最近解决了,同时还发现,这个信息差能挣钱,文末还给到方法,执行力强的伙伴原样照做,0成本,马上开搞&#x…

mplab x ide 中文使用手册_SCI必备利器:翻译又快又准,强推这款超牛X的神器!...

随手转发给好友和朋友圈 编辑:科研小通再分享一款翻译神器,不用调用Google服务器,速度超快。实时翻译,服务器速度杠杠的。今天给大家安利一款超牛X的翻译神器:Mate Translate。官网首页https://gikken.co/mate-transla…

python流行趋势_Python流行度再创新高,学Python就从风变编程开始

10月初,全球编程语言社区TIOBE公布了2020年10月编程语言排行榜,排名情况相较前几个月变化不大,前十名分别为C、Java、Python、C 、C#、Visual Basic、JavaScript 、PHP、R和SQL。其中,Python继续稳居第三,且其受欢迎度…

8086汇编4位bcd码_238期中4头3尾,排列五第19239期爱我彩规

爱我彩规专业研究(七星彩、排列五) 前四位的铁码与定位规,有幸开通爱我彩规公众号,努力为大家提供稳定的号码参考。作者微信号awc1125。 逢星期二和星期天不在彩码课堂公众号转发,因星期二和星期天每日八篇巳排满,只发爱我彩规公…

python通过链接下载文件-如何使用Python通过HTTP下载文件?

import urllib urllib.urlretrieve ("http://www.example.com/songs/mp3.mp3", "mp3.mp3") (用于Python 3)import urllib.request和urllib.request.urlretrieve) 还有一个有“进度栏”的import urllib2 url "http://download.thinkbroadband.com/10M…

centos6.5 编译安装mysql_Centos6.5编译安装mysql 5.7.14详细教程

此文实例给亲们分享了CENTOS6.5 编译mysql 5.7.14安装配置方法,供大家参考,具体内容如下mysql5.7.14 编译安装在自定义文件路径下下载安装包配置安装环境编译安装cmake\-DCMAKE_INSTALL_PREFIX/data/db5714 \-DMYSQL_DATADIR/data/db5714/var \-DMYSQL_U…

时间插件只能选择整点和半点_我花一小时自制了三款PPT插件,不仅免费分享,还想手把手教你制作...

更准确的说,三顿花一小时给PPT里这个天天和你见面的功能区做了一次彻底的整容:我精简了好多根本用不到的功能,还添加了一大波可以让你效率翻倍的一键操作,比如一键拆分文字,一键美化图表等等。这样的改头换面操作起来一…

c主线程如何等待子线程结束 linux_使用互斥量进行同步 - Linux C进程与多线程入门_Linux编程_Linux公社-Linux系统门户网站...

互斥简单地理解就是,一个线程进入工作区后,如果有其他线程想要进入工作区,它就会进入等待状态,要等待工作区内的线程结束后才可以进入。基本函数(1) pthread_mutex_init函数原型:int pthread_mutex_init ( pthread_mut…

电脑声音太小如何增强_感觉手机音量太小了?教你这样设置,声音立马大上许多...

不管是打电话,还是看电视,如果觉得手机的声音太小了,总会感到听起来很吃力,那么我们遇到这种情况,可以怎么办呢?建议大家看看下面这个方法,或许会让你的手机音量瞬间变大。1、打开单声道音频如今…

java 8时间操作_Java8 时间日期类操作

Java8 时间日期类操作Java8的时间类有两个重要的特性线程安全不可变类,返回的都是新的对象显然,该特性解决了原来java.util.Date类与SimpleDateFormat线程不安全的问题。同时Java8的时间类提供了诸多内置方法,方便了对时间进行相应的操作。上…

java虚拟机_一文彻底读懂Java虚拟机!(JVM)

提到Java虚拟机(JVM),可能大部分人的第一印象是“难”,但当让我们真正走入“JVM世界”的时候,会发现其实问题并不像我们想象中的那么复杂。唯一真正令我们恐惧的,其实是恐惧本身。而作为整个JVM系列的首篇,本文将带你解…

java open course_关于开闭原则 JavaDiscountCourse 类的设计

亲爱的同学,你好,我是geely老师的助教。你这样挺不错的。和老师的设计有不同的思路,赞。我再修改一下,看看能不能还有不一样的想法。public class DiscountCourse implements ICourse{private ICourse course;//折扣private doubl…

如何把一个软件嵌入另一个软件_新增一个软件一个游戏

今后会不定时增加付费软件的试用,如果大家有希望选购的IOS软件可留言。如果各位觉得软件好用,请去App Store购买支持开发者。MaginNote 3 (¥88)简介:MarginNote 3,全新上线电子阅读器,助力更高效书籍阅读和学习.革新性整合阅读标注…

java图书管理系统技术难度_Java图书管理系统练习程序(一)

Java图书管理系统练习程序第一部分该部分主要实现命令行方式的界面与无数据库访问的练习,通过本练习、主要掌握Java的基础知识与面向对象程序设计思想、面向接口编程技术的知识与运用。一、练习程序功能分析该练习程序主要用于学习Java的基础编程知识与面向接口编程…