假设已有训练好的向量值,构建索引(nlist和随机样本按需选取)
import numpy as np
import faiss
import pickle
from tqdm import tqdm
import time
import os
import random
def read_embeddings ( directory, batch_size= 10000 ) : for root, dirs, files in os. walk( directory) : for file in files: cur_file = os. path. join( root, file ) print ( "Loading file >>>" , cur_file) lines= [ ] with open ( cur_file, 'r' ) as file : lines = file . readlines( ) batch_ids = [ ] batch_embeddings = [ ] for i, line in enumerate ( tqdm( lines, ncols= 100 ) ) : if i > 0 and i % batch_size == 0 : yield np. array( batch_embeddings, dtype= 'float32' ) , batch_idsbatch_ids = [ ] batch_embeddings = [ ] parts = line. strip( ) . split( '\t' ) identifier = parts[ 0 ] vector_str = parts[ 1 ] vector = np. fromstring( vector_str[ 1 : - 1 ] , sep= ',' ) batch_ids. append( identifier) batch_embeddings. append( vector) if batch_embeddings: yield np. array( batch_embeddings, dtype= 'float32' ) , batch_idstry : directory_path = './data' embeddings_batches = [ ] ids = [ ] for embeddings_batch, ids_batch in read_embeddings( directory_path) : embeddings_batches. append( embeddings_batch) ids. extend( ids_batch) print ( "Data loading complete, start building the index" ) N = sum ( batch. shape[ 0 ] for batch in embeddings_batches) D = embeddings_batches[ 0 ] . shape[ 1 ] print ( f"Embeddings shape: { N} x { D} " ) nlist = 100000 m = 32 n_bits = 8 quantizer = faiss. IndexFlatL2( D) index = faiss. IndexIVFPQ( quantizer, D, nlist, m, n_bits) print ( "Start training the index..." ) all_embeddings= np. vstack( embeddings_batches) train_start = time. time( ) sample_size = min ( 1000000 , N) sample_indices = random. sample( range ( N) , sample_size) sample_embeddings = all_embeddings[ sample_indices] print ( "随机选取样本训练" ) index. train( sample_embeddings) train_end = time. time( ) print ( f"Training completed, time taken: { ( train_end - train_start) / 3600 : .2f } hours" ) print ( "Start adding embeddings to the index..." ) add_start = time. time( ) flag= 0 for embeddings_batch in embeddings_batches: flag+= 1 if flag% 100 == 0 : print ( flag) index. add( embeddings_batch) add_end = time. time( ) print ( f"Adding embeddings completed, time taken: { ( add_end - add_start) / 3600 : .2f } hours" ) print ( "Start saving the index..." ) save_start = time. time( ) faiss. write_index( index, "index_ivfpq_1b.faiss" ) save_end = time. time( ) print ( f"Index saved, time taken: { ( save_end - save_start) / 3600 : .2f } hours" ) index_to_identifier = { "faiss_v1_" + str ( i) : identifier for i, identifier in enumerate ( ids) } with open ( 'index_to_identifier_1b.pkl' , 'wb' ) as f: pickle. dump( index_to_identifier, f) print ( "Index to identifier mapping saved." )
except Exception as e: print ( "Error occurred during index construction:" , str ( e) )
向量查询
import time
import numpy as np
import faiss
import pickle
index = faiss. read_index( "index_ivfpq_1b.faiss" )
with open ( 'index_to_identifier_1b.pkl' , 'rb' ) as f: index_to_identifier = pickle. load( f)
index. nprobe = 100
faiss. omp_set_num_threads( 4 )
query_embedding = np. array( [ [ - 0.01962059736251831 , 0.11334816366434097 , - 0.09471801668405533 , 0.0641612783074379 , 0.016695162281394005 , 0.03470868244767189 , 0.059329044073820114 , - 0.024794576689600945 , - 0.012960868887603283 , - 0.0744692012667656 , - 0.07942882925271988 , 0.19218777120113373 , 0.14370097219944 , 0.11092912405729294 , - 0.06869585067033768 , 0.08476870507001877 , 0.10311301797628403 , - 0.09529904276132584 , 0.11519007384777069 , 0.07435101270675659 , - 0.07236043363809586 , 0.010397439822554588 , - 0.06027359142899513 , - 0.08405963331460953 , 0.031723152846097946 , - 0.1143064945936203 , 0.18072178959846497 , 0.07466364651918411 , 0.10553380101919174 , - 0.10898686945438385 , - 0.19313931465148926 , 0.15539272129535675 , - 0.11933872103691101 , - 0.13383139669895172 , 0.0754752978682518 , 0.04579591378569603 , 0.07465954124927521 , - 0.0241111870855093 , - 0.06121497601270676 , - 0.10494254529476166 , - 0.01837378740310669 , 0.1292468160390854 , - 0.0056768800131976604 , 0.06756076216697693 , - 0.08115670830011368 , 0.09304261207580566 , 0.06945249438285828 , - 0.057487890124320984 , 0.07290451973676682 , - 0.01492359396070242 , 0.14174117147922516 , 0.0752357617020607 , 0.014304161071777344 , - 0.0023451936431229115 , 0.08765687793493271 , 0.10875667631626129 , 0.1779395043849945 , - 0.04857892543077469 , 0.054570272564888 , - 0.15957848727703094 , 0.008002348244190216 , 0.03754493221640587 , 0.07620261609554291 , 0.01903180405497551 , 0.14646433293819427 , - 0.07392526417970657 , 0.02997334860265255 , - 0.04795815050601959 , 0.039741817861795425 , - 0.06323029100894928 , - 0.0361541248857975 , 0.1155063807964325 , - 0.03679197281599045 , 0.08797583729028702 , - 0.068557009100914 , - 0.14507029950618744 , 0.06844533234834671 , 0.09862343966960907 , 0.012137680314481258 , - 0.012296526692807674 , 0.05485907569527626 , 0.08134670555591583 , 0.06546603888273239 , 0.10151205956935883 , - 0.1254400908946991 , 0.06678715348243713 , 0.015612985007464886 , 0.03761797398328781 , 0.11426421254873276 , - 0.10608682036399841 , 0.0054876371286809444 , - 0.13291053473949432 , - 0.1383194625377655 , - 0.060186877846717834 , 0.040753982961177826 , 0.025832200422883034 , 0.06087275967001915 , 0.07576646655797958 , - 0.025103572756052017 , 0.0819762796163559 , 0.06338494271039963 , 0.09223338961601257 , 0.11740309000015259 , 0.16588829457759857 , 0.0016070181736722589 , - 0.11642675846815109 , 0.06580012291669846 , 0.07179497182369232 , - 0.11596480011940002 , 0.05284847319126129 , 0.018308958038687706 , 0.2823641896247864 , 0.0026317911688238382 , - 0.013333271257579327 , - 0.07727757096290588 , - 0.06593139469623566 , 0.06467396765947342 , 0.04348631948232651 , 0.02083323895931244 , - 0.004868550691753626 , - 0.06408777832984924 , - 0.12004149705171585 , 0.09156100451946259 , 0.04209277778863907 , 0.04682828485965729 , 0.06600149720907211 , 0.014075364917516708 , 0.02114858292043209 ] ] , dtype= 'float32' ) query_id = "龙血王手串价格及图片" s = time. time( )
num_queries, D = query_embedding. shape
k = 10
distances, indices = index. search( query_embedding, k)
print ( f"Query ID: { query_id} " )
print ( "Top k results:" )
for j in range ( k) : idx = indices[ 0 , j] distance = distances[ 0 , j] if idx != - 1 : idx= "faiss_v1_" + str ( idx) identifier = index_to_identifier. get( idx, "Unknown" ) print ( f" { j+ 1 } . ID: { identifier} , Distance: { distance} " ) else : print ( f" { j+ 1 } . No result" ) e = time. time( )
print ( f"Time taken for search: { e - s} seconds" )