机器配置
- 阿里云
GPU
规格ecs.gn6i-c4g1.xlarge
NVIDIA T4
显卡*1GPU
显存16G
*1
准备训练数据
- 进入
/ChatGLM-6B/ptuning
mkdir AdvertiseGen
cd AdvertiseGen
- 上传
dev.json
和train.json
- 内容都是
{"content": "你是谁", "summary": "你好,我是V校人工智能,江湖人称V-Chat。"}
{"content": "V校", "summary": "全宇宙最牛的智慧校园产品"}
安装依赖
pip install fastapi uvicorn datasets jieba rouge_chinese nltk cpm_kernels
修改train.sh
- 修改模型参数文件位置:
--model_name_or_path ../THUDM/chatglm2-6b
- 修改后的
train.sh
PRE_SEQ_LEN=128
LR=2e-2
NUM_GPUS=1torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \--do_train \--train_file AdvertiseGen/train.json \--validation_file AdvertiseGen/dev.json \--preprocessing_num_workers 10 \--prompt_column content \--response_column summary \--overwrite_cache \--model_name_or_path ../THUDM/chatglm2-6b \--output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \--overwrite_output_dir \--max_source_length 64 \--max_target_length 128 \--per_device_train_batch_size 1 \--per_device_eval_batch_size 1 \--gradient_accumulation_steps 16 \--predict_with_generate \--max_steps 3000 \--logging_steps 10 \--save_steps 1000 \--learning_rate $LR \--pre_seq_len $PRE_SEQ_LEN \--quantization_bit 4
开始训练
bash train.sh
训练进度
- 查看
GPU
使用:watch -n 0.5 nvidia-smi
推理
- 修改
evaluate.sh
- 修改模型参数文件位置:
--model_name_or_path ../THUDM/chatglm2-6b
- 修改后的
evaluate.sh
PRE_SEQ_LEN=128
CHECKPOINT=adgen-chatglm2-6b-pt-128-2e-2
STEP=3000
NUM_GPUS=1torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \--do_predict \--validation_file AdvertiseGen/dev.json \--test_file AdvertiseGen/dev.json \--overwrite_cache \--prompt_column content \--response_column summary \--model_name_or_path ../THUDM/chatglm2-6b \--ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \--output_dir ./output/$CHECKPOINT \--overwrite_output_dir \--max_source_length 64 \--max_target_length 64 \--per_device_eval_batch_size 1 \--predict_with_generate \--pre_seq_len $PRE_SEQ_LEN \--quantization_bit 4
- 开始推理:
sh evaluate.sh
在这里插入图片描述
评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
./output/adgen-chatglm2-6b-pt-128-2e-2/generated_predictions.txt
。
运行
- 修改
web_demo.sh
- 修改模型参数文件位置:
--model_name_or_path ../THUDM/chatglm2-6b
- 修改后的
web_demo.sh
PRE_SEQ_LEN=128CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \--model_name_or_path ../THUDM/chatglm2-6b \--ptuning_checkpoint output/adgen-chatglm2-6b-pt-128-2e-2/checkpoint-3000 \--pre_seq_len $PRE_SEQ_LEN
- 修改
web_demo.sh
在末尾的位置修改如下
#demo.queue().launch(share=False, inbrowser=True)
demo.queue().launch(share=True, inbrowser=True, server_name = '0.0.0.0', server_port=7860)
- 启动:
sh web_demo.sh
- 浏览器访问:
http://xx.xx.xx.xx:7860