[Kaggle] Spam/Ham Email Classification 垃圾邮件分类(BERT)

文章目录

    • 1. 数据处理
    • 2. 下载预训练模型
    • 3. 加载数据
    • 4. 定义模型
    • 5. 训练
    • 6. 提交测试结果

练习地址:https://www.kaggle.com/c/ds100fa19
相关博文:
[Kaggle] Spam/Ham Email Classification 垃圾邮件分类(spacy)
[Kaggle] Spam/Ham Email Classification 垃圾邮件分类(RNN/GRU/LSTM)

本文使用 huggingface 上的预训练模型,在预训练模型的基础上,使用垃圾邮件数据集,进行训练 finetune,在kaggle提交测试结果

本文代码参考了《自然语言处理动手学Bert文本分类》

1. 数据处理

from datetime import timedelta
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
train = pd.read_csv("train.csv")
test_csv = pd.read_csv("test.csv")
train = train.fillna(" ")
test_csv = test_csv.fillna(" ")
train['all'] = train['subject'] + ' ' + train['email'] # 合并两个特征# 切分出一些验证集,分层抽样
from sklearn.model_selection import StratifiedShuffleSplit
splt = StratifiedShuffleSplit(n_splits=1,test_size=0.2,random_state=1)
for train_idx, valid_idx in splt.split(train, train['spam']):train_part = train.loc[train_idx]valid_part = train.loc[valid_idx]y_train = train_part['spam']
y_valid = valid_part['spam']
X_train = train_part['all']
X_valid = valid_part['all']X_test = test_csv['subject'] + ' ' + test_csv['email']
y_test = [0]*len(X_test) # 测试集没有标签,这么处理方便代码处理
y_test = torch.LongTensor(y_test) # 转成tensor

2. 下载预训练模型

预训练模型

模型下载很慢的话,我传到 csdn了,可以免费下载

以上模型文件放在一个文件夹里,如./bert_hugginggace/

提前安装包
pip install transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassificationtokenizer = AutoTokenizer.from_pretrained("./bert_hugginggace")
# distilbert-base-uncased-finetuned-sst-2-englishpretrain_model = AutoModelForSequenceClassification.from_pretrained("./bert_hugginggace")

一些使用的参数

PAD, CLS = '[PAD]', '[CLS]'
max_seq_len = 128
bert_hidden = 768
num_classes = 2
learning_rate = 1e-5
decay = 0.01
num_epochs = 5
early_stop_time = 2000
batch_size = 32
save_path = "./best_model.ckpt" # 最好的模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

3. 加载数据

  • 数据需要编码成 bert 需要的格式
    需要 token_ids, attention_mask
def load_dataset(texts, labels):contents = []for t, label in zip(texts, labels):token = tokenizer.tokenize(t)token = [CLS] + token# ['[CLS]', 'subject', ':', 'cell', 'phones', 'coming', 'soon', '<', 'html', '>', ...]seq_len = len(token)mask = []token_ids = tokenizer.convert_tokens_to_ids(token)# [101, 3395, 1024, 3526, 11640, 2746, 2574, 1026, 16129, 。。。]if len(token) < max_seq_len: # 长度不够的,pad 补齐mask = [1]*len(token) + [0]*(max_seq_len-len(token))token_ids = token_ids + [0]*(max_seq_len-len(token))else: # 超长的,截断mask = [1]*max_seq_lentoken_ids = token_ids[:max_seq_len]seq_len = max_seq_leny = [0]*num_classes y[label] = 1 # 处理下标签,方便后面计算 二元交叉熵损失contents.append((token_ids, y, seq_len, mask))return contents
  • 编写数据集迭代器,训练的时候,每次取出 batch_size 个样本来更新权重
class datasetIter():def __init__(self, datasets, batch_size, device):self.datasets = datasetsself.idx = 0self.device = deviceself.batch_size = batch_sizeself.batches = len(datasets)//batch_sizeself.residues = Falseif len(datasets)%batch_size != 0:self.residues = True # 剩余不足 batch_size 个的样本def __next__(self):if self.residues and self.idx==self.batches:batch_data = self.datasets[self.idx * self.batch_size : len(self.datasets)]self.idx += 1batch_data = self._to_tensor(batch_data)return batch_dataelif self.idx > self.batches:self.idx = 0raise StopIterationelse:batch_data = self.datasets[self.idx * self.batch_size : (self.idx+1) * self.batch_size]self.idx += 1batch_data = self._to_tensor(batch_data)return batch_datadef _to_tensor(self, datasets):x = torch.LongTensor([item[0] for item in datasets]).to(self.device)y = torch.FloatTensor([item[1] for item in datasets]).to(self.device)seq_len = torch.LongTensor([item[2] for item in datasets]).to(self.device)mask = torch.LongTensor([item[3] for item in datasets]).to(self.device)return (x, seq_len, mask), ydef __iter__(self):return selfdef __len__(self):if self.residues:return self.batches + 1else:return self.batches
def build_iter(datasets, batch_size, device):iter = datasetIter(datasets,batch_size,device)return iter

