改进步骤
- 数据增强:使用GPT模型生成更多的训练数据。
- 使用更高级的模型:使用BERT或其他预训练的语言模型进行文本分类。
- 经验条和经验值显示:在网页端显示当前的经验值,并添加一个经验条。
数据增强和训练数据集
我们可以通过OpenAI的GPT模型来生成更多的训练数据。这里假设我们已经有一定数量的基础训练数据,并使用GPT模型生成更多样化的数据。
initial_data = {'工作': ['完成工作任务','解决问题','编写代码','参加会议'],'生活': ['锻炼身体','学习新知识','看电影','休息'],'家务': ['打扫卫生','洗衣服','做饭','整理房间']
}
使用GPT生成更多数据
我们将使用OpenAI的API生成更多数据。请确保您已设置OpenAI API密钥:
import openaiopenai.api_key = 'your_openai_api_key'def generate_data(prompt, max_tokens=50):response = openai.Completion.create(engine="text-davinci-003",prompt=prompt,max_tokens=max_tokens,n=5,stop=None,temperature=0.7)return [choice['text'].strip() for choice in response['choices']]def augment_data(data):augmented_data = {}for category, examples in data.items():augmented_data[category] = examples.copy()for example in examples:prompt = f"给出一些类似于'{example}'的{category}活动。"new_examples = generate_data(prompt)augmented_data[category].extend(new_examples)return augmented_data# 增强数据
augmented_data = augment_data(initial_data)
使用BERT进行文本分类
我们将使用transformers
库中的BERT模型进行分类。以下是完整的app.py
代码:
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from flask import Flask, render_template, request, jsonify
import torch
import os
import pandas as pd
from datetime import datetime
import openai# 初始化nltk
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')# 初始化分类器和评分模型
classifier = None
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')# 初始化分类和评分标准
categories = ['工作', '生活', '家务']
score_criteria = {'工作': {'完成任务': 10,'解决问题': 5,'编程': 8,'参加会议': 4,'其他': 1},'生活': {'锻炼': 5,'学习': 3,'休息': 2,'看电影': 2,'其他': 1},'家务': {'打扫卫生': 5,'洗衣服': 3,'做饭': 4,'整理房间': 3,'其他': 1}
}# 初始化基础分数和等级
base_score = 0
level = '初级'# 初始化日志文件和评分记录
daily_log = []
scores = []# 初始化词形还原器
lemmatizer = WordNetLemmatizer()# 设置OpenAI API密钥
openai.api_key = 'your_openai_api_key'# 检查并创建示例数据文件
def create_example_files():example_data = {'工作': ['完成工作任务\n','解决问题\n','编写代码\n','参加会议\n','其他工作相关活动\n'],'生活': ['锻炼身体\n','学习新知识\n','看电影\n','休息\n','其他生活相关活动\n'],'家务': ['打扫卫生\n','洗衣服\n','做饭\n','整理房间\n','其他家务相关活动\n']}for category, data in example_data.items():filename = f'{category}_log.txt'if not os.path.exists(filename):with open(filename, 'w', encoding='utf-8') as file:file.writelines(data)# 使用GPT生成更多数据
def generate_data(prompt, max_tokens=50):response = openai.Completion.create(engine="text-davinci-003",prompt=prompt,max_tokens=max_tokens,n=5,stop=None,temperature=0.7)return [choice['text'].strip() for choice in response['choices']]# 增强数据
def augment_data(data):augmented_data = {}for category, examples in data.items():augmented_data[category] = examples.copy()for example in examples:prompt = f"给出一些类似于'{example}'的{category}活动。"new_examples = generate_data(prompt)augmented_data[category].extend(new_examples)return augmented_data# 读取和增强数据
def load_and_augment_data():create_example_files()initial_data = {'工作': [],'生活': [],'家务': []}for category in categories:with open(f'{category}_log.txt', 'r', encoding='utf-8') as file:initial_data[category] = file.read().splitlines()return augment_data(initial_data)# 自定义数据集类
class ActivityDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_len):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self, item):text = self.texts[item]label = self.labels[item]encoding = self.tokenizer.encode_plus(text,add_special_tokens=True,max_length=self.max_len,return_token_type_ids=False,padding='max_length',return_attention_mask=True,return_tensors='pt',truncation=True)return {'text': text,'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'labels': torch.tensor(label, dtype=torch.long)}# 训练分类器
def train_classifier():global classifierdata = load_and_augment_data()texts = []labels = []for category, examples in data.items():texts.extend(examples)labels.extend([categories.index(category)] * len(examples))train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.1, random_state=42)train_dataset = ActivityDataset(train_texts, train_labels, tokenizer, max_len=128)val_dataset = ActivityDataset(val_texts, val_labels, tokenizer, max_len=128)training_args = TrainingArguments(output_dir='./results',num_train_epochs=3,per_device_train_batch_size=16,per_device_eval_batch_size=16,warmup_steps=500,weight_decay=0.01,logging_dir='./logs',logging_steps=10,evaluation_strategy='epoch')model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(categories))trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=val_dataset,)trainer.train()classifier = model# 预处理文本数据
def preprocess_text(data):processed_data = []for text in data:# 分词tokens = word_tokenize(text.lower())# 去除停用词和标点符号stop_words = set(stopwords.words('english'))filtered_tokens = [token for token in tokens if token.isalnum() and token not in stop_words]# 词形还原lemmatized_tokens = [lemmatizer.lemmatize(token) for token in filtered_tokens]# 重新组合为文本processed_text = ' '.join(lemmatized_tokens)processed_data.append(processed_text)return processed_data# 智能分类工作日常
def classify_activity(activity):inputs = tokenizer(activity, return_tensors='pt', truncation=True, padding=True, max_length=128)outputs = classifier(**inputs)predictions = torch.argmax(outputs.logits, dim=1)return categories[predictions.item()]# 根据分类和评分标准计算活动得分
def score_activity(activity, category):score = 0for word, points in score_criteria[category].items():if word in activity:score += pointsreturn score# 记录每天的工作日常到Excel
def log_daily_activity(activity):global base_score, level# 智能分类category = classify_activity(activity)# 计算得分score = score_activity(activity, category)# 更新基础分数和等级base_score += scoreif base_score >= 100:level = '高级'elif base_score >= 50:level = '中级'# 记录日志daily_log.append((activity, category, score))# 更新评分记录scores.append(score)# 获取当前时间timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')# 记录到Excellog_to_excel(timestamp, activity, category, score)# 记录数据到Excel文件
def log_to_excel(timestamp, activity, category, score):filename = 'activity_log.xlsx'if os.path.exists(filename):df = pd.read_excel(filename)else:df = pd.DataFrame(columns=['Timestamp', 'Activity', 'Category', 'Score'])new_entry = pd.DataFrame([[timestamp, activity, category, score]], columns=['Timestamp', 'Activity', 'Category', 'Score'])df = pd.concat([df, new_entry], ignore_index=True)df.to_excel(filename, index=False)# 创建Flask应用
app = Flask(__name__)@app.route('/')
def index():return render_template('index.html', base_score=base_score, level=level)@app.route('/log_activity', methods=['POST'])
def log_activity():activity = request.form['activity']log_daily_activity(activity)return jsonify({'base_score': base_score, 'level': level})@app.route('/simulate')
def simulate():global base_score, level# 模拟每天的活动记录activities = ['完成工作任务','锻炼身体','解决问题','学习新知识','编写代码','看电影','休息','参加会议','打扫卫生','洗衣服','做饭','整理房间','其他活动']for activity in activities:log_daily_activity(activity)# 返回当前分数和等级return jsonify({'base_score': base_score, 'level': level})if __name__ == '__main__':# 训练分类器和评分模型train_classifier()# 启动Flask应用app.run(debug=True)
HTML代码(templates/index.html
)
我们将创建一个表单,让用户输入他们的活动,并通过Ajax发送到服务器,同时添加经验值和经验条的显示:
<!DOCTYPE html>
<html lang="en">
<head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><title>升级动画</title><style>body {text-align: center;font-family: Arial, sans-serif;}.character {width: 200px;height: 200px;margin: 50px auto;border: 2px solid #000;border-radius: 10px;background-color: #f0f0f0;position: relative;}.character img {width: 100%;height: 100%;display: none;}.character img.active {display: block;}.level {font-size: 1.5em;margin-top: 20px;}.upgrade {font-size: 1.2em;color: green;}.form-container {margin: 20px auto;}.progress-bar {width: 80%;margin: 20px auto;background-color: #e0e0e0;border-radius: 10px;overflow: hidden;}.progress-bar-inner {height: 20px;background-color: #76c7c0;width: 0;}</style>
</head>
<body><div class="character"><img src="/static/character_level1.png" id="level1" class="active"><img src="/static/character_level2.png" id="level2"><img src="/static/character_level3.png" id="level3"></div><div class="level" id="level">当前等级: 初级</div><div class="upgrade" id="upgrade"></div><div class="form-container"><form id="activity-form"><input type="text" id="activity" name="activity" placeholder="输入今天的活动" required><button type="submit">记录活动</button></form></div><div class="progress-bar"><div class="progress-bar-inner" id="progress-bar-inner"></div></div><div id="experience">当前经验值: 0</div><script>let baseScore = {{ base_score }};let level = "{{ level }}";function updateLevel(newLevel) {const levelImages = {'初级': document.getElementById('level1'),'中级': document.getElementById('level2'),'高级': document.getElementById('level3')};for (let level in levelImages) {if (level === newLevel) {levelImages[level].classList.add('active');document.getElementById('level').textContent = `当前等级: ${level}`;document.getElementById('upgrade').textContent = `升级到${level}!`;} else {levelImages[level].classList.remove('active');}}}function updateExperience(newScore) {baseScore = newScore;document.getElementById('experience').textContent = `当前经验值: ${baseScore}`;let progressBar = document.getElementById('progress-bar-inner');progressBar.style.width = `${Math.min((baseScore % 50) * 2, 100)}%`;}function logActivity(event) {event.preventDefault();const activity = document.getElementById('activity').value;fetch('/log_activity', {method: 'POST',headers: {'Content-Type': 'application/x-www-form-urlencoded'},body: `activity=${encodeURIComponent(activity)}`}).then(response => response.json()).then(data => {updateLevel(data.level);updateExperience(data.base_score);document.getElementById('activity').value = '';});}document.getElementById('activity-form').addEventListener('submit', logActivity);</script>
</body>
</html>