NLP入门——基于梯度下降法分类的应用

问题分析

我们前面研究的都是基于统计的方法,通过不同的统计方法得到不同的准确率,通过改善统计的方式来提高准确率。现在我们要研究基于数学的方式来预测准确率。

假设我们有一个分词 s_{class,word},class是该对象的类别,word是该分词的词频。
则该分词在某类别中的分数为:s_(class) = sum(s{class,word})
那么我们欲求一个正确的类别:s_(class_g),如何使正确的类别的分数最高呢?
这里引入了损失函数 l o s s = s ( c l a s s m ) − s ( c l a s s g ) loss = s(class_m) - s(class_g) loss=s(classm)s(classg) 即用最大类别的分数 - 正确类别的分数。如果正确类别即是分数最大的类别,则loss为0;否则loss将为一个大于0的数。
如果loss函数越来越小,则说明正确的类别越来越多,准确率也越来越高。因此我们需要找到一个 s_(class),来使得loss最小。

一元函数的梯度下降法示例

图片来自吴恩达深度学习教程。
在这里插入图片描述
假定代价函数loss为J(w),则有 J ( w ) = J ( w ) − α ∗ d J ( w ) d w J(w)=J(w)-α*\frac {dJ(w)}{dw} J(w)=J(w)αdwdJ(w)
α为学习率,用来控制每步的步长,即向下走一步的长度dJ(w)/dw就是函数J(w)对w求导
在这里插入图片描述
如图所示,该点的导数就是这个点相切于J(w)的小三角的高除以宽,如果我们从图中点为初始点开始梯度下降算法,则该点斜率符号为正,即dJ(w)/dw > 0,因此接下来会向左走一步。
在这里插入图片描述
整个梯度下降法的迭代过程就是不断向左走,直到逼近最小值点
在这里插入图片描述
若我们以如图点为初始化点,则该点处斜率为负,即dJ(w)/dw < 0,所以接下来会向右走:
在这里插入图片描述
整个梯度下降法的迭代过程就是不断向右走,即朝着最小值点的方向走。

梯度下降过程中有两个问题:
1.步长 不能过大,可能会跳过最小值点
答:设置学习率不断减小
2.局部最优解未必是全局最优解
答:参数量越大,找到全局最小值的概率越小,局部最小值就越接近全局最小值

基于梯度下降法的分类

所以我们可以有如下分析:
若class_max == class_g:
loss = 0
若class_max > class_g:
l o s s = s u m ( s c l a s s m , w o r d ) − s u m ( s c l a s s g , w o r d ) loss = sum(s_{class_m,word}) - sum(s_{class_g,word}) loss=sum(sclassm,word)sum(sclassg,word)
则有
d ( l o s s ) d ( s c l a s s m , w o r d ) = 1 \frac {d( loss) }{d (s_{class_m,word})} = 1 d(sclassm,word)d(loss)=1
d ( l o s s ) d ( s c l a s s g , w o r d ) = − 1 \frac{d( loss) }{d (s_{class_g,word})} = -1 d(sclassg,word)d(loss)=1
我们将loss沿着导数方向移动:
s c l a s s m , w o r d − = l r ∗ 1 s_{class_m,word} -= lr*1 sclassm,word=lr1
s c l a s s g , w o r d − = l r ∗ ( − 1 ) s_{class_g,word} -= lr*(-1) sclassg,word=lr(1)
化简则有:
s c l a s s m , w o r d − = l r s_{class_m,word} -= lr sclassm,word=lr
s c l a s s g , w o r d + = l r s_{class_g,word} += lr sclassg,word+=lr
这会使得错误的s_{class_max,word}分数减小,使得正确的s_{class_g,word}分数增大,使得s(class_g)趋近于s(class_max),使loss函数值不断减小。
最初准确率会快速上升,这是调参的结果。当准确率逐渐稳定出现震荡时,我们得到的即为全局最优解,我们绕开统计的方法,找到一组权重求得最优解。

