版本
0.28.0
源码
使用langchain4j,可以通过AiServices来封装聊天模型API,实现会话记忆,工具调用,搜索增强,内容审查等功能,并提供简单灵活的用户接口
DefaultAiServices是其默认实现类型,通过动态代理的方式实现用户定义的服务接口
class DefaultAiServices<T> extends AiServices<T> {private static final int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 10;DefaultAiServices(AiServiceContext context) {super(context);}// 校验使用提示词模板发送消息的方法参数static void validateParameters(Method method) {// 如果只有一个参数或者没有参数跳过检查(参数直接作为内容发送的方法/其他非发送内容的方法)Parameter[] parameters = method.getParameters();if (parameters == null || parameters.length < 2) {return;}for (Parameter parameter : parameters) {// 获取应用于提示词模板的参数(带有V注解)V v = parameter.getAnnotation(V.class);// 获取用户消息模板参数dev.langchain4j.service.UserMessage userMessage = parameter.getAnnotation(dev.langchain4j.service.UserMessage.class);// 获取记忆ID参数MemoryId memoryId = parameter.getAnnotation(MemoryId.class);// 获取用户名参数UserName userName = parameter.getAnnotation(UserName.class);// 如果没有任何模板参数则报错if (v == null && userMessage == null && memoryId == null && userName == null) {throw illegalConfiguration("Parameter '%s' of method '%s' should be annotated with @V or @UserMessage or @UserName or @MemoryId",parameter.getName(), method.getName());}}}public T build() {// 基本校验// 1. 校验chatModel/streamingChatModel是否有值// 2. 校验toolSpecifications有值时上下文是否启用记忆(使用工具调用至少需要在记忆中保存3个消息)performBasicValidation();// 校验方法使用了Moderate时是否同时指定了审查模型(moderationModel)for (Method method : context.aiServiceClass.getMethods()) {if (method.isAnnotationPresent(Moderate.class) && context.moderationModel == null) {throw illegalConfiguration("The @Moderate annotation is present, but the moderationModel is not set up. " +"Please ensure a valid moderationModel is configured before using the @Moderate annotation.");}}// 构造动态代理Object proxyInstance = Proxy.newProxyInstance(context.aiServiceClass.getClassLoader(),new Class<?>[]{context.aiServiceClass},new InvocationHandler() {private final ExecutorService executor = Executors.newCachedThreadPool();@Overridepublic Object invoke(Object proxy, Method method, Object[] args) throws Exception {// 直接执行Object类定义的方法if (method.getDeclaringClass() == Object.class) {// methods like equals(), hashCode() and toString() should not be handled by this proxyreturn method.invoke(this, args);}// 校验提示词模板参数validateParameters(method);// 获取系统消息Optional<SystemMessage> systemMessage = prepareSystemMessage(method, args);// 获取用户消息UserMessage userMessage = prepareUserMessage(method, args);// 获取记忆ID参数值,如果没有记忆ID参数则使用默认值“default”Object memoryId = memoryId(method, args).orElse(DEFAULT);// 使用检索增强生成(RAG),将检索结果内容与用户原始消息文本整合作为用户消息if (context.retrievalAugmentor != null) {List<ChatMessage> chatMemory = context.hasChatMemory()? context.chatMemory(memoryId).messages(): null;Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory);userMessage = context.retrievalAugmentor.augment(userMessage, metadata);}// 用于提供客制化的输出解析,根据函数返回类型生成需要返回的消息格式的相关提示词,追加到用户消息里面// 如果返回类型为String,AiMessage,TokenStream,Response则不追加格式提示词// 如果返回类型为void则报错// 如果返回类型为enum枚举类型,则追加提示词“\nYou must answer strictly in the following format: one of value1,value2,value3...,valueN”// 如果返回类型是 boolean/byte/short/int/long/BigInteger/float/double/BigDecimal/Date/LocalDate/LocalTime/LocalDateTime 或其对应包装类型,则追加对应值类型提示词,例如“\nYou must answer strictly in the following format: one of [true, false]” ,“...format: integer number in range [-128, 127]”// 如果返回类型是List/Set,则追加提示词“You must put every item on a separate line.”// 否则追加提示词,以json形式返回 “You must answer strictly in the following JSON format: {...}”String outputFormatInstructions = outputFormatInstructions(method.getReturnType());userMessage = UserMessage.from(userMessage.text() + outputFormatInstructions);// 如果包含聊天记忆,则在聊天记忆中追加系统消息和用户消息if (context.hasChatMemory()) {ChatMemory chatMemory = context.chatMemory(memoryId);systemMessage.ifPresent(chatMemory::add);chatMemory.add(userMessage);}// 从记忆中获取消息清单或构建新的消息清单List<ChatMessage> messages;if (context.hasChatMemory()) {messages = context.chatMemory(memoryId).messages();} else {messages = new ArrayList<>();systemMessage.ifPresent(messages::add);messages.add(userMessage);}// 执行审查Future<Moderation> moderationFuture = triggerModerationIfNeeded(method, messages);// 以流式处理消息if (method.getReturnType() == TokenStream.class) {return new AiServiceTokenStream(messages, context, memoryId); // 尚未实现响应内容审查,也不支持工具调用}// 调用chatModel生成响应Response<AiMessage> response = context.toolSpecifications == null? context.chatModel.generate(messages): context.chatModel.generate(messages, context.toolSpecifications);// 获取token用量TokenUsage tokenUsageAccumulator = response.tokenUsage();// 校验审查结果verifyModerationIfNeeded(moderationFuture);// 执行工具调用// 工具调用的最大执行次数(10)int executionsLeft = MAX_SEQUENTIAL_TOOL_EXECUTIONS;while (true) {if (executionsLeft-- == 0) {throw runtime("Something is wrong, exceeded %s sequential tool executions",MAX_SEQUENTIAL_TOOL_EXECUTIONS);}// 获取AI响应消息,添加到记忆中AiMessage aiMessage = response.content();if (context.hasChatMemory()) {context.chatMemory(memoryId).add(aiMessage);}// 如果不存在工具调用请求则中断if (!aiMessage.hasToolExecutionRequests()) {break;}// 根据工具调用请求,依次调用工具,并将工具执行结果消息添加到记忆中ChatMemory chatMemory = context.chatMemory(memoryId); for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {ToolExecutor toolExecutor = context.toolExecutors.get(toolExecutionRequest.name());String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, memoryId);ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest,toolExecutionResult);chatMemory.add(toolExecutionResultMessage);}// 根据添加了工具执行结果的记忆再次调用模型生成response = context.chatModel.generate(chatMemory.messages(), context.toolSpecifications);// 累计token用量tokenUsageAccumulator = tokenUsageAccumulator.add(response.tokenUsage());}// 返回最终的响应response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason());// 将响应解析为方法对应的返回类型对象return parse(response, method.getReturnType());}private Future<Moderation> triggerModerationIfNeeded(Method method, List<ChatMessage> messages) {if (method.isAnnotationPresent(Moderate.class)) {return executor.submit(() -> {List<ChatMessage> messagesToModerate = removeToolMessages(messages);return context.moderationModel.moderate(messagesToModerate).content();});}return null;}});return (T) proxyInstance;}// 准备系统消息private Optional<SystemMessage> prepareSystemMessage(Method method, Object[] args) {// 获取提示词模板变量Parameter[] parameters = method.getParameters();Map<String, Object> variables = getPromptTemplateVariables(args, parameters);dev.langchain4j.service.SystemMessage annotation = method.getAnnotation(dev.langchain4j.service.SystemMessage.class);if (annotation != null) {// 获取 SystemMessage 注解的系统消息提示词模板String systemMessageTemplate = getPromptText(method,"System",annotation.fromResource(), // 提示词资源文件,如果没有则取value值annotation.value(), // 提示词文本annotation.delimiter() // 换行符);// 根据模板和变量获取提示词Prompt prompt = PromptTemplate.from(systemMessageTemplate).apply(variables);return Optional.of(prompt.toSystemMessage());}return Optional.empty();}// 准备用户消息private static UserMessage prepareUserMessage(Method method, Object[] args) {Parameter[] parameters = method.getParameters();Map<String, Object> variables = getPromptTemplateVariables(args, parameters);// 获取用户名参数String userName = getUserName(parameters, args);dev.langchain4j.service.UserMessage annotation = method.getAnnotation(dev.langchain4j.service.UserMessage.class);if (annotation != null) {String userMessageTemplate = getPromptText(method,"User",annotation.fromResource(),annotation.value(),annotation.delimiter());// 如果模板中使用了{{it}}占位符,则只允许使用一个模板参数if (userMessageTemplate.contains("{{it}}")) {if (parameters.length != 1) {throw illegalConfiguration("Error: The {{it}} placeholder is present but the method does not have exactly one parameter. " +"Please ensure that methods using the {{it}} placeholder have exactly one parameter.");}variables = singletonMap("it", toString(args[0]));}Prompt prompt = PromptTemplate.from(userMessageTemplate).apply(variables);if (userName != null) {// 使用用户名构造用户消息return userMessage(userName, prompt.text());} else {return prompt.toUserMessage();}}// 方法如果没有UserMessage注解,查找使用UserMessage注解的参数,作为消息内容for (int i = 0; i < parameters.length; i++) {if (parameters[i].isAnnotationPresent(dev.langchain4j.service.UserMessage.class)) {String text = toString(args[i]);if (userName != null) {return userMessage(userName, text);} else {return userMessage(text);}}}// 如果完全没有参数则报错if (args == null || args.length == 0) {throw illegalConfiguration("Method should have at least one argument");}// 如果只有一个没有注解的参数,则作为消息内容if (args.length == 1) {String text = toString(args[0]);if (userName != null) {return userMessage(userName, text);} else {return userMessage(text);}}throw illegalConfiguration("For methods with multiple parameters, each parameter must be annotated with @V, @UserMessage, @UserName or @MemoryId");}// 根据方法提示词注解获取提示词文本// resource 提示词资源文件,如果没有则取value值// value 提示词文本// delimiter 分隔符(换行符)private static String getPromptText(Method method, String type, String resource, String[] value, String delimiter) {String messageTemplate;if (!resource.trim().isEmpty()) {messageTemplate = getResourceText(method.getDeclaringClass(), resource);if (messageTemplate == null) {throw illegalConfiguration("@%sMessage's resource '%s' not found", type, resource);}} else {messageTemplate = String.join(delimiter, value);}if (messageTemplate.trim().isEmpty()) {throw illegalConfiguration("@%sMessage's template cannot be empty", type);}return messageTemplate;}private static String getResourceText(Class<?> clazz, String name) {return getText(clazz.getResourceAsStream(name));}private static String getText(InputStream inputStream) {if (inputStream == null) {return null;}try (Scanner scanner = new Scanner(inputStream);Scanner s = scanner.useDelimiter("\\A")) {return s.hasNext() ? s.next() : "";}}private Optional<Object> memoryId(Method method, Object[] args) {Parameter[] parameters = method.getParameters();for (int i = 0; i < parameters.length; i++) {if (parameters[i].isAnnotationPresent(MemoryId.class)) {Object memoryId = args[i];if (memoryId == null) {throw illegalArgument("The value of parameter %s annotated with @MemoryId in method %s must not be null",parameters[i].getName(), method.getName());}return Optional.of(memoryId);}}return Optional.empty();}// 获取用户名参数private static String getUserName(Parameter[] parameters, Object[] args) {for (int i = 0; i < parameters.length; i++) {if (parameters[i].isAnnotationPresent(UserName.class)) {return args[i].toString();}}return null;}// 获取提示词模板变量// 遍历V注解的变量,返回变量名和变量值映射private static Map<String, Object> getPromptTemplateVariables(Object[] args, Parameter[] parameters) {Map<String, Object> variables = new HashMap<>();for (int i = 0; i < parameters.length; i++) {V varAnnotation = parameters[i].getAnnotation(V.class);if (varAnnotation != null) {String variableName = varAnnotation.value();Object variableValue = args[i];variables.put(variableName, variableValue);}}return variables;}private static String toString(Object arg) {if (arg.getClass().isArray()) {return arrayToString(arg);} else if (arg.getClass().isAnnotationPresent(StructuredPrompt.class)) {return StructuredPromptProcessor.toPrompt(arg).text();} else {return arg.toString();}}private static String arrayToString(Object arg) {StringBuilder sb = new StringBuilder("[");int length = Array.getLength(arg);for (int i = 0; i < length; i++) {sb.append(toString(Array.get(arg, i)));if (i < length - 1) {sb.append(", ");}}sb.append("]");return sb.toString();}
}