[oneAPI] 使用字符级 RNN 生成名称

[oneAPI] 使用字符级 RNN 生成名称

  • oneAPI特殊写法
  • 使用字符级 RNN 生成名称
    • Intel® Optimization for PyTorch
    • 数据下载
    • 加载数据并对数据进行处理
    • 创建网络
    • 训练过程
      • 准备训练
      • 训练网络
    • 结果
  • 参考资料

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

oneAPI特殊写法

import intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')rnn = RNN(n_letters, 128, n_letters)
optim = torch.optim.SGD(rnn.parameters(), lr=0.01)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
rnn, optim = ipex.optimize(rnn, optimizer=optim)criterion = nn.NLLLoss()

使用字符级 RNN 生成名称

为了深入探索语言模型在分类和生成方面的卓越能力,我们特意设计了一个独特的任务。此任务的独特之处在于,它旨在综合学习多种语言的词义特征,以确保生成的内容与各种语言的词组相关性一致。

在任务的具体描述中,我们提供了一个多语言数据集,这个数据集包含多种语言的文本。通过这个数据集,我们的目标是使模型能够在生成名称时融合不同语言的特征。具体来说,我们会提供一个词的开头作为提示,然后模型将能够根据这个开头生成对应语言的名称,从而将不同语言的词意和语法特征进行完美融合。

通过这一任务,我们旨在实现一个在多语言环境中具有卓越生成和分类能力的语言模型。通过学习并融合不同语言的词义和语法特征,我们让使模型具备更广泛的应用潜力,能够在不同语境下生成准确、符合语法规则的名称。

> python sample.py Russian RUS
Rovakov
Uantov
Shavakov> python sample.py German GER
Gerren
Ereng
Rosher> python sample.py Spanish SPA
Salla
Parer
Allan> python sample.py Chinese CHI
Chan
Hang
Iun

Intel® Optimization for PyTorch

在本次实验中,我们利用PyTorch和Intel® Optimization for PyTorch的强大功能,对PyTorch进行了精心的优化和扩展。这些优化举措极大地增强了PyTorch在各种任务中的性能,尤其是在英特尔硬件上的表现更加突出。通过这些优化策略,我们的模型在训练和推断过程中变得更加敏捷和高效,显著地减少了计算时间,提高了整体效能。我们通过深度融合硬件和软件的精巧设计,成功地释放了硬件潜力,使得模型的训练和应用变得更加快速和高效。这一系列优化举措为人工智能应用开辟了新的前景,带来了全新的可能性。
在这里插入图片描述

数据下载

从这里下载数据 并将其解压到当前目录。

加载数据并对数据进行处理

简而言之,有一堆data/names/[Language].txt每行都有一个名称的纯文本文件。我们将行分割成一个数组,将 Unicode 转换为 ASCII,最后得到一个字典。{language: [names …]}

from io import open
import glob
import os
import unicodedata
import stringall_letters = string.ascii_letters + " .,;'-"
n_letters = len(all_letters) + 1  # Plus EOS markerdef findFiles(path): return glob.glob(path)# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn'and c in all_letters)# Read a file and split into lines
def readLines(filename):with open(filename, encoding='utf-8') as some_file:return [unicodeToAscii(line.strip()) for line in some_file]# Build the category_lines dictionary, a list of lines per category
category_lines = {}
all_categories = []
for filename in findFiles('data/names/*.txt'):category = os.path.splitext(os.path.basename(filename))[0]all_categories.append(category)lines = readLines(filename)category_lines[category] = linesn_categories = len(all_categories)if n_categories == 0:raise RuntimeError('Data not found. Make sure that you downloaded data ''from https://download.pytorch.org/tutorial/data.zip and extract it to ''the current directory.')print('# categories:', n_categories, all_categories)
print(unicodeToAscii("O'Néàl"))

Output:

# categories: 18 ['Arabic', 'Chinese', 'Czech', 'Dutch', 'English', 'French', 'German', 'Greek', 'Irish', 'Italian', 'Japanese', 'Korean', 'Polish', 'Portuguese', 'Russian', 'Scottish', 'Spanish', 'Vietnamese']
O'Neal

