利用clip模型实现text2draw

参考论文

实践

有数据增强的代码

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as Fclass GeometrymatchLoss(torch.nn.Module):def __init__(self, device, reference_images_path):super(GeometrymatchLoss, self).__init__()self.device = deviceself.model, clip_preprocess = clip.load('ViT-B/32', self.device, jit=False)self.model.eval()self.preprocess = transforms.Compose([clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisationself.reference_images_feature = self.reference_images_feature(reference_images_path)self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)self.text = clip.tokenize([ "A picture of triangle"]).to(device)self.text_features = self.model.encode_text(self.text)# self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)print("text_features.requires_grad:",self.text_features.requires_grad)self.text_features=self.text_features.detach()self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]# Image Augmentation Transformationself.augment_trans = transforms.Compose([transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),])def forward(self, t,canvas_width, canvas_height,shapes):scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)# 渲染图像render = pydiffvg.RenderFunction.applytarget = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)if target.shape[-1] == 4:target = self.compose_image_with_white_background(target)if t%100==0:pydiffvg.imwrite(target.cpu(), f'learn/log_augs/output_{t}.png', gamma=2.2)# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)img = target.unsqueeze(0)img = img.permute(0, 3, 1, 2)loss = 0NUM_AUGS = 4img_augs = []for n in range(NUM_AUGS):img_augs.append(self.augment_trans(img))im_batch = torch.cat(img_augs)image_features = self.model.encode_image(im_batch)# logit_scale = self.model.logit_scale.exp()for n in range(NUM_AUGS):loss -= torch.cosine_similarity(self.text_features, image_features[n:n + 1], dim=1)return lossdef compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:if img.shape[-1] == 3:  # return img if it is already rgbreturn img# Compose img with white backgroundalpha = img[:, :, 3:4]img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(img.shape[0], img.shape[1], 3, device=self.device)return imgdef read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imagedef reference_images_feature(self, reference_images_path):reference_images_num = len(os.listdir(reference_images_path))reference_images_feature = []for i in range(reference_images_num):i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))if i_reference_image.shape[-1] == 4:i_reference_image = self.compose_image_with_white_background(i_reference_image)# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()reference_images_feature.append(i_reference_image_features)return torch.cat(reference_images_feature)def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:if path_to_png_image.endswith('.webp'):numpy_image = np.array(webp.load_image(path_to_png_image))else:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imageif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)from tqdm import tqdmdef get_bezier_circle(radius: float = 80,segments: int = 4,bias: np.array = np.asarray([100., 100.])):deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)points = torch.stack((torch.cos(deg), torch.sin(deg))).Tpoints = points * radius + torch.tensor(bias).unsqueeze(dim=0)points = points.type(torch.FloatTensor).contiguous()return pointsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")matchLoss = GeometrymatchLoss(device, "reference_images/")# print(matchLoss.reference_images_feature.shape)# img1 = read_png_image_from_path('learn/output.png')canvas_width, canvas_height = 224, 224num_segments=4points1 = get_bezier_circle()path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),is_closed=True)shapes=[path]path.points.requires_grad = Trueprint(id(path.points))print(id(points1))points_vars = []points_vars.append(path.points)points_optim = torch.optim.Adam(points_vars, lr=1)pbar = tqdm(range(100000))print(points1)for t in pbar:# print(t)points_optim.zero_grad()# print("match_loss:", match_loss)match_loss = matchLoss(t,224, 224, shapes)match_loss.backward()# print(path.points.grad)points_optim.step()pbar.set_postfix({"match_loss": f"{match_loss.item()}"})# print(points_vars[0])pass

迭代1000轮次后生成的结果
在这里插入图片描述

没有图像增强

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as Fclass GeometrymatchLoss(torch.nn.Module):def __init__(self, device, reference_images_path):super(GeometrymatchLoss, self).__init__()self.device = deviceself.model, clip_preprocess = clip.load('ViT-B/32', self.device, jit=False)self.model.eval()self.preprocess = transforms.Compose([clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisation# self.preprocess = transforms.Compose([clip_preprocess.transforms[-1]])  # clip normalisationself.reference_images_feature = self.reference_images_feature(reference_images_path)self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)self.text = clip.tokenize([ "A picture of triangle"]).to(device)# self.text = clip.tokenize(["A picture of rectangle", "A picture of triangle", "A picture of circle", "A picture of pentagon","A picture of five-pointed star"]).to(device)self.text_features = self.model.encode_text(self.text)self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)print("text_features.requires_grad:",self.text_features.requires_grad)self.text_features=self.text_features.detach()self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]# Image Augmentation Transformationself.augment_trans = transforms.Compose([transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),])def forward(self, t,canvas_width, canvas_height,shapes):scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)# 渲染图像render = pydiffvg.RenderFunction.applytarget = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)if target.shape[-1] == 4:target = self.compose_image_with_white_background(target)if t%100==0:pydiffvg.imwrite(target.cpu(), f'learn/log/output_{t}.png', gamma=2.2)# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)img = target.unsqueeze(0)img = img.permute(0, 3, 1, 2)loss = 0NUM_AUGS = 4img_augs = []for n in range(NUM_AUGS):img_augs.append(self.augment_trans(img))im_batch = torch.cat(img_augs)image_features = self.model.encode_image(img)self.targets_features: torch.tensor=image_features[0]self.targets_features = self.targets_features / self.targets_features.norm(dim=-1, keepdim=True)loss -= torch.cosine_similarity(self.text_features, self.targets_features, dim=1)return lossdef compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:if img.shape[-1] == 3:  # return img if it is already rgbreturn img# Compose img with white backgroundalpha = img[:, :, 3:4]img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(img.shape[0], img.shape[1], 3, device=self.device)return imgdef read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imagedef reference_images_feature(self, reference_images_path):reference_images_num = len(os.listdir(reference_images_path))reference_images_feature = []for i in range(reference_images_num):i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))if i_reference_image.shape[-1] == 4:i_reference_image = self.compose_image_with_white_background(i_reference_image)# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()reference_images_feature.append(i_reference_image_features)return torch.cat(reference_images_feature)def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:if path_to_png_image.endswith('.webp'):numpy_image = np.array(webp.load_image(path_to_png_image))else:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imageif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)from tqdm import tqdmdef get_bezier_circle(radius: float = 80,segments: int = 4,bias: np.array = np.asarray([100., 100.])):deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)points = torch.stack((torch.cos(deg), torch.sin(deg))).Tpoints = points * radius + torch.tensor(bias).unsqueeze(dim=0)points = points.type(torch.FloatTensor).contiguous()return pointsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")matchLoss = GeometrymatchLoss(device, "reference_images/")# print(matchLoss.reference_images_feature.shape)# img1 = read_png_image_from_path('learn/output.png')canvas_width, canvas_height = 224, 224num_segments=4points1 = get_bezier_circle()path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),is_closed=True)shapes=[path]path.points.requires_grad = Trueprint(id(path.points))print(id(points1))points_vars = []points_vars.append(path.points)points_optim = torch.optim.Adam(points_vars, lr=1)pbar = tqdm(range(100000))print(points1)for t in pbar:# print(t)points_optim.zero_grad()# print("match_loss:", match_loss)match_loss = matchLoss(t,224, 224, shapes)match_loss.backward()# print(path.points.grad)points_optim.step()pbar.set_postfix({"match_loss": f"{match_loss.item()}"})# print(points_vars[0])pass

迭代1000轮次后生成的结果
在这里插入图片描述
迭代2000轮次后生成的结果
在这里插入图片描述
迭代4000轮次后生成的结果
在这里插入图片描述
迭代8000轮次后生成的结果
在这里插入图片描述

无图像增强效果不好的原因分析

论文CLIPDraw: Exploring Text-to-Drawing Synthesisthrough Language-Image Encoders解释

在这里插入图片描述

论文StyleCLIPDraw: Coupling Content and Style in Text-to-Drawing Translation解释

在这里插入图片描述

个人理解

因为有很多图片可以和一个文本相匹配,对于我们人来说这些图片有一个根本和文本不相关,如果进行图像增强大概率会得到局部最优值。在计算损失函数之前对图片先进行增强,透过透视等变换,相关的图片不论如何变换和文本的相似度基本不会降低,而不相关的图像变换完之后一般会让相似度降低,这样就可以防止不相关图片对实验结果的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/52881.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

滚柱导轨:数控机床高效运行的驱动力

机床制造者最关心的莫过于机床的精度,刚性和使用寿命,对导轨系统的关注甚少。但导轨为机床功能的实现奠定了可靠的基础,各种类型的机床工作部件,都是利用控制轴在指定的导轨上运动。机床设计者根据机床的类型和用途选用各种不同形…

C++ 中 static 关键字

修饰普通变量,修改变量的存储区域和生命周期,使变量存储在静态区,在 main 函数运行前就分配了空间,如果有初始值就用初始值初始化它,如果没有初始值系统用默认值初始化它。 修饰普通函数,表明函数的作用范围…

Python进阶04-网络编程

零、文章目录 Python进阶04-网络编程 1、计算机网络 网络相关知识请参考计算机网络详解 (1)IP地址的概念 IP 地址就是标识网络中设备的一个地址,好比现实生活中的家庭地址。 (2)IP地址的表现形式 IP 地址分为两类…

Javaweb学习之Vue数据绑定(五)

前期回顾 Javaweb学习之Vue实践小界面(四)-CSDN博客 认识数据绑定 Vue.js 中的数据绑定是一个核心概念,它极大地简化了前端开发中数据与视图(DOM)之间的同步问题。Vue.js 通过其响应式系统实现了数据的双向绑定&…

【Python Web开发】Flask+HTML学习笔记

目录 Flask框架一、安装flask库二、运行一个网页三、库函数及变量 HTML标签语言一、基本格式二、标签2.1 块级标签2.1.1 标题2.1.2 div2.1.3 图片2.1.4 列表2.1.5 表格 2.2 行内标签2.2.1 span2.2.2 超链接2.2.3 输入 2.3 其他标签2.3.1 提交表单 Flask框架 一、安装flask库 …

电阻柜柜体采用不锈钢材质有什么优点

电阻柜柜体采用不锈钢材质有什么优点: 1. **耐腐蚀性**:不锈钢材质具有极好的耐腐蚀性,能有效抵抗各种腐蚀物质的侵蚀,确保设备经久耐用。 2. **美观大方**:不锈钢表面光滑,质感细腻,呈现出大气…

Elasticsearch写入、读取、更新、删除以及批量操作(golang)

目录 1、背景 2、elasticsearch基础操作 2.1 创建es链接客户端 2.2 创建、删除索引 2.3 插入文档 2.3 查询文档 2.4 删除一条文档 2.5 更新文档 2.6 逻辑查询 2.7 滚动查询(批量查询) 2.8 批量插入数据 2.9 批量更新文档 2.10 批量删除文档 3、检索条件设置 4、测…

探索Unity与C#的无限潜能:从新手到高手的编程之旅

在数字创意与技术创新交织的今天,Unity游戏引擎凭借其强大的跨平台能力和灵活的编程接口,成为了无数开发者心中的首选。而C#,作为Unity的官方脚本语言,更是以其面向对象的特性和丰富的库支持,为游戏开发注入了无限可能…

Golang | Leetcode Golang题解之第375题猜数字大小II

题目&#xff1a; 题解&#xff1a; func getMoneyAmount(n int) int {f : make([][]int, n1)for i : range f {f[i] make([]int, n1)}for i : n - 1; i > 1; i-- {for j : i 1; j < n; j {f[i][j] j f[i][j-1]for k : i; k < j; k {cost : k max(f[i][k-1], f[…

关于Scrapy的那些事儿(四)Scrapy Shell

Scrapy Shell launch Scrapy shell 使用如下命令&#xff1a; scrapy shell <url>当运行scrapy shell的时候&#xff0c;它为我们提供了一些功能函数&#xff1a; shelp() :打印可用对象和快捷命令的帮助列表fetch&#xff08;request or url&#xff09;&#xff1a;…

CMake构建学习笔记11-minizip库的构建

准确来说&#xff0c;minizip其实是zlib提供的辅助工具&#xff0c;位于zlib库的contrib文件夹内。minizip提供了更为高级一点的接口&#xff0c;能直接操作文件进行压缩。不过&#xff0c;有点麻烦的是这个工具并没有提供CMake构建的方式。那么可以按照构建giflib的方式&#…

Java开发工程师-匹配性岗位(借鉴性质)

1.匹配性质 技能迁移:Java开发工程师通常具备较强的编程能力、逻辑思维和问题解决能力,这些技能可以迁移到其他领域。行业选择:考虑目前行业趋势以及未来发展方向,Java工程师可以转向大数据、人工智能、云计算等等领域。个人兴趣与职业发展:转行时个人的兴趣和职业发展规划…

1.1量化交易的定义与魅力

Hey&#xff0c;Python高手们&#xff0c;欢迎来到量化交易的世界&#xff01;在这里&#xff0c;我们不谈风花雪月&#xff0c;只谈数字和代码。量化交易&#xff0c;听起来是不是有点像是“用电脑代替人脑做交易”的黑科技&#xff1f;没错&#xff0c;你猜对了&#xff01; …

【JAVA入门】Day28 - 数据结构

【JAVA入门】Day28 - 数据结构 文章目录 【JAVA入门】Day28 - 数据结构一、栈二、队列三、数组3.1 ArrayList 四、链表4.1 LinkedList 五、二叉树5.1 二叉查找树5.2 二叉树的遍历方式5.3 平衡二叉树5.4 平衡二叉树的旋转5.5 平衡二叉树需要旋转的几种情况 六、红黑树6.1 红黑规…

永成防回水防回气装置煤矿毫不犹豫选择

永成防回水防回气装置煤矿毫不犹豫选择&#xff0c;不敢说我们有多好&#xff0c;我们只把简单的事做好&#xff0c;用心服务&#xff0c;因为品质&#xff0c;所以信任。因为信任&#xff0c;所以值得选择。 本防回水防回气装置是一种用于煤矿瓦斯管路爆渣和燃烧时防止回火、…

abstract类

1 问题 abstract类不能用 new运算符创建对象&#xff0c;必须产生其子类&#xff0c;由子类创建对象。若 abstract类的类体中有 abstract()方法&#xff0c;则只允许声明&#xff0c;不能带有方法体&#xff0c;而该类的子类必须实现 abstract()方法。一个abstract类只关心子类…

3_1_PID控制原理

自从计算机进入控制领域以来&#xff0c;用数字计算机代替模拟计算机调节器组成计算机控制系统&#xff0c;不仅可以用软件实现PID控制算法&#xff0c;而且可以利用计算机的逻辑功能&#xff0c;使PID控制更加灵活。数字PID控制在生产过程中是一种最普遍采用的控制方法&#x…

[Algorithm][综合训练][奇数位丢弃][求和][计算字符串的编辑距离]详细讲解

目录 1.奇数位丢弃1.题目链接2.算法原理详解 && 代码实现 2.求和1.题目链接2.算法原理详解 && 代码实现 3.计算字符串的编辑距离1.题目链接2.算法原理详解 && 代码实现 1.奇数位丢弃 1.题目链接 奇数位丢弃 2.算法原理详解 && 代码实现 解法…

YOLOv9改进策略【损失函数篇】| 利用MPDIoU,加强边界框回归的准确性

一、背景 目标检测和实例分割中的关键问题&#xff1a; 现有的大多数边界框回归损失函数在不同的预测结果下可能具有相同的值&#xff0c;这降低了边界框回归的收敛速度和准确性。 现有损失函数的不足&#xff1a; 现有的基于 ℓ n \ell_n ℓn​范数的损失函数简单但对各种尺度…

计算机网络速成(三)

一、网络协议与模型 什么是协议&#xff1f; 协议是指计算机系统中完成特定任务所必需的规则和约定&#xff0c;特别是数据传输和交换的规则和约定。OSI和TCP/IP是什么&#xff1f; OSI&#xff08;开放式系统互连参考模型&#xff09;是一种网络架构模型&#xff0c;将网络系…