#learnw.py
#encoding: utf-8import sys
from json import dump
from math import log, sqrt, inf
from random import uniform, seed
from tqdm import tqdmdef build_model(srcf, tgtf):#{class: {word: freq}}_c, _w = set(), set() #_c 为所有的类别,_w 为所有的词with open(srcf,"rb") as fsrc, open(tgtf,"rb") as ftgt:for sline, tline in zip(fsrc, ftgt):_s, _t = sline.strip(), tline.strip() # s,t分别为句子和类别if _s and _t:_s, _class = _s.decode("utf-8"), _t.decode("utf-8")if _class not in _c:#如果_c中没有这个类别_c.add(_class)#添加到_c中for word in _s.split():#遍历一行里面空格隔开的每个分词if word not in _w:#如果_w中没有这个类别_w.add(word)#添加到_w中_ =sqrt( 2.0 / (len(_c) + len(_w))) #设置随机数,一定要小return {_class: {_word: uniform(-_, _) for _word in _w} for _class in _c}              #为每个类别创建一个词典,词典中是子词和随机到的初始点wdef compute_instance(model, lin):rs = {}_max_score, _max_class = -inf, Nonefor _class, v in model.items(): #对模型中的每一类和类型的词典 v:{word: freq}rs[_class] = _s = sum(v.get(_word, 0.0) for _word in lin) #这个类的分数即为该类中所有子词分数之和if _s > _max_score:_max_score = _s_max_class = _class#获取分数最高的类别return rs, _max_class, _max_score#返回每个类别的分数、最大分数的类别、最大的分数
#返回每一类的频率分布,在这行中所有分词分数之和最大的类,最大的分数 def train(srcf, tgtf, model, base_lr=0.1, max_run=128):#学习率初始为0.1,训练128轮with open(srcf,"rb") as fsrc, open(tgtf,"rb") as ftgt:for _epoch in tqdm(range(1,max_run+1)): #在每轮中fsrc.seek(0)    #对文件做复位,每次读取文件开头ftgt.seek(0)_lr = base_lr / sqrt(_epoch) #将学习率不断变小,控制步长越来越小total = err = 0for sline, tline in zip(fsrc, ftgt):#对每行数据_s, _t = sline.strip(), tline.strip()if _s and _t:_s, _class = _s.decode("utf-8").split(), _t.decode("utf-8")#_s为子词,_t为类别scores, max_class, max_score = compute_instance(model, _s)#算出每行中每个类别的分数、最大分数的类别、最大的分数if max_class != _class:#判断最大的类别和标准答案的类别是否相等for _word in _s:model[max_class][_word] -= _lr  #错误的类别model[_class][_word] += _lr     #正确的类别err += 1    #错误的数量+1total += 1      #总条数+1print("Epoch %d: %.2f" % (_epoch, (float(total - err) / total *100.0)))#返回每轮预测的准确率,转化为百分数return modeldef save(modin, frs):with open(frs, "w") as f:dump(modin, f) #用dump方法向文件写strif __name__=="__main__":seed(408) #固定随机数种子save(train(*sys.argv[1:3], build_model(*sys.argv[1:3])),sys.argv[3])#将第一个参数和第二个参数(训练集句子和分类)给build_model后,将句子、标签和返回的词典传入train中,最后保存模型

我们可以使用tqdm Python进度条来展示演示过程。在命令行执行:

:~/nlp/tnews$ python learnw.py src.train.bpe.txt tgt.train.s.txt learnw.model.txt

在这里插入图片描述
执行到100轮左右,准确率就在98%左右震荡。

对验证集做验证与早停

使用前面我们写好的脚本,在命令行输入:

:~/nlp/tnews$ python predict.py src.dev.bpe.txt learnw.model.txt pred.learnw.dev.txt
:~/nlp/tnews$ python acc.py pred.learnw.dev.txt tgt.dev.s.txt 
46.92

