BERT的中文问答系统36-2

为了使聊天机器人在生成答案时不依赖于特定的训练数据集,我引入其他方法来生成答案。例如,可以使用预训练的语言模型(如BERT)直接生成答案,或者使用搜索引擎来获取答案。以下BERT的中文问答系统36-1改进后的代码
1.引入预训练模型生成答案:使用BERT模型直接生成答案。
2.使用搜索引擎获取答案:如果模型生成的答案不满意,可以使用搜索引擎(如百度)来获取答案。

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer, BertForQuestionAnswering
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)def setup_logging():log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',handlers=[logging.FileHandler(log_file),logging.StreamHandler()])setup_logging()# 数据集类
class XihuaDataset(Dataset):def __init__(self, file_path, tokenizer, max_length=128):self.tokenizer = tokenizerself.max_length = max_lengthself.data = self.load_data(file_path)def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]question = item['question']human_answer = item['human_answers'][0]chatgpt_answer = item['chatgpt_answers'][0]try:inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)except Exception as e:logging.warning(f"跳过无效项 {idx}: {e}")return self.__getitem__((idx + 1) % len(self.data))return {'input_ids': inputs['input_ids'].squeeze(),'attention_mask': inputs['attention_mask'].squeeze(),'human_input_ids': human_inputs['input_ids'].squeeze(),'human_attention_mask': human_inputs['attention_mask'].squeeze(),'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),'human_answer': human_answer,'chatgpt_answer': chatgpt_answer}# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):dataset = XihuaDataset(file_path, tokenizer, max_length)return DataLoader(dataset, batch_size=batch_size, shuffle=True)# 模型定义
class XihuaModel(torch.nn.Module):def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):super(XihuaModel, self).__init__()self.bert = BertModel.from_pretrained(pretrained_model_name)self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)pooled_output = outputs.pooler_outputlogits = self.classifier(pooled_output)return logits# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):model.train()total_loss = 0.0num_batches = len(data_loader)for batch_idx, batch in enumerate(data_loader):try:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)human_input_ids = batch['human_input_ids'].to(device)human_attention_mask = batch['human_attention_mask'].to(device)chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)optimizer.zero_grad()human_logits = model(human_input_ids, human_attention_mask)chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)human_labels = torch.ones(human_logits.size(0), 1).to(device)chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)loss.backward()optimizer.step()total_loss += loss.item()if progress_var:progress_var.set((batch_idx + 1) / num_batches * 100)except Exception as e:logging.warning(f"跳过无效批次: {e}")return total_loss / len(data_loader)# 主训练函数
def main_train(retrain=False):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'使用设备: {device}')tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)if retrain:model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')if os.path.exists(model_path):model.load_state_dict(torch.load(model_path, map_location=device))logging.info("加载现有模型")else:logging.info("没有找到现有模型,将使用预训练模型")optimizer = optim.Adam(model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)num_epochs = 30for epoch in range(num_epochs):train_loss = train(model, train_data_loader, optimizer, criterion, device)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}')torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型训练完成并保存")# 网络搜索函数
def search_baidu(query):url = f"https://www.baidu.com/s?wd={query}"headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}response = requests.get(url, headers=headers)soup = BeautifulSoup(response.text, 'html.parser')results = soup.find_all('div', class_='c-abstract')if results:return results[0].get_text().strip()return "没有找到相关信息"# GUI界面
class XihuaChatbotGUI:def __init__(self, root):self.root = rootself.root.title("羲和聊天机器人")self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)self.load_model()self.model.eval()# 使用预训练的问答模型self.qa_model = BertForQuestionAnswering.from_pretrained('F:/models/bert-base-chinese').to(self.device)# 历史记录self.history = []self.create_widgets()def create_widgets(self):# 设置样式style = ttk.Style()style.theme_use('clam')# 顶部框架top_frame = ttk.Frame(self.root)top_frame.pack(pady=10)self.question_label = ttk.Label(top_frame, text="问题:", font=("Arial", 12))self.question_label.grid(row=0, column=0, padx=10)self.question_entry = ttk.Entry(top_frame, width=50, font=("Arial", 12))self.question_entry.grid(row=0, column=1, padx=10)self.answer_button = ttk.Button(top_frame, text="获取回答", command=self.get_answer, style='TButton')self.answer_button.grid(row=0, column=2, padx=10)# 中部框架middle_frame = ttk.Frame(self.root)middle_frame.pack(pady=10)self.chat_text = tk.Text(middle_frame, height=20, width=100, font=("Arial", 12), wrap='word')self.chat_text.grid(row=0, column=0, padx=10, pady=10)self.chat_text.tag_configure("user", justify='right', foreground='blue')self.chat_text.tag_configure("xihua", justify='left', foreground='green')# 底部框架bottom_frame = ttk.Frame(self.root)bottom_frame.pack(pady=10)self.correct_button = ttk.Button(bottom_frame, text="准确", command=self.mark_correct, style='TButton')self.correct_button.grid(row=0, column=0, padx=10)self.incorrect_button = ttk.Button(bottom_frame, text="不准确", command=self.mark_incorrect, style='TButton')self.incorrect_button.grid(row=0, column=1, padx=10)self.train_button = ttk.Button(bottom_frame, text="训练模型", command=self.train_model, style='TButton')self.train_button.grid(row=0, column=2, padx=10)self.retrain_button = ttk.Button(bottom_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True), style='TButton')self.retrain_button.grid(row=0, column=3, padx=10)self.progress_var = tk.DoubleVar()self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200, mode='determinate')self.progress_bar.grid(row=1, column=0, columnspan=4, pady=10)self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))self.log_text.grid(row=2, column=0, columnspan=4, pady=10)self.evaluate_button = ttk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, style='TButton')self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)self.history_button = ttk.Button(bottom_frame, text="查看历史记录", command=self.view_history, style='TButton')self.history_button.grid(row=3, column=1, padx=10, pady=10)self.save_history_button = ttk.Button(bottom_frame, text="保存历史记录", command=self.save_history, style='TButton')self.save_history_button.grid(row=3, column=2, padx=10, pady=10)def get_answer(self):question = self.question_entry.get()if not question:messagebox.showwarning("输入错误", "请输入问题")return# 使用预训练的问答模型生成答案context = "这是一个通用的上下文,可以根据需要进行扩展。"  # 可以根据实际需求提供更具体的上下文inputs = self.tokenizer(question, context, return_tensors='pt', padding='max_length', truncation=True, max_length=512)inputs = {k: v.to(self.device) for k, v in inputs.items()}with torch.no_grad():outputs = self.qa_model(**inputs)start_scores = outputs.start_logitsend_scores = outputs.end_logitsstart_index = torch.argmax(start_scores)end_index = torch.argmax(end_scores) + 1answer = self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][start_index:end_index]))if not answer or answer == '[CLS]':answer = "这个我也不清楚,你问问零吧"else:answer = answer.strip()self.chat_text.insert(tk.END, f"用户: {question}\n", "user")self.chat_text.insert(tk.END, f"羲和: {answer}\n", "xihua")# 添加到历史记录self.history.append({'question': question,'answer_type': "预训练模型回答",'specific_answer': answer,'accuracy': None  # 初始状态为未评价})def load_model(self):model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')if os.path.exists(model_path):self.model.load_state_dict(torch.load(model_path, map_location=self.device))logging.info("加载现有模型")else:logging.info("没有找到现有模型,将使用预训练模型")def train_model(self, retrain=False):file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])if not file_path:messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")returntry:dataset = XihuaDataset(file_path, self.tokenizer)data_loader = DataLoader(dataset, batch_size=8, shuffle=True)# 加载已训练的模型权重if retrain:self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))self.model.to(self.device)self.model.train()optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()num_epochs = 30for epoch in range(num_epochs):train_loss = train(self.model, data_loader, optimizer, criterion, self.device, self.progress_var)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')self.log_text.insert(tk.END, f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}\n')self.log_text.see(tk.END)torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型训练完成并保存")self.log_text.insert(tk.END, "模型训练完成并保存\n")self.log_text.see(tk.END)messagebox.showinfo("训练完成", "模型训练完成并保存")except Exception as e:logging.error(f"模型训练失败: {e}")self.log_text.insert(tk.END, f"模型训练失败: {e}\n")self.log_text.see(tk.END)messagebox.showerror("训练失败", f"模型训练失败: {e}")def evaluate_model(self):# 这里可以添加模型评估的逻辑messagebox.showinfo("评估结果", "模型评估功能暂未实现")def mark_correct(self):if self.history:self.history[-1]['accuracy'] = Truemessagebox.showinfo("评价成功", "您认为这次回答是准确的")def mark_incorrect(self):if self.history:self.history[-1]['accuracy'] = Falsequestion = self.history[-1]['question']answer = search_baidu(question)self.chat_text.insert(tk.END, f"搜索引擎结果: {answer}\n", "xihua")messagebox.showinfo("评价成功", "您认为这次回答是不准确的")def view_history(self):history_window = tk.Toplevel(self.root)history_window.title("历史记录")history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))history_text.pack(padx=10, pady=10)for entry in self.history:history_text.insert(tk.END, f"问题: {entry['question']}\n")history_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")history_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")if entry['accuracy'] is None:history_text.insert(tk.END, "评价: 未评价\n")elif entry['accuracy']:history_text.insert(tk.END, "评价: 准确\n")else:history_text.insert(tk.END, "评价: 不准确\n")history_text.insert(tk.END, "-" * 50 + "\n")def save_history(self):file_path = filedialog.asksaveasfilename(defaultextension=".json", filetypes=[("JSON files", "*.json")])if not file_path:returnwith open(file_path, 'w') as f:json.dump(self.history, f, ensure_ascii=False, indent=4)messagebox.showinfo("保存成功", "历史记录已保存到文件")# 主函数
if __name__ == "__main__":# 启动GUIroot = tk.Tk()app = XihuaChatbotGUI(root)root.mainloop()

