目录
一、代码解析
1.1 searchTest.ts
1.2 controller.ts
本文接上一篇文章FastGPT 知识库搜索测试功能解析 对具体代码进行解析。
一、代码解析
FastGPT 知识库的搜索测试功能主要涉及两个文件,分别是 searchTest.ts 和 controller.ts 文件,下面分别进行介绍。
1.1 searchTest.ts
文件路径是 projects/app/src/pages/api/core/dataset/searchTest.ts,搜索测试功能的主文件,代码如下所示。
async function handler(req: NextApiRequest) {console.log("function handler(req: NextApiRequest)")const {datasetId, // 知识库 idtext, // 搜索测试框输入的检索文本limit = 1500, // 引用的 token 上限similarity, // 最低相关度,默认是0searchMode, // 检索模式,例如:usingReRank, // 是否对召回文本进行相关性重排,需要结合rerank模型;datasetSearchUsingExtensionQuery = false, // 是否开启问题补全;datasetSearchExtensionModel, // 问题补全所用的模型;datasetSearchExtensionBg = '' // 问题补全的对话背景描述;} = req.body as SearchTestProps;// 判断知识库 id 以及检索文本是否为空if (!datasetId || !text) {return Promise.reject(CommonErrEnum.missingParams);}// 计时const start = Date.now();// auth dataset role 查询是否有读数据库的权限(ReadPermissionVal 读取权限值)const { dataset, teamId, tmbId, apikey } = await authDataset({req,authToken: true,authApiKey: true,datasetId,per: ReadPermissionVal});// auth balanceawait checkTeamAIPoints(teamId);// 获取补全模型const extensionModel =datasetSearchUsingExtensionQuery && datasetSearchExtensionModel? getLLMModel(datasetSearchExtensionModel): undefined;// 问题通过LLM进行补全const { concatQueries, rewriteQuery, aiExtensionResult } = await datasetSearchQueryExtension({query: text,extensionModel,extensionBg: datasetSearchExtensionBg});console.log("[test]: pre searchDatasetData");// pgvector 中查询相似的向量const { searchRes, tokens, ...result } = await searchDatasetData({teamId,reRankQuery: rewriteQuery,queries: concatQueries,model: dataset.vectorModel,limit: Math.min(limit, 20000),similarity,datasetIds: [datasetId],searchMode,usingReRank: usingReRank && (await checkTeamReRankPermission(teamId))});// push bill 更新 token 费用const { totalPoints } = pushGenerateVectorUsage({teamId,tmbId,tokens,model: dataset.vectorModel,source: apikey ? UsageSourceEnum.api : UsageSourceEnum.fastgpt,...(aiExtensionResult &&extensionModel && {extensionModel: extensionModel.name,extensionTokens: aiExtensionResult.tokens})});// Mongodb 更新 apikey tokenif (apikey) {updateApiKeyUsage({apikey,totalPoints: totalPoints});}return {list: searchRes, // 存储检索结果duration: `${((Date.now() - start) / 1000).toFixed(3)}s`, // 时长queryExtensionModel: aiExtensionResult?.model, //...result};
}export default NextAPI(handler);
函数 handler 主要是打辅助,主力在 searchDatasetData 函数中。
函数 handler 传入的配置多数都是在知识库搜索配置的参数,如下所示。
1.2 controller.ts
主要处理逻辑在 searchDatasetData 函数中,其调用 getVectorsByText 获取测试文本的向量化,在 pgvector 中查询相似度高的向量,然后,通过 mongodb 查询向量的原文。
type SearchDatasetDataProps = {teamId: string;model: string;similarity?: number; // min distancelimit: number; // max Token limitdatasetIds: string[];searchMode?: `${DatasetSearchModeEnum}`;usingReRank?: boolean;reRankQuery: string;queries: string[];
};export async function searchDatasetData(props: SearchDatasetDataProps) {console.log("function searchDatasetData");let {teamId,reRankQuery,queries,model,similarity = 0,limit: maxTokens,searchMode = DatasetSearchModeEnum.embedding,usingReRank = false,datasetIds = []} = props;/* init params */// 默认搜索模式是 embeddinngsearchMode = DatasetSearchModeMap[searchMode] ? searchMode : DatasetSearchModeEnum.embedding;// 是否使用重排模型usingReRank = usingReRank && global.reRankModels.length > 0;// Compatible with topk limitif (maxTokens < 50) {maxTokens = 1500;}let set = new Set<string>();let usingSimilarityFilter = false;/* function */// 1. countRecallLimit,根据搜索模式修改限制,分别对应三种检索方式:const countRecallLimit = () => {if (searchMode === DatasetSearchModeEnum.embedding) { // 语义检索return {embeddingLimit: 100,fullTextLimit: 0};}if (searchMode === DatasetSearchModeEnum.fullTextRecall) { // 全文检索return {embeddingLimit: 0,fullTextLimit: 100};}return { // 混合检索embeddingLimit: 80,fullTextLimit: 60};};// 2. embeddingRecallconst embeddingRecall = async ({ query, limit }: { query: string; limit: number }) => {const { vectors, tokens } = await getVectorsByText({ // 获取输入文本的向量,vectors 为转换后的向量model: getVectorModel(model), // 从配置文件中获取 model 的配置信息input: query,type: 'query'});const { results } = await recallFromVectorStore({ // 在 pg vector 中查找相似向量teamId,datasetIds,vector: vectors[0],limit});// get q and a 在 Mongodb 中查找向量的文本形式const dataList = (await MongoDatasetData.find({teamId,datasetId: { $in: datasetIds },collectionId: { $in: Array.from(new Set(results.map((item) => item.collectionId))) },'indexes.dataId': { $in: results.map((item) => item.id?.trim()) }},'datasetId collectionId q a chunkIndex indexes').populate('collectionId', 'name fileId rawLink externalFileId externalFileUrl').lean()) as DatasetDataWithCollectionType[];// add score to data(It's already sorted. The first one is the one with the most points)const concatResults = dataList.map((data) => {const dataIdList = data.indexes.map((item) => item.dataId);const maxScoreResult = results.find((item) => {return dataIdList.includes(item.id);});return {...data,score: maxScoreResult?.score || 0};});concatResults.sort((a, b) => b.score - a.score);const formatResult = concatResults.map((data, index) => {if (!data.collectionId) {console.log('Collection is not found', data);}const result: SearchDataResponseItemType = {id: String(data._id),q: data.q,a: data.a,chunkIndex: data.chunkIndex,datasetId: String(data.datasetId),collectionId: String(data.collectionId?._id),...getCollectionSourceData(data.collectionId),score: [{ type: SearchScoreTypeEnum.embedding, value: data.score, index }]};return result;});return {embeddingRecallResults: formatResult,tokens};};// 3. fullTextRecallconst fullTextRecall = async ({query,limit}: {query: string;limit: number;}): Promise<{fullTextRecallResults: SearchDataResponseItemType[];tokenLen: number;}> => {if (limit === 0) {return {fullTextRecallResults: [],tokenLen: 0};}let searchResults = (await Promise.all(datasetIds.map((id) =>MongoDatasetData.find({teamId,datasetId: id,$text: { $search: jiebaSplit({ text: query }) }},{score: { $meta: 'textScore' },_id: 1,datasetId: 1,collectionId: 1,q: 1,a: 1,chunkIndex: 1}).sort({ score: { $meta: 'textScore' } }).limit(limit).lean()))).flat() as (DatasetDataSchemaType & { score: number })[];// resortsearchResults.sort((a, b) => b.score - a.score);searchResults.slice(0, limit);const collections = await MongoDatasetCollection.find({_id: { $in: searchResults.map((item) => item.collectionId) }},'_id name fileId rawLink');return {fullTextRecallResults: searchResults.map((item, index) => {const collection = collections.find((col) => String(col._id) === String(item.collectionId));return {id: String(item._id),datasetId: String(item.datasetId),collectionId: String(item.collectionId),...getCollectionSourceData(collection),q: item.q,a: item.a,chunkIndex: item.chunkIndex,indexes: item.indexes,score: [{ type: SearchScoreTypeEnum.fullText, value: item.score, index }]};}),tokenLen: 0};};// 4. reRankSearchResultconst reRankSearchResult = async ({data,query}: {data: SearchDataResponseItemType[];query: string;}): Promise<SearchDataResponseItemType[]> => {try {const results = await reRankRecall({query,documents: data.map((item) => ({id: item.id,text: `${item.q}\n${item.a}`}))});if (results.length === 0) {usingReRank = false;return [];}// add new score to dataconst mergeResult = results.map((item, index) => {const target = data.find((dataItem) => dataItem.id === item.id);if (!target) return null;const score = item.score || 0;return {...target,score: [{ type: SearchScoreTypeEnum.reRank, value: score, index }]};}).filter(Boolean) as SearchDataResponseItemType[];return mergeResult;} catch (error) {usingReRank = false;return [];}};// 5. filterResultsByMaxTokensconst filterResultsByMaxTokens = async (list: SearchDataResponseItemType[],maxTokens: number) => {const results: SearchDataResponseItemType[] = [];let totalTokens = 0;for await (const item of list) {totalTokens += await countPromptTokens(item.q + item.a);if (totalTokens > maxTokens + 500) {break;}results.push(item);if (totalTokens > maxTokens) {break;}}return results.length === 0 ? list.slice(0, 1) : results;};// 6. multiQueryRecall 首先,将 query 转换为 vector,然后,在 pgvector 中检索相似,最后在 mongodb 查找 vector 对应的文本,处理后返回。const multiQueryRecall = async ({embeddingLimit,fullTextLimit}: {embeddingLimit: number;fullTextLimit: number;}) => {// multi query recallconst embeddingRecallResList: SearchDataResponseItemType[][] = [];const fullTextRecallResList: SearchDataResponseItemType[][] = [];let totalTokens = 0;await Promise.all(queries.map(async (query) => { // 遍历多个 queryconst [{ tokens, embeddingRecallResults }, { fullTextRecallResults }] = await Promise.all([embeddingRecall({query,limit: embeddingLimit}),fullTextRecall({query,limit: fullTextLimit})]);totalTokens += tokens;embeddingRecallResList.push(embeddingRecallResults);fullTextRecallResList.push(fullTextRecallResults);}));// rrf concatconst rrfEmbRecall = datasetSearchResultConcat(embeddingRecallResList.map((list) => ({ k: 60, list }))).slice(0, embeddingLimit);const rrfFTRecall = datasetSearchResultConcat(fullTextRecallResList.map((list) => ({ k: 60, list }))).slice(0, fullTextLimit);return {tokens: totalTokens,embeddingRecallResults: rrfEmbRecall,fullTextRecallResults: rrfFTRecall};};// 上面都是函数的定义/* main step */// count limitconst { embeddingLimit, fullTextLimit } = countRecallLimit();// recall const { embeddingRecallResults, fullTextRecallResults, tokens } = await multiQueryRecall({embeddingLimit,fullTextLimit});// ReRank resultsconst reRankResults = await (async () => {if (!usingReRank) return [];set = new Set<string>(embeddingRecallResults.map((item) => item.id));const concatRecallResults = embeddingRecallResults.concat(fullTextRecallResults.filter((item) => !set.has(item.id)));// remove same q and a dataset = new Set<string>();const filterSameDataResults = concatRecallResults.filter((item) => {// 删除所有的标点符号与空格等,只对文本进行比较const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));if (set.has(str)) return false;set.add(str);return true;});return reRankSearchResult({query: reRankQuery,data: filterSameDataResults});})();// embedding recall and fullText recall rrf concatconst rrfConcatResults = datasetSearchResultConcat([{ k: 60, list: embeddingRecallResults },{ k: 60, list: fullTextRecallResults },{ k: 58, list: reRankResults }]);// remove same q and a dataset = new Set<string>();const filterSameDataResults = rrfConcatResults.filter((item) => {// 删除所有的标点符号与空格等,只对文本进行比较const str = hashStr(`${item.q}${item.a}`.replace(/[^\p{L}\p{N}]/gu, ''));if (set.has(str)) return false;set.add(str);return true;});// score filterconst scoreFilter = (() => {if (usingReRank) {usingSimilarityFilter = true;return filterSameDataResults.filter((item) => {const reRankScore = item.score.find((item) => item.type === SearchScoreTypeEnum.reRank);if (reRankScore && reRankScore.value < similarity) return false;return true;});}if (searchMode === DatasetSearchModeEnum.embedding) {usingSimilarityFilter = true;return filterSameDataResults.filter((item) => {const embeddingScore = item.score.find((item) => item.type === SearchScoreTypeEnum.embedding);if (embeddingScore && embeddingScore.value < similarity) return false;return true;});}return filterSameDataResults;})();return {searchRes: await filterResultsByMaxTokens(scoreFilter, maxTokens),tokens,searchMode,limit: maxTokens,similarity,usingReRank,usingSimilarityFilter};
}
multiQueryRecall : 首先,将 query 转换为 vector,然后,在 pgvector 中检索相似,最后在 mongodb 查找 vector 对应的文本,处理后返回。主要在 embeddingRecall 函数中实现。
getVectorsByText : 负责将搜索的问题转换为向量表示;
recallFromVectorStore : 在 pg vector 中查找相似向量;
MongoDatasetData.find :将 recallFromVectorStore 查询出的相似向量在 mongodb 中找出原文本。
其他内容后面再详细展开介绍。
参考链接:
[1] FastGPT源码深度剖析:混合检索及语料召回逻辑 - 技术栈