【SVG 生成系列论文(九)】如何通过文本生成 svg logo?IconShop 模型推理代码详解

  • SVG 生成系列论文(一) 和 SVG 生成系列论文(二) 分别介绍了 StarVector 的大致背景和详细的模型细节。
  • SVG 生成系列论文(三)和 SVG 生成系列论文(四)则分别介绍实验、数据集和数据增强细节。
  • SVG 生成系列论文(五)介绍了从光栅图像(如 PNG、JPG 格式)转换为矢量图形(如 SVG、EPS 格式)的关键技术-像素预过滤(pixel prefiltering), Diffvg 这篇论文也是 SVG 生成与编辑领域中 “基于优化”方法的开创性研究。
  • SVG 生成系列论文(六) 和 SVG 生成系列论文(七) 简要介绍了 IconShop 的背景、应用和部分细节。
  • SVG 生成系列论文(八)则介绍了模型架构和具体的训练技巧。

本文将详细拆解 IconShop(论文原文以及代码🔗)的模型结构和对应开源代码。上篇有提到过模型架构如下所示,本篇则从代码的逻辑进行解释,主要是 /path/to/IconShop/model/decoder.py 中的 sample 以及 forward 两个函数。

模型架构

架构整体分为 4 个部分:SVG 图标嵌入(SVG Icon Embedding),文本嵌入(Text Embedding),输入准备和输出生成。

在这里插入图片描述

sample 函数

在原项目中,调用模型推理是从 sample_pixels = sketch_decoder.sample(n_samples=BS, text=tokenized_text) 中进行的,因此重点解读这部分代码。

  1. 输入部分:
    self: 指代调用该方法的对象,意味着可以访问类的属性和方法。
    n_samples: 要生成的样本数量,这里是 batchsize。
    text: 输入的文本数据,用于条件生成。形状为[batch_size, text_len],其中text_len是文本序列的长度,默认是 50。
    pixel_seq: 已有的像素序列,初始化为None,形状为[batch_size, max_len]。
    xy_seq: 已有的坐标序列,初始化为None,形状为[batch_size, max_len, 2]。
def sample(self, n_samples, text, pixel_seq=None, xy_seq=None):""" sample from distribution (top-k, top-p) """pix_samples = []xy_samples = []# latent_ext_samples = []top_k = 0top_p = 0.5
  1. 初始化:
    定义了空列表pix_samplesxy_samples 以存储采样的像素和坐标序列。
    同时定义了 top_k 和 top_p 作为采样策略的参数,默认使用“Top-P”策略(top_p=0.5),不使用“Top-K”。了解更多,可参见大模型推理常见采样策略(Top-k, Top-p)
    # Mapping from pixel index to xy coordiantepixel2xy = {}x=np.linspace(0, BBOX-1, BBOX)y=np.linspace(0, BBOX-1, BBOX)xx,yy=np.meshgrid(x,y)xy_grid = (np.array((xx.ravel(), yy.ravel())).T).astype(int)for pixel, xy in enumerate(xy_grid):pixel2xy[pixel] = xy+COORD_PAD+SVG_END
  1. 像素到坐标映射:
    创建一个从像素索引到坐标(x, y)的映射。这里假设有一个固定大小的边界框(BBOX=200),通过网格生成所有可能的坐标组合,并将这些坐标与像素索引关联起来。其中,COORD_PAD = NUM_END_TOKEN + NUM_MASK_AND_EOM(NUM_END_TOKEN = 3,NUM_MASK_AND_EOM = 2),SVG_END = 1

  2. 处理文本输入: 确保文本序列的长度不超过预设的最大长度self.text_len

  3. 循环采样:

  • for k in range(text.shape[1] + pixlen, self.total_seq_len): 是开始对文本之后的 SVG 序列进行采样(预测)。
  • 初始化或更新像素pixel_seq和坐标序列xy_seq的长度。
  • 使用模型的forward方法计算当前状态下下一个token的概率分布。
  • 应用“Top-K”和“Top-P”过滤策略(top_k_top_p_filtering)到概率分布上,仅保留最可能的token。
  • 从过滤后的分布中多分类采样得到下一个像素值。
  • 将采样的像素索引转换为实际的坐标。其中 PIX_PAD = NUM_END_TOKEN + NUM_MASK_AND_EOM, SVG_END = 1
  1. 序列更新和早停条件:
  • 将生成的像素和坐标序列加入现有序列中。
  • 检查生成是否完成(例如生成到结束符),如果有完成的样本则记录并移除。
  • 如果所有样本都生成完毕,提前停止循环。
  1. 返回结果:
    返回生成的坐标序列xy_samples
    # Sample per tokentext = text[:, :self.text_len]pixlen = 0 if pixel_seq is None else pixel_seq.shape[1]for k in range(text.shape[1] + pixlen, self.total_seq_len):if k == text.shape[1]:pixel_seq = [None] * n_samplesxy_seq = [None, None] * n_samples# pass through modelwith torch.no_grad():p_pred = self.forward(pixel_seq, xy_seq, None, text)p_logits = p_pred[:, -1, :]next_pixels = []# Top-p sampling of next pixelfor logit in p_logits: filtered_logits = top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p)next_pixel = torch.multinomial(F.softmax(filtered_logits, dim=-1), 1)next_pixel -= self.num_text_tokennext_pixels.append(next_pixel.item())# Convert pixel index to xy coordinatenext_xys = []for pixel in next_pixels:if pixel >= PIX_PAD+SVG_END:xy = pixel2xy[pixel-PIX_PAD-SVG_END]else:xy = np.array([pixel, pixel]).astype(int)next_xys.append(xy)next_xys = np.vstack(next_xys)  # [BS, 2]next_pixels = np.vstack(next_pixels)  # [BS, 1]# Add next tokensnextp_seq = torch.LongTensor(next_pixels).view(len(next_pixels), 1).cuda()nextxy_seq = torch.LongTensor(next_xys).unsqueeze(1).cuda()if pixel_seq[0] is None:pixel_seq = nextp_seqxy_seq = nextxy_seqelse:pixel_seq = torch.cat([pixel_seq, nextp_seq], 1)xy_seq = torch.cat([xy_seq, nextxy_seq], 1)# Early stoppingdone_idx = np.where(next_pixels==0)[0]if len(done_idx) > 0:done_pixs = pixel_seq[done_idx] done_xys = xy_seq[done_idx]# done_ext = latent_ext[done_idx]# for pix, xy, ext in zip(done_pixs, done_xys, done_ext):for pix, xy in zip(done_pixs, done_xys):pix = pix.detach().cpu().numpy()xy = xy.detach().cpu().numpy()pix_samples.append(pix)xy_samples.append(xy)# latent_ext_samples.append(ext.unsqueeze(0))left_idx = np.where(next_pixels!=0)[0]if len(left_idx) == 0:break # no more jobs to doelse:pixel_seq = pixel_seq[left_idx]xy_seq = xy_seq[left_idx]text = text[left_idx]# return pix_samples, latent_ext_samplesreturn xy_samples

