NLP入门——基于梯度下降法分类的应用

问题分析

我们前面研究的都是基于统计的方法,通过不同的统计方法得到不同的准确率,通过改善统计的方式来提高准确率。现在我们要研究基于数学的方式来预测准确率。

假设我们有一个分词 s_{class,word},class是该对象的类别,word是该分词的词频。
则该分词在某类别中的分数为:s_(class) = sum(s{class,word})
那么我们欲求一个正确的类别:s_(class_g),如何使正确的类别的分数最高呢?
这里引入了损失函数 l o s s = s ( c l a s s m ) − s ( c l a s s g ) loss = s(class_m) - s(class_g) loss=s(classm)s(classg) 即用最大类别的分数 - 正确类别的分数。如果正确类别即是分数最大的类别,则loss为0;否则loss将为一个大于0的数。
如果loss函数越来越小,则说明正确的类别越来越多,准确率也越来越高。因此我们需要找到一个 s_(class),来使得loss最小。

一元函数的梯度下降法示例

图片来自吴恩达深度学习教程。
在这里插入图片描述
假定代价函数loss为J(w),则有 J ( w ) = J ( w ) − α ∗ d J ( w ) d w J(w)=J(w)-α*\frac {dJ(w)}{dw} J(w)=J(w)αdwdJ(w)
α为学习率,用来控制每步的步长,即向下走一步的长度dJ(w)/dw就是函数J(w)对w求导
在这里插入图片描述
如图所示,该点的导数就是这个点相切于J(w)的小三角的高除以宽,如果我们从图中点为初始点开始梯度下降算法,则该点斜率符号为正,即dJ(w)/dw > 0,因此接下来会向左走一步。
在这里插入图片描述
整个梯度下降法的迭代过程就是不断向左走,直到逼近最小值点
在这里插入图片描述
若我们以如图点为初始化点,则该点处斜率为负,即dJ(w)/dw < 0,所以接下来会向右走:
在这里插入图片描述
整个梯度下降法的迭代过程就是不断向右走,即朝着最小值点的方向走。

梯度下降过程中有两个问题:
1.步长 不能过大,可能会跳过最小值点
答:设置学习率不断减小
2.局部最优解未必是全局最优解
答:参数量越大,找到全局最小值的概率越小,局部最小值就越接近全局最小值

基于梯度下降法的分类

所以我们可以有如下分析:
若class_max == class_g:
loss = 0
若class_max > class_g:
l o s s = s u m ( s c l a s s m , w o r d ) − s u m ( s c l a s s g , w o r d ) loss = sum(s_{class_m,word}) - sum(s_{class_g,word}) loss=sum(sclassm,word)sum(sclassg,word)
则有
d ( l o s s ) d ( s c l a s s m , w o r d ) = 1 \frac {d( loss) }{d (s_{class_m,word})} = 1 d(sclassm,word)d(loss)=1
d ( l o s s ) d ( s c l a s s g , w o r d ) = − 1 \frac{d( loss) }{d (s_{class_g,word})} = -1 d(sclassg,word)d(loss)=1
我们将loss沿着导数方向移动:
s c l a s s m , w o r d − = l r ∗ 1 s_{class_m,word} -= lr*1 sclassm,word=lr1
s c l a s s g , w o r d − = l r ∗ ( − 1 ) s_{class_g,word} -= lr*(-1) sclassg,word=lr(1)
化简则有:
s c l a s s m , w o r d − = l r s_{class_m,word} -= lr sclassm,word=lr
s c l a s s g , w o r d + = l r s_{class_g,word} += lr sclassg,word+=lr
这会使得错误的s_{class_max,word}分数减小,使得正确的s_{class_g,word}分数增大,使得s(class_g)趋近于s(class_max),使loss函数值不断减小。
最初准确率会快速上升,这是调参的结果。当准确率逐渐稳定出现震荡时,我们得到的即为全局最优解,我们绕开统计的方法,找到一组权重求得最优解。