4. 定义模型

class myModel(nn.Module):def __init__(self):super(myModel, self).__init__()self.pretrain_model = pretrain_model # 预训练的bert模型for param in self.pretrain_model.parameters():param.requires_grad = True # 打开 finetune 开关def forward(self, x):context = x[0]mask = x[2]out = self.pretrain_model(context, attention_mask=mask)out = torch.sigmoid(out.logits) # sigmoid到 (0,1) 方便计算交叉熵return out

5. 训练

import time
import torch.nn.functional as Ffrom sklearn import metrics
from transformers.optimization import AdamW
  • 辅助计时函数
def get_time_dif(starttime):# calculate used timeendtime = time.time()return timedelta(seconds=int(round(endtime-starttime)))
  • 训练
def train(model, train_iter, dev_iter, test_iter):starttime = time.time() # 记录开始时间model.train()optimizer = AdamW(model.parameters(),lr=learning_rate,weight_decay=decay)total_batch = 0dev_best_loss = float("inf")last_improve = 0no_improve_flag = Falsemodel.train()for epoch in range(num_epochs):print("Epoch {}/{}".format(epoch+1, num_epochs))for i, (X, y) in enumerate(train_iter):outputs = model(X) # batch_size * num_classesmodel.zero_grad() # 清理梯度增量loss = F.binary_cross_entropy(outputs, y)loss.backward()optimizer.step()if total_batch%100 == 0: # 打印训练信息truelabels = torch.max(y.data, 1)[1].cpu()pred = torch.max(outputs, 1)[1].cpu()train_acc = metrics.accuracy_score(truelabels, pred)# 调用 评估函数 检查验证集上的效果dev_acc, dev_loss = evaluate(model, dev_iter) # 检查验证集上的效果, 保留效果最好的if dev_loss < dev_best_loss:dev_best_loss = dev_losstorch.save(model.state_dict(), save_path)improve = '*'last_improve = total_batchelse:improve = ' 'time_dif = get_time_dif(starttime)# 打印训练信息,id : >右对齐,n 宽度,.3 小数位数msg = 'Iter:{0:>6}, Train Loss:{1:>5.2}, Train Acc:{2:>6.2}, Val Loss:{3:>5.2}, val Acc :{4:>6.2%}, Time:{5} {6}'print(msg.format(total_batch, loss.item(),train_acc, dev_loss, dev_acc, time_dif, improve))model.train()total_batch += 1# 如果长时间没有改进,认为收敛,停止训练if total_batch - last_improve > early_stop_time:print("no improve after {} times, stop!".format(early_stop_time))no_improve_flag = Truebreakif no_improve_flag:break# 调用 测试函数,生成预测结果test(model, test_iter)
  • 评估函数
def evaluate(model, dev_iter):model.eval() # 评估模式loss_total = 0pred_all = np.array([], dtype=int)labels_all = np.array([], dtype=int)with torch.no_grad(): # 不记录图的操作,不更新梯度for X, y in dev_iter:outputs = model(X)loss = F.binary_cross_entropy(outputs, y)loss_total += losstruelabels = torch.max(y.data, 1)[1].cpu()pred = torch.max(outputs, 1)[1].cpu().numpy()labels_all = np.append(labels_all, truelabels)pred_all = np.append(pred_all, pred)acc = metrics.accuracy_score(labels_all, pred_all)return acc, loss_total/len(dev_iter)
  • 测试函数
def test(model, test_iter):model.load_state_dict(torch.load(save_path)) # 加载最佳模型model.eval() # 评估模式pred_all = np.array([], dtype=int)with torch.no_grad():for X, y in test_iter:outputs = model(X)pred = torch.max(outputs, 1)[1].cpu().numpy()pred_all = np.append(pred_all, pred)# 写入提交文件id = test_csv['id']output = pd.DataFrame({'id':id, 'Class': pred_all})output.to_csv("submission_bert.csv",  index=False)
  • 运行主程序
