基于星火大模型的群聊对话分角色要素提取挑战赛Task1笔记
跑通baseline
1、安装依赖
下载相应的数据库
!pip install --upgrade -q spark_ai_python
2、配置导入
导入必要的包。
from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import json
配置设置相关参数。
#星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
#星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
SPARKAI_APP_ID = '2e001699'
SPARKAI_API_SECRET = 'ZmU2YTliYmU1YjViODlkMDYwOWZlOTc4'
SPARKAI_API_KEY = '52a07d74ef95aead407a958f448d4464'
#星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_DOMAIN = 'generalv3.5'
申领大模型API来自:https://console.xfyun.cn/app/myapp
3、模型测试
def get_completions(text):messages = [ChatMessage(role="user",content=text)]spark = ChatSparkLLM(spark_api_url=SPARKAI_URL,spark_app_id=SPARKAI_APP_ID,spark_api_key=SPARKAI_API_KEY,spark_api_secret=SPARKAI_API_SECRET,spark_llm_domain=SPARKAI_DOMAIN,streaming=False,)handler = ChunkPrintHandler()a = spark.generate([messages], callbacks=[handler])return a.generations[0][0].text# 测试模型配置是否正确
text = "你好"
get_completions(text)
4、数据读取
def read_json(json_file_path):"""读取json文件"""with open(json_file_path, 'r') as f:data = json.load(f)return datadef write_json(json_file_path, data):"""写入json文件"""with open(json_file_path, 'w') as f:json.dump(data, f, ensure_ascii=False, indent=4)# 读取数据
train_data = read_json("dataset/train.json")
test_data = read_json("dataset/test_data.json")
5、Prompt设计
4. 加载决策树模型进行训练
model = LGBMClassifier(verbosity=-1)
model.fit(train.iloc[:, 2:].values, train['Label'])
pred = model.predict(test.iloc[:, 1:].values, )
6、主函数启动
import jsonclass JsonFormatError(Exception):def __init__(self, message):self.message = messagesuper().__init__(self.message)def check_and_complete_json_format(data):required_keys = {"基本信息-姓名": str,"基本信息-手机号码": str,"基本信息-邮箱": str,"基本信息-地区": str,"基本信息-详细地址": str,"基本信息-性别": str,"基本信息-年龄": str,"基本信息-生日": str,"咨询类型": list,"意向产品": list,"购买异议点": list,"客户预算-预算是否充足": str,"客户预算-总体预算金额": str,"客户预算-预算明细": str,"竞品信息": str,"客户是否有意向": str,"客户是否有卡点": str,"客户购买阶段": str,"下一步跟进计划-参与人": list,"下一步跟进计划-时间点": str,"下一步跟进计划-具体事项": str}if not isinstance(data, list):raise JsonFormatError("Data is not a list")for item in data:if not isinstance(item, dict):raise JsonFormatError("Item is not a dictionary")for key, value_type in required_keys.items():if key not in item:item[key] = [] if value_type == list else ""if not isinstance(item[key], value_type):raise JsonFormatError(f"Key '{key}' is not of type {value_type.__name__}")if value_type == list and not all(isinstance(i, str) for i in item[key]):raise JsonFormatError(f"Key '{key}' does not contain all strings in the list")return data# Example usage:
json_data = '''
[{"基本信息-姓名": "张三","基本信息-手机号码": "12345678901","基本信息-邮箱": "zhangsan@example.com","基本信息-地区": "北京市","基本信息-详细地址": "朝阳区某街道","基本信息-性别": "男","基本信息-年龄": "30","基本信息-生日": "1990-01-01","咨询类型": ["询价"],"意向产品": ["产品A"],"购买异议点": ["价格高"],"客户预算-预算是否充足": "充足","客户预算-总体预算金额": "10000","客户预算-预算明细": "详细预算内容","竞品信息": "竞争对手B","客户是否有意向": "有意向","客户是否有卡点": "无卡点","客户购买阶段": "合同中","下一步跟进计划-参与人": ["客服A"],"下一步跟进计划-时间点": "2024-07-01","下一步跟进计划-具体事项": "沟通具体事项"}
]
'''try:data = json.loads(json_data)completed_data = check_and_complete_json_format(data)print("Completed JSON:", json.dumps(completed_data, ensure_ascii=False, indent=4))
except JsonFormatError as e:print(f"JSON format error: {e.message}")# 5. 保存结果文件到本地
pd.DataFrame({'uuid': test['uuid'],'Label': pred}
).to_csv('submit.csv', index=None)
7、生成提交文件
from tqdm import tqdmretry_count = 5 # 重试次数
result = []
error_data = []for index, data in tqdm(enumerate(test_data)):index += 1is_success = Falsefor i in range(retry_count):try:res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))infos = convert_all_json_in_text_to_dict(res)infos = check_and_complete_json_format(infos)result.append({"infos": infos,"index": index})is_success = Truebreakexcept Exception as e:print("index:", index, ", error:", e)continueif not is_success:data["index"] = indexerror_data.append(data)
8、保存输出
write_json("output.json", result)
print("index:", index, ", error:", e)continue
if not is_success:data["index"] = indexerror_data.append(data)
## 8、保存输出```python
write_json("output.json", result)