#learnw.py
#encoding: utf-8import sys
from json import dump
from math import log, sqrt, inf
from random import uniform, seed
from tqdm import tqdmdef build_model(srcf, tgtf):#{class: {word: freq}}_c, _w = set(), set() #_c 为所有的类别,_w 为所有的词with open(srcf,"rb") as fsrc, open(tgtf,"rb") as ftgt:for sline, tline in zip(fsrc, ftgt):_s, _t = sline.strip(), tline.strip() # s,t分别为句子和类别if _s and _t:_s, _class = _s.decode("utf-8"), _t.decode("utf-8")if _class not in _c:#如果_c中没有这个类别_c.add(_class)#添加到_c中for word in _s.split():#遍历一行里面空格隔开的每个分词if word not in _w:#如果_w中没有这个类别_w.add(word)#添加到_w中_ =sqrt( 2.0 / (len(_c) + len(_w))) #设置随机数,一定要小return {_class: {_word: uniform(-_, _) for _word in _w} for _class in _c}              #为每个类别创建一个词典,词典中是子词和随机到的初始点wdef compute_instance(model, lin):rs = {}_max_score, _max_class = -inf, Nonefor _class, v in model.items(): #对模型中的每一类和类型的词典 v:{word: freq}rs[_class] = _s = sum(v.get(_word, 0.0) for _word in lin) #这个类的分数即为该类中所有子词分数之和if _s > _max_score:_max_score = _s_max_class = _class#获取分数最高的类别return rs, _max_class, _max_score#返回每个类别的分数、最大分数的类别、最大的分数
#返回每一类的频率分布,在这行中所有分词分数之和最大的类,最大的分数 def train(srcf, tgtf, model, base_lr=0.1, max_run=128):#学习率初始为0.1,训练128轮with open(srcf,"rb") as fsrc, open(tgtf,"rb") as ftgt:for _epoch in tqdm(range(1,max_run+1)): #在每轮中fsrc.seek(0)    #对文件做复位,每次读取文件开头ftgt.seek(0)_lr = base_lr / sqrt(_epoch) #将学习率不断变小,控制步长越来越小total = err = 0for sline, tline in zip(fsrc, ftgt):#对每行数据_s, _t = sline.strip(), tline.strip()if _s and _t:_s, _class = _s.decode("utf-8").split(), _t.decode("utf-8")#_s为子词,_t为类别scores, max_class, max_score = compute_instance(model, _s)#算出每行中每个类别的分数、最大分数的类别、最大的分数if max_class != _class:#判断最大的类别和标准答案的类别是否相等for _word in _s:model[max_class][_word] -= _lr  #错误的类别model[_class][_word] += _lr     #正确的类别err += 1    #错误的数量+1total += 1      #总条数+1print("Epoch %d: %.2f" % (_epoch, (float(total - err) / total *100.0)))#返回每轮预测的准确率,转化为百分数return modeldef save(modin, frs):with open(frs, "w") as f:dump(modin, f) #用dump方法向文件写strif __name__=="__main__":seed(408) #固定随机数种子save(train(*sys.argv[1:3], build_model(*sys.argv[1:3])),sys.argv[3])#将第一个参数和第二个参数(训练集句子和分类)给build_model后,将句子、标签和返回的词典传入train中,最后保存模型

我们可以使用tqdm Python进度条来展示演示过程。在命令行执行:

:~/nlp/tnews$ python learnw.py src.train.bpe.txt tgt.train.s.txt learnw.model.txt

在这里插入图片描述
执行到100轮左右,准确率就在98%左右震荡。

对验证集做验证与早停

使用前面我们写好的脚本,在命令行输入:

:~/nlp/tnews$ python predict.py src.dev.bpe.txt learnw.model.txt pred.learnw.dev.txt
:~/nlp/tnews$ python acc.py pred.learnw.dev.txt tgt.dev.s.txt 
46.92

可以看到,准确率是46.92%,在验证集上表现一般。与我们前面提到的统计方法相比,可以大致认为是统计方法最优化的最高准确率。

