流式请求gpt并且流式推送相关前端页面
1)java流式获取gpt答案
1、读取文件流的方式
使用post请求数据,由于gpt是eventsource的方式返回数据,所以格式是data:,需要手动替换一下值
/**
org.apache.http.client.methods
**/
@SneakyThrowsprivate void chatStream(List<ChatParamMessagesBO> messagesBOList) {CloseableHttpClient httpclient = HttpClients.createDefault();HttpPost httpPost = new HttpPost("https://api.openai.com/v1/chat/completions");httpPost.setHeader("Authorization","xxxxxxxxxxxx");httpPost.setHeader("Content-Type","application/json; charset=UTF-8");ChatParamBO build = ChatParamBO.builder().temperature(0.7).model("gpt-3.5-turbo").messages(messagesBOList).stream(true).build();System.out.println(JsonUtils.toJson(build));httpPost.setEntity(new StringEntity(JsonUtils.toJson(build),"utf-8"));CloseableHttpResponse response = httpclient.execute(httpPost);try {HttpEntity entity = response.getEntity();if (entity != null) {InputStream inputStream = entity.getContent();BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));String line;while ((line = reader.readLine()) != null) {// 处理 event stream 数据try {
// System.out.println(line);ChatResultBO chatResultBO = JsonUtils.toObject(line.replace("data:", ""), ChatResultBO.class);String content = chatResultBO.getChoices().get(0).getDelta().getContent();log.info(content);// System.out.println(chatResultBO.getChoices().get(0).getMessage().getContent());} catch (Exception e) {
// e.printStackTrace();}}}} finally {response.close();}}
2、sse链接的方式获取数据
用到了okhttp
需要先引用相关maven:
<dependency><groupId>com.squareup.okhttp3</groupId><artifactId>okhttp</artifactId></dependency><dependency><groupId>com.squareup.okhttp3</groupId><artifactId>okhttp-sse</artifactId></dependency>
// 定义see接口Request request = new Request.Builder().url("https://api.openai.com/v1/chat/completions").header("Authorization","xxx").post(okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"),param.toJSONString())).build();OkHttpClient okHttpClient = new OkHttpClient.Builder().connectTimeout(10, TimeUnit.MINUTES).readTimeout(10, TimeUnit.MINUTES)//这边需要将超时显示设置长一点,不然刚连上就断开,之前以为调用方式错误被坑了半天.build();// 实例化EventSource,注册EventSource监听器RealEventSource realEventSource = new RealEventSource(request, new EventSourceListener() {@Overridepublic void onOpen(EventSource eventSource, Response response) {log.info("onOpen");}@SneakyThrows@Overridepublic void onEvent(EventSource eventSource, String id, String type, String data) {
// log.info("onEvent");log.info(data);//请求到的数据}@Overridepublic void onClosed(EventSource eventSource) {log.info("onClosed");
// emitter.complete();}@Overridepublic void onFailure(EventSource eventSource, Throwable t, Response response) {log.info("onFailure,t={},response={}",t,response);//这边可以监听并重新打开
// emitter.complete();}});realEventSource.connect(okHttpClient);//真正开始请求的一步
2)流式推送答案
方法一:通过订阅式SSE/WebSocket
原理是先建立链接,然后不断发消息就可以
1、websocket
创建相关配置:
import javax.websocket.Session;import lombok.Data;/*** @description WebSocket客户端连接*/
@Data
public class WebSocketClient {// 与某个客户端的连接会话,需要通过它来给客户端发送数据private Session session;//连接的uriprivate String uri;}
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;@Configuration
public class WebSocketConfig {@Beanpublic ServerEndpointExporter serverEndpointExporter() {return new ServerEndpointExporter();}
}
配置相关service
@Slf4j
@Component
@ServerEndpoint("/websocket/chat/{chatId}")
public class ChatWebsocketService {static final ConcurrentHashMap<String, List<WebSocketClient>> webSocketClientMap= new ConcurrentHashMap<>();private String chatId;/*** 连接建立成功时触发,绑定参数* @param session 与某个客户端的连接会话,需要通过它来给客户端发送数据* @param chatId 商户ID*/@OnOpenpublic void onOpen(Session session, @PathParam("chatId") String chatId){WebSocketClient client = new WebSocketClient();client.setSession(session);client.setUri(session.getRequestURI().toString());List<WebSocketClient> webSocketClientList = webSocketClientMap.get(chatId);if(webSocketClientList == null){webSocketClientList = new ArrayList<>();}webSocketClientList.add(client);webSocketClientMap.put(chatId, webSocketClientList);this.chatId = chatId;}/*** 收到客户端消息后调用的方法** @param message 客户端发送过来的消息*/@OnMessagepublic void onMessage(String message) {log.info("chatId = {},message = {}",chatId,message);// 回复消息this.chatStream(BaseUtil.newList(ChatParamMessagesBO.builder().content(message).role("user").build()));
// this.sendMessage(chatId,message+"233");}/*** 连接关闭时触发,注意不能向客户端发送消息了* @param chatId*/@OnClosepublic void onClose(@PathParam("chatId") String chatId){webSocketClientMap.remove(chatId);}/*** 通信发生错误时触发* @param session* @param error*/@OnErrorpublic void onError(Session session, Throwable error) {System.out.println("发生错误");error.printStackTrace();}/*** 向客户端发送消息* @param chatId* @param message*/public void sendMessage(String chatId,String message){try {List<WebSocketClient> webSocketClientList = webSocketClientMap.get(chatId);if(webSocketClientList!=null){for(WebSocketClient webSocketServer:webSocketClientList){webSocketServer.getSession().getBasicRemote().sendText(message);}}} catch (IOException e) {e.printStackTrace();throw new RuntimeException(e.getMessage());}}/*** 流式调用查询gpt* @param messagesBOList* @throws IOException*/@SneakyThrowsprivate void chatStream(List<ChatParamMessagesBO> messagesBOList) {// TODO 和GPT的访问请求}
}
测试,postman建立链接
2、SSE
本质也是基于订阅推送方式
前端:
<!DOCTYPE html>
<html lang="en"><head><meta charset="UTF-8"><title>SseEmitter</title>
</head><body>
<button onclick="closeSse()">关闭连接</button>
<div id="message"></div>
</body>
<script>let source = null;// 用时间戳模拟登录用户//const id = new Date().getTime();const id = '7829083B42464C5B9C445A087E873C7D';if (window.EventSource) {// 建立连接source = new EventSource('http://172.28.54.27:8902/api/sse/connect?conversationId=' + id);setMessageInnerHTML("连接用户=" + id);/*** 连接一旦建立,就会触发open事件* 另一种写法:source.onopen = function (event) {}*/source.addEventListener('open', function(e) {setMessageInnerHTML("建立连接。。。");}, false);/*** 客户端收到服务器发来的数据* 另一种写法:source.onmessage = function (event) {}*/source.addEventListener('message', function(e) {//console.log(e);setMessageInnerHTML(e.data);});source.addEventListener("close", function (event) {// 在这里处理关闭事件console.log("Server closed the connection");// 可以选择关闭EventSource连接source.close();});/*** 如果发生通信错误(比如连接中断),就会触发error事件* 或者:* 另一种写法:source.onerror = function (event) {}*/source.addEventListener('error', function(e) {console.log(e);if (e.readyState === EventSource.CLOSED) {setMessageInnerHTML("连接关闭");} else {console.log(e);}}, false);} else {setMessageInnerHTML("你的浏览器不支持SSE");}// 监听窗口关闭事件,主动去关闭sse连接,如果服务端设置永不过期,浏览器关闭后手动清理服务端数据window.onbeforeunload = function() {//closeSse();};// 关闭Sse连接function closeSse() {source.close();const httpRequest = new XMLHttpRequest();httpRequest.open('GET', 'http://172.28.54.27:8902/api/sse/disconnection?conversationId=' + id, true);httpRequest.send();console.log("close");}// 将消息显示在网页上function setMessageInnerHTML(innerHTML) {document.getElementById('message').innerHTML += innerHTML + '<br/>';}
</script></html>
后端:
controller
import org.springframework.cloud.context.config.annotation.RefreshScope;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;import java.util.Set;
import java.util.function.Consumer;import javax.annotation.Resource;import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;@Validated
@RestController
@RequestMapping("/api/sse")
@Slf4j
@RefreshScope // 会监听变化实时变化值
public class SseController {@Resourceprivate SseBizService sseBizService;/*** 创建用户连接并返回 SseEmitter** @param conversationId 用户ID* @return SseEmitter*/@SneakyThrows@GetMapping(value = "/connect", produces = "text/event-stream; charset=utf-8")public SseEmitter connect(String conversationId) {// 设置超时时间,0表示不过期。默认30秒,超过时间未完成会抛出异常:AsyncRequestTimeoutExceptionSseEmitter sseEmitter = new SseEmitter(0L);// 注册回调sseEmitter.onCompletion(completionCallBack(conversationId));sseEmitter.onError(errorCallBack(conversationId));sseEmitter.onTimeout(timeoutCallBack(conversationId));log.info("创建新的sse连接,当前用户:{}", conversationId);sseBizService.addConnect(conversationId,sseEmitter);sseBizService.sendMsg(conversationId,"链接成功");
// sseCache.get(conversationId).send(SseEmitter.event().reconnectTime(10000).data("链接成功"),MediaType.TEXT_EVENT_STREAM);return sseEmitter;}/*** 给指定用户发送信息 -- 单播*/@GetMapping(value = "/send", produces = "text/event-stream; charset=utf-8")public void sendMessage(String conversationId, String msg) {sseBizService.sendMsg(conversationId,msg);}/*** 移除用户连接*/@GetMapping(value = "/disconnection", produces = "text/event-stream; charset=utf-8")public void removeUser(String conversationId) {log.info("移除用户:{}", conversationId);sseBizService.deleteConnect(conversationId);}/*** 向多人发布消息 -- 组播* @param groupId 开头标识* @param message 消息内容*/public void groupSendMessage(String groupId, String message) {/* if (!BaseUtil.isNullOrEmpty(sseCache)) {*//*Set<String> ids = sseEmitterMap.keySet().stream().filter(m -> m.startsWith(groupId)).collect(Collectors.toSet());batchSendMessage(message, ids);*//*sseCache.forEach((k, v) -> {try {if (k.startsWith(groupId)) {v.send(message, MediaType.APPLICATION_JSON);}} catch (IOException e) {log.error("用户[{}]推送异常:{}", k, e.getMessage());removeUser(k);}});}*/}/*** 群发所有人 -- 广播*/public void batchSendMessage(String message) {/*sseCache.forEach((k, v) -> {try {v.send(message, MediaType.APPLICATION_JSON);} catch (IOException e) {log.error("用户[{}]推送异常:{}", k, e.getMessage());removeUser(k);}});*/}/*** 群发消息*/public void batchSendMessage(String message, Set<String> ids) {ids.forEach(userId -> sendMessage(userId, message));}/*** 获取当前连接信息*/
// public List<String> getIds() {
// return new ArrayList<>(sseCache.keySet());
// }/*** 获取当前连接数量*/
// public int getUserCount() {
// return count.intValue();
// }private Runnable completionCallBack(String userId) {return () -> {log.info("结束连接:{}", userId);removeUser(userId);};}private Runnable timeoutCallBack(String userId) {return () -> {log.info("连接超时:{}", userId);removeUser(userId);};}private Consumer<Throwable> errorCallBack(String userId) {return throwable -> {log.info("连接异常:{}", userId);removeUser(userId);};}
}
service
import org.springframework.cloud.context.config.annotation.RefreshScope;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;@Component
@Slf4j
@RefreshScope // 会监听变化实时变化值
public class SseBizService {/*** * 当前连接数*/private AtomicInteger count = new AtomicInteger(0);/*** 使用map对象,便于根据userId来获取对应的SseEmitter,或者放redis里面*/private Map<String, SseEmitter> sseCache = new ConcurrentHashMap<>();/*** 添加用户* @author pengbin <pengbin>* @date 2023/9/11 11:37* @param* @return*/public void addConnect(String id,SseEmitter sseEmitter){sseCache.put(id, sseEmitter);// 数量+1count.getAndIncrement();}/*** 删除用户* @author pengbin <pengbin>* @date 2023/9/11 11:37* @param* @return*/public void deleteConnect(String id){sseCache.remove(id);// 数量+1count.getAndDecrement();}/*** 发送消息* @author pengbin <pengbin>* @date 2023/9/11 11:38* @param* @return*/@SneakyThrowspublic void sendMsg(String id, String msg){if(sseCache.containsKey(id)){sseCache.get(id).send(msg, MediaType.TEXT_EVENT_STREAM);}}}
方法二:SSE建立eventSource,使用完成后即刻销毁
前端:在接收到结束标识后立即销毁
/*** 客户端收到服务器发来的数据* 另一种写法:source.onmessage = function (event) {}*/source.addEventListener('message', function(e) {//console.log(e);setMessageInnerHTML(e.data);if(e.data == '[DONE]'){source.close();}});
后端:
@SneakyThrows@GetMapping(value = "/stream/sse", produces = MediaType.TEXT_EVENT_STREAM_VALUE)public SseEmitter completionsStream(@RequestParam String conversationId){//List<ChatParamMessagesBO> messagesBOList =new ArrayList();// 获取内容信息ChatParamBO build = ChatParamBO.builder().temperature(0.7).stream(true).model("xxxx").messages(messagesBOList).build();SseEmitter emitter = new SseEmitter();// 定义see接口Request request = new Request.Builder().url("xxx").header("Authorization","xxxx").post(okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"),JsonUtils.toJson(build))).build();OkHttpClient okHttpClient = new OkHttpClient.Builder().connectTimeout(10, TimeUnit.MINUTES).readTimeout(10, TimeUnit.MINUTES)//这边需要将超时显示设置长一点,不然刚连上就断开,之前以为调用方式错误被坑了半天.build();StringBuffer sb = new StringBuffer("");// 实例化EventSource,注册EventSource监听器RealEventSource realEventSource = null;realEventSource = new RealEventSource(request, new EventSourceListener() {@Overridepublic void onOpen(EventSource eventSource, Response response) {log.info("onOpen");}@SneakyThrows@Overridepublic void onEvent(EventSource eventSource, String id, String type, String data) {log.info(data);//请求到的数据try {ChatResultBO chatResultBO = JsonUtils.toObject(data.replace("data:", ""), ChatResultBO.class);String content = chatResultBO.getChoices().get(0).getDelta().getContent();sb.append(content);emitter.send(SseEmitter.event().data(JsonUtils.toJson(ChatContentBO.builder().content(content).build())));} catch (Exception e) {
// e.printStackTrace();}if("[DONE]".equals(data)){emitter.send(SseEmitter.event().data(data));emitter.complete();log.info("result={}",sb);}}@Overridepublic void onClosed(EventSource eventSource) {log.info("onClosed,eventSource={}",eventSource);//这边可以监听并重新打开
// emitter.complete();}@Overridepublic void onFailure(EventSource eventSource, Throwable t, Response response) {log.info("onFailure,t={},response={}",t,response);//这边可以监听并重新打开
// emitter.complete();}});realEventSource.connect(okHttpClient);//真正开始请求的一步return emitter;}