利用clip模型实现text2draw

参考论文

实践

有数据增强的代码

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as Fclass GeometrymatchLoss(torch.nn.Module):def __init__(self, device, reference_images_path):super(GeometrymatchLoss, self).__init__()self.device = deviceself.model, clip_preprocess = clip.load('ViT-B/32', self.device, jit=False)self.model.eval()self.preprocess = transforms.Compose([clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisationself.reference_images_feature = self.reference_images_feature(reference_images_path)self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)self.text = clip.tokenize([ "A picture of triangle"]).to(device)self.text_features = self.model.encode_text(self.text)# self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)print("text_features.requires_grad:",self.text_features.requires_grad)self.text_features=self.text_features.detach()self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]# Image Augmentation Transformationself.augment_trans = transforms.Compose([transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),])def forward(self, t,canvas_width, canvas_height,shapes):scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)# 渲染图像render = pydiffvg.RenderFunction.applytarget = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)if target.shape[-1] == 4:target = self.compose_image_with_white_background(target)if t%100==0:pydiffvg.imwrite(target.cpu(), f'learn/log_augs/output_{t}.png', gamma=2.2)# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)img = target.unsqueeze(0)img = img.permute(0, 3, 1, 2)loss = 0NUM_AUGS = 4img_augs = []for n in range(NUM_AUGS):img_augs.append(self.augment_trans(img))im_batch = torch.cat(img_augs)image_features = self.model.encode_image(im_batch)# logit_scale = self.model.logit_scale.exp()for n in range(NUM_AUGS):loss -= torch.cosine_similarity(self.text_features, image_features[n:n + 1], dim=1)return lossdef compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:if img.shape[-1] == 3:  # return img if it is already rgbreturn img# Compose img with white backgroundalpha = img[:, :, 3:4]img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(img.shape[0], img.shape[1], 3, device=self.device)return imgdef read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imagedef reference_images_feature(self, reference_images_path):reference_images_num = len(os.listdir(reference_images_path))reference_images_feature = []for i in range(reference_images_num):i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))if i_reference_image.shape[-1] == 4:i_reference_image = self.compose_image_with_white_background(i_reference_image)# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()reference_images_feature.append(i_reference_image_features)return torch.cat(reference_images_feature)def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:if path_to_png_image.endswith('.webp'):numpy_image = np.array(webp.load_image(path_to_png_image))else:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imageif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)from tqdm import tqdmdef get_bezier_circle(radius: float = 80,segments: int = 4,bias: np.array = np.asarray([100., 100.])):deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)points = torch.stack((torch.cos(deg), torch.sin(deg))).Tpoints = points * radius + torch.tensor(bias).unsqueeze(dim=0)points = points.type(torch.FloatTensor).contiguous()return pointsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")matchLoss = GeometrymatchLoss(device, "reference_images/")# print(matchLoss.reference_images_feature.shape)# img1 = read_png_image_from_path('learn/output.png')canvas_width, canvas_height = 224, 224num_segments=4points1 = get_bezier_circle()path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),is_closed=True)shapes=[path]path.points.requires_grad = Trueprint(id(path.points))print(id(points1))points_vars = []points_vars.append(path.points)points_optim = torch.optim.Adam(points_vars, lr=1)pbar = tqdm(range(100000))print(points1)for t in pbar:# print(t)points_optim.zero_grad()# print("match_loss:", match_loss)match_loss = matchLoss(t,224, 224, shapes)match_loss.backward()# print(path.points.grad)points_optim.step()pbar.set_postfix({"match_loss": f"{match_loss.item()}"})# print(points_vars[0])pass

迭代1000轮次后生成的结果
在这里插入图片描述

没有图像增强

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as Fclass GeometrymatchLoss(torch.nn.Module):def __init__(self, device, reference_images_path):super(GeometrymatchLoss, self).__init__()self.device = deviceself.model, clip_preprocess = clip.load('ViT-B/32', self.device, jit=False)self.model.eval()self.preprocess = transforms.Compose([clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisation# self.preprocess = transforms.Compose([clip_preprocess.transforms[-1]])  # clip normalisationself.reference_images_feature = self.reference_images_feature(reference_images_path)self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)self.text = clip.tokenize([ "A picture of triangle"]).to(device)# self.text = clip.tokenize(["A picture of rectangle", "A picture of triangle", "A picture of circle", "A picture of pentagon","A picture of five-pointed star"]).to(device)self.text_features = self.model.encode_text(self.text)self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)print("text_features.requires_grad:",self.text_features.requires_grad)self.text_features=self.text_features.detach()self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]# Image Augmentation Transformationself.augment_trans = transforms.Compose([transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),])def forward(self, t,canvas_width, canvas_height,shapes):scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)# 渲染图像render = pydiffvg.RenderFunction.applytarget = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)if target.shape[-1] == 4:target = self.compose_image_with_white_background(target)if t%100==0:pydiffvg.imwrite(target.cpu(), f'learn/log/output_{t}.png', gamma=2.2)# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)img = target.unsqueeze(0)img = img.permute(0, 3, 1, 2)loss = 0NUM_AUGS = 4img_augs = []for n in range(NUM_AUGS):img_augs.append(self.augment_trans(img))im_batch = torch.cat(img_augs)image_features = self.model.encode_image(img)self.targets_features: torch.tensor=image_features[0]self.targets_features = self.targets_features / self.targets_features.norm(dim=-1, keepdim=True)loss -= torch.cosine_similarity(self.text_features, self.targets_features, dim=1)return lossdef compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:if img.shape[-1] == 3:  # return img if it is already rgbreturn img# Compose img with white backgroundalpha = img[:, :, 3:4]img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(img.shape[0], img.shape[1], 3, device=self.device)return imgdef read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imagedef reference_images_feature(self, reference_images_path):reference_images_num = len(os.listdir(reference_images_path))reference_images_feature = []for i in range(reference_images_num):i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))if i_reference_image.shape[-1] == 4:i_reference_image = self.compose_image_with_white_background(i_reference_image)# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()reference_images_feature.append(i_reference_image_features)return torch.cat(reference_images_feature)def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:if path_to_png_image.endswith('.webp'):numpy_image = np.array(webp.load_image(path_to_png_image))else:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imageif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)from tqdm import tqdmdef get_bezier_circle(radius: float = 80,segments: int = 4,bias: np.array = np.asarray([100., 100.])):deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)points = torch.stack((torch.cos(deg), torch.sin(deg))).Tpoints = points * radius + torch.tensor(bias).unsqueeze(dim=0)points = points.type(torch.FloatTensor).contiguous()return pointsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")matchLoss = GeometrymatchLoss(device, "reference_images/")# print(matchLoss.reference_images_feature.shape)# img1 = read_png_image_from_path('learn/output.png')canvas_width, canvas_height = 224, 224num_segments=4points1 = get_bezier_circle()path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),is_closed=True)shapes=[path]path.points.requires_grad = Trueprint(id(path.points))print(id(points1))points_vars = []points_vars.append(path.points)points_optim = torch.optim.Adam(points_vars, lr=1)pbar = tqdm(range(100000))print(points1)for t in pbar:# print(t)points_optim.zero_grad()# print("match_loss:", match_loss)match_loss = matchLoss(t,224, 224, shapes)match_loss.backward()# print(path.points.grad)points_optim.step()pbar.set_postfix({"match_loss": f"{match_loss.item()}"})# print(points_vars[0])pass

迭代1000轮次后生成的结果
在这里插入图片描述
迭代2000轮次后生成的结果
在这里插入图片描述
迭代4000轮次后生成的结果
在这里插入图片描述
迭代8000轮次后生成的结果
在这里插入图片描述

无图像增强效果不好的原因分析

论文CLIPDraw: Exploring Text-to-Drawing Synthesisthrough Language-Image Encoders解释

在这里插入图片描述

论文StyleCLIPDraw: Coupling Content and Style in Text-to-Drawing Translation解释

在这里插入图片描述

个人理解

因为有很多图片可以和一个文本相匹配,对于我们人来说这些图片有一个根本和文本不相关,如果进行图像增强大概率会得到局部最优值。在计算损失函数之前对图片先进行增强,透过透视等变换,相关的图片不论如何变换和文本的相似度基本不会降低,而不相关的图像变换完之后一般会让相似度降低,这样就可以防止不相关图片对实验结果的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/52881.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

滚柱导轨:数控机床高效运行的驱动力

机床制造者最关心的莫过于机床的精度,刚性和使用寿命,对导轨系统的关注甚少。但导轨为机床功能的实现奠定了可靠的基础,各种类型的机床工作部件,都是利用控制轴在指定的导轨上运动。机床设计者根据机床的类型和用途选用各种不同形…

Python进阶04-网络编程

零、文章目录 Python进阶04-网络编程 1、计算机网络 网络相关知识请参考计算机网络详解 (1)IP地址的概念 IP 地址就是标识网络中设备的一个地址,好比现实生活中的家庭地址。 (2)IP地址的表现形式 IP 地址分为两类…

【Python Web开发】Flask+HTML学习笔记

目录 Flask框架一、安装flask库二、运行一个网页三、库函数及变量 HTML标签语言一、基本格式二、标签2.1 块级标签2.1.1 标题2.1.2 div2.1.3 图片2.1.4 列表2.1.5 表格 2.2 行内标签2.2.1 span2.2.2 超链接2.2.3 输入 2.3 其他标签2.3.1 提交表单 Flask框架 一、安装flask库 …

探索Unity与C#的无限潜能:从新手到高手的编程之旅

在数字创意与技术创新交织的今天,Unity游戏引擎凭借其强大的跨平台能力和灵活的编程接口,成为了无数开发者心中的首选。而C#,作为Unity的官方脚本语言,更是以其面向对象的特性和丰富的库支持,为游戏开发注入了无限可能…

Golang | Leetcode Golang题解之第375题猜数字大小II

题目&#xff1a; 题解&#xff1a; func getMoneyAmount(n int) int {f : make([][]int, n1)for i : range f {f[i] make([]int, n1)}for i : n - 1; i > 1; i-- {for j : i 1; j < n; j {f[i][j] j f[i][j-1]for k : i; k < j; k {cost : k max(f[i][k-1], f[…

【JAVA入门】Day28 - 数据结构

【JAVA入门】Day28 - 数据结构 文章目录 【JAVA入门】Day28 - 数据结构一、栈二、队列三、数组3.1 ArrayList 四、链表4.1 LinkedList 五、二叉树5.1 二叉查找树5.2 二叉树的遍历方式5.3 平衡二叉树5.4 平衡二叉树的旋转5.5 平衡二叉树需要旋转的几种情况 六、红黑树6.1 红黑规…

永成防回水防回气装置煤矿毫不犹豫选择

永成防回水防回气装置煤矿毫不犹豫选择&#xff0c;不敢说我们有多好&#xff0c;我们只把简单的事做好&#xff0c;用心服务&#xff0c;因为品质&#xff0c;所以信任。因为信任&#xff0c;所以值得选择。 本防回水防回气装置是一种用于煤矿瓦斯管路爆渣和燃烧时防止回火、…

3_1_PID控制原理

自从计算机进入控制领域以来&#xff0c;用数字计算机代替模拟计算机调节器组成计算机控制系统&#xff0c;不仅可以用软件实现PID控制算法&#xff0c;而且可以利用计算机的逻辑功能&#xff0c;使PID控制更加灵活。数字PID控制在生产过程中是一种最普遍采用的控制方法&#x…

[Algorithm][综合训练][奇数位丢弃][求和][计算字符串的编辑距离]详细讲解

目录 1.奇数位丢弃1.题目链接2.算法原理详解 && 代码实现 2.求和1.题目链接2.算法原理详解 && 代码实现 3.计算字符串的编辑距离1.题目链接2.算法原理详解 && 代码实现 1.奇数位丢弃 1.题目链接 奇数位丢弃 2.算法原理详解 && 代码实现 解法…

YOLOv9改进策略【损失函数篇】| 利用MPDIoU,加强边界框回归的准确性

一、背景 目标检测和实例分割中的关键问题&#xff1a; 现有的大多数边界框回归损失函数在不同的预测结果下可能具有相同的值&#xff0c;这降低了边界框回归的收敛速度和准确性。 现有损失函数的不足&#xff1a; 现有的基于 ℓ n \ell_n ℓn​范数的损失函数简单但对各种尺度…

Redis与SpringMVC的整合与最佳实践

整合Redis与Spring MVC&#xff08;现在通常是Spring Boot的一部分&#xff09;可以提高应用性能&#xff0c;特别是在处理大量数据缓存和会话状态管理方面。 下面是一些关于如何整合Redis与Spring MVC的最佳实践&#xff1a; 1. 引入依赖 首先&#xff0c;你需要在你的项目中…

【Java】Maven多环境切换实战(实操图解)

Java系列文章目录 补充内容 Windows通过SSH连接Linux 第一章 Linux基本命令的学习与Linux历史 文章目录 Java系列文章目录一、前言二、学习内容&#xff1a;三、问题描述四、解决方案&#xff1a;4.1 Maven多环境配置学习4.2 切换环境4.2.1 先打包4.2.2 之后可以切换 五、总结…

【ACM独立出版 | 厦大主办】第五届计算机科学与管理科技国际学术会议(ICCSMT 2024,10月18-20)

第五届计算机科学与管理科技国际学术会议(ICCSMT 2024) 定于2024年10月18-20日在中国厦门举行。 会议旨在为从事“计算机科学”与“管理科技”研究的专家学者、工程技术人员、技术研发人员提供一个共享科研成果和前沿技术&#xff0c;了解学术发展趋势&#xff0c;拓宽研究思路…

设计模式结构型模式之适配器模式

结构型模式之适配器模式 一、概述和使用场景1、概述2、使用场景&#xff1a;3、主要分类 二、 代码示例1、类适配器模式2、接口适配器3、对象适配器 四、总结1、适配器模式2、适配器模式的优点3、适配器模式的缺点 一、概述和使用场景 1、概述 适配器模式是一种结构型设计模式…

力扣面试经典算法150题:整数转罗马数字

整数转罗马数字 今天的题目是力扣面试经典150题中的数组的中等难度题&#xff1a; 整数转罗马数字。 题目链接&#xff1a;https://leetcode.cn/problems/integer-to-roman/description/?envTypestudy-plan-v2&envIdtop-interview-150 题目描述 七个不同的符号代表罗马…

SprinBoot+Vue餐厅点餐系统的设计与实现

目录 1 项目介绍2 项目截图3 核心代码3.1 Controller3.2 Service3.3 Dao3.4 application.yml3.5 SpringbootApplication3.5 Vue 4 数据库表设计5 文档参考6 计算机毕设选题推荐7 源码获取 1 项目介绍 博主个人介绍&#xff1a;CSDN认证博客专家&#xff0c;CSDN平台Java领域优质…

stm32智能颜色送餐小车(ESP8266WIFI模块、APP制作、物联网模型建立、MQTTFX)

大家好啊&#xff0c;我是情谊&#xff0c;今天我们来介绍一下我最近设计的stm32产品&#xff0c;我们在今年七月份的时候参加了光电设计大赛&#xff0c;我们小队使用的就是stm32的智能送餐小车&#xff0c;虽然止步于省赛&#xff0c;但是还是一次成长的经验吧&#xff0c;那…

java在项目中实现excel导入导出

一、初识EasyExcel* 1. Apache POI 先说POI&#xff0c;有过报表导入导出经验的同学&#xff0c;应该听过或者使用。 Apache POI是Apache软件基金会的开源函式库&#xff0c;提供跨平台的Java API实现Microsoft Office格式档案读写。但是存在如下一些问题&#xff1a; 1.1 …

SpringBoot 项目——抽奖系统

本项目主要实现的功能是&#xff1a;主要服务于管理员用户&#xff0c;其可圈选奖品&#xff0c;人员来创建抽奖活动&#xff0c;并进行在线抽奖&#xff0c;并可通过短信或邮件的方式通知中奖者&#xff0c;同时普通用户可查看已结束的抽奖活动的中奖结果&#xff1b; 一、项…

TESSY创建单元测试或集成测试工程

我们以tessy5.1 IDE为例&#xff0c;给大家展示工程的创建过程。 1、打开TESSY5.1软件后&#xff0c;会弹出&#xff1a; 2、点击NEW Project后&#xff0c;会弹出&#xff1a; 3、接下来&#xff0c;就可以打开刚创建的工程&#xff1a; 4、进入到TESSY的主界面后&#xff0c…