我们找到的参数是在训练集上表现最好的,但未必是在所有情况下表现最好的。我们需要的模型是在验证集上性能最好的,因此我们需要在模型对验证集的表现上做修改:

#encoding: utf-8def eva(srcvf, tgtvf, model):#采集对验证集的正确率total = corr = 0with open(srcvf,"rb") as fsrc, open(tgtvf,"rb") as ftgt:for sline, tline in zip(fsrc, ftgt):#对每行数据_s, _t = sline.strip(), tline.strip()if _s and _t:_s, _class = _s.decode("utf-8").split(), _t.decode("utf-8")max_class = compute_instance(model, _s)[1]  #仅需要预测的最大类别if max_class == _class:corr += 1total += 1return float(corr) / total * 100.0        #返回每轮验证集的准确率         def train(srcf, tgtf, srcvf, tgtvf, model, frs, base_lr=0.1, max_run=128):#学习率初始为0.1,训练128轮with open(srcf,"rb") as fsrc, open(tgtf,"rb") as ftgt:_best_acc = eva(srcvf, tgtvf, model)print("Epoch 0: dev %.2f" % _best_acc)  #获得每轮中验证集上的准确率for _epoch in range(1,max_run+1): #在每轮中fsrc.seek(0)    #对文件做复位,每次读取文件开头ftgt.seek(0)_lr = base_lr / sqrt(_epoch) #将学习率不断变小,控制步长越来越小total = err = 0for sline, tline in zip(fsrc, ftgt):#对每行数据_s, _t = sline.strip(), tline.strip()if _s and _t:_s, _class = _s.decode("utf-8").split(), _t.decode("utf-8")#_s为子词,_t为类别scores, max_class, max_score = compute_instance(model, _s)#算出每行中每个类别的分数、最大分数的类别、最大的分数if max_class != _class:#判断最大的类别和标准答案的类别是否相等for _word in _s:model[max_class][_word] -= _lr  #错误的类别model[_class][_word] += _lr     #正确的类别err += 1    #错误的数量+1total += 1      #总条数+1_eva_acc = eva(srcvf, tgtvf, model)print("Epoch %d: train %.2f, dev %.2f" % (_epoch, (float(total - err) /total *100.0), _eva_acc))   #返回每轮训练集、验证集的准确率,转化为百分数if _eva_acc >= _best_acc:   #如果当前在验证集的准确率是最好的,就保存这个模型,并更新最好准确率save(model, frs)_best_acc = _eva_accreturn modeldef save(modin, frs):with open(frs, "w") as f:dump(modin, f) #用dump方法向文件写strif __name__=="__main__":seed(408) #固定随机数种子train(*sys.argv[1:5], build_model(*sys.argv[1:3]),sys.argv[5])#传入训练句子、训练标签、验证句子、验证标签以及保存模型文件

在命令行执行:

:~/nlp/tnews$ python learnw.py src.train.bpe.txt tgt.train.s.txt src.dev.bpe.txt tgt.dev.s.txt learnw.model.txt 

在这里插入图片描述
我们可以看到,对验证集的测试上,在很早的轮次已经趋于稳定值了。因此我们没有必要跑完所有128轮,对验证集的准确率的提升没有意义。
我们引入early_stop参数,作为早停的停止条件。如果我们连续early_stop轮都没有找到更好的结果,我们就将训练停止 ,避免无效的训练:

if _eva_acc >= _best_acc:   #如果当前在验证集的准确率是最好的,就保存这个模型,并更新最好准确率save(model, frs)_best_acc = _eva_accanbest = 0
else:anbest += 1if anbest > earlystop:	#如果连续earlystop轮准确率都没有提高,则停止训练break  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/29883.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【漏洞复现】金和OA C6 download.jsp 任意文件读取漏洞

免责声明&#xff1a; 本文内容旨在提供有关特定漏洞或安全漏洞的信息&#xff0c;以帮助用户更好地了解可能存在的风险。公布此类信息的目的在于促进网络安全意识和技术进步&#xff0c;并非出于任何恶意目的。阅读者应该明白&#xff0c;在利用本文提到的漏洞信息或进行相关测…

