参考文档:https://talks.anghami.com/blazing-fast-approximate-nearest-neighbour-search-on-apache-spark-using-hnsw/
HNSW参数调优文档:https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md
spark 运行HNSW向量检索分为以下三步
1 创建HNSW索引,并存储到磁盘
2 将存储的索引分发到每个executor
3 进行向量检索
使用HHSW构建索引,并使用spark进行分布式向量检索,1200万向量构建索引40分钟,向量检索10分钟完成(时间取决于m和ef的大小,本人m=30,ef=1000,不然总是报错m或者ef太小)如m=30,ef=1000 1200万构建索引20分钟,向量检索还是10分钟。
1 创建HNSW索引
输入为spark dataset格式数据,有id和features组成,features为Array[Float]形式向量
import com.stepstone.search.hnswlib.jna.{Index, SpaceName}
import org.apache.spark.SparkFiles
import org.apache.spark.sql.{Dataset, Encoder, SparkSession}
import java.nio.file.Paths
import scala.reflect.runtime.universe.TypeTag
class annUtilsHnsw {/*** Builds an hnsw index.** Default HNSW parameters are found to be good enough.** HNSW index requires integer based object ids, so the builder re-indexes the original objects keys into integer* keys.** For information on HNSW parameter tuning, [[https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md]]** @param vectorSize features vector size* @param features objects features to build an index for* @param m a parameter for construction HNSW index* @param efConstruction a parameter for construction HNSW index* @tparam Key type of the object id in features objects* @return*/def buildHnswIndex[Key : TypeTag : Encoder](spark:SparkSession,vectorSize: Int,features: Dataset[(Key, Array[Float])],m: Int = 100,efConstruction: Int = 200): HnswIndex[Key] = {// map objects keys to integer based index to be used in the HNSW index as it only accepts integer keyimport spark.implicits._val featuresReindexed = features.rdd.zipWithIndex().map(x=>{(x._1._1,x._1._2,x._2.toInt)}) .toDF("id", "features","index_id").select("index_id", "id", "features").cache()// collect feature vectorsval featuresList = featuresReindexed.select($"index_id", $"features".cast("array<float>")).as[(Int, Array[Float])].collect()val objectIDsMap = featuresReindexed.select("index_id", "id").as[(Int, Key)].repartition(100)// build indexval index = new Index(SpaceName.COSINE, vectorSize)index.initialize(featuresList.length, m, efConstruction, (System.currentTimeMillis() / 1000).toInt)// index.initialize(indexLength, 16, 200, (System.currentTimeMillis() / 1000).toInt)println("featuresList length",featuresList.length)// add vectors in parallel using .parfeaturesList.par.foreach {case (id: Int, vector: Array[Float]) =>index.addItem(vector, id)}// return wrapped indexnew HnswIndex(vectorSize, index, objectIDsMap)}}
2 索引存储及查找
存储索引,加载索引并分发到每个executor.然后进行ANN查找
import com.stepstone.search.hnswlib.jna.{Index, SpaceName}
import org.apache.spark.SparkFiles
import org.apache.spark.sql.{Dataset, Encoder, SparkSession}
import java.nio.file.Paths
import scala.reflect.runtime.universe.TypeTagclass HnswIndex[DstKey : TypeTag : Encoder](vectorSize: Int,index: Index,objectIDsMap: Dataset[(Int, DstKey)]) {/*** Executres KNN query using an HNSW index.** @param queryFeatures features to generates recs for* @param minScoreThreshold Minimum similarity/distance.* @param topK number of top recommendations to generate per instance* @param ef HNSW search time parameter* @param queryNumPartitions number of partitions for query vectors* @return*/def knnQuery[SrcKey: TypeTag : Encoder](spark: SparkSession, queryFeatures: Dataset[(SrcKey, Array[Float])],minScoreThreshold: Double,topK: Int,ef: Int,queryNumPartitions: Int = 200, indexSavePath: String, m: Int, efConstruction: Int): Dataset[(SrcKey, DstKey, Double)] = {import spark.implicits._// init tmp directoryval indexLength = index.getLengthval saveLocalPath = "index"val indexLocalLocation = Paths.get(saveLocalPath)val indexFileName = indexLocalLocation.getFileName.toStringprintln("indexFileName", indexFileName)// saving index locallyindex.save(indexLocalLocation)println(index.getData(0).get().mkString(","))val saveAbsoluteLocalPath = saveLocalPathprintln("local path", indexLocalLocation.toAbsolutePath.toString)println("absolute path: ", saveAbsoluteLocalPath)// add file to spark context to be sent to running nodesspark.sparkContext.addFile(indexFileName, true)// spark.sparkContext.addFile(indexSavePath,true)println("context path: ", SparkFiles.getRootDirectory + "/" + indexFileName)// The current interface to HNSW misses the functionality of setting the ef query time// parameter, but it's lower bounded by topK as per https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md#search-parameters,// so set a large value of k as max(ef, topK) to get high recall, then cut off after getting the nearest neighbor.val k = math.max(topK, ef)// local scope vectorSizeval vectorSizeLocal = vectorSize// execute queryingqueryFeatures.repartition(queryNumPartitions).toDF("id", "features").withColumn("features", $"features".cast("array<float>")).as[(SrcKey, Array[Float])].mapPartitions((it: Iterator[(SrcKey, Array[Float])]) => {// load indexval index = new Index(SpaceName.COSINE, vectorSizeLocal)index.initialize(indexLength, m, efConstruction, (System.currentTimeMillis() / 1000).toInt)index.load(Paths.get(SparkFiles.getRootDirectory + "/" + indexFileName), indexLength)it.flatMap(x => {val idx = x._1val vector = x._2val queryTuple = index.knnQuery(vector, k)val result = queryTuple.getIds// queryTuple.getLabels.zip(queryTuple.getCoefficients).map(qt => (idx, qt._1, 1.0 - qt._2.toDouble)).filter(_._3 >= minScoreThreshold).sortBy(_._3).reverse.slice(0, topK)result})}).as[(Int, Int, Double)].toDF("src_id", "index_id", "score").join(objectIDsMap.toDF("index_id", "dst_id"), Seq("index_id")).select("src_id", "dst_id", "score").repartition(400).as[(SrcKey, DstKey, Double)]}
}
3 word2vec向量检索实例
- 训练word2vec模型
- 将模型的向量取出,调用上面buildHnswIndex 构建索引
- 分布式进行knnQuery 向量检索
import org.apache.spark.ml.feature.Word2VecModel
import org.apache.spark.ml.linalg.DenseVectorobject exampleWord2Vec {def main(args: Array[String]): Unit = {val spark = SparkSession.builder().getOrCreate()val GraphInputModel = "graph/model/word2vecmodel"val indexPath = "graph/model/index"spark.udf.register("denseVec2Array",(vec:DenseVector ) => vec.toArray.map(_.toFloat))spark.udf.register("vectorSplit",(a:String)=>(a.split(',').map(_.toFloat)))import spark.implicits._val word2vec = Word2VecModel.load(GraphInputModel )println(word2vec .getVectors.schema)word2vec .getVectors.show(10)println(word2vec .getVectors.count())val itemEmbeddings = word2vec .getVectors.selectExpr("cast(word as Int) as word", "denseVec2Array(vector) features").as[(Int,Array[Float])]itemEmbeddings.show()println(itemEmbeddings.schema)val vectorsize=itemEmbeddings.take(1)(0)._2.lengthval hnswIndex = new annUtilsHnsw().buildHnswIndex(spark, vectorsize, itemEmbeddings, 20)val queryDF=hnswIndex.knnQuery[Int](spark,itemEmbeddings.limit(20),0.3,20,200,160,indexPath,20,200)queryDF .write.mode("overwrite").save(savePathMl + "graph/muiscEmbedding")}}
4 HNSW pom依赖文件
hnswlib-jna
<dependency><groupId>com.stepstone.search.hnswlib.jna</groupId><artifactId>hnswlib-jna</artifactId><version>1.4.2</version></dependency>