可以看到,准确率是46.92%,在验证集上表现一般。与我们前面提到的统计方法相比,可以大致认为是统计方法最优化的最高准确率。

我们找到的参数是在训练集上表现最好的,但未必是在所有情况下表现最好的。我们需要的模型是在验证集上性能最好的,因此我们需要在模型对验证集的表现上做修改:

#encoding: utf-8def eva(srcvf, tgtvf, model):#采集对验证集的正确率total = corr = 0with open(srcvf,"rb") as fsrc, open(tgtvf,"rb") as ftgt:for sline, tline in zip(fsrc, ftgt):#对每行数据_s, _t = sline.strip(), tline.strip()if _s and _t:_s, _class = _s.decode("utf-8").split(), _t.decode("utf-8")max_class = compute_instance(model, _s)[1]  #仅需要预测的最大类别if max_class == _class:corr += 1total += 1return float(corr) / total * 100.0        #返回每轮验证集的准确率         def train(srcf, tgtf, srcvf, tgtvf, model, frs, base_lr=0.1, max_run=128):#学习率初始为0.1,训练128轮with open(srcf,"rb") as fsrc, open(tgtf,"rb") as ftgt:_best_acc = eva(srcvf, tgtvf, model)print("Epoch 0: dev %.2f" % _best_acc)  #获得每轮中验证集上的准确率for _epoch in range(1,max_run+1): #在每轮中fsrc.seek(0)    #对文件做复位,每次读取文件开头ftgt.seek(0)_lr = base_lr / sqrt(_epoch) #将学习率不断变小,控制步长越来越小total = err = 0for sline, tline in zip(fsrc, ftgt):#对每行数据_s, _t = sline.strip(), tline.strip()if _s and _t:_s, _class = _s.decode("utf-8").split(), _t.decode("utf-8")#_s为子词,_t为类别scores, max_class, max_score = compute_instance(model, _s)#算出每行中每个类别的分数、最大分数的类别、最大的分数if max_class != _class:#判断最大的类别和标准答案的类别是否相等for _word in _s:model[max_class][_word] -= _lr  #错误的类别model[_class][_word] += _lr     #正确的类别err += 1    #错误的数量+1total += 1      #总条数+1_eva_acc = eva(srcvf, tgtvf, model)print("Epoch %d: train %.2f, dev %.2f" % (_epoch, (float(total - err) /total *100.0), _eva_acc))   #返回每轮训练集、验证集的准确率,转化为百分数if _eva_acc >= _best_acc:   #如果当前在验证集的准确率是最好的,就保存这个模型,并更新最好准确率save(model, frs)_best_acc = _eva_accreturn modeldef save(modin, frs):with open(frs, "w") as f:dump(modin, f) #用dump方法向文件写strif __name__=="__main__":seed(408) #固定随机数种子train(*sys.argv[1:5], build_model(*sys.argv[1:3]),sys.argv[5])#传入训练句子、训练标签、验证句子、验证标签以及保存模型文件

在命令行执行:

:~/nlp/tnews$ python learnw.py src.train.bpe.txt tgt.train.s.txt src.dev.bpe.txt tgt.dev.s.txt learnw.model.txt 

在这里插入图片描述
我们可以看到,对验证集的测试上,在很早的轮次已经趋于稳定值了。因此我们没有必要跑完所有128轮,对验证集的准确率的提升没有意义。
我们引入early_stop参数,作为早停的停止条件。如果我们连续early_stop轮都没有找到更好的结果,我们就将训练停止 ,避免无效的训练:

if _eva_acc >= _best_acc:   #如果当前在验证集的准确率是最好的,就保存这个模型,并更新最好准确率save(model, frs)_best_acc = _eva_accanbest = 0
else:anbest += 1if anbest > earlystop:	#如果连续earlystop轮准确率都没有提高,则停止训练break  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/29883.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【漏洞复现】金和OA C6 download.jsp 任意文件读取漏洞