forward

这个函数的主要作用是前向传播(或称推理),根据输入的像素序列、坐标序列、掩码和文本,计算模型的输出。具体过程如下:

  1. 准备输入:
  • 如果需要计算损失return_loss,则去掉最后一个时间步的数据。
  • 计算上下文序列的长度c_seqlen,包括文本和像素序列的长度。
def forward(self, pix, xy, mask, text, return_loss=False):'''pix.shape  [batch_size, max_len]xy.shape   [batch_size, max_len, 2]mask.shape [batch_size, max_len]text.shape [batch_size, text_len]'''pixel_v = pix[:, :-1] if return_loss else pixxy_v = xy[:, :-1] if return_loss else xypixel_mask = mask[:, :-1] if return_loss else maskc_bs, c_seqlen, device = text.shape[0], text.shape[1], text.deviceif pixel_v[0] is not None:c_seqlen += pixel_v.shape[1]  
  1. 嵌入计算:

对输入的文本进行嵌入计算text_emb
如果有像素和坐标序列,则分别对它们进行嵌入计算(coord_embed_x, pixel_embed),并与文本嵌入拼接。

# Context embedding valuescontext_embedding = torch.zeros((1, c_bs, self.embed_dim)).to(device) # [1, bs, dim]# tokens.shape [batch_size, text_len, emb_dim]tokens = self.text_emb(text)# Data input embeddingif pixel_v[0] is not None:# coord_embed.shape [batch_size, max_len-1, emb_dim]# pixel_embed.shape [batch_size, max_len-1, emb_dim] coord_embed = self.coord_embed_x(xy_v[...,0]) + self.coord_embed_y(xy_v[...,1]) # [bs, vlen, dim]pixel_embed = self.pixel_embed(pixel_v)embed_inputs = pixel_embed + coord_embed# tokens.shape [batch_size, text_len+max_len-1, emb_dim]tokens = torch.cat((tokens, embed_inputs), dim=1)
  1. 位置编码和掩码计算:

计算位置编码,并与嵌入后的输入序列拼接。
生成用于Transformer的掩码nopeak_mask,确保模型不会看到未来的时间步。
如果有像素掩码pixel_mask,则将其扩展到与嵌入序列匹配的形状。

    # nopeak_mask.shape [c_seqlen+1, c_seqlen+1]nopeak_mask = torch.nn.Transformer.generate_square_subsequent_mask(c_seqlen+1).to(device)  # masked with -infif pixel_mask is not None:# pixel_mask.shape [batch_size, text_len+max_len]pixel_mask = torch.cat([(torch.zeros([c_bs, context_embedding.shape[0]+self.text_len])==1).to(device), pixel_mask], axis=1)  
  1. Transformer解码:

将处理后的输入序列送入Transformer解码器,得到输出序列。

    decoder_out = self.decoder(tgt=decoder_inputs, memory=memory_encode, memory_key_padding_mask=None,tgt_mask=nopeak_mask, tgt_key_padding_mask=pixel_mask)
  1. 计算logits和损失:

通过全连接层logit_fc计算输出的logits
如果需要计算损失,则分别计算文本和像素的损失cross_entropy,并返回总损失。
如果不需要计算损失,则直接返回logits。

    # Logits fclogits = self.logit_fc(decoder_out)  # [seqlen, bs, dim] logits = logits.transpose(1,0)  # [bs, textlen+seqlen, total_token] logits_mask = self.logits_mask[:, :c_seqlen+1]max_neg_value = -torch.finfo(logits.dtype).maxlogits.masked_fill_(logits_mask, max_neg_value)if return_loss:logits = rearrange(logits, 'b n c -> b c n')text_logits = logits[:, :, :self.text_len]pix_logits = logits[:, :, self.text_len:]pix_logits = rearrange(pix_logits, 'b c n -> (b n) c')pix_mask = ~mask.reshape(-1)pix_target = pix.reshape(-1) + self.num_text_tokentext_loss = F.cross_entropy(text_logits, text)pix_loss = F.cross_entropy(pix_logits[pix_mask], pix_target[pix_mask], ignore_index=MASK+self.num_text_token)loss = (text_loss + self.loss_img_weight * pix_loss) / (self.loss_img_weight + 1)return loss, pix_loss, text_losselse:return logits

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/21780.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024码蹄杯初赛 拔河(非二分解法)

AK选手前来补充一发邪典(水数据)写法 题面: 简单来说就是给你一个序列,让你选择一段连续区间,使得这个区间平均值最大,同时区间长度大于等于F。 很显然对于区间求和直接用前缀和优化到O(1),但是…

jar包部署到服务器,修改jar包配置文件

jar包部署到服务器 打包项目1.jar包分离2.整体打包配置文件配置文件分离整体打包修改配置文件 打包项目 maven项目打包有两种&#xff0c;一是将自己的项目和依赖包分离&#xff0c;二是打包成一个jar包 1.jar包分离 需要在pom文件中引入依赖 <build><finalName&…

Docker基础篇之将本地镜像发布到私有库

文章目录 1. Docker Registry简介2. 将本地镜像推送到私有库 1. Docker Registry简介 Docker Registry是官方提供的工具&#xff0c;可以用于构建私有镜像仓库。 2. 将本地镜像推送到私有库 下载Docker Registry docker pull registry现在我们可以从镜像中看到下载的Regist…

【轻松搞定形象照】助你打造编程等级考试、竞赛专属二寸靓照,报名无忧,展现最佳风采!

更多资源请关注纽扣编程微信公众号 ​ 在数字化时代&#xff0c;拍照似乎变得轻而易举&#xff0c;但当我们需要一张特定规格的一寸照片时&#xff0c;事情就变得复杂起来。随着编程等级考试和各类信息学竞赛的日益临近&#xff0c;不少考生都为了一张符合要求的一寸照片而忙…

抽屉式备忘录(共25041字)

Sing Me to Sleep <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>与妖为邻的备忘录</title&g…

pytorch学习day4

一、卷积层&#xff08;Convolution Layers&#xff09; 卷积层是卷积神经网络&#xff08;CNN&#xff09;中的核心组件&#xff0c;用于提取输入数据的特征。它们通过应用卷积运算来捕捉局部的空间特征&#xff0c;非常适合处理图像和视频等具有空间结构的数据。 1. 卷积层的…

创建模拟器

修改模拟器默认路径 由于模拟器文件比较大&#xff0c;默认路径在C:\Users\用户名.android\avd&#xff0c;可修改默认路径 创建修改后的路径文件 D:\A-software\Android\AVD添加系统变量ANDROID_SDK_HOME&#xff1a;D:\A-software\Android\AVD重启Android Studio 创建模拟…

【机器学习】机器学习与大模型在人工智能领域的融合应用与性能优化新探索

文章目录 引言机器学习与大模型的基本概念机器学习概述监督学习无监督学习强化学习 大模型概述GPT-3BERTResNetTransformer 机器学习与大模型的融合应用自然语言处理文本生成文本分类机器翻译 图像识别自动驾驶医学影像分析 语音识别智能助手语音转文字 大模型性能优化的新探索…