# 确定随机数
np.random.seed(520)
torch.manual_seed(520)
torch.cuda.manual_seed_all(520)
torch.backends.cudnn.deterministic = True# 加载数据
train_data = load_dataset(X_train, y_train)
valid_data = load_dataset(X_valid, y_valid)
test_data = load_dataset(X_test, y_test)# 数据迭代器
train_iter = build_iter(train_data, batch_size, device)
valid_iter = build_iter(valid_data, batch_size, device)
test_iter = build_iter(test_data, batch_size, device)# 模型
model = myModel().to(device)# 训练、评估、测试
train(model, train_iter, valid_iter, test_iter)

6. 提交测试结果

Private Score:0.98714
Public Score:0.99000

没怎么调参,准确率接近99%,效果还是很不错的!

欢迎大家提出意见和指正!多谢!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/473199.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python 中main函数总结

Python使用缩进对齐组织代码的执行&#xff0c;所有没有缩进的代码&#xff08;非函数定义和类定义&#xff09;&#xff0c;都会在载入时自动执行&#xff0c;这些代码&#xff0c;可以认为是Python的main函数。 每个文件&#xff08;模块&#xff09;都可以任意写一些没有缩进…

python从图片里提取主要颜色

一、需求&#xff1a; 从一张图片里提取主要的几种颜色 二、效果&#xff1a; 三、代码 from PIL import Image, ImageDraw, ImageFontdef get_dominant_colors(infile):image Image.open(infile)# 缩小图片&#xff0c;否则计算机压力太大small_image image.resize((80, …

LeetCode 790. 多米诺和托米诺平铺(动态规划)

文章目录1. 题目2. 解题1. 题目 有两种形状的瓷砖&#xff1a; 一种是 2x1 的多米诺形&#xff0c; 另一种是形如 “L” 的托米诺形。 两种形状都可以旋转。 XX <- 多米诺XX <- "L" 托米诺 X给定 N 的值&#xff0c;有多少种方法可以平铺 2 x N 的面板&…

Django后端编辑图片提取主要颜色API

一、需求 前端页面需要调用后端API&#xff0c;实现获取主要颜色json数据 二、图片效果 三、代码实现&#xff1a; # Create your views here. import os from django.core.files.storage import default_storage from django.http import HttpResponse, JsonResponse from …

LeetCode 898. 子数组按位或操作(前缀和思想)

文章目录1. 题目2. 解题2.1 超时解2.2 正解1. 题目 我们有一个非负整数数组 A。 对于每个&#xff08;连续的&#xff09;子数组 B [A[i], A[i1], ..., A[j]] &#xff08; i < j&#xff09;&#xff0c;我们对 B 中的每个元素进行按位或操作&#xff0c;获得结果 A[i] …

天池 在线编程 回合制游戏(前缀和)

文章目录1. 题目2. 解题1. 题目 QW 是一个回合制游戏的玩家&#xff0c;今天他决定去打怪。 QW 在一场战斗中会碰到 n 个怪物&#xff0c;每个怪物有攻击力 atk[i]&#xff0c;每回合结束时如果第 i 个怪物还活着&#xff0c;就会对 QW 造成 atk[i] 的伤害。 QW 只能在每回合…

Python程序员的圣经——《Python编程快速上手:让繁琐工作自动化》尾末附下载地址

一、前言 如今&#xff0c;人们面临的大多数任务都可以通过编写计算机软件来完成。Python是一种解释型、面向对象、动态数据类型的高级程序设计语言。通过Python编程&#xff0c;我们能够解决现实生活中的很多任务。 今天给大家分享一份Python程序员的圣经——《Python编程快…

POJ 3608

1.计算P上y坐标值最小的顶点&#xff08;称为 yminP &#xff09;和Q上y坐标值最大的顶点&#xff08;称为 ymaxQ&#xff09;。 2.为多边形在 yminP 和 ymaxQ 处构造两条切线 LP 和 LQ 使得他们对应的多边形位于他们的右侧。 此时 LP 和 LQ 拥有不同的方向&#xff0c; 并且 y…

天池 在线编程 聪明的销售(计数+贪心)

文章目录1. 题目2. 解题1. 题目 销售主管的任务是出售一系列的物品&#xff0c;其中每个物品都有一个编号。 由于出售具有相同编号的商品会更容易&#xff0c;所以销售主管决定删除一些物品。 现在她知道她最多能删除多少物品&#xff0c;她想知道最终袋子里最少可以包含多少…