免责声明&#xff1a; 本文内容旨在提供有关特定漏洞或安全漏洞的信息&#xff0c;以帮助用户更好地了解可能存在的风险。公布此类信息的目的在于促进网络安全意识和技术进步&#xff0c;并非出于任何恶意目的。阅读者应该明白&#xff0c;在利用本文提到的漏洞信息或进行相关测…

AI写作与个人写作:思考性的探究

在人工智能&#xff08;AI&#xff09;技术日益成熟的今天&#xff0c;AI写作已经成为现实&#xff0c;例如文心一言、kimi、研导AI写作等工具。然而&#xff0c;当机器开始涉足写作这一传统上被认为是人类独有的创造性活动时&#xff0c;我们不禁要问&#xff1a;AI写作是否能…

AI数据分析:Excel表格智能判断数据起点来计算增长率

工作任务&#xff1a;计算Excel表格中2023年1月到2024年4月的总增长率和复合增长率。 如果数据都有的情况下&#xff0c;公式很简单&#xff1a; 总增长率 (O2-B2)/B2 复合增长率 POWER((O2/B2),1/13)-1 但是&#xff0c;2023年1月、2月、3月的数据&#xff0c;有些有&…

AI办公自动化:用通义千问批量翻译长篇英语TXT文档

在deepseek中输入提示词&#xff1a; 你是一个Python编程专家&#xff0c;现在要完成一个编写基于qwen-turbo模型API和dashscope库的程序脚本&#xff0c;具体步骤如下&#xff1a; 打开文件夹&#xff1a;F:\AI自媒体内容\待翻译&#xff1b; 获取里面所有TXT文档&#xff…

Vue3搭载后端服务器开发文档

1 第8章 “微商城”后端服务器搭建 “微商城”后端服务器基于 ThinkJS MySQL &#xff0c;以下是环境搭建文档。 8.1 搭建 MySQL 环境 8.1.1 安装 MySQL 本项目基于 MySQL 5.7 社区版&#xff0c;如果您还没有安装&#xff0c;请继续阅读安装步骤。如果您 已经安…

mac如何检测硬盘损坏 常用mac硬盘检测坏道工具推荐

mac有时候也出现一些问题&#xff0c;比如硬盘损坏。硬盘损坏会导致数据丢失、系统崩溃、性能下降等严重的后果&#xff0c;所以及时检测和修复硬盘损坏是非常必要的。那么&#xff0c;mac如何检测硬盘损坏呢&#xff1f;有哪些常用的mac硬盘检测坏道工具呢&#xff1f; 一、m…

Python 数据可视化 散点图

Python 数据可视化 散点图 import matplotlib.pyplot as plt import numpy as npdef plot_scatter(ref_info_dict, test_info_dict):# 绘制散点图&#xff0c;ref横&#xff0c;test纵plt.figure(figsize(80, 48))n 0# scatter_header_list [peak_insert_size, median_insert…

nginx反向代理动静分离和负载均衡

一.nginx 反向代理简要介绍 1.什么是反向代理 反向代理是一种服务器&#xff0c;在这种设置中&#xff0c;代理服务器接收客户端的请求&#xff0c;并将这些请求转发给一个或多个后端服务器&#xff08;例如应用服务器、数据库服务器等&#xff09;。然后&#xff0c;后端服务…

【车载开发系列】IIC总线协议基础篇

【车载开发系列】IIC总线协议基础篇 【车载开发系列】IIC总线协议基础篇 【车载开发系列】IIC总线协议基础篇一. 什么是I2C二. I2C使用场景三. I2C的特点四. 传输速度四种模式五. IIC基本通讯规则1&#xff09;起始信号S2&#xff09;停止信号P3&#xff09;发送数据&#xff…

【LinkedList与链表】

目录 1&#xff0c;ArrayList的缺陷 2&#xff0c;链表 2.1 链表的概念及结构 2.2 链表的实现 2.2.1 无头单向非循环链表实现 3&#xff0c;LinkedList的模拟实现 3.1 无头双向链表实现 4&#xff0c;LinkedList的使用 4.1 什么是LinkedList 4.2 LinkedList的使用 5…

