目录
问题描述
问题解决:
问题描述
针对bad case中,错误的主要原因是边界定位不准确问题,sub,obj抽取过短。
因此想要通过jieba分词,然后调用GPT4的api判断当前的新span是否符合条件。
问题解决:
import json
from pdb import set_trace as stop
import jiebaimport openaifrom tqdm import tqdmopenai.api_key = "api_key" # GPT4.0
openai.api_base = 'https://api.ngapi.top/v1'def get_response(prompt, temperature=0.5, max_tokens=2048):print(prompt)completion = openai.ChatCompletion.create(# model="gpt-3.5-turbo",model="gpt-4",temperature=0,top_p=0,# max_tokens=max_tokens,messages=[{"role": "user", "content": f"{prompt}"}])return completionllm_generated_path= "/public/home/hongy/qtxu/Qwen-main/results/Ele_lora/pred_20240101_instruction_0104.jsonl"
change_path = "/public/home/hongy/qtxu/Qwen-main/results/Ele_lora/pred_20240101_instruction_0104_post.txt"po_dict = {"相等":'equal',"更好": 'better','更差': 'worse','不同': 'different'}pad_word = '无'chinese_punctuation = [',', '。', '?', '!', ':', ';', '‘', '’', '“', '”', '(', ')', '【', '】', '{', '}', '《', '》', '、', '——', '-', '……', '~', '·']def get_previsous_word(cur_span, cur_sent):front_prompt = f"在输入语句({cur_sent})中,({cur_span})的前一个单词是什么?。直接给出答案即可。"front_result = get_response(front_prompt)['choices'][0]['message']['content']if front_result=='的':cur_span = front_result+cur_spanfront_prompt = f"在输入语句({cur_sent})中,({cur_span})的前一个单词是什么?直接给出答案即可。"front_result = get_response(front_prompt)['choices'][0]['message']['content']return front_resultdef identify_nonu_phrase(front_result, cur_span, cur_sent):identify_prompt = f"在输入语句({cur_sent})中,({front_result}{cur_span})是一个可以表示物品名称、物品品牌的名词或名词短语吗?直接回答'yes'或'no'"# if '#' in identify_prompt:# identify_prompt = identify_prompt.replace('#','')identify_result = get_response(identify_prompt)['choices'][0]['message']['content']return identify_resultdef get_chinese_index(cur_span, cur_sent):index = cur_sent.find(cur_span) # 没发现的话, index = -1 return index def get_front_end_word(text, span):text_seg_list = jieba.cut(text, cut_all=False)span_seg_list = jieba.cut(span,cut_all=False )text_result = " ".join(text_seg_list)span_result = " ".join(span_seg_list)index = text_result.find(span_result) # 获取最后一个位置front_word =text_result[:index].split()[-1] # 获取前一个元素indexif front_word == '的':front_front_word = text_result[:index-2].split()[-1] # 因为有一个空格,所以是-2front_word = front_front_word+front_wordend_word = text_result[index + len(span_result):].split()[0] # 至于后面的0要不要添加,需要依据统计结果而定return front_word, end_worddef post_processing(cur_span, cur_sent, pad_word):if cur_span == pad_word: # 如果是空,则返回本身final_span = pad_wordelse:cur_span_index = get_chinese_index(cur_span, cur_sent)if cur_span_index == 0: # 如果当前给定的span已经位于句首,则保持不变final_span = cur_spanelse:front_result, end_result = get_front_end_word(cur_sent, cur_span)identify_result = identify_nonu_phrase(front_result, cur_span, cur_sent)print("identify_result结果是:", identify_result)if identify_result=='yes':final_span = front_result+cur_spanelse:final_span = cur_spanreturn final_spanwith open(llm_generated_path, 'r') as fr, open(change_path, 'w') as fw:for line in fr:cur_line = json.loads(line)cur_sent = cur_line['query'].split('\n\n')[1][7:-52].strip() # instruction2# cur_sent = cur_line['query'].split('\n\n')[-1][7:-57].strip() # instruction kaisongcompar = cur_line['type'] # 是否是比较句if compar == 1:# cur_sent = cur_line['query'].split('\n\n')[1][7:-32].strip() fw.write(cur_sent + "\n")result = cur_line['output'].strip().split('\n')gold = cur_line['truth'].strip().split('\n') # # for j in range(0, len(gold), 2): # 如果是位置信息,则是 for j in range(0, len(gold), 2)# gold_quintuple = gold[j][7:].strip()# fw.write("gold:"+ gold_quintuple + "\n")for i in range(0, len(result), 2): # 同上 如果是位置信息,则是 for j in range(0, len(gold), 2)cur_quintuple = result[i][7:].strip() # 有几个特殊的,不能以逗号分隔# stop()# cur_quintuple_index = result[i+1][5:].strip() # '元组位置:(,17:18:19:20:21:22:23,12:13,24:25)'cur_quintuple_list = cur_quintuple[1:-1].split(',')# cur_quintuple_index_list = cur_quintuple_index[1:-1].split(',')sub, obj, asp, op, polarity = cur_quintuple_list[0].strip(), cur_quintuple_list[1].strip(), cur_quintuple_list[2].strip(), cur_quintuple_list[3].strip(), cur_quintuple_list[-1].strip()# sub_index, obj_index, asp_index, op_index = cur_quintuple_index_list[0].strip(),cur_quintuple_index_list[1].strip(),cur_quintuple_index_list[2].strip(),cur_quintuple_index_list[3].strip()sub = sub if sub else pad_wordobj = obj if obj else pad_wordasp = asp if asp else pad_wordop = op if op else pad_wordpolarity = po_dict[polarity] if polarity else pad_word# 对产生的结果进行后处理# stop()post_sub = post_processing(sub, cur_sent, pad_word) # sub_index.split(";")[0]post_obj = post_processing(obj, cur_sent, pad_word)# post_asp = post_processing(asp, cur_sent, pad_word)# stop()final_quintuple = '('+sub +','+obj+','+ asp + ','+ op+','+polarity+')'post_final_quintuple = '('+post_sub +','+post_obj+','+ asp + ','+ op+','+polarity+')'# fw.write("final_quintuple"+final_quintuple +"\n")# fw.write("post_final_quintuple"+post_final_quintuple+"\n")fw.write(post_final_quintuple+"\n")