- 模型sql文件 :https://pan.baidu.com/s/1hugrI9e
- 使用数据链接 https://pan.baidu.com/s/1kWz8fNh
- NaiveBayes Spark Mllib训练
package com.xxx.xxx.xxximport java.io.ObjectInputStream
import java.sql.{Connection, DriverManager, PreparedStatement}
import java.util.{Arrays, Date, Scanner}import org.ansj.splitWord.analysis.ToAnalysis
import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.mllib.feature.{HashingTF, IDF}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.DataFrame
import org.apache.spark.{SparkConf, SparkContext}import scala.collection.mutable.ArrayBuffer/*** Created by Zsh on 1/31 0031.*/
object WudiBayesModel {var conn: Connection = nullvar stmt: PreparedStatement = nullval outputPath = "/zshModel"val driverUrl = "jdbc:mysql://192.168.2.107:3306/data_mining?user=xxx&password=xxx&zeroDateTimeBehavior=convertToNull&characterEncoding=utf-8&autoDeserialize=true"var df :DataFrame = nullval classify:String = "健康"var model:NaiveBayesModel = nullval lables = "社会|体育|汽车|女性|新闻|科技|财经|军事|广告|娱乐|健康|教育|旅游|生活|文化"def main(args: Array[String]): Unit ={// training(lables)// val text = "今年“五一”假期虽然缩短为三天,但来个“周边游”却正逢其时。昨日记者从保定市旅游部门了解到,5月1日—3日,该市满城汉墓景区将举办“2008全国滑翔伞俱乐部联赛”第一站比赛。届时将有来自全国各地的滑翔伞高手云集陵山,精彩上演自由飞翔。\uE40C滑翔伞俱乐部举办联赛为我国第一次,所有参赛运动员均在国际或国内大型比赛中取得过名次,并且所有运动员必须持2008年贴花的中国航空运动协会会员证书和B级以上滑翔伞运动证书,所使用的比赛动作均为我国滑翔伞最新动作。本届比赛的项目,除保留传统的“留空时间赛”和“精确着陆赛”以外,还增加了“盘升高度赛”等内容。届时,参赛运动员将冲击由保定市运动员韩均创造的1450米的盘升高度记录。截至目前,已有11个省市的50多名运动员报名参赛,其中包括多名外籍运动员和7名女运动员。 (来源:燕赵晚报)\uE40C(责任编辑:李妍)"//测试mysql读取的模型
// BayesUtils.testMysql(text,lables)// inputTestModel()}//手动输入数据测试模型def inputTestModel(): Unit ={val scan = new Scanner(System.in)val startTime = new Date().getTimeloadModelval time2 =new Date().getTimeprintln("加载模型时间:"+(time2-startTime))println("模型加载完毕-----")while(true) {val str = scan.nextLine()testData(model,str,lables)println("---------------------------------")}}//批量测试某个类准确率def batchTesting(): Unit ={var time2 =new Date().getTimeval result = df.map(x=>testData(model,x.getAs("content").toString,lables))var time3 =new Date().getTimeprintln("预测需要时间:"+ (time3-time2))println("准确率:" + result.filter(_.equals(classify)).count().toDouble/result.count())}//加载模型def loadModel(){val conf = new SparkConf().setAppName("NaiveBayesExample1").setMaster("local").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer").set("spark.kryoserializer.buffer.max", "1024mb")val sc =new SparkContext(conf)val sqlContext = new org.apache.spark.sql.SQLContext(sc)val model = NaiveBayesModel.load(sc,outputPath)val jdbcDF = sqlContext.read.options(Map("url" -> driverUrl, "dbtable" -> "industry_classify_tmp")).format("jdbc").load()jdbcDF.registerTempTable("testData")val sql = "select content from testData where classify in ('"+classify+"')"df = sqlContext.sql(sql)}def testModel()={val conf = new SparkConf().setAppName("NaiveBayesExample1").setMaster("local").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer").set("spark.kryoserializer.buffer.max", "1024mb")val sc =new SparkContext(conf)val sqlContext = new org.apache.spark.sql.SQLContext(sc)val model = NaiveBayesModel.load(sc,outputPath)model}//测试数据类型def testData(model :NaiveBayesModel,text:String,labels_name:String)={// val text= "新浪微博采集"val dim = math.pow(2, 20).toIntval hashingTF= new HashingTF(dim)val tfVector = hashingTF.transform(tokenizer(text))val d = model.predict(tfVector)// val labels_name = "社会|体育|汽车|女性|新闻|科技|财经|军事|娱乐|健康|教育|旅游|文化"val list2 = labels_name.split("\\|").toList//println(list2(d.toInt))println("result:"+list2(d.toInt) + " " + d + " " + text)list2(d.toInt)}//训练模型def training(labels_name:String): Unit ={//全部类型标签// val labels_name = "社会|体育|汽车|女性|新闻|科技|财经|军事|娱乐|健康|教育|旅游|文化"val list2 = labels_name.split("\\|").toList//标签转化list对应(0 - list.length)的listvar num=0.0 to labels_name.split("\\|").length.toDouble by 1 toListval tuples = list2.zip(num).toMapval temp= labels_name.split("\\|").toList //.toList.zip(0 to labelsname.split("\\|").length)var str:String = ""for(i<-0 to temp.length-1){if(i<temp.length-1)str=str+"""""""+temp(i)+"""","""elsestr=str+"""""""+temp(i)+"""""""}val conf = new SparkConf().setAppName("NaiveBayesExample1").setMaster("local[4]")// .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")// .set("spark.kryoserializer.buffer.max", "1024mb")val sc =new SparkContext(conf)val sqlContext = new org.apache.spark.sql.SQLContext(sc)val jdbcDF = sqlContext.read.options(Map("url" -> driverUrl, "dbtable" -> "industry_classify_tmp")).format("jdbc").load()jdbcDF.registerTempTable("testData")val sql = "select content,classify from testData where classify in ("+str+")"// println("str"+str)//val jdbcDF = sqlContext.read.options(Map("url" -> driverUrl, "dbtable" -> "industry_classify_tmp")).format("jdbc").load()// jdbcDF.registerTempTable("testData")// val sql = "select content,classify from testData where classify in ('"+str+"')"// df = sqlContext.sql(sql)//从Mysql读取训练模型所需数据val trainData = sqlContext.sql(sql)// val trainData = dfprintln("trcount:"+trainData.count())//获取正文与标签字段,eg.(text,教育)val trainData1 = trainData.map(x=>(x.getAs("content").toString,x.getAs("classify").toString))//将正文分词,标签用数字替换val trainData2 = trainData1.map(x=>(tokenizer(x._1),tuples(x._2)))//tfidf训练所需的部分val cData = trainData2.map(_._1)//标签字段(1,2,3,4,5,6,7,8...)val clData = trainData2.map(_._2)//设置向量维度,该值越大模型占用空间越大,河里设置该值val dim = math.pow(2, 16).toInt//计算TFval hashTF= new HashingTF(dim)val tf = hashTF.transform(cData).cache()//计算idfval hashIDF = new IDF().fit(tf)val idf = hashIDF.transform(tf)//将计算后的向量与标签字段关联val zip = clData.zip(idf)//转化为可训练的类型LabeledPointval tData = zip.map{case (label,vector) =>LabeledPoint(label,vector)}//切分数据60%训练数据,40%验证数据val splits = tData.randomSplit(Array(0.7, 0.3), seed = 11L)val trData = splits(0).cache()val teData = splits(1).cache()val model = NaiveBayes.train(trData,lambda = 1.0, modelType = "multinomial")model.save(sc,outputPath)println("save model success !")//将model转换BayesModelData2,保存到mysqlval data = BayesModelData2(model.labels.toArray,model.pi.toArray,model.theta.map(_.toArray).toArray,"multinomial")//保存到mysqlserializeToMysql(data)//===============模型验证val testAndLabel = teData.map(x=>{(model.predict(x.features),x.label)})// println("****************************")// testAndLabel.foreach(println)// println("****************************")val total = testAndLabel.count()//已知分类val totalPostiveNum = testAndLabel.filter(x => x._2 == 11.0).count()//预测结果val totalTrueNum = testAndLabel.filter(x => x._1 == 11.0).count()//某一类别预测正确数val testRealTrue = testAndLabel.filter(x => x._1 == x._2 && x._2 == 11.0).count()//全部预测正确数val testReal = testAndLabel.filter(x => x._1 == x._2).count()val testAccuracy = 1.0 * testReal / totalval testPrecision = 1.0 * testRealTrue / totalTrueNumval testRecall = 1.0 * testRealTrue / totalPostiveNumprintln("统计分类准确率:============================")println("准确率:", testAccuracy) //预测正确数/预测总数 Accuracy=(TP+TN)/(TP+FP+TN+FN) Error= (FP+FN)/(TP+FP+TN+FN)println("精确度:", testPrecision) //预测为P实际T/实际为P 查准率 Precision=TP/(TP+FP)println("召回率:", testRecall) //预测为P实际T/实际为T 查全率 Recall=TP/(TP+FN)// val accuracy = 1.0 * testAndLabel.filter(x => x._1 == x._2).count() / teData.count()println("模型准确度============================")}def tokenizer(line: String): Seq[String] = {val reg1 = "@\\w{2,20}:".rval reg2 = "http://[0-9a-zA-Z/\\?&#%$@\\=\\\\]+".rAnsjSegment(line).split(",").filter(_!=null).filter(token => !reg1.pattern.matcher(token).matches).filter(token => !reg2.pattern.matcher(token).matches)// .filter(token => !stopwordSet.contains(token)).toSeq}def AnsjSegment(line: String): String={val StopNatures="""w","",null,"s", "f", "b", "z", "r", "q", "d", "p", "c", "uj", "ul","en", "y", "o", "h", "k", "x"""val KeepNatures=List("n","v","a","m","t")val StopWords=Arrays.asList("的", "是","了") //Arrays.asList(stopwordlist.toString())//val filter = new FilterRecognition()//加入停用词//filter.insertStopWords(StopWords)//加入停用词性//filter.insertStopNatures(StopNatures)//filter.insertStopRegex("小.*?")//此步骤将会只取分词,不附带词性//for (i <- Range(0, filter1.size())) {//word += words.get(i).getName//}val words = ToAnalysis.parse(line)val word = ArrayBuffer[String]()for (i <- Range(0,words.size())) { //KeepNatures.contains(words.get(i).getNatureStr.substring(0,1))&&if(KeepNatures.contains(words.get(i).getNatureStr.substring(0,1))&&words.get(i).getName.length()>=2)word += words.get(i).getName}// println(word)word.mkString(",")}//保存到mysqldef serializeToMysql[T](o: BayesModelData2) { //文件序列化val model_Id = "test"new MysqlConn()val query="replace into "+"ams_recommender_model"+"(model_ID,model_Data) values (?,?)"stmt=conn.prepareStatement(query)stmt.setString(1, model_Id)stmt.setObject(2,o)stmt.executeQuery()conn.close()}class MysqlConn() {val trainning_url="jdbc:mysql://192.168.2.107:3306/data_mining?user=xxx&password=xxx&zeroDateTimeBehavior=convertToNull&characterEncoding=utf-8"try {//当前使用训练和输出同一个url,以后可以分为两个conn = DriverManager.getConnection(trainning_url, "xxx", "xxx")} catch {case e: Exception => println("mysql连接异常")}}//从mysql取出,并将类型转换def deserializeFromMysql[T](): BayesModelData2 = { //文件反序列化 bytes: Array[Byte]new MysqlConn()val model_Id = "test"val query="select model_Data from "+"ams_recommender_model"+" where model_ID='"+ model_Id +"' "stmt=conn.prepareStatement(query)val resultSet = stmt.executeQuery()resultSet.next()val bis= resultSet.getBlob("model_Data").getBinaryStream()val ois = new ObjectInputStream(bis)conn.close()ois.readObject.asInstanceOf[BayesModelData2]}
}
- 调用BayesUtils类,目录必须是org.apache.spark ,因为NaiveBayesModel是private[spark]私有的
- 参考:How to use BLAS library in Spark (Symbol BLAS is inaccessible from this space) - spark:http://note.youdao.com/noteshare?id=7f1eec90cc6e56303d06ff92422c29b6&sub=wcp151747625212826
调用
package org.apache.sparkimport com.izhonghong.mission.learn.BayesModelData2
import com.izhonghong.mission.learn.WudiBayesModel.{deserializeFromMysql, tokenizer}
import org.apache.spark.mllib.classification.NaiveBayesModel
import org.apache.spark.mllib.feature.HashingTF
import org.apache.spark.mllib.linalg.{DenseMatrix, DenseVector, Vector}/*** Created by Zsh on 2/1 0001.*/
object BayesUtils {//测试mysql读取出的模型def testMysql(text:String,labels_name:String){val hashingTF= new HashingTF()val tfVector = hashingTF.transform(tokenizer(text))val BayesModelData2 = deserializeFromMysql()val model = new NaiveBayesModel(BayesModelData2.labels,BayesModelData2.pi,BayesModelData2.theta,BayesModelData2.modelType)val d = model.predict(tfVector)
// val d = predict(BayesModelData2,tfVector)val list2 = labels_name.split("\\|").toListlist2(d.toInt)println("result:"+list2(d.toInt) + " " + d + " " + text)}//预测返回类别,NaiveBayesModel源码中提取,最初因为NaiveBayesModel无法引用,后来讲源码提取出来发现,在spark目录下就可以new NaiveBayesModeldef predict(bayesModel :BayesModelData2,tfVector:Vector)={val thetaMatrix = new DenseMatrix(bayesModel.labels.length, bayesModel.theta(0).length, bayesModel.theta.flatten, true)val piVector = new DenseVector(bayesModel.pi)val prob = thetaMatrix.multiply(tfVector)org.apache.spark.mllib.linalg.BLAS.axpy( 1.0, piVector, prob)val d = bayesModel.labels(prob.argmax)d}}
- 主要配置文件
<dependency><groupId>org.apache.spark</groupId><artifactId>spark-mllib_2.10</artifactId><version>1.6.0</version></dependency><dependency><groupId>org.ansj</groupId><artifactId>ansj_seg</artifactId><version>5.0.4</version></dependency>
全部配置文件
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><modelVersion>4.0.0</modelVersion><groupId>com.xxx</groupId><artifactId>xxx-xxx-xxx</artifactId><version>1.0-SNAPSHOT</version><properties><maven.compiler.source>1.6</maven.compiler.source><maven.compiler.target>1.6</maven.compiler.target><encoding>UTF-8</encoding><scala.tools.version>2.10</scala.tools.version><scala.version>2.10.6</scala.version><hbase.version>1.2.2</hbase.version></properties><dependencies><!-- <dependency><groupId>org.apache.spark</groupId><artifactId>spark-mllib_2.11</artifactId><version>2.1.0</version></dependency>--><!--<dependency><groupId>org.apache.spark</groupId><artifactId>spark-mllib_2.11</artifactId><version>1.6.0</version></dependency>--><!-- <dependency><groupId>com.hankcs</groupId><artifactId>hanlp</artifactId><version>portable-1.5.0</version></dependency>--><dependency><groupId>org.apache.spark</groupId><artifactId>spark-mllib_2.10</artifactId><version>1.6.0</version></dependency><dependency><groupId>org.ansj</groupId><artifactId>ansj_seg</artifactId><version>5.0.4</version></dependency><dependency><groupId>org.scala-lang</groupId><artifactId>scala-library</artifactId><version>2.10.6</version></dependency><dependency><groupId>org.apache.kafka</groupId><artifactId>kafka-clients</artifactId><version>0.10.0.0</version></dependency><dependency><groupId>net.sf.json-lib</groupId><classifier>jdk15</classifier><artifactId>json-lib</artifactId><version>2.4</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-streaming-kafka_2.10</artifactId><version>1.6.2</version></dependency><!-- <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming-kafka-0-10_2.10</artifactId><version>2.1.1</version> </dependency> --><dependency><groupId>org.apache.spark</groupId><artifactId>spark-streaming_2.10</artifactId><version>1.6.2</version><exclusions><exclusion><artifactId>scala-library</artifactId><groupId>org.scala-lang</groupId></exclusion></exclusions></dependency><!-- <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming_2.10</artifactId><version>2.1.1</version> <scope>provided</scope> </dependency> --><dependency><groupId>com.huaban</groupId><artifactId>jieba-analysis</artifactId><version>1.0.2</version></dependency><dependency><groupId>com.alibaba</groupId><artifactId>fastjson</artifactId><version>1.2.14</version></dependency><dependency><groupId>redis.clients</groupId><artifactId>jedis</artifactId><version>2.9.0</version></dependency><dependency><groupId>org.scala-lang</groupId><artifactId>scala-library</artifactId><version>${scala.version}</version></dependency><dependency><groupId>org.apache.hbase</groupId><artifactId>hbase-server</artifactId><version>1.2.2</version><exclusions><exclusion><artifactId>servlet-api-2.5</artifactId><groupId>org.mortbay.jetty</groupId></exclusion></exclusions></dependency><!-- <dependency><groupId>com.alibaba</groupId><artifactId>fastjson</artifactId><version>1.2.18</version></dependency>--><dependency><groupId>org.apache.spark</groupId><artifactId>spark-core_2.10</artifactId><version>1.6.2</version><!-- <version>2.1.1</version> --></dependency><dependency><groupId>org.apache.hadoop</groupId><artifactId>hadoop-client</artifactId><version>2.7.0</version></dependency><dependency><groupId>org.apache.hadoop</groupId><artifactId>hadoop-common</artifactId><version>2.7.0</version></dependency><dependency><groupId>org.apache.hadoop</groupId><artifactId>hadoop-hdfs</artifactId><version>2.7.0</version><exclusions><exclusion><groupId>javax.servlet.jsp</groupId><artifactId>*</artifactId></exclusion><exclusion><artifactId>servlet-api</artifactId><groupId>javax.servlet</groupId></exclusion></exclusions></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-sql_2.10</artifactId><version>1.6.2</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-hive_2.10</artifactId><version>1.6.2</version></dependency><dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId><version>5.1.39</version></dependency><!--<dependency><groupId>org.apache.hbase</groupId><artifactId>hbase-server</artifactId><version>1.2.2</version></dependency>--><!-- Test --><dependency><groupId>junit</groupId><artifactId>junit</artifactId><version>4.11</version><scope>test</scope></dependency><dependency><groupId>org.specs2</groupId><artifactId>specs2_${scala.tools.version}</artifactId><version>1.13</version><scope>test</scope></dependency><dependency><groupId>org.scalatest</groupId><artifactId>scalatest_${scala.tools.version}</artifactId><version>2.0.M6-SNAP8</version><scope>test</scope></dependency></dependencies><build><plugins><plugin><groupId>net.alchim31.maven</groupId><artifactId>scala-maven-plugin</artifactId><version>3.2.0</version><executions><execution><goals><goal>compile</goal><goal>testCompile</goal></goals></execution></executions></plugin><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-jar-plugin</artifactId><configuration><archive><manifest><addClasspath>true</addClasspath><classpathPrefix>lib/</classpathPrefix><mainClass></mainClass></manifest></archive></configuration></plugin><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-compiler-plugin</artifactId><configuration><source>1.8</source><target>1.8</target></configuration></plugin><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-dependency-plugin</artifactId><executions><execution><id>copy</id><phase>package</phase><goals><goal>copy-dependencies</goal></goals><configuration><outputDirectory>${project.build.directory}/lib</outputDirectory></configuration></execution></executions></plugin></plugins></build><!-- <build> <plugins> <plugin> <artifactId>maven-assembly-plugin</artifactId><configuration> <archive> <manifest> 这里要替换成jar包main方法所在类 <mainClass>com.sf.pps.client.IntfClientCall</mainClass></manifest> <manifestEntries> <Class-Path>.</Class-Path> </manifestEntries></archive> <descriptorRefs> <descriptorRef>jar-with-dependencies</descriptorRef></descriptorRefs> </configuration> <executions> <execution> <id>make-assembly</id>this is used for inheritance merges <phase>package</phase> 指定在打包节点执行jar包合并操作<goals> <goal>single</goal> </goals> </execution> </executions> </plugin></plugins> </build> --></project>