对话模型Demo解读(使用代码解读原理)

文章目录

  • 前言
  • 一、数据加工
  • 二、模型搭建
  • 三、模型训练
    • 1、构建模型
    • 2、优化器与损失函数定义
    • 3、模型训练
  • 四、模型推理
  • 五、所有Demo源码


前言

对话模型是一种人工智能技术,旨在使计算机能够像人类一样进行对话和交流。这种模型通常基于深度学习和自然语言处理技术,能够理解自然语言并做出相应的回应。然而现有博客很少介绍对话模型内容,也很少用一个简单代码带领大家理解其原理。因此,我创建一个简单的对话模型,在不适用Hugging Face或LSTM结构,旨在使用一个简单的全连接神经网络来实现这个模型,且代码基于PyTorch框架搭建,意在帮助读者构建对话模型知识。当然,模型仅是一个简单模型,旨在帮助理解原理,不具备很好效果能力。


一、数据加工

文本数据最终都是转为对应字典索引代表其文本内容,输入模型加工,实现nlp任务,对话模型也不列外。因此,我们需要构建一个字典映射(可参考:点击这里)与文本数据,并按照字典映射转换为对应索引id,其代码如下:

# 定义一个简单的对话数据集
data = [("hi", "hello"),("how are you?", "I'm fine, thank you."),("what's your name?", "I'm a chatbot.")
]# 构建词汇表
vocab = list(set(" ".join([x[0] + " " + x[1] for x in data])))
vocab.append("<SOS>")
vocab.append("<EOS>")word_to_idx = {word: i for i, word in enumerate(vocab)}
idx_to_word = {i: word for i, word in enumerate(vocab)}# 将对话数据集转换为索引序列
def to_idx_seq(sentence):return [word_to_idx[word] for word in sentence]data_x = [to_idx_seq(x[0]) for x in data]
data_y = [to_idx_seq(x[1]) for x in data]data_y =[[word_to_idx["<SOS>"]]+list(x)+[word_to_idx["<EOS>"]] for x in data_y]

其字典内容如下:
在这里插入图片描述

二、模型搭建

这里,也是最重要内容,如何搭建对话模型,我是使用transformer结构搭建(之前模型使用LSTM模型搭建),创建一个简单的对话生成模型,也使用一个基于全连接层的神经网络来实现这个模型,其代码如下:


# 定义一个简单的Transformer生成式对话模型
class ChatbotTransformer(nn.Module):def __init__(self, input_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers):super(ChatbotTransformer, self).__init__()self.input_dim = input_dimself.output_dim = output_dimself.embedding = nn.Embedding(len(vocab), input_dim)self.transformer = nn.Transformer(d_model=input_dim, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)self.fc = nn.Linear(input_dim, output_dim)def forward(self, src, tgt):src = self.embedding(src).permute(1, 0, 2)  # 调整输入张量维度tgt = self.embedding(tgt).permute(1, 0, 2)  # 调整输入张量维度output = self.transformer(src, tgt)output = self.fc(output)return output

从代码上看,我们需要输入src为前面提问数据,而答案生成输入,是每个tgt字输入,按照顺序输出预测。

三、模型训练

1、构建模型

我大概试了一下,使用更多层对于简单数据反而效果不佳,我的数据又比较简单,我构建了较少的层来预测模型。其模型构建代码如下:

# 创建模型实例
# model = ChatbotTransformer(input_dim=256, output_dim=len(vocab), nhead=8, num_encoder_layers=6, num_decoder_layers=6)
model = ChatbotTransformer(input_dim=16, output_dim=len(vocab), nhead=8, num_encoder_layers=1, num_decoder_layers=1)

2、优化器与损失函数定义

优化器定义我将不考虑介绍,只说下文本预测是交叉熵方式,实际文本基本都采用该方法作为loss计算。其代码如下:

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

3、模型训练

接下来,我们解读如何训练对话模型,我们获得对应输入数据与生成预测数据,我们开始训练模型,其代码如下:

# 模型训练
epochs = 800
for epoch in range(epochs):total_loss = 0for i in range(len(data_x)):optimizer.zero_grad()input_seq = torch.tensor(data_x[i])target_seq = torch.tensor(data_y[i])for j in range(len(target_seq)-1):if random.random() < 0.5:j = random.randint(0, len(target_seq)-2)output = model(input_seq.unsqueeze(0), target_seq[:j+1].unsqueeze(0))  # 使用目标序列的前n-1个词预测后n-1个词loss = criterion(output.view(-1, len(vocab)), target_seq[1:j+2].view(-1))  # 计算损失,需要将输出形状转换成二维loss.backward()optimizer.step()total_loss += loss.item()if (epoch+1) % 10 == 0:print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, total_loss / len(data_x)))