改进点:

使用预训练的问答模型:引入了 BertForQuestionAnswering 模型来生成答案。这样可以在不依赖特定训练数据集的情况下生成答案。
搜索引擎备用:如果预训练模型生成的答案不满意,可以使用搜索引擎(如百度)来获取答案。
通过这些改进,你的聊天机器人可以在不加载特定训练数据集的情况下生成准确的答案。希望这些改进对你有帮助!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/62317.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

升级智享 AI 直播三代:领航原生直播驶向自动化运营新航道

在瞬息万变的数字商业世界,直播行业恰似一艘破浪前行的巨轮,原生直播作为初始 “航船”,在历经风雨后,终于迎来智享 AI 直播三代这股强劲 “东风”,校准航向,开启自动化运营的全新航道,驶向一片…

鸿蒙多线程应用-taskPool

并发模型 并发模型是用来实现不同应用场景中并发任务的编程模型,常见的并发模型分为基于内存共享的并发模型和基于消息通信的并发模型。 Actor并发模型作为基于消息通信并发模型的典型代表,不需要开发者去面对锁带来的一系列复杂偶发的问题,同…

JavaScript实用工具lodash库

Lodash中文文档: Lodash 简介 | Lodash中文文档 | Lodash中文网 Lodash是一个功能强大、易于使用的JavaScript实用工具库,它提供了丰富的函数和工具,能够方便地处理集合、字符串、数值、函数等多种数据类型。通过使用Lodash,开发者可以大幅…

