简介
CodeGeeX2 是多语言代码生成模型 CodeGeeX (KDD’23) 的第二代模型。不同于一代 CodeGeeX(完全在国产华为昇腾芯片平台训练) ,CodeGeeX2 是基于 ChatGLM2 架构加入代码预训练实现,得益于 ChatGLM2 的更优性能,CodeGeeX2 在多项指标上取得性能提升(+107% > CodeGeeX;仅60亿参数即超过150亿参数的 StarCoder-15B 近10%)
官方仓库地址:https://github.com/THUDM/CodeGeeX2/tree/main
部署步骤
安装依赖
conda create --name CodeGeeX2 python=3.9
conda activate CodeGeeX2// #或者手动下载并上传至服务器
git clone git@github.com:THUDM/CodeGeeX2.git
cd CodeGeeX2
conda install -y pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.7 -c pytorch -c nvidia
# 一般来说,不需要安装这个,用显卡驱动自带的就行。
# conda install -y cuda-toolkit=11.7.0 -c nvidiapip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
下载模型数据
由于huggingface国内无法访问,我们使用镜像站下载
pip install -U huggingface_hub -i https://mirrors.aliyun.com/pypi/simple/
export HF_ENDPOINT=https://hf-mirror.com
# 下载THUDM/codegeex2-6b模型到本地目录:/mnt/data/CodeGeeX2/model
huggingface-cli download --resume-download --local-dir-use-symlinks False THUDM/codegeex2-6b --cache-dir /mnt/data/CodeGeeX2/model
启动
quick start - 单gpu
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, cache_dir="/mnt/data/CodeGeeX2/model")
model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, device='cuda', cache_dir="/mnt/data/CodeGeeX2/model")
model = model.eval()# remember adding a language tag for better performance
prompt = "# language: Python\n# write a bubble sort function\n"
inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_length=256, top_k=1)
response = tokenizer.decode(outputs[0])print(response)
quick start - 多gpu
demo.gpus为CodeGeeX2官方仓库的demo文件夹下的gps文件,需要修改num_gpus为机器实际的值,可以通过nvidia-smi -L
查看GPU个数。
from transformers import AutoTokenizer, AutoModel
from demo.gpus import load_model_on_gpusdef get_model():tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, cache_dir="/mnt/data/CodeGeeX2/model")# gpus文件在demo文件夹中model = load_model_on_gpus("THUDM/codegeex2-6b", num_gpus=4, cache_dir="/mnt/data/CodeGeeX2/model")model = model.eval()return tokenizer, modeltokenizer, model = get_model()# remember adding a language tag for better performance
prompt = "# language: Python\n# write a bubble sort function\n"
inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_length=256, top_k=1)
response = tokenizer.decode(outputs[0])print(response)
quick start - http api
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModel
from demo.gpus import load_model_on_gpus
import uvicorn, json, datetime
import argparse#获取选项
def add_code_generation_args(parser):group = parser.add_argument_group(title="CodeGeeX2 DEMO")group.add_argument("--model-path",type=str,default="THUDM/codegeex2-6b",)group.add_argument("--listen",type=str,default="127.0.0.1",)group.add_argument("--port",type=int,default=7860,)group.add_argument("--workers",type=int,default=1,)group.add_argument( "--cpu",action="store_true",)group.add_argument( "--half",action="store_true",)group.add_argument("--quantize",type=int,default=None,)group.add_argument("--chatglm-cpp",action="store_true",)return parserLANGUAGE_TAG = {"Abap" : "* language: Abap","ActionScript" : "// language: ActionScript","Ada" : "-- language: Ada","Agda" : "-- language: Agda","ANTLR" : "// language: ANTLR","AppleScript" : "-- language: AppleScript","Assembly" : "; language: Assembly","Augeas" : "// language: Augeas","AWK" : "// language: AWK","Basic" : "' language: Basic","C" : "// language: C","C#" : "// language: C#","C++" : "// language: C++","CMake" : "# language: CMake","Cobol" : "// language: Cobol","CSS" : "/* language: CSS */","CUDA" : "// language: Cuda","Dart" : "// language: Dart","Delphi" : "{language: Delphi}","Dockerfile" : "# language: Dockerfile","Elixir" : "# language: Elixir","Erlang" : f"% language: Erlang","Excel" : "' language: Excel","F#" : "// language: F#","Fortran" : "!language: Fortran","GDScript" : "# language: GDScript","GLSL" : "// language: GLSL","Go" : "// language: Go","Groovy" : "// language: Groovy","Haskell" : "-- language: Haskell","HTML" : "<!--language: HTML-->","Isabelle" : "(*language: Isabelle*)","Java" : "// language: Java","JavaScript" : "// language: JavaScript","Julia" : "# language: Julia","Kotlin" : "// language: Kotlin","Lean" : "-- language: Lean","Lisp" : "; language: Lisp","Lua" : "// language: Lua","Markdown" : "<!--language: Markdown-->","Matlab" : f"% language: Matlab","Objective-C" : "// language: Objective-C","Objective-C++": "// language: Objective-C++","Pascal" : "// language: Pascal","Perl" : "# language: Perl","PHP" : "// language: PHP","PowerShell" : "# language: PowerShell","Prolog" : f"% language: Prolog","Python" : "# language: Python","R" : "# language: R","Racket" : "; language: Racket","RMarkdown" : "# language: RMarkdown","Ruby" : "# language: Ruby","Rust" : "// language: Rust","Scala" : "// language: Scala","Scheme" : "; language: Scheme","Shell" : "# language: Shell","Solidity" : "// language: Solidity","SPARQL" : "# language: SPARQL","SQL" : "-- language: SQL","Swift" : "// language: swift","TeX" : f"% language: TeX","Thrift" : "/* language: Thrift */","TypeScript" : "// language: TypeScript","Vue" : "<!--language: Vue-->","Verilog" : "// language: Verilog","Visual Basic" : "' language: Visual Basic",
}app = FastAPI()@app.post("/")
async def create_item(request: Request):global model, tokenizerjson_post_raw = await request.json()json_post = json.dumps(json_post_raw)json_post_list = json.loads(json_post)lang = json_post_list.get('lang')prompt = json_post_list.get('prompt')max_length = json_post_list.get('max_length', 128)top_p = json_post_list.get('top_p', 0.95)temperature = json_post_list.get('temperature', 0.2)top_k = json_post_list.get('top_k', 1)if lang != "None":prompt = LANGUAGE_TAG[lang] + "\n" + promptinputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)response = model.generate(inputs,max_length=max_length,do_sample=temperature > 0,top_p=top_p,top_k=top_k,temperature=temperature)now = datetime.datetime.now()time = now.strftime("%Y-%m-%d %H:%M:%S")answer = {"response": response,"lang": lang,"status": 200,"time": time}log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'print(log)return answerdef get_model():tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, cache_dir="/mnt/data/CodeGeeX2/model")# gpus文件在demo文件夹中model = load_model_on_gpus("THUDM/codegeex2-6b", num_gpus=3, cache_dir="/mnt/data/CodeGeeX2/model")model = model.eval()return tokenizer, modelif __name__ == '__main__': parser = argparse.ArgumentParser()parser = add_code_generation_args(parser)args, _ = parser.parse_known_args()tokenizer, model = get_model()uvicorn.run(app, host=args.listen, port=args.port, workers=args.workers)
启动命令
python web_quick_start.py --listen '0.0.0.0' --port 7860 --workders 100# 另一个终端中调用
curl -X POST "http://127.0.0.1:7860" \-H 'Content-Type: application/json' \-d '{"lang": "Python", "prompt": "# Write a quick sort function"}'
参考资料
校验torch安装结果
import torch
print(torch.cuda.is_available())
# 返回gpu数量;
print(torch.cuda.device_count())
# 返回gpu名字,设备索引默认从0开始;
print(torch.cuda.get_device_name(0))
print(torch.cuda.current_device())# 查看nvidia cuda toolkit版本
nvcc -V
版本查询
cuda tookit版本:https://anaconda.org/nvidia/cuda-toolkit
pytorch与cuda对应版本:https://pytorch.org/get-started/previous-versions/