创建网络

序列到序列网络,或 seq2seq 网络,或编码器解码器网络,是由两个称为编码器和解码器的 RNN 组成的模型。编码器读取输入序列并输出单个向量,解码器读取该向量以产生输出序列。

我添加了第二个线性层o2o(在组合隐藏层和输出层之后)以赋予其更多的功能。还有一个 dropout 层,它以给定的概率(此处为 0.1)随机将部分输入归零,通常用于模糊输入以防止过度拟合。在这里,我们在网络末端使用它来故意添加一些混乱并增加采样多样性。
在这里插入图片描述

######################################################################
# Creating the Network
# ====================
import torch
import torch.nn as nnimport intel_extension_for_pytorch as ipexclass RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.i2h = nn.Linear(n_categories + input_size + hidden_size, hidden_size)self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size)self.o2o = nn.Linear(hidden_size + output_size, output_size)self.dropout = nn.Dropout(0.1)self.softmax = nn.LogSoftmax(dim=1)def forward(self, category, input, hidden):input_combined = torch.cat((category, input, hidden), 1)hidden = self.i2h(input_combined)output = self.i2o(input_combined)output_combined = torch.cat((hidden, output), 1)output = self.o2o(output_combined)output = self.dropout(output)output = self.softmax(output)return output, hiddendef initHidden(self):return torch.zeros(1, self.hidden_size)

训练过程

准备训练

import random# Random item from a list
def randomChoice(l):return l[random.randint(0, len(l) - 1)]# Get a random category and random line from that category
def randomTrainingPair():category = randomChoice(all_categories)line = randomChoice(category_lines[category])return category, line
# One-hot vector for category
def categoryTensor(category):li = all_categories.index(category)tensor = torch.zeros(1, n_categories)tensor[0][li] = 1return tensor# One-hot matrix of first to last letters (not including EOS) for input
def inputTensor(line):tensor = torch.zeros(len(line), 1, n_letters)for li in range(len(line)):letter = line[li]tensor[li][0][all_letters.find(letter)] = 1return tensor# ``LongTensor`` of second letter to end (EOS) for target
def targetTensor(line):letter_indexes = [all_letters.find(line[li]) for li in range(1, len(line))]letter_indexes.append(n_letters - 1) # EOSreturn torch.LongTensor(letter_indexes)

为了训练过程中的方便,我们将创建一个randomTrainingExample 函数来获取随机(类别、线)对并将它们转换为所需的(类别、输入、目标)张量。

# Make category, input, and target tensors from a random category, line pair
def randomTrainingExample():category, line = randomTrainingPair()category_tensor = categoryTensor(category)input_line_tensor = inputTensor(line)target_line_tensor = targetTensor(line)return category_tensor, input_line_tensor, target_line_tensor

训练网络

criterion = nn.NLLLoss()learning_rate = 0.0005def train(category_tensor, input_line_tensor, target_line_tensor):target_line_tensor.unsqueeze_(-1)hidden = rnn.initHidden()rnn.zero_grad()loss = torch.Tensor([0]) # you can also just simply use ``loss = 0``for i in range(input_line_tensor.size(0)):output, hidden = rnn(category_tensor, input_line_tensor[i], hidden)l = criterion(output, target_line_tensor[i])loss += lloss.backward()for p in rnn.parameters():p.data.add_(p.grad.data, alpha=-learning_rate)return output, loss.item() / input_line_tensor.size(0)

为了跟踪训练需要多长时间,我添加了一个 timeSince(timestamp)返回人类可读字符串的函数:

import time
import mathdef timeSince(since):now = time.time()s = now - sincem = math.floor(s / 60)s -= m * 60return '%dm %ds' % (m, s)

训练就像平常一样 - 多次调用训练并等待几分钟,打印当前时间和每个print_every 示例的损失,并存储每个plot_every示例的平均损失all_losses以供稍后绘制。

