BERT的中文问答系统52(项目1,分py文件)

项目目录结构

XihuaChatbot/
├── data/
│   ├── train_data.jsonl
│   └── test_data.jsonl
├── logs/
├── models/
├── records/
├── src/
│   ├── main.py
│   ├── dataset.py
│   ├── model.py
│   ├── utils.py
│   └── gui.py
└── README.md

项目文件内容
README.md

# 羲和聊天机器人## 项目介绍羲和聊天机器人是一个基于BERT模型的多领域聊天机器人,支持中文和英文。该项目包括数据处理、模型训练、模型评估和图形用户界面(GUI)等功能。## 目录结构
XihuaChatbot/
├── data/
│ ├── train_data.jsonl
│ └── test_data.jsonl
├── logs/
├── models/
├── records/
├── src/
│ ├── main.py
│ ├── dataset.py
│ ├── model.py
│ ├── utils.py
│ └── gui.py
└── README.md
## 运行环境Python 3.11.9+
- PyTorch 1.7+
- Transformers 4.0+
- Tkinter
- Requests
- BeautifulSoup4
## 安装依赖
pip install torch transformers requests beautifulsoup4
运行项目
确保数据文件 train_data.jsonl 和 test_data.jsonl 存在于 data/ 目录中。
运行主程序:
python src/main.py
功能说明
数据处理:dataset.py 负责读取和处理数据。
模型定义:model.py 定义了基于BERT的模型。
工具函数:utils.py 包含了一些辅助函数,如日志配置、网络请求等。
图形用户界面:gui.py 实现了图形用户界面,包括问答、模型训练、模型评估等功能。
联系方式
如有任何问题或建议,请联系 [554687453@qq.com.com]
src/main.py
import os
import logging
import tkinter as tk
from gui import XihuaChatbotGUI# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)def setup_logging():log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',handlers=[logging.FileHandler(log_file),logging.StreamHandler()])setup_logging()if __name__ == "__main__":# 启动GUIroot = tk.Tk()app = XihuaChatbotGUI(root)root.mainloop()

src/dataset.py

import os
import json
import jsonlines
from transformers import BertTokenizerclass XihuaDataset:def __init__(self, file_path, tokenizer, max_length=128):self.tokenizer = tokenizerself.max_length = max_lengthself.data = self.load_data(file_path)def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]question = item.get('question', '')human_answer = item.get('human_answers', [''])[0]chatgpt_answer = item.get('chatgpt_answers', [''])[0]try:inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)except Exception as e:logging.warning(f"跳过无效项 {idx}: {e}")return self.__getitem__((idx + 1) % len(self.data))return {'input_ids': inputs['input_ids'].squeeze(),'attention_mask': inputs['attention_mask'].squeeze(),'human_input_ids': human_inputs['input_ids'].squeeze(),'human_attention_mask': human_inputs['attention_mask'].squeeze(),'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),'human_answer': human_answer,'chatgpt_answer': chatgpt_answer}

src/model.py

import torch
import torch.nn as nn
from transformers import BertModelclass XihuaModel(nn.Module):def __init__(self, pretrained_model_name):super(XihuaModel, self).__init__()self.bert = BertModel.from_pretrained(pretrained_model_name)self.classifier = nn.Linear(self.bert.config.hidden_size, 1)def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)pooled_output = outputs.pooler_outputlogits = self.classifier(pooled_output)return logits

src/utils.py

import os
import logging
import datetime
import requests
from bs4 import BeautifulSoup
from difflib import SequenceMatcherdef setup_logging():LOGS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logs')os.makedirs(LOGS_DIR, exist_ok=True)log_file = os.path.join(LOGS_DIR, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',handlers=[logging.FileHandler(log_file),logging.StreamHandler()])def search_baidu(query):url = f"https://www.baidu.com/s?wd={query}"headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}response = requests.get(url, headers=headers)soup = BeautifulSoup(response.text, 'html.parser')results = soup.find_all('div', class_='c-abstract')if results:return results[0].get_text().strip()return "没有找到相关信息"def search_baidu_baike(query):url = f"https://baike.baidu.com/item/{query}"headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}response = requests.get(url, headers=headers)soup = BeautifulSoup(response.text, 'html.parser')meta_description = soup.find('meta', attrs={'name': 'description'})if meta_description:return meta_description['content']return "没有找到相关信息"