数据结构-最短路径问题

一.问题分类 二.无权图单源最短路算法 dist[]数组记录的是个个顶点到源点的距离这个数组的下标表示顶点 源点到自己的距离是0,dist[s]0 path[]数组记录的是这个顶点的前驱,可以同过这个数组找到源点到个个顶点的距离 代码如下 void Unweighted(MGraph Graph, Ver…

ulimit -n是1024无法启动redis

ulimit -n 命令显示的是当前 shell 会话中可以打开的最大文件描述符数。如果这个值设置得太低,可能会导致 Redis 无法启动,因为 Redis 需要大量的文件描述符来处理客户端连接、持久化文件等。 默认情况下,Redis 可能需要更多的文件描述符&am…

Vue.js 实现用户注册功能

在本篇博客中&#xff0c;我们将通过一个简单的例子来展示如何使用 Vue.js 来实现一个用户注册功能。我们将创建一个包含用户名、邮箱和密码输入的表单&#xff0c;并在用户点击“创建账号”按钮时进行简单的验证。 完整代码 <!DOCTYPE html> <html lang"en&q…

transformers训练(NLP)阅读理解(多项选择)

简介 在阅读理解任务中&#xff0c;有一种通过多项选择其中一个答案来训练机器的阅读理解。比如&#xff1a;给定一个或多个文档h,以及一个问题S和对应的多个答案候选&#xff0c;输出问题S的答案E&#xff0c;E是答案候选中的某一个选项。 这样的目的就是通过文档&#xff0c…

【Java 学习】面向程序的三大特性:封装、继承、多态

引言 1. 封装1.1 什么是封装呢&#xff1f;1.2 访问限定符1.3 使用封装 2. 继承2.1 为什么要有继承&#xff1f;2.2 继承的概念2.3 继承的语法2.4 访问父类成员2.4.1 子类中访问父类成员的变量2.4.2 访问父类的成员方法 2.5 super关键字2.6 子类的构造方法 3. 多态3.1 多态的概…

impala入门与实践

1.impala基本介绍 impala是cloudera提供的一款高效率的sql查询工具&#xff0c;提供实时的查询效果&#xff0c;官方测试性能比hive快10到100倍&#xff0c;其sql查询比sparkSQL还要更加快速&#xff0c;号称是当前大数据领域最快的查询sql工具。impala是参照谷歌的新三篇论文…

shell查看服务器的内存和CPU,实时使用情况

要查看服务器的内存和 CPU 实时使用情况&#xff0c;可以使用以下方法和命令&#xff1a; 1. 使用 top 运行 top 命令以显示实时的系统性能信息&#xff0c;包括 CPU 和内存使用情况。 top按 q 退出。输出内容包括&#xff1a; CPU 使用率&#xff1a;位于顶部&#xff0c;标…

java中链表的数据结构的理解

在 Java 中&#xff0c;链表是一种常见的数据结构&#xff0c;可以通过类的方式实现自定义链表。以下是关于 Java 中链表的数据结构和实现方式的详细介绍。 1. 自定义链表结构 Java 中链表通常由一个节点类 (ListNode) 和可能的链表操作类构成。 节点类 (ListNode) 这是链表…

结构方程模型(SEM)入门到精通:lavaan VS piecewiseSEM、全局估计/局域估计;潜变量分析、复合变量分析、贝叶斯SEM在生态学领域应用

目录 第一章 夯实基础 R/Rstudio简介及入门 第二章 结构方程模型&#xff08;SEM&#xff09;介绍 第三章 R语言SEM分析入门&#xff1a;lavaan VS piecewiseSEM 第四章 SEM全局估计&#xff08;lavaan&#xff09;在生态学领域高阶应用 第五章 SEM潜变量分析在生态学领域…

2.mybatis整体配置

文章目录 mybatis-config.xml介绍SqlSessionFactoryBuilderXMLConfigBuilderpropertiessetting类型别名&#xff08;typeAliases&#xff09;扫描插件(plugins)解析objectFactory(对象工厂)解析objectWrapperFactory解析reflectorFactorysettingsElement()方法环境配置&#xf…

把本地新项目初始化传到github

在本地项目根目录下初始化Git仓库 git init将项目文件添加到Git仓库,接下来&#xff0c;你需要将项目中的文件添加到Git仓库中。可以使用git add命令来添加文件或目录。如果你想要添加所有文件&#xff0c;可以使用.来表示当前目录中的所有文件&#xff1a; git add .提交项目…

软件测试丨Pytest 第三方插件与 Hook 函数

Pytest不仅是一个用于编写简单和复杂测试的框架&#xff0c;还有大量的第三方插件以及灵活的Hook函数供我们使用&#xff0c;这些功能大大增强了其在软件测试中的应用。通过使用Pytest&#xff0c;测试开发变得简便、安全、高效&#xff0c;同时也能帮助我们更快地修复Bug&…

小米PC电脑手机互联互通,小米妙享,小米电脑管家,老款小米笔记本怎么使用,其他品牌笔记本怎么使用,一分钟教会你

说在前面 之前我们体验过妙享中心&#xff0c;里面就有互联互通的全部能力&#xff0c;现在有了小米电脑管家&#xff0c;老款的笔记本竟然用不了&#xff0c;也可以理解&#xff0c;毕竟老款笔记本做系统研发的时候没有预留适配的文件补丁&#xff0c;至于其他品牌的winPC小米…

python爬虫案例——猫眼电影数据抓取之字体解密,多套字体文件解密方法(20)

文章目录 1、任务目标2、网站分析3、代码编写1、任务目标 目标网站:猫眼电影(https://www.maoyan.com/films?showType=2) 要求:抓取该网站下,所有即将上映电影的预约人数,保证能够获取到实时更新的内容;如下: 2、网站分析 进入目标网站,打开开发者模式,经过分析,我…

一分钟食用前端测试框架Jest

安装 其实食用Jest是很简单的,我们只需要安装Jest即可 npm install --save-dev jestyarn add --dev jestpnpm add --save-dev jest ESmodule 本身来说,Jest是不支持Esmodule的,他支持CommonJS,我们需要Babel改一下 npm i --save-dev babel-jest babel/core babel/preset-env …

MySQL中的ROW_NUMBER窗口函数简单了解下

ROW_NUMBER() 是 MySQL8引入的窗口函数之一&#xff0c;它为查询结果集中的每一行分配一个唯一的顺序号&#xff08;行号&#xff09;。这个顺序号是基于窗口函数的 ORDER BY 子句进行排序的&#xff0c;可以根据指定的排序顺序生成连续的整数值。 ROW_NUMBER() 在分页、去重、…

从 App Search 到 Elasticsearch — 挖掘搜索的未来

作者&#xff1a;来自 Elastic Nick Chow App Search 将在 9.0 版本中停用&#xff0c;但 Elasticsearch 拥有你构建强大的 AI 搜索体验所需的一切。以下是你需要了解的内容。 生成式人工智能的最新进展正在改变用户行为&#xff0c;激励开发人员创造更具活力、更直观、更引人入…