假设input_seq =tensor([17, 6, 0, 15, 1, 8, 9, 15, 7, 6, 11, 4]),target_seq=tensor([20, 12, 18, 13, 15, 5, 16, 10, 9, 2, 15, 3, 17, 1, 10, 19, 15, 7, 6, 11, 14, 21]),vocab为字典映射,共22个映射。我将试图模型连续2轮解释一下,训练时候相关变化。
第二轮输出结果:
在这里插入图片描述
第三轮输出结果:
在这里插入图片描述
后面以此类推迭代。

很明显,提问每次都是全部输入,而输出则是第一个20开始与输入共同进模型,分别是模型src与tgt,不断重复与预测完成训练。而loss计算都是往后取一个target文本,也发现并不会计算20索引""文本。

四、模型推理

最后,让我们使用训练好的模型进行推理,实际和上面训练讲到方法类似,我们开始""文本开始,不断给出生成对应文本,也就是对话内容。

# 进行推理
def generate_response(input_sentence):model.eval()input_seq = torch.tensor(to_idx_seq(input_sentence))target_seq = torch.tensor([word_to_idx["<SOS>"]])  # 在开始时使用特殊的起始标记with torch.no_grad():for i in range(20):  # 限制生成的句子长度为20个词output = model(input_seq.unsqueeze(0), target_seq.unsqueeze(0))output_token = output.argmax(2)[-1].item()print(idx_to_word[output_token], end=" ")target_seq = torch.cat((target_seq, torch.tensor([output_token])), dim=0)res = [idx_to_word[int(k)] for k in target_seq]print(res[1:-1])return res[1:-1]# 进行对话生成
generate_response("hi")

其结果如下:
在这里插入图片描述


五、所有Demo源码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
# 定义一个简单的对话数据集
data = [# ("hi", "hello"),("how are you?", "I'm fine, thank you."),# ("what's your name?", "I'm a chatbot.")
]# 构建词汇表
vocab = list(set(" ".join([x[0] + " " + x[1] for x in data])))
vocab.append("<SOS>")
vocab.append("<EOS>")word_to_idx = {word: i for i, word in enumerate(vocab)}
idx_to_word = {i: word for i, word in enumerate(vocab)}# 将对话数据集转换为索引序列
def to_idx_seq(sentence):return [word_to_idx[word] for word in sentence]data_x = [to_idx_seq(x[0]) for x in data]
data_y = [to_idx_seq(x[1]) for x in data]data_y =[[word_to_idx["<SOS>"]]+list(x)+[word_to_idx["<EOS>"]] for x in data_y]# 定义一个简单的Transformer生成式对话模型
class ChatbotTransformer(nn.Module):def __init__(self, input_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers):super(ChatbotTransformer, self).__init__()self.input_dim = input_dimself.output_dim = output_dimself.embedding = nn.Embedding(len(vocab), input_dim)self.transformer = nn.Transformer(d_model=input_dim, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)self.fc = nn.Linear(input_dim, output_dim)def forward(self, src, tgt):src = self.embedding(src).permute(1, 0, 2)  # 调整输入张量维度tgt = self.embedding(tgt).permute(1, 0, 2)  # 调整输入张量维度output = self.transformer(src, tgt)output = self.fc(output)return output# 创建模型实例
# model = ChatbotTransformer(input_dim=256, output_dim=len(vocab), nhead=8, num_encoder_layers=6, num_decoder_layers=6)
model = ChatbotTransformer(input_dim=16, output_dim=len(vocab), nhead=8, num_encoder_layers=1, num_decoder_layers=1)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 模型训练
epochs = 800
for epoch in range(epochs):total_loss = 0for i in range(len(data_x)):optimizer.zero_grad()input_seq = torch.tensor(data_x[i])target_seq = torch.tensor(data_y[i])for j in range(len(target_seq)-1):if random.random() < 0.5:j = random.randint(0, len(target_seq)-2)output = model(input_seq.unsqueeze(0), target_seq[:j+1].unsqueeze(0))  # 使用目标序列的前n-1个词预测后n-1个词loss = criterion(output.view(-1, len(vocab)), target_seq[1:j+2].view(-1))  # 计算损失,需要将输出形状转换成二维loss.backward()optimizer.step()total_loss += loss.item()if (epoch+1) % 10 == 0:print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, total_loss / len(data_x)))# 进行推理
def generate_response(input_sentence):model.eval()input_seq = torch.tensor(to_idx_seq(input_sentence))target_seq = torch.tensor([word_to_idx["<SOS>"]])  # 在开始时使用特殊的起始标记with torch.no_grad():for i in range(20):  # 限制生成的句子长度为20个词output = model(input_seq.unsqueeze(0), target_seq.unsqueeze(0))output_token = output.argmax(2)[-1].item()print(idx_to_word[output_token], end=" ")target_seq = torch.cat((target_seq, torch.tensor([output_token])), dim=0)res = [idx_to_word[int(k)] for k in target_seq]print(res[1:-1])return res[1:-1]# 进行对话生成
generate_response("how are you?")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/678523.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTTP 超文本传送协议