AI数据分析:Excel表格智能判断数据起点来计算增长率

工作任务&#xff1a;计算Excel表格中2023年1月到2024年4月的总增长率和复合增长率。 如果数据都有的情况下&#xff0c;公式很简单&#xff1a; 总增长率 (O2-B2)/B2 复合增长率 POWER((O2/B2),1/13)-1 但是&#xff0c;2023年1月、2月、3月的数据&#xff0c;有些有&…

AI办公自动化:用通义千问批量翻译长篇英语TXT文档

在deepseek中输入提示词&#xff1a; 你是一个Python编程专家&#xff0c;现在要完成一个编写基于qwen-turbo模型API和dashscope库的程序脚本&#xff0c;具体步骤如下&#xff1a; 打开文件夹&#xff1a;F:\AI自媒体内容\待翻译&#xff1b; 获取里面所有TXT文档&#xff…

mac如何检测硬盘损坏 常用mac硬盘检测坏道工具推荐

mac有时候也出现一些问题&#xff0c;比如硬盘损坏。硬盘损坏会导致数据丢失、系统崩溃、性能下降等严重的后果&#xff0c;所以及时检测和修复硬盘损坏是非常必要的。那么&#xff0c;mac如何检测硬盘损坏呢&#xff1f;有哪些常用的mac硬盘检测坏道工具呢&#xff1f; 一、m…

Python 数据可视化 散点图

Python 数据可视化 散点图 import matplotlib.pyplot as plt import numpy as npdef plot_scatter(ref_info_dict, test_info_dict):# 绘制散点图&#xff0c;ref横&#xff0c;test纵plt.figure(figsize(80, 48))n 0# scatter_header_list [peak_insert_size, median_insert…

nginx反向代理动静分离和负载均衡

一.nginx 反向代理简要介绍 1.什么是反向代理 反向代理是一种服务器&#xff0c;在这种设置中&#xff0c;代理服务器接收客户端的请求&#xff0c;并将这些请求转发给一个或多个后端服务器&#xff08;例如应用服务器、数据库服务器等&#xff09;。然后&#xff0c;后端服务…

【LinkedList与链表】

目录 1&#xff0c;ArrayList的缺陷 2&#xff0c;链表 2.1 链表的概念及结构 2.2 链表的实现 2.2.1 无头单向非循环链表实现 3&#xff0c;LinkedList的模拟实现 3.1 无头双向链表实现 4&#xff0c;LinkedList的使用 4.1 什么是LinkedList 4.2 LinkedList的使用 5…

上位机图像处理和嵌入式模块部署(h750 mcu和ad/da电路)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 大部分同学学习mcu的时候&#xff0c;都会把重点放在232、485、can、usb、eth这些常规的通信接口上面。还有一部分同学&#xff0c;可能会对lcd、c…

Codeforces Round 953 (Div. 2 ABCDEF题) 视频讲解

A. Alice and Books Problem Statement Alice has n n n books. The 1 1 1-st book contains a 1 a_1 a1​ pages, the 2 2 2-nd book contains a 2 a_2 a2​ pages, … \ldots …, the n n n-th book contains a n a_n an​ pages. Alice does the following: She …

【HTML01】HTML基础-基本元素-附带案例-作业

文章目录 HTML 概述学HTML到底学什么HTML的基本结构HTML的注释的作用html的语法HTML的常用标签&#xff1a;相关单词参考资料 HTML 概述 英文全称&#xff1a;Hyper Text Markup Language 中文&#xff1a;超文本标记语言&#xff0c;就将常用的50多个标记嵌入在纯文本中&…

RabbitMQ 入门

目录 一&#xff1a;什么是MQ 二&#xff1a;安装RabbitMQ 三&#xff1a;在Java中如何实现MQ的使用 RabbitMQ的五种消息模型 1.基本消息队列&#xff08;BasicQueue&#xff09; 2.工作消息队列&#xff08;WorkQueue&#xff09; 3. 发布订阅&#xff08;Publish、S…