src/gui.py

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup
import calendar
from tkinter import simpledialog
from src.dataset import XihuaDataset
from src.model import XihuaModel
from src.utils import setup_logging, search_baidu, search_baidu_baike# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))# 配置日志
setup_logging()# 数据集类
class XihuaDataset(Dataset):def __init__(self, file_path, tokenizer, max_length=128):self.tokenizer = tokenizerself.max_length = max_lengthself.data = self.load_data(file_path)def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]question = item.get('question', '')human_answer = item.get('human_answers', [''])[0]chatgpt_answer = item.get('chatgpt_answers', [''])[0]try:inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)except Exception as e:logging.warning(f"跳过无效项 {idx}: {e}")return self.__getitem__((idx + 1) % len(self.data))return {'input_ids': inputs['input_ids'].squeeze(),'attention_mask': inputs['attention_mask'].squeeze(),'human_input_ids': human_inputs['input_ids'].squeeze(),'human_attention_mask': human_inputs['attention_mask'].squeeze(),'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),'human_answer': human_answer,'chatgpt_answer': chatgpt_answer}# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):dataset = XihuaDataset(file_path, tokenizer, max_length)return DataLoader(dataset, batch_size=batch_size, shuffle=True)# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):model.train()total_loss = 0.0num_batches = len(data_loader)for batch_idx, batch in enumerate(data_loader):try:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)human_input_ids = batch['human_input_ids'].to(device)human_attention_mask = batch['human_attention_mask'].to(device)chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)optimizer.zero_grad()human_logits = model(human_input_ids, human_attention_mask)chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)human_labels = torch.ones(human_logits.size(0), 1).to(device)chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)loss.backward()optimizer.step()total_loss += loss.item()if progress_var:progress_var.set((batch_idx + 1) / num_batches * 100)except Exception as e:logging.warning(f"跳过无效批次: {e}")return total_loss / len(data_loader)# 模型评估函数
def evaluate_model(model, data_loader, device):model.eval()correct = 0total = 0with torch.no_grad():for batch in data_loader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)human_input_ids = batch['human_input_ids'].to(device)human_attention_mask = batch['human_attention_mask'].to(device)chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)human_logits = model(human_input_ids, human_attention_mask)chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)human_labels = torch.ones(human_logits.size(0), 1).to(device)chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)human_correct = (torch.sigmoid(human_logits) > 0.5).float() == human_labelschatgpt_correct = (torch.sigmoid(chatgpt_logits) > 0.5).float() == chatgpt_labelscorrect += human_correct.sum().item() + chatgpt_correct.sum().item()total += human_labels.size(0) + chatgpt_labels.size(0)accuracy = correct / totalreturn accuracy# GUI界面
class XihuaChatbotGUI:def __init__(self, root):self.root = rootself.root.title("羲和聊天机器人")self.language = tk.StringVar(value='zh')self.tokenizer = Noneself.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.models = {}self.current_model_type = Noneself.load_models()self.load_data()# 历史记录self.history = []self.create_widgets()def create_widgets(self):# 设置样式style = ttk.Style()style.theme_use('clam')style.configure('TButton', font=('Arial', 12), padding=10)style.configure('TLabel', font=('Arial', 12), padding=10)style.configure('TEntry', font=('Arial', 12), padding=10)style.configure('TText', font=('Arial', 12), padding=10)# 顶部框架top_frame = ttk.Frame(self.root)top_frame.pack(pady=10)self.date_label = ttk.Label(top_frame, text="", font=("Arial", 12))self.date_label.grid(row=0, column=0, padx=10)self.update_date_label()language_frame = ttk.Frame(top_frame)language_frame.grid(row=0, column=1, padx=10)language_label = ttk.Label(language_frame, text="选择语言:", font=("Arial", 12))language_label.grid(row=0, column=0, padx=10)language_menu = ttk.Combobox(language_frame, textvariable=self.language, values=['zh', 'en'], state='readonly')language_menu.grid(row=0, column=1, padx=10)language_menu.bind('<<ComboboxSelected>>', self.change_language)self.question_label = ttk.Label(top_frame, text="问题:", font=("Arial", 12))self.question_label.grid(row=0, column=2, padx=10)self.question_entry = ttk.Entry(top_frame, width=50, font=("Arial", 12))self.question_entry.grid(row=0, column=3, padx=10)self.answer_button = ttk.Button(top_frame, text="获取回答", command=self.get_answer, style='TButton')self.answer_button.grid(row=0, column=4, padx=10)# 中部框架middle_frame = ttk.Frame(self.root)middle_frame.pack(pady=10)self.chat_text = tk.Text(middle_frame, height=20, width=100, font=("Arial", 12), wrap='word')self.chat_text.grid(row=0, column=0, padx=10, pady=10)self.chat_text.tag_configure("user", justify='right', foreground='blue')self.chat_text.tag_configure("xihua", justify='left', foreground='green')# 底部框架bottom_frame = ttk.Frame(self.root)bottom_frame.pack(pady=10)self.clear_button = ttk.Button(bottom_frame, text="清空聊天记录", command=self.clear_chat, style='TButton')self.clear_button.grid(row=0, column=0, padx=10)self.correct_button = ttk.Button(bottom_frame, text="准确", command=self.mark_correct, style='TButton')self.correct_button.grid(row=0, column=1, padx=10)self.incorrect_button = ttk.Button(bottom_frame, text="不准确", command=self.mark_incorrect, style='TButton')self.incorrect_button.grid(row=0, column=2, padx=10)self.train_button = ttk.Button(bottom_frame, text="训练模型", command=self.train_model, style='TButton')self.train_button.grid(row=0, column=3, padx=10)self.retrain_button = ttk.Button(bottom_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True), style='TButton')self.retrain_button.grid(row=0, column=4, padx=10)self.progress_var = tk.DoubleVar()self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200, mode='determinate')self.progress_bar.grid(row=1, column=0, columnspan=5, pady=10)self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))self.log_text.grid(row=2, column=0, columnspan=5, pady=10)self.evaluate_button = ttk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, style='TButton')self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)self.history_button = ttk.Button(bottom_frame, text="查看历史记录", command=self.view_history, style='TButton')self.history_button.grid(row=3, column=1, padx=10, pady=10)self.save_history_button = ttk.Button(bottom_frame, text="保存历史记录", command=self.save_history, style='TButton')self.save_history_button.grid(row=3, column=2, padx=10, pady=10)# 日历框架calendar_frame = ttk.Frame(self.root)calendar_frame.pack(pady=10)self.calendar = tkcalendar.Calendar(calendar_frame, selectmode='day', year=datetime.now().year, month=datetime.now().month, day=datetime.now().day)self.calendar.pack(pady=10)def update_date_label(self):current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")self.date_label.config(text=f"当前时间: {current_time}")self.root.after(1000, self.update_date_label)def clear_chat(self):self.chat_text.delete(1.0, tk.END)def get_answer(self):question = self.question_entry.get()if not question:messagebox.showwarning("输入错误", "请输入问题")return# 自动选择模型model_type = self.detect_model_type(question)self.select_model(model_type)inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)with torch.no_grad():input_ids = inputs['input_ids'].to(self.device)attention_mask = inputs['attention_mask'].to(self.device)logits = self.model(input_ids, attention_mask)if logits.item() > 0:answer_type = "羲和回答"else:answer_type = "零回答"specific_answer = self.get_specific_answer(question, answer_type)self.chat_text.insert(tk.END, f"用户: {question}\n", "user")self.chat_text.insert(tk.END, f"羲和: {specific_answer}\n", "xihua")# 添加到历史记录self.history.append({'question': question,'answer_type': answer_type,'specific_answer': specific_answer,'accuracy': None,  # 初始状态为未评价'baidu_baike': None  # 初始状态为无百度百科结果})def get_specific_answer(self, question, answer_type):# 使用模糊匹配查找最相似的问题best_match = Nonebest_ratio = 0.0for item in self.data:ratio = SequenceMatcher(None, question, item['question']).ratio()if ratio > best_ratio:best_ratio = ratiobest_match = itemif best_match:if answer_type == "羲和回答":return best_match['human_answers'][0]else:return best_match['chatgpt_answers'][0]return "这个我也不清楚,你问问零吧"def load_data(self):self.data = self.load_data_from_file(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))def load_data_from_file(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef load_models(self):MODELS_DIR = os.path.join(PROJECT_ROOT, 'models')model_types = ['历史', '聊天', '娱乐', '电脑', '军事', '汽车', '植物', '科技','名人', '生活', '法律', '企业', '标准']for model_type in model_types:model_path = os.path.join(MODELS_DIR, f'xihua_model_{model_type}_{self.language.get()}.pth')if os.path.exists(model_path):model = XihuaModel(pretrained_model_name=self.get_pretrained_model_name()).to(self.device)model.load_state_dict(torch.load(model_path, map_location=self.device))self.models[model_type] = modellogging.info(f"加载 {model_type} 模型")else:logging.info(f"没有找到 {model_type} 模型,将使用预训练模型")self.models[model_type] = XihuaModel(pretrained_model_name=self.get_pretrained_model_name()).to(self.device)def get_pretrained_model_name(self):if self.language.get() == 'zh':return 'F:/models/bert-base-chinese'elif self.language.get() == 'en':return 'bert-base-uncased'return 'bert-base-uncased'def select_model(self, model_type):if model_type in self.models:self.model = self.models[model_type]self.current_model_type = model_typelogging.info(f"选择 {model_type} 模型")else:logging.warning(f"没有找到 {model_type} 模型,使用默认模型")self.model = XihuaModel(pretrained_model_name=self.get_pretrained_model_name()).to(self.device)self.current_model_type = Nonedef detect_model_type(self, question):if "皇帝" in question or "朝代" in question:return '历史'if "娱乐" in question:return '娱乐'if "电脑" in question:return '电脑'if "军事" in question:return '军事'if "汽车" in question:return '汽车'if "植物" in question:return '植物'if "科技" in question:return '科技'if "名人" in question:return '名人'if "生活" in question or "出行" in question or "菜品" in question or "菜谱" in question or "居家" in question:return '生活'if "法律" in question:return '法律'if "企业" in question:return '企业'if "标准" in question:return '标准'return '聊天'def change_language(self, event):self.language = event.widget.get()self.load_models()self.load_data()def train_model(self, retrain=False):file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])if not file_path:messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")returnmodel_type = self.detect_model_type(file_path)self.select_model(model_type)try:dataset = XihuaDataset(file_path, self.tokenizer)data_loader = DataLoader(dataset, batch_size=8, shuffle=True)# 加载已训练的模型权重if retrain:model_path = os.path.join(PROJECT_ROOT, 'models', f'xihua_model_{model_type}_{self.language.get()}.pth')self.model.load_state_dict(torch.load(model_path, map_location=self.device))self.model.to(self.device)self.model.train()optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()num_epochs = 30best_loss = float('inf')patience = 5no_improvement_count = 0for epoch in range(num_epochs):train_loss = train(self.model, data_loader, optimizer, criterion, self.device, self.progress_var)logging.info(f'第 {epoch+1} 轮次, 损失: {train_loss:.10f}')self.log_text.insert(tk.END, f'第 {epoch+1} 轮次, 损失: {train_loss:.10f}\n')self.log_text.see(tk.END)if train_loss < best_loss:best_loss = train_lossno_improvement_count = 0model_path = os.path.join(PROJECT_ROOT, 'models', f'xihua_model_{model_type}_{self.language.get()}.pth')torch.save(self.model.state_dict(), model_path)logging.info("模型保存")else:no_improvement_count += 1if no_improvement_count >= patience:logging.info("早停机制触发,停止训练")breaklogging.info("模型训练完成并保存")self.log_text.insert(tk.END, "模型训练完成并保存\n")self.log_text.see(tk.END)messagebox.showinfo("训练完成", "模型训练完成并保存")except Exception as e:logging.error(f"模型训练失败: {e}")self.log_text.insert(tk.END, f"模型训练失败: {e}\n")self.log_text.see(tk.END)messagebox.showerror("训练失败", f"模型训练失败: {e}")def evaluate_model(self):test_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/test_data.jsonl'), self.tokenizer, batch_size=8, max_length=128)accuracy = evaluate_model(self.model, test_data_loader, self.device)logging.info(f"模型评估准确率: {accuracy:.4f}")self.log_text.insert(tk.END, f"模型评估准确率: {accuracy:.4f}\n")self.log_text.see(tk.END)messagebox.showinfo("评估结果", f"模型评估准确率: {accuracy:.4f}")def mark_correct(self):if self.history:self.history[-1]['accuracy'] = Truemessagebox.showinfo("评价成功", "您认为这次回答是准确的")def mark_incorrect(self):if self.history:self.history[-1]['accuracy'] = Falsequestion = self.history[-1]['question']self.show_reference_options(question)def show_reference_options(self, question):reference_window = tk.Toplevel(self.root)reference_window.title("参考答案")reference_label = ttk.Label(reference_window, text="请选择参考答案来源:", font=("Arial", 12))reference_label.pack(pady=10)baidu_button = ttk.Button(reference_window, text="百度百科", command=lambda: self.get_reference_answer(question, 'baidu_baike'), style='TButton')baidu_button.pack(pady=5)def get_reference_answer(self, question, source):if source == 'baidu_baike':baike_answer = self.search_baidu_baike(question)self.chat_text.insert(tk.END, f"百度百科结果: {baike_answer}\n", "xihua")self.history[-1]['baidu_baike'] = baike_answermessagebox.showinfo("参考答案", f"已获取{source}的结果")def search_baidu_baike(self, query):return search_baidu_baike(query)def view_history(self):history_window = tk.Toplevel(self.root)history_window.title("历史记录")history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))history_text.pack(padx=10, pady=10)for entry in self.history:history_text.insert(tk.END, f"问题: {entry['question']}\n")history_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")history_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")if entry['accuracy'] is None:history_text.insert(tk.END, "评价: 未评价\n")elif entry['accuracy']:history_text.insert(tk.END, "评价: 准确\n")else:history_text.insert(tk.END, "评价: 不准确\n")if entry['baidu_baike']:history_text.insert(tk.END, f"百度百科结果: {entry['baidu_baike']}\n")history_text.insert(tk.END, "-" * 50 + "\n")def save_history(self):RECORDS_DIR = os.path.join(PROJECT_ROOT, 'records')os.makedirs(RECORDS_DIR, exist_ok=True)file_name = datetime.now().strftime('%Y-%m-%d_%H-%M-%S.txt')file_path = os.path.join(RECORDS_DIR, file_name)with open(file_path, 'w', encoding='utf-8') as f:for entry in self.history:f.write(f"用户: {entry['question']}\n")f.write(f"羲和: {entry['specific_answer']}\n")if entry['baidu_baike']:f.write(f"百度百科结果: {entry['baidu_baike']}\n")f.write("-" * 50 + "\n")# 保存为JSON格式json_records = []for entry in self.history:record = {"question": entry['question'],"human_answers": [entry['specific_answer']] if entry['answer_type'] == "羲和回答" else [],"chatgpt_answers": [entry['specific_answer']] if entry['answer_type'] == "零回答" else [],"baidu_baike": entry['baidu_baike']}json_records.append(record)json_file_name = datetime.now().strftime('%Y-%m-%d_%H-%M-%S.json')json_file_path = os.path.join(RECORDS_DIR, json_file_name)with open(json_file_path, 'w', encoding='utf-8') as f:json.dump(json_records, f, ensure_ascii=False, indent=4)messagebox.showinfo("保存成功", f"历史记录已保存到 {file_path}{json_file_path}")if __name__ == "__main__":# 启动GUIroot = tk.Tk()app = XihuaChatbotGUI(root)root.mainloop()