rnn = RNN(n_letters, 128, n_letters)n_iters = 100000
print_every = 5000
plot_every = 500
all_losses = []
total_loss = 0 # Reset every ``plot_every`` ``iters``start = time.time()for iter in range(1, n_iters + 1):output, loss = train(*randomTrainingExample())total_loss += lossif iter % print_every == 0:print('%s (%d %d%%) %.4f' % (timeSince(start), iter, iter / n_iters * 100, loss))if iter % plot_every == 0:all_losses.append(total_loss / plot_every)total_loss = 0

结果

在这里插入图片描述

参考资料

https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html#

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/44565.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3:Ubuntu上配置QT交叉编译环境并编译QT程序到Jetson Orin Nano(ARM)

1.Ubuntu Qt 配置交叉编译环境 1.1 ubuntu 20.04安装Qt sudo apt-get install qtcreator 1.2 配置QT GCC配置同上 最后配置Kits 上面设置完成之后 ,设置Kits 中的Device(这是为了能够直接把项目部署到arm设备上) 点击NEXT之后会出现连接被拒绝,不用担…

函数极限与连续性——张宇老师学习笔记

Latex 源代码以及成品PDF(Debug版本):https://wwsk.lanzouc.com/itaDI15vddcb Latex编译Debug版本: $ xelatex 函数极限与连续性.texLatex编译Relese版本(无例题、习题,只有概念定义)&#xf…

小程序 CSS-in-JS 和原子化的另一种选择

小程序 CSS-in-JS 和原子化的另一种选择 小程序 CSS-in-JS 和原子化的另一种选择 介绍快速开始 pandacss 安装和配置 0. 安装和初始化 pandacss1. 配置 postcss2. 检查你的 panda.config.ts3. 修改 package.json 脚本4. 全局 css 注册 pandacss5. 配置的优化与别名 weapp-pand…

Log4Qt日志框架(1)- 引入到QT中

Log4Qt日志框架(1)- 引入到QT中 1 下载源码2 简介3 加入到自己的项目中3.1 使用库文件3.2 引入源文件 4 说明 1 下载源码 github:https://github.com/MEONMedical/Log4Qt 官方(版本较老):https://sourceforge.net/projects/log4q…

希望计算机专业同学都知道这些博主

湖科大教书匠——计算机网络 “宝藏老师”、“干货满满”、“羡慕湖科大”…这些都是网友对这门网课的评价,可见网课质量之高!最全面的面试网站 湖南科技大学《计算机网络》微课堂是该校高军老师精心制作的视频课程,用简单的语言描述复杂的…

【开发】视频云存储EasyCVR视频汇聚平台AI智能算法定制

安防视频集中存储EasyCVR视频汇聚平台,可支持海量视频的轻量化接入与汇聚管理。平台能提供视频存储磁盘阵列、视频监控直播、视频轮播、视频录像、云存储、回放与检索、智能告警、服务器集群、语音对讲、云台控制、电子地图、平台级联、H.265自动转码等功能。为了便…

idea使用docker生成镜像(打包镜像,导入镜像,导出镜像)

1:先下载安装dockerdesktop,安装成功后 2: 在cmd执行docker -v,查看安装的docker版本 C:\Users\dell>docker -v Docker version 24.0.5, build ced09963:需要启动 dockerdesktop应用,才算启动docker&a…

openai多模态大模型:clip详解及实战

引言 CLIP全称Constrastive Language-Image Pre-training,是OpenAI推出的采用对比学习的文本-图像预训练模型。CLIP惊艳之处在于架构非常简洁且效果好到难以置信,在zero-shot文本-图像检索,zero-shot图像分类,文本→图像生成任务…

Windows 11 下使用 VMWare Workstation 17 Pro 新建 CentOS Stream 9 64位 虚拟机 并配置网络

文章目录 为什么选择 CentOS Stream 9下载安装访问连接快照克隆网络配置 为什么选择 CentOS Stream 9 CentOS Linux 8: 已经过了 End-of-life (EOL)CentOS Linux 7: EOL Jun 30th, 2024CentOS Stream 8: EOL May 31st, 2024CentOS Stream 9: End of RHEL9 full support phase …

PySpark-核心编程

