最近在微软开源了GraphRAG,项目,是一个很棒的项目,本着研究学习的态度下载了该项目测试,发现目前只可以使用openai chat gpt,或azure open chat gpt,也就是说意味着资料要上传到第三方网站处理,为了本地的ollama也可以使用特意开发了个openai chat请求转换成 ollama代理工具,只需一个python类即可实现openai chat gpt格式转换成本地使用的ollama格式。
实现openai chat gpt格式转换成本地使用的ollama格式 python 代码:
from http.server import BaseHTTPRequestHandler, HTTPServer
import json
from socketserver import ThreadingMixIn
from urllib.parse import urlparse, parse_qs
from queue import Queue
import requests
import argparse
from ascii_colors import ASCIIColors# Directly defining server configurations
servers = [("server1", {'url': 'http://localhost:11434', 'queue': Queue()}),# Add more servers if needed
]# Define the Ollama model to use
ollama_model = 'qwen2:7b'def main():parser = argparse.ArgumentParser()parser.add_argument('--port', type=int, default=8000, help='Port number for the server')args = parser.parse_args()ASCIIColors.red("Ollama Proxy server")class RequestHandler(BaseHTTPRequestHandler):def _send_response(self, response):self.send_response(response.status_code)for key, value in response.headers.items():if key.lower() not in ['content-length', 'transfer-encoding', 'content-encoding']:self.send_header(key, value)self.send_header('Transfer-Encoding', 'chunked')self.end_headers()try:for chunk in response.iter_content(chunk_size=1024):if chunk:self.wfile.write(b"%X\r\n%s\r\n" % (len(chunk), chunk))self.wfile.flush()self.wfile.write(b"0\r\n\r\n")except BrokenPipeError:passdef do_GET(self):self.log_request()self.proxy()def do_POST(self):self.log_request()self.proxy()def proxy(self):url = urlparse(self.path)path = url.pathget_params = parse_qs(url.query) or {}if self.command == "POST":content_length = int(self.headers['Content-Length'])post_data = self.rfile.read(content_length)post_data_str = post_data.decode('utf-8')try:post_params = json.loads(post_data_str)except json.JSONDecodeError:post_params = {}post_params['model'] = ollama_modelpost_params = json.dumps(post_params).encode('utf-8')else:post_params = {}# Find the server with the lowest number of queue entries.min_queued_server = servers[0]for server in servers:cs = server[1]if cs['queue'].qsize() < min_queued_server[1]['queue'].qsize():min_queued_server = serverif path == '/api/generate' or path == '/api/chat':que = min_queued_server[1]['queue']que.put_nowait(1)try:post_data_dict = {}if isinstance(post_data, bytes):post_data_str = post_data.decode('utf-8')post_data_dict = json.loads(post_data_str)response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params,data=post_params, stream=post_data_dict.get("stream", False))self._send_response(response)except Exception:passfinally:que.get_nowait()else:# For other endpoints, just mirror the request.response = requests.request(self.command, min_queued_server[1]['url'] + path, params=get_params,data=post_params)self._send_response(response)class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):passprint('Starting server')server = ThreadedHTTPServer(('', args.port), RequestHandler) # Set the entry port here.print(f'Running server on port {args.port}')server.serve_forever()if __name__ == "__main__":main()