项目运行步骤
安装依赖:

pip install torch transformers requests beautifulsoup4

准备数据:

将训练数据 train_data.jsonl 和测试数据 test_data.jsonl 放入 data/ 目录中。
运行主程序:

bash
python src/main.py
这样,您就可以运行羲和聊天机器人的项目了。希望这些改进能使您的项目更加美观和功能完善!如果有任何问题或需要进一步的帮助,请随时告诉我。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/63524.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenSearch Dashboard 权限管理:如何设置只读权限

一、简介 OpenSearch Dashboard 是一个强大的数据可视化和管理工具,在实际应用中,经常需要为不同用户设置不同的访问权限。本文将详细介绍如何在 OpenSearch Dashboard 中设置只读权限,使用户只能查看数据而无法进行修改操作。 二、前置条件 OpenSearch 集群已安装并运行O…

Scala的隐式对象和隐式类

1.隐式对象 object Test1 {case class DatabaseConfig(drive:String,url:String)//隐式对象//格式:就是在对象前面加一个 implicit//作用:给函数当默认值implicit object MySqlConfig extends DatabaseConfig("sqlserver.jdbc","localhost:3306")//定义一…

ARMv8-A MacOS调试环境搭建

文章目录 简介安装qemu交叉编译工具链C语言插件 gdb调试测试代码添加调试配置 JLink 调试树莓派 简介 本节主要介绍基于Visual Studio Code在MacOS下调试环境的搭建&#xff0c;Linux发行版上的过程也类型&#xff0c;它主要使用到以下工具链&#xff1a; aarch64 架构的交叉…

React - useActionState、useFormStatus与表单处理

参考文档&#xff1a;react18.3.1官方文档 一些概念&#xff1a; React 的 Canary 和 Experimental 频道是 React 团队用于发布和测试新功能的渠道。 useActionState useActionState 是一个可以根据某个表单动作的结果更新 state 的 Hook。 const [state, formAction, isPe…

解决docker拉取镜像失败问题

下载镜像 [roottest-server-01 ~]# docker pull nginx Using default tag: latest Error response from daemon: Get "https://registry-1.docker.io/v2/": read tcp 192.168.40.180:37356->54.227.20.253:443: read: connection reset by peer报错&#xff1a;E…

java Random随机数

Randoms是什么 在Java中&#xff0c;Random类是用于生成伪随机数的工具。它位于java.util包中。以下是一些使用Random类生成不同类型的随机数的方法&#xff1a; 1 创建 Random 类的实例 2 生成一个随机的int值&#xff08;范围从Integer.MIN_VALUE到Integer.MAX_VALUE&#…

ollama-webui - Ollama的ChatGPT 风格的 Web 界面

更多AI开源软件&#xff1a; 发现分享好用的AI工具、AI开源软件、AI模型、AI变现 - 小众AI小众AI&#xff1a;发现分享好用的AI工具、AI开源软件、AI模型。收录了AI搜索引擎&#xff0c;AI绘画工具、AI对话聊天、AI音频工具、AI图片工具、AI视频工具、AI内容检测、AI法律助手、…

【算法练习】852. 山脉数组的峰顶索引

题目链接&#xff1a;852. 山脉数组的峰顶索引 根据题目用复杂度用O(long n)的方法解决问题&#xff0c;我们可以想到用二分查找解决&#xff1a; class Solution { public:int peakIndexInMountainArray(vector<int>& arr) {int left0,rightarr.size()-1;while(left…

LLM - 多模态大模型的开源评估工具 VLMEvalKit 部署与测试 教程

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/144353087 免责声明&#xff1a;本文来源于个人知识与公开资料&#xff0c;仅用于学术交流&#xff0c;欢迎讨论&#xff0c;不支持转载。 VLMEva…

MySQL | 尚硅谷 | 第12章_MySQL数据类型精讲

MySQL笔记&#xff1a;第12章_MySQL数据类型精讲 文章目录 MySQL笔记&#xff1a;第12章_MySQL数据类型精讲第12章_MySQL数据类型精讲 1. MySQL中的数据类型2. 整数类型2.1 类型介绍2.2 可选属性2.2.1 M2.2.2 UNSIGNED2.2.3 ZEROFILL 2.3 适用场景2.4 如何选择&#xff1f;演示…

后端报错: message: “For input string: \“\““

这个错误信息表明后端尝试将一个空字符串 "" 转换为某种数值类型&#xff08;如整数、长整型等&#xff09;&#xff0c;但转换失败了。在许多编程语言中&#xff0c;如果你试图解析一个非数字的字符串&#xff08;在这个情况下是一个空字符串&#xff09;为数值类型…

Java 文件IO

一、什么是文件IO 文件是一个广义的概念&#xff0c;操作系统将很多资源都抽象成文件&#xff0c;这篇文章讲解文件特指硬盘上的文件 在硬盘上存在很多文件和目录&#xff0c;它们以一种N叉树的结构存储 注意&#xff1a;文件夹也是一种文件&#xff0c;它是一种目录文件 二、…

Rnnoise和SpeexDsp两种降噪方式有什么区别?

在蒙以CourseMaker 7.0软件中&#xff0c;增加了两种降噪模式&#xff0c;一种是Rnnoise&#xff0c;一种是SpeexDsp&#xff0c;这两种降噪模式有什么区别呢&#xff1f; Rnnoise 基于神经网络。当噪声与 rnnoise 的模型训练的噪声匹配时&#xff0c;它的效果非常好。比如说&…

使用aspx,完成一个转发http的post请求功能的api接口,url中增加目标地址参数,传递自定义header参数

使用aspx&#xff0c;完成一个转发http的post请求功能的api接口&#xff0c;url中增加目标地址参数&#xff0c;传递自定义header参数 首先&#xff0c;简单实现一下&#xff0c;如何在ASPX页面中实现这个功能实现代码说明&#xff1a;注意事项&#xff1a; 然后进阶&#xff0…

搭建Discuz论坛

lnmp l&#xff1a;linux操作系统 n&#xff1a;nginx前端页面 m&#xff1a;mysql数据库&#xff0c;账号密码等等都是保存在这个数据库里面 p&#xff1a;php------nginx擅长处理的是静态页面&#xff0c;页面登录账户&#xff0c;需要请求到数据库&#xff0c;通过php把动态…

鸿蒙分享(四):弹窗简单封装

代码仓库&#xff1a;https://gitee.com/linguanzhong/share_harmonyos 鸿蒙api:12 引用的harmony-utils地址&#xff1a;OpenHarmony三方库中心仓 引用的harmony-dialog地址&#xff1a;OpenHarmony三方库中心仓 引用的loading-dialog地址OpenHarmony三方库中心仓 import…

厦门凯酷全科技有限公司抖音电商服务的卓越典范

在短视频和直播带货迅速崛起的时代&#xff0c;厦门凯酷全科技有限公司&#xff08;以下简称“凯酷全科技”&#xff09;以其专业的服务、创新的精神以及对市场的深刻理解&#xff0c;在抖音电商领域中脱颖而出&#xff0c;成为众多品牌商家信赖的选择。本文将深入探讨凯酷全科…

扫二维码进小程序的指定页面

草料二维码解码器 微信开发者工具 获取二维码解码的参数->是否登陆->跳转 options.q onLoad: function (options) {// console.log("options",options.q)if (options && options.q) {// 解码二维码携带的链接信息let qrUrl decodeURIComponent(optio…

Java期末考试——题库+浓缩关键知识点

文章分为两部分&#xff1a;知识点总结和题库练习。 每个部分都有相关的考点和题型&#xff0c;确保覆盖考试的主要内容和常见题目。 一篇文章助你拿下期末&#xff01;&#xff01; Java期末考试——题库浓缩关键知识点 一、Java基础知识总结 1. 面向对象的基本特征 Java …

IC验证工程师基础知识

SVA 断言大法&#xff1a; SystemVerilog断言&#xff08;Assertion&#xff09;是一种用于设计验证的语言扩展&#xff0c;它可以在仿真或形式验证过程中指定设计属性并检查其正确性。SystemVerilog断言提供了一种表达设计应满足的属性的方式&#xff0c;允许设计者执行断言、…