1 超文本传送协议 HTTP HTTP 是面向事务的 (transaction-oriented) 应用层协议。 使用 TCP 连接进行可靠的传送。 定义了浏览器与万维网服务器通信的格式和规则。 是万维网上能够可靠地交换文件&#xff08;包括文本、声音、图像等各种多媒体文件&#xff09;的重要基础。 H…

HarmonyOS 开发学习笔记

HarmonyOS 开发学习笔记 一、开发准备1.1、了解ArkTs语言1.2、TypeScript语法1.2.1、变量声明1.2.2、条件控制1.2.3、函数1.2.4、类和接口1.2.5、模块开发 1.3、快速入门 二、ArkUI组件2.1、Image组件2.2、Text文本显示组件2.3、TextInput文本输入框组件2.4、Button按钮组件2.5…

开源个人订阅跟踪器Wallos

本文软件由网友 P家单推人 推荐&#xff1b; 什么 Wallos &#xff1f; Wallos 是一款功能强大、开源且可自我托管的网络应用程序&#xff0c;旨在让您轻松管理财务。告别复杂的电子表格和昂贵的财务软件–Wallos简化了跟踪费用的过程&#xff0c;帮助您更好地控制财务生活。 软…

LeetCode---383周赛

题目列表 3028. 边界上的蚂蚁 3029. 将单词恢复初始状态所需的最短时间 I 3030. 找出网格的区域平均强度 3031. 将单词恢复初始状态所需的最短时间 II 一、边界上的蚂蚁 这题没什么好说的&#xff0c;模拟就行&#xff0c;本质就是看前缀和有几个为0。 代码如下 class S…

Spring Cloud Hystrix 参数配置、简单使用、DashBoard

Spring Cloud Hystrix 文章目录 Spring Cloud Hystrix一、Hystrix 服务降级二、Hystrix使用示例三、OpenFeign Hystrix四、Hystrix参数HystrixCommand.Setter核心参数Command PropertiesFallback降级配置Circuit Breaker 熔断器配置Metrix 健康统计配置Request Context 相关参数…

【java】12:封装

面向对象编程三大特征 1.基本介绍 面向对象编程有三大特征&#xff1a;封装、继承和多态。 2.封装介绍 封装(encapsulation)就是把抽象出的数据[属性]和对数据的操作[方法]封装在一起&#xff0c;数据被保护在内部&#xff0c;程序的其它部分只有通过被授权的操作[方法]&am…

开局一个破碗的故事例子

在一个寒冷的冬日&#xff0c;一个瘦弱的小姑娘拿着一个破碗&#xff0c;孤独地走在被白雪覆盖的街道上。她的名字叫小梅&#xff0c;她的父母早逝&#xff0c;留下她一个人在这个世界上艰难地生活。 小梅的破碗里只有几个铜板&#xff0c;那是她前一天沿街乞讨所得&#xff0c…

林浩然与杨凌云的Java世界奇遇记:垃圾回收大冒险

林浩然与杨凌云的Java世界奇遇记&#xff1a;垃圾回收大冒险 The Java Adventure Chronicles of Lin Haoran and Yang Lingyun: Garbage Collection Odyssey 在一个充满0和1代码森林的世界里&#xff0c;住着两位勇敢的程序员侠侣——林浩然和杨凌云。林浩然是个身怀Java绝技的…

sheng的学习笔记-docker部署数据库oracle,mysql

部署目录&#xff1a;sheng的学习笔记-部署-目录-CSDN博客 docker基础知识可参考 sheng的学习笔记-docker部署&#xff0c;原理图&#xff0c;命令&#xff0c;用idea设置docker docker安装数据库 mac版本 安装oracle 下载oracle镜像 打开终端&#xff0c;输入 docker s…