2. PySpark——RDD编程入门 文章目录 2. PySpark——RDD编程入门2.1 程序执行入口SparkContext对象2.2 RDD的创建2.2.1 并行化创建2.2.2 获取RDD分区数2.2.3 读取文件创建 2.3 RDD算子2.4 常用Transformation算子2.4.1 map算子2.4.2 flatMap算子2.4.3 reduceByKey算子2.4.4 Wor…

第 7 章 排序算法(2)(冒泡排序)

7.5冒泡排序 7.5.1基本介绍 冒泡排序(Bubble Sorting)的基本思想是:通过对待排序序列从前向后(从下标较小的元素开始),依次比较相邻元素的值,若发现逆序则交换,使值较大的元素逐渐从前移向后部…

工具推荐:Chat2DB一款开源免费的多数据库客户端工具

文章首发地址 Chat2DB是一款开源免费的多数据库客户端工具,适用于Windows和Mac操作系统,可在本地安装使用,也可以部署到服务器端并通过Web页面进行访问。 相较于传统的数据库客户端软件如Navicat、DBeaver,Chat2DB具备了与AIGC…

韩顺平Linux 四十四--

四十四、rwx权限 权限的基本介绍 输入指令 ls -l 显示的内容如下 -rwxrw-r-- 1 root 1213 Feb 2 09:39 abc0-9位说明 第0位确定文件类型(d , - , l , c , b) l 是链接,相当于 windows 的快捷方式- 代表是文件是普通文件d 是目录,相…

Spring Security OAuth2.0认证授权

(单体项目的认证,微服务项目的认证授权) 1.基本概念 1.1 什么是认证 进入移动互联网时代,大家每天都在刷手机,常用的软件有微信、支付宝、头条等,下边拿微信来举例子说明认证相关的基本概念,在…

腾讯云3年轻量应用服务器2核4G5M和2核2G4M详细介绍

腾讯云轻量应用服务器3年配置,目前可以选择三年的轻量配置为2核2G4M和2核4G5M,2核2G4M和2核4G5M带宽,当然也可以选择选一年,第二年xufei会比较gui,腾讯云百科分享腾讯云轻量应用服务器3年配置表: 目录 腾…

机器人制作开源方案 | 送餐机器人

作者:赖志彩、曹柳洲、王恩开、李雪儿、杨玉凯 单位:华北科技学院 指导老师:张伟杰、罗建国 一、作品简介 1. 场景调研 1.1项目目的 近年来,全国多地疫情频发,且其传染性极高,食品接触是传播途径之一。…

均线多头排列和突破前高形态叠加,只为抓取主升浪!股票量化分析工具QTYX-V2.6.9...

功能概述 我们的股票量化系统QTYX在实战中不断迭代升级,针对当前行情,主要聚焦在抓取主升浪的强势股。 单一的指标是用局限的,QTYX的选股框架,是把多指标结合起来一起过滤出强势股。 QTYX支持从市场4000多只票中过滤出强势股的流程…

利用“病毒制造机”实现脚本病毒的制造

一、脚本病毒的概念: 脚本病毒通常是 JavaScript 或 VBScript 等语言编写的恶意代码,一般广告性质,会修改 IE 首页、修改注册表等信息,对用户计算机造成破坏。 通过网页进行的传播的病毒较为典型,脚本病毒还会有如下前…

使用PyMuPDF库的PDF合并和分拆程序

PDF工具应用程序是一个使用wxPython和PyMuPDF库编写的简单工具,用于合并和分拆PDF文件。它提供了一个用户友好的图形界面,允许用户选择源文件夹和目标文件夹,并对PDF文件进行操作。 C:\pythoncode\blog\pdfmergandsplit.py 功能特点 选择文…

unity物体移动至指定位置

物体坐标与物体移动 世界坐标与局部坐标之间的转换物体移动至指定位置需求思路注意 世界坐标与局部坐标之间的转换 在Unity中,物体的坐标分为局部坐标和世界坐标。 局部坐标是相对于物体的父对象的坐标系,而世界坐标是相对于场景的整体坐标系。 使用tr…