关于计算机书籍的收集与整理(一)

本文来源&#xff1a;https://github.com/pinefor1983/CS-Growing-book 一、程序员技术、管理和认知 1、程序员技术&管理 关于程序员职场晋升&#xff0c;这是我的7点具体建议优秀程序员的7个特点对码农后浪的6点建议程序员百万年薪进阶指南做好技术管理&#xff0c;你必须…

天池 在线编程 放小球(动态规划)

文章目录1. 题目2. 解题2.1 动态规划1. 题目 n 个桶中小球的个数已知, 可以操作 k 次(每次从桶中取出一个球,或者添加一个球), 每个桶有规定的最大容量 W[i]。 求操作后两相邻桶之间的最大差值的平方的最小值。 n < 100 W[i] < 100样例 1: 输入: 5 6 [1,2,3,4,5] [15,…

LeetCode 1716. 计算力扣银行的钱(等差数列)

文章目录1. 题目2. 解题1. 题目 Hercy 想要为购买第一辆车存钱。他 每天 都往力扣银行里存钱。 最开始&#xff0c;他在周一的时候存入 1 块钱。 从周二到周日&#xff0c;他每天都比前一天多存入 1 块钱。 在接下来每一个周一&#xff0c;他都会比 前一个周一 多存入 1 块钱…

LeetCode 1717. 删除子字符串的最大得分

文章目录1. 题目2. 解题374 / 1631&#xff0c;前22.9%1215 / 7873&#xff0c;前15.4%1. 题目 给你一个字符串 s 和两个整数 x 和 y 。你可以执行下面两种操作任意次。 删除子字符串 "ab" 并得到 x 分。 比方说&#xff0c;从 “cabxbae” 删除 ab &#xff0c;得…

利用Python把四张图片按照顺序拼接起来

一、需求&#xff1a; 给出四张图片&#xff0c;按照一定的顺序拼接起来 二、图片&#xff1a; 左上角&#xff1a;&#xff08;像素512*512&#xff09; 右上角&#xff1a;&#xff08;像素284*512&#xff09; 左下角&#xff1a;&#xff08;像素284*512&#xff09; 右…

Linux:文件创建时间如何修改?

一、需求 修改文件创建时间 二、知识及方法步骤 touch命令用于创建空白文件或修改文件时间。 在Linux系统中一个文件有三种时间&#xff1a; 更改内容的时间 - mtime&#xff1a;当文件进行被写的时候&#xff0c;CTime就会更新更改权限的时间 - ctime&#xff1a;当文件的…

小案例:编写立方体六个面,合成一张全景图后端

一、需求&#xff1a; 给出立方体六个面&#xff0c;合成一张全景图 二、主要知识&#xff1a;py360convert 2.1、该项目的特点&#xff1a; 立方体贴图和等矩形之间的转换 等角于平面 纯python实现&#xff0c;仅依赖于numpy和scipy矢量化实施&#xff08;在大多数地…

LeetCode 1721. 交换链表中的节点(快慢指针)

文章目录1. 题目2. 解题1. 题目 给你链表的头节点 head 和一个整数 k 。 交换 链表正数第 k 个节点和倒数第 k 个节点的值后&#xff0c;返回链表的头节点&#xff08;链表 从 1 开始索引&#xff09;。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], k 2 输出&am…

爬虫小案例:基于Bing关键词批量下载图片

一、需求&#xff1a; 基于Bing网站&#xff0c;输入关键词&#xff0c;批量下载图片保存到本地 二、演示&#xff1a; 三、直接上代码 import os import urllib.request import urllib.parse from bs4 import BeautifulSoup import re import time# 设置请求头 header {Us…

LeetCode 1722. 执行交换操作后的最小汉明距离(并查集)

文章目录1. 题目2. 解题1. 题目 给你两个整数数组 source 和 target &#xff0c;长度都是 n 。 还有一个数组 allowedSwaps &#xff0c;其中每个 allowedSwaps[i] [ai, bi] 表示你可以交换数组 source 中下标为 ai 和 bi&#xff08;下标从 0 开始&#xff09;的两个元素。…

线性表的顺序表示和实现

/* 顺序表存储结构容易实现随机存取线性表的第i 个数据元素的操作&#xff0c;但在实现插入、 删除的操作时要移动大量数据元素&#xff0c;所以&#xff0c;它适用于数据相对稳定的线性表&#xff0c;如职工工资 表、学生学籍表等。 c2-1.h 是动态分配的顺序表存储结构&#x…