第4天:用户认证系统实现

第4天&#xff1a;用户认证系统实现 目标 实现用户认证系统&#xff0c;包括用户注册、登录、登出和密码管理。 任务概览 使用Django内置的用户认证系统。创建用户注册和登录表单。实现用户登出和密码重置功能。 详细步骤 1. 使用Django内置的用户认证系统 Django提供了…

上位机图像处理和嵌入式模块部署(h750 mcu和ad/da电路)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 大部分同学学习mcu的时候&#xff0c;都会把重点放在232、485、can、usb、eth这些常规的通信接口上面。还有一部分同学&#xff0c;可能会对lcd、c…

【Java基础5】JDK、JRE和JVM的区别与联系

JDK、JRE和JVM的区别与联系 Java是一种广泛使用的编程语言&#xff0c;它的跨平台特性得益于Java虚拟机&#xff08;JVM&#xff09;。然而&#xff0c;在Java的世界里&#xff0c;JDK、JRE和JVM这三个术语常常让人感到困惑。本文将阐述它们各自的功能&#xff0c;以及它们是如…

【設計モードの特性に基づく動的ルーティングマッピングモード】

ASP.NET Coreでは、HTTP要求を対応するコントローラ操作にマッピングするためのルーティングはコア機能の1つです。「ルーティング駆動設計モデル」は私が作りあげたばかりの設計モデル名ですが、ASPに基づくことができます。NET Coreのルーティング特性は、ルーティングを中心…

Codeforces Round 953 (Div. 2 ABCDEF题) 视频讲解

A. Alice and Books Problem Statement Alice has n n n books. The 1 1 1-st book contains a 1 a_1 a1​ pages, the 2 2 2-nd book contains a 2 a_2 a2​ pages, … \ldots …, the n n n-th book contains a n a_n an​ pages. Alice does the following: She …

【HTML01】HTML基础-基本元素-附带案例-作业

文章目录 HTML 概述学HTML到底学什么HTML的基本结构HTML的注释的作用html的语法HTML的常用标签&#xff1a;相关单词参考资料 HTML 概述 英文全称&#xff1a;Hyper Text Markup Language 中文&#xff1a;超文本标记语言&#xff0c;就将常用的50多个标记嵌入在纯文本中&…

spark常见面试题

文章目录 1.Spark 的运行流程&#xff1f;2.Spark 中的 RDD 机制理解吗&#xff1f;3.RDD 的宽窄依赖4.DAG 中为什么要划分 Stage&#xff1f;5.Spark 程序执行&#xff0c;有时候默认为什么会产生很多 task&#xff0c;怎么修改默认 task 执行个数&#xff1f;6.RDD 中 reduce…

从0到1上线小程序的步骤

文章目录 一、开发前的准备二、开发中三、开发完成的上线部署相关资料和网址 开发一个小程序&#xff08;例如微信小程序&#xff09;涉及到多个阶段&#xff0c;每个阶段都有特定的步骤和要求。以下是详细的步骤及相关资料和网址&#xff0c;帮助你在开发前、开发中和开发完成…

镜像源问题:pip,npm,git,Linux,docker

镜像源的作用 提高下载速度&#xff1a;镜像源通常位于全球不同的地理位置&#xff0c;用户可以选择离自己最近的镜像源下载软件或更新&#xff0c;从而大大提高下载速度和效率。 负载均衡&#xff1a;通过将下载请求分散到多个镜像源&#xff0c;可以减轻主服务器的负载&…

RabbitMQ 入门

目录 一&#xff1a;什么是MQ 二&#xff1a;安装RabbitMQ 三&#xff1a;在Java中如何实现MQ的使用 RabbitMQ的五种消息模型 1.基本消息队列&#xff08;BasicQueue&#xff09; 2.工作消息队列&#xff08;WorkQueue&#xff09; 3. 发布订阅&#xff08;Publish、S…