Python网络通信

目录 基本的网络知识 TCP/IP IP地址 端口 HTTP/HTTPS HTTP HTTPS 搭建自己的Web服务器 urllib.request模块 发送GET请求 发送POST请求 JSON数据 JSON文档的结构 JSON数据的解码 下载图片示例 返回所有备忘录信息 此文章讲解如何通过Python访问互联网上的资源&a…

《CSS 简易速速上手小册》第7章:CSS 预处理器与框架(2024 最新版)

文章目录 7.1 Sass&#xff1a;更高效的 CSS 编写7.1.1 基础知识7.1.2 重点案例&#xff1a;主题颜色和字体管理7.1.3 拓展案例 1&#xff1a;响应式辅助类7.1.4 拓展案例 2&#xff1a;深色模式支持 7.2 Bootstrap&#xff1a;快速原型设计和开发7.2.1 基础知识7.2.2 重点案例…

ueransim关于ue侧nas层相关代码解读

一.在文件UERANSIM\UERANSIM-3.2.6\src\ue\nas中enc.cpp中完成了NAS&#xff08;非接入层&#xff09;信令的加密和解密是通过NAS_ENC模块实现的。NAS_ENC模块负责将NAS信令消息进行加密&#xff0c;以确保其传输过程中的安全性。 具体来说&#xff0c;当UE发送NAS信令消息时&…

零基础如何学习编曲,究竟需要准备什么?

初学者常常弄不清楚作曲和编曲的区别&#xff0c;在这里我为大家讲解一下两者的差别。狭义上来说&#xff1a;作曲可以理解为写旋律&#xff0c;而编曲就是写伴奏。那么接下来让我们一起看看零基础编曲&#xff0c;究竟需要准备些什么? 一、理论 众所周知,乐理是最基础的理论…

【JMX】JAVA监控的基石

目录 1.概述 2.MBean 2.1.Standard MBean 2.2.Dynamic MBean 2.3.Model Bean 2.4.Dynamic MBean和Model Bean的区别 2.5.MXBean 2.6.Open Bean 3.控制台 1.概述 什么是JMX&#xff0c;首先来看一段对话&#xff1a; Java Management Extensions&#xff08;JMX&#…

探索ChatGPT-4:智能会话的未来已来

深入了解ChatGPT-4&#xff1a;前沿AI的强大功能 ChatGPT-4是最先进的语言模型之一&#xff0c;由OpenAI开发&#xff0c;它在自然语言理解和生成方面的能力已经达到了新的高度。如今&#xff0c;ChatGPT-4已经被广泛应用于多个领域&#xff0c;从教育到企业&#xff0c;再到技…

Java学习第十一节之命令行传参和断更原因

package method;public class Demo03 {public static void main(String[] args) {//args.length数组长度for (int i 0; i < args.length; i) {System.out.println("args[" i "]:"args[i]);}}}为什么没更新了&#xff1f; 家里有长辈生病了不好在医院照…

【算法与数据结构】496、503、LeetCode下一个更大元素I II

文章目录 一、496、下一个更大元素 I二、503、下一个更大元素II三、完整代码 所有的LeetCode题解索引&#xff0c;可以看这篇文章——【算法和数据结构】LeetCode题解。 一、496、下一个更大元素 I 思路分析&#xff1a;本题思路和【算法与数据结构】739、LeetCode每日温度类似…

嵌入式硬件越老越吃香,确实没错!

不知不觉已经从事硬件设计7年多了&#xff0c;7年对于一个从事硬件设计来说能有几个完整的生涯。2016年毕业&#xff0c;2023年即将结束&#xff0c;我已经在汽车这行业“摸爬滚打”了7年的时光。 回顾这7年&#xff0c;自己真的成长了很多很多。有项目失败整改的经验收获&…

Linux进程间通信(IPC)

要想进程间通信&#xff0c;数据交换&#xff0c;必须通过内核&#xff1b; 一个进程将数据写到内核&#xff0c;然后另一个进程从内核读走数据。 IPC&#xff1a;进程间通信&#xff08;interprocess communication) 通信方式&#xff1a; 管道信号共享映射区&#xff08;…

假期day7

设计qq界面 代码 ui->lab1->setPixmap(QPixmap(":/pictrue/denglu.webp"));ui->lab1->setScaledContents(true);ui->lab2->setPixmap(QPixmap(":/pictrue/passwd.jpg"));ui->lab2->setScaledContents(true);ui->lab3->setP…