【论文阅读】Multi-Camera Unified Pre-Training via 3D Scene Reconstruction

论文链接 代码链接 多摄像头三维感知已成为自动驾驶领域的一个重要研究领域&#xff0c;为基于激光雷达的解决方案提供了一种可行且具有成本效益的替代方案。具有成本效益的解决方案。现有的多摄像头算法主要依赖于单目 2D 预训练。然而&#xff0c;单目 2D 预训练忽略了多摄像…

【深度学习】GPT-3,Language Models are Few-Shot Learners(一)

论文&#xff1a; https://arxiv.org/abs/2005.14165 摘要 最近的研究表明&#xff0c;通过在大规模文本语料库上进行预训练&#xff0c;然后在特定任务上进行微调&#xff0c;可以在许多NLP任务和基准上取得显著的进展。虽然这种方法在结构上通常是任务无关的&#xff0c;但…

走进Web3时代的物联网领域:科技的无限可能

随着Web3技术的迅速发展&#xff0c;物联网&#xff08;IoT&#xff09;领域正迎来一场深刻的变革。本文将深入探讨Web3时代如何重新定义物联网的边界和未来发展的无限可能性&#xff0c;从技术原理到应用案例&#xff0c;为读者呈现一个充满挑战和机遇的全新科技景观。 1. Web…

mediasoup源码分析(三)channel创建及信令交互

mediasoup源码分析--channel创建及信令交互 概述跨职能图业务流程图代码剖析 概述 在golang实现mediasoup的tcp服务及channel通道一文中&#xff0c;已经介绍过信令服务中tcp和channel的创建&#xff0c;本文主要讲解c中mediasoup的channel创建&#xff0c;以及信令服务和medi…

如何避免接口重复请求(axios推荐使用AbortController)

前言&#xff1a; 我们日常开发中&#xff0c;经常会遇到点击一个按钮或者进行搜索时&#xff0c;请求接口的需求。 如果我们不做优化&#xff0c;连续点击按钮或者进行搜索&#xff0c;接口会重复请求。 以axios为例&#xff0c;我们一般以以下几种方法为主&#xff1a; 1…

【Pmac】PMAC QT联合开发中各种可能遇到的坑

目录 1. 错误 C2027 使用了未定义类型“PCOMMSERVERLib::DEVUPLOAD”2. 输入了正确的pmac的ip地址&#xff0c;没有显示可选的pmac设备3. Pmac DTC-28B无读数 使用QT编写PMAC上位机程序时&#xff0c;利用QT中的dump工具可以将pcommserver.exe转化为pcommserverlib.h和pcommser…

调度算法-内存页面置换算法

缺⻚异常&#xff08;缺⻚中断&#xff09; 与⼀般中断的主要区别在于&#xff1a; 缺⻚中断在指令执⾏「期间」产⽣和处理中断信号&#xff0c;⽽⼀般中断在⼀条指令执⾏「完成」后检查和处理中断信号。缺⻚中断返回到该指令的开始重新执⾏「该指令」&#xff0c;⽽⼀般中断返…

【HarmonyOS】鸿蒙应用模块化实现

【HarmonyOS】鸿蒙应用模块化实现 一、Module的概念 Module是HarmonyOS应用的基本功能单元&#xff0c;包含了源代码、资源文件、第三方库及应用清单文件&#xff0c;每一个Module都可以独立进行编译和运行。一个HarmonyOS应用通常会包含一个或多个Module&#xff0c;因此&am…

我主编的电子技术实验手册(08)——串联电阻分压

本专栏是笔者主编教材&#xff08;图0所示&#xff09;的电子版&#xff0c;依托简易的元器件和仪表安排了30多个实验&#xff0c;主要面向经费不太充足的中高职院校。每个实验都安排了必不可少的【预习知识】&#xff0c;精心设计的【实验步骤】&#xff0c;全面丰富的【思考习…