【android 9】【input】【7.发送按键事件1——InputReader线程】

系列文章目录 本人系列文章-CSDN博客 目录 系列文章目录 1.简介 1.1发送流程介绍 1.2 时序图 2.普通按键消息发送部分源码分析 2.1 设备的监听 2.2 inputreader线程阻塞等待事件发生 2.3 按键事件的产生 2.4 EventHub::getEvents 2.5 InputReader::loopOnce 2.6 process…

丢失的数字 ---- 位运算

题目链接 题目: 分析: 解法一: 哈希表解法二: 高斯求和解法三:位运算 异或运算根据运算的性质, 相同的两个a异或 0 以示例一为例: 数组中有0,1,3, 缺失的数字是2, 那么只要我们将数组与0,1,2,3 异或, 就会得到2 代码: class Solution {public int missingNumber(int[] num…

【Redis】 Java操作客户端命令——集合操作与有序集合操作

文章目录 &#x1f343;前言&#x1f333;集合操作&#x1f6a9;sadd 和 smembers&#x1f6a9;srem 和 sismember&#x1f6a9;scard&#x1f6a9;sinter&#x1f6a9;sunion&#x1f6a9;sdiff &#x1f332;有序集合操作&#x1f6a9;zadd 和 zrange&#x1f6a9;zrem 和 zc…

拖拽tableView

拖拽tableView&#xff0c;随手指移动&#xff0c;插入。demo地址github

JAVAEE之网络初识_协议、TCP/IP网络模型、封装、分用

前言 在这一节我们简单介绍一下网络的发展 一、通信网络基础 网络互连的目的是进行网络通信&#xff0c;也即是网络数据传输&#xff0c;更具体一点&#xff0c;是网络主机中的不同进程间&#xff0c;基于网络传输数据。那么&#xff0c;在组建的网络中&#xff0c;如何判断到…

迪丽热巴与大姐的璀璨友情

迪丽热巴与“大姐”的璀璨友情&#xff1a;星光熠熠&#xff0c;友谊长存在娱乐圈的繁华舞台上&#xff0c;有两位耀眼的女星&#xff0c;她们如同夜空中亮的两颗星&#xff0c;交相辉映&#xff0c;共同谱写着一段段动人的佳话。她们&#xff0c;一个是被亲切称为“迪迪”的迪…

HarmonyOS-9(stage模式)

配置文件 {"module": {"requestPermissions": [ //权限{"name": "ohos.permission.EXECUTE_INSIGHT_INTENT"}],"name": "entry", //模块的名称"type": "entry", //模块类型 :ability类型和…

AWR设置工程仿真频率、原理图仿真频率、默认单位

AWR设置工程仿真频率、原理图仿真频率、默认单位 生活不易&#xff0c;喵喵叹气。马上就要上班了&#xff0c;公司的ADS的版权紧缺&#xff0c;主要用的软件都是NI 的AWR&#xff0c;只能趁着现在没事做先学习一下子了&#xff0c;希望不要裁我。 最近稍微学习了一下AWR这个软…

UMG绝对坐标与局部空间

在 Unreal Engine 的 UMG&#xff08;Unreal Motion Graphics&#xff09;中&#xff0c;“绝对坐标”和“局部空间”是两个常见的概念&#xff0c;主要用于描述 UI 元素的位置和大小。 概念与区别 绝对坐标&#xff08;Absolute Coordinates&#xff09;&#xff1a;这是指相…

list~模拟实现

目录 list的介绍及使用 list的底层结构 节点类的实现 list的实现 构造函数 拷贝构造 方法一&#xff1a;方法二&#xff1a; 析构函数 赋值重载 insert / erase push_/pop_(尾插/尾删/头插/头删) begin和end&#xff08;在已建立迭代器的基础上&#xff09; 迭代…

kafka命令--简单粗暴有效

zookeeper bin目录下执行 启动&#xff1a;./zkServer.sh start 停止&#xff1a;./zkServer.sh stop 重启&#xff1a;./zkServer.sh restart 状态&#xff1a;./zkServer.sh status kafka bin目录下执行 启动&#xff1a;./kafka-server-start.sh -daemon …/config/server.…

直播预告|手把手教你玩转 Milvus Lite !

Milvus Lite&#xff08;https://milvus.io/docs/milvus_lite.md&#xff09;是一个轻量级向量数据库&#xff0c;支持本地运行&#xff0c;可用于搭建 Python 应用&#xff0c;由 Zilliz 基于全球最受欢迎的开源向量数据库 Milvus&#xff08;https://milvus.io/intro&#xf…