Dataset for Stable Diffusion

1.Dataset for Stable Diffusion

笔记来源:
1.Flickr8k数据集处理
2.处理Flickr8k数据集
3.Github:pytorch-stable-diffusion
4.Flickr 8k Dataset
5.dataset_flickr8k.json

1.1 Dataset

采用Flicker8k数据集,该数据集有两个文件,第一个文件为Flicker8k_Dataset (全部为图片),第二个文件为Flickr8k.token.txt (含两列image_id和caption),其中一个image_id对应5个caption (sentence)

1.2 Dataset description file

数据集文本描述文件:dataset_flickr8k.json
文件格式如下:
{“images”: [ {“sentids”: [ ],“imgid”: 0,“sentences”:[{“tokens”:[ ]}, {“tokens”:[ ], “raw”: “…”, “imgid”:0, “sentid”:0}, …, “split”: “train”, “filename”: …jpg}, {“sentids”…} ], “dataset”: “flickr8k”}

参数解释
“sentids”:[0,1,2,3,4]caption 的 id 范围(一个image对应5个caption,所以sentids从0到4)
“imgid”:0image 的 id(从0到7999共8000张image)
“sentences”:[ ]包含一张照片的5个caption
“tokens”:[ ]每个caption分割为单个word
“raw”: " "每个token连接起来的caption
“imgid”: 0与caption相匹配的image的id
“sentid”: 0imag0对应的具体的caption的id
“split”:" "将该image和对应caption划分到训练集or验证集or测试集
“filename”:“…jpg”image具体名称

dataset_flickr8k.json

1.3 Process Datasets

下面代码引用自:Flickr8k数据集处理(仅作学习使用)

import json
import os
import random
from collections import Counter, defaultdict
from matplotlib import pyplot as plt
from PIL import Image
from argparse import Namespace
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transformsdef create_dataset(dataset='flickr8k', captions_per_image=5, min_word_count=5, max_len=30):"""Parameters:dataset: Name of the datasetcaptions_per_image: Number of captions per imagemin_word_count: Only consider words that appear at least this many times in the dataset (excluding the test set)max_len: Maximum number of words in a caption. Captions longer than this will be truncated.Output:A vocabulary file: vocab.jsonThree dataset files: train_data.json, val_data.json, test_data.json"""# Paths for reading data and saving processed data# Path to the dataset JSON fileflickr_json_path = ".../sd/data/dataset_flickr8k.json"# Folder containing imagesimage_folder = ".../sd/data/Flicker8k_Dataset"# Folder to save processed results# The % operator is used to format the string by replacing %s with the value of the dataset variable.# For example, if dataset is "flickr8k", the resulting output_folder will be# /home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion/sd/data/flickr8k.output_folder = ".../sd/data/%s" % dataset# Ensure output directory existsos.makedirs(output_folder, exist_ok=True)print(f"Output folder: {output_folder}")# Read the dataset JSON filewith open(file=flickr_json_path, mode="r") as j:data = json.load(fp=j)# Initialize containers for image paths, captions, and vocabulary# Dictionary to store image pathsimage_paths = defaultdict(list)# Dictionary to store image captionsimage_captions = defaultdict(list)# Count the number of elements, then count and return a dictionary# key:element value:the number of elements.vocab = Counter()# read from file dataset_flickr8k.jsonfor img in data["images"]:  # Iterate over each image in the datasetsplit = img["split"]  # Determine the split (train, val, or test) for the imagecaptions = []for c in img["sentences"]:  # Iterate over each caption for the image# Update word frequency count, excluding test set dataif split != "test":  # Only update vocabulary for train/val splits# c['tokens'] is a list, The number of occurrences of each word in the list is increased by onevocab.update(c['tokens'])  # Update vocabulary with words in the caption# Only consider captions that are within the maximum lengthif len(c["tokens"]) <= max_len:captions.append(c["tokens"])  # Add the caption to the list if it meets the length requirementif len(captions) == 0:  # Skip images with no valid captionscontinue# Construct the full image path/home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion# image_folder + image_name# ./Flicker8k_Dataset/img['filename']path = os.path.join(image_folder, img['filename'])# Save the full image path and its captions in the respective dictionariesimage_paths[split].append(path)image_captions[split].append(captions)'''After the above steps, we have:- vocab(a dict) keys:words、values: counts of all words- image_paths: (a dict) keys "train", "val", and "test"; values: lists of absolute image paths- image_captions: (a dict) keys: "train", "val", and "test"; values: lists of captions'''/home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion........

我们通过dataset_flickr8k.json文件把数据集转化为三个词典

dictkeyvalue
vacabwordfrequency of words in all captions
image_path“train”、“val”、“test”lists of absolute image path
image_captions“train”、“val”、“test”lists of captions

我们通过Debug打印其中的内容

print(vocab)
print(image_paths["train"][1])
print(image_captions["train"][1])

def create_dataset(dataset='flickr8k', captions_per_image=5, min_word_count=5, max_len=30):"""Parameters:dataset: Name of the datasetcaptions_per_image: Number of captions per imagemin_word_count: Only consider words that appear at least this many times in the dataset (excluding the test set)max_len: Maximum number of words in a caption. Captions longer than this will be truncated.Output:A vocabulary file: vocab.jsonThree dataset files: train_data.json, val_data.json, test_data.json"""........# Create the vocabulary, adding placeholders for special tokens# Add placeholders<pad>, unregistered word identifiers<unk>, sentence beginning and end identifiers<start><end>words = [w for w in vocab.keys() if vocab[w] > min_word_count]  # Filter words by minimum countvocab = {k: v + 1 for v, k in enumerate(words)}  # Create the vocabulary with indices# Add special tokens to the vocabularyvocab['<pad>'] = 0vocab['<unk>'] = len(vocab)vocab['<start>'] = len(vocab)vocab['<end>'] = len(vocab)# Save the vocabulary to a filewith open(os.path.join(output_folder, 'vocab.json'), "w") as fw:json.dump(vocab, fw)# Process each dataset split (train, val, test)# Iterate over each split: split = "train" 、 split = "val" 和 split = "test"for split in image_paths:# List of image paths for the splitimgpaths = image_paths[split]  # type(imgpaths)=list# List of captions for the splitimcaps = image_captions[split]  # type(imcaps)=list# store result that converting words of caption to their respective indices in the vocabularyenc_captions = []for i, path in enumerate(imgpaths):# Check if the image can be openedimg = Image.open(path)# Ensure each image has the required number of captionsif len(imcaps[i]) < captions_per_image:filled_num = captions_per_image - len(imcaps[i])# Repeat captions if neededcaptions = imcaps[i] + [random.choice(imcaps[i]) for _ in range(0, filled_num)]else:# Randomly sample captions if there are more than neededcaptions = random.sample(imcaps[i], k=captions_per_image)assert len(captions) == captions_per_imagefor j, c in enumerate(captions):# Encode each caption by converting words to their respective indices in the vocabularyenc_c = [vocab['<start>']] + [vocab.get(word, vocab['<unk>']) for word in c] + [vocab["<end>"]]enc_captions.append(enc_c)assert len(imgpaths) * captions_per_image == len(enc_captions)data = {"IMAGES": imgpaths,"CAPTIONS": enc_captions}# Save the processed dataset for the current split (train,val,test)with open(os.path.join(output_folder, split + "_data.json"), 'w') as fw:json.dump(data, fw)create_dataset()

经过create_dataset函数,我们得到如下图的文件

四个文件的详细内容见下表

train_data.json中的第一个key:IMAGES
train_data.json中的第二个key:CAPTIONS
test_data.json中的第一个key:IMAGES
test_data.json中的第二个key:CAPTIONS
val_data.json中的第一个key:IMAGES
val_data.json中的第二个key:CAPTIONS
vocab.json开始部分
vocab.json结尾部分

生成vocab.json的关键代码
首先统计所有caption中word出现至少大于5次的word,而后给这些word依次赋予一个下标

# Create the vocabulary, adding placeholders for special tokens# Add placeholders<pad>, unregistered word identifiers<unk>, sentence beginning and end identifiers<start><end># Create a list of words from the vocabulary that have a frequency higher than 'min_word_count'# min_word_count: Only consider words that appear at least this many times in the dataset (excluding the test set)words = [w for w in vocab.keys() if vocab[w] > min_word_count]  # Filter words by minimum count# assign an index to each word, starting from 1 (indices start from 0, so add 1)vocab = {k: v + 1 for v, k in enumerate(words)}  # Create the vocabulary with indices

最终生成vocab.json

生成 [“split”]_data.json 的关键
读入文件dataset_flickr8k.json,并创建两个字典,第一个字典放置每张image的绝对路径,第二个字典放置描述image的caption,根据vocab将token换为下标保存,根据文件dataset_flickr8k.json中不同的split,这image的绝对路径和相应caption保存在不同文件中(train_data.json、test_data.json、val_data.json)

dataset_flickr8k.json

train_data.json


从vocab中获取token的下标得到CAPTION的编码

for j, c in enumerate(captions):# Encode each caption by converting words to their respective indices in the vocabularyenc_c = [vocab['<start>']] + [vocab.get(word, vocab['<unk>']) for word in c] + [vocab["<end>"]]enc_captions.append(enc_c)


尝试使用上面生成的测试集文件test_data.json和vocab.json输出某张image以及对应的caption
下面代码引用自:Flickr8k数据集处理(仅作学习使用)

'''
test
1.Iterates over the 5 captions for 下面代码引用自:[Flickr8k数据集处理](https://blog.csdn.net/weixin_48981284/article/details/134676813)(仅作学习使用)the 250th image.
2.Retrieves the word indices for each caption.
3.Converts the word indices to words using vocab_idx2word.
4.Joins the words to form complete sentences.
5.Prints each caption.
'''
import json
from PIL import Image
from matplotlib import pyplot as plt
# Load the vocabulary from the JSON file
with open('.../sd/data/flickr8k/vocab.json', 'r') as f:vocab = json.load(f)  # Load the vocabulary from the JSON file into a dictionary
# Create a dictionary to map indices to words
vocab_idx2word = {idx: word for word, idx in vocab.items()}
# Load the test data from the JSON file
with open('.../sd/data/flickr8k/test_data.json', 'r') as f:data = json.load(f)  # Load the test data from the JSON file into a dictionary
# Open and display the 250th image in the test set
# Open the image at index 250 in the 'IMAGES' list
content_img = Image.open(data['IMAGES'][250])
plt.figure(figsize=(6, 6))
plt.subplot(1,1,1)
plt.imshow(content_img)
plt.title('Image')
plt.axis('off')
plt.show()
# Print the lengths of the data, image list, and caption list
# Print the number of keys in the dataset dictionary (should be 2: 'IMAGES' and 'CAPTIONS')
print(len(data))
print(len(data['IMAGES']))  # Print the number of images in the 'IMAGES' list
print(len(data["CAPTIONS"]))  # Print the number of captions in the 'CAPTIONS' list
# Display the captions for the 300th image
# Iterate over the 5 captions associated with the 300th image
for i in range(5):# Get the word indices for the i-th caption of the 300th imageword_indices = data['CAPTIONS'][250 * 5 + i]# Convert indices to words and join them to form a captionprint(''.join([vocab_idx2word[idx] for idx in word_indices]))

data 的 key 有两个 IMAGES 和 CAPTIONS
测试集image有1000张,每张对应5个caption,共5000个caption
第250张图片的5个caption如下图

1.4 Dataloader

下面代码引用自:Flickr8k数据集处理(仅作学习使用)

import json
import os
import random
from collections import Counter, defaultdict
from PIL import Image
import torch
from torch.utils.data import Dataset
from torch.utils import data
import torchvision.transforms as transformsclass ImageTextDataset(Dataset):"""Pytorch Dataset class to generate data batches using torch DataLoader"""def __init__(self, dataset_path, vocab_path, split, captions_per_image=5, max_len=30, transform=None):"""Parameters:dataset_path: Path to the JSON file containing the datasetvocab_path: Path to the JSON file containing the vocabularysplit: The dataset split, which can be "train", "val", or "test"captions_per_image: Number of captions per imagemax_len: Maximum number of words per captiontransform: Image transformation methods"""self.split = split# Validate that the split is one of the allowed valuesassert self.split in {"train", "val", "test"}# Store captions per imageself.cpi = captions_per_image# Store maximum caption lengthself.max_len = max_len# Load the datasetwith open(dataset_path, "r") as f:self.data = json.load(f)# Load the vocabularywith open(vocab_path, "r") as f:self.vocab = json.load(f)# Store the image transformation methodsself.transform = transform# Number of captions in the dataset# Calculate the size of the datasetself.dataset_size = len(self.data["CAPTIONS"])def __getitem__(self, i):"""Retrieve the i-th sample from the dataset"""# Get [i // self.cpi]-th image corresponding to the i-th sample (each image has multiple captions)img = Image.open(self.data['IMAGES'][i // self.cpi]).convert("RGB")# Apply image transformation if providedif self.transform is not None:# Apply the transformation to the imageimg = self.transform(img)# Get the length of the captioncaplen = len(self.data["CAPTIONS"][i])# Pad the caption if its length is less than max_lenpad_caps = [self.vocab['<pad>']] * (self.max_len + 2 - caplen)# Convert the caption to a tensor and pad itcaption = torch.LongTensor(self.data["CAPTIONS"][i] + pad_caps)return img, caption, caplen  # Return the image, caption, and caption lengthdef __len__(self):return self.dataset_size  # Number of samples in the datasetdef make_train_val(data_dir, vocab_path, batch_size, workers=4):"""Create DataLoader objects for training, validation, and testing sets.Parameters:data_dir: Directory where the dataset JSON files are locatedvocab_path: Path to the vocabulary JSON filebatch_size: Number of samples per batchworkers: Number of subprocesses to use for data loading (default is 4)Returns:train_loader: DataLoader for the training setval_loader: DataLoader for the validation settest_loader: DataLoader for the test set"""# Define transformation for training settrain_tx = transforms.Compose([transforms.Resize(256),  # Resize images to 256x256transforms.ToTensor(),  # Convert image to PyTorch tensortransforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize using ImageNet mean and std])val_tx = transforms.Compose([transforms.Resize(256),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# Create dataset objects for training, validation, and test setstrain_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "train_data.json"), vocab_path=vocab_path,split="train", transform=train_tx)vaild_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "val_data.json"), vocab_path=vocab_path,split="val", transform=val_tx)test_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "test_data.json"), vocab_path=vocab_path,split="test", transform=val_tx)# Create DataLoader for training set with data shufflingtrain_loder = data.DataLoader(dataset=train_set, batch_size=batch_size, shuffer=True,num_workers=workers, pin_memory=True)# Create DataLoader for validation set without data shufflingval_loder = data.DataLoader(dataset=vaild_set, batch_size=batch_size, shuffer=False,num_workers=workers, pin_memory=True, drop_last=False)# Create DataLoader for test set without data shufflingtest_loder = data.DataLoader(dataset=test_set, batch_size=batch_size, shuffer=False,num_workers=workers, pin_memory=True, drop_last=False)return train_loder, val_loder, test_loder

创建好train_loader后,接下来我们就可以着手开始训练SD了!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/44813.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端工程化10-webpack静态的模块化打包工具之各种loader处理器

9.1、案例编写 我们创建一个component.js 通过JavaScript创建了一个元素&#xff0c;并且希望给它设置一些样式&#xff1b; 我们自己写的css,要把他加入到Webpack的图结构当中&#xff0c;这样才能被webpack检测到进行打包&#xff0c; style.css–>div_cn.js–>main…

Flower花所比特币交易及交易费用科普

在加密货币交易中&#xff0c;选择一个可靠的平台至关重要。Flower花所通过提供比特币交易服务脱颖而出。本文将介绍在Flower花所进行比特币交易的基础知识及其交易费用。 什么是Flower花所&#xff1f; Flower花所是一家加密货币交易平台&#xff0c;为新手和资深交易者提供…

【C++】开源:drogon-web框架配置使用

&#x1f60f;★,:.☆(&#xffe3;▽&#xffe3;)/$:.★ &#x1f60f; 这篇文章主要介绍drogon-web框架配置使用。 无专精则不能成&#xff0c;无涉猎则不能通。——梁启超 欢迎来到我的博客&#xff0c;一起学习&#xff0c;共同进步。 喜欢的朋友可以关注一下&#xff0c;…

Linux系统编程-线程同步详解

线程同步是指多个线程协调工作&#xff0c;以便在共享资源的访问和操作过程中保持数据一致性和正确性。在多线程环境中&#xff0c;线程是并发执行的&#xff0c;因此如果多个线程同时访问和修改共享资源&#xff0c;可能会导致数据不一致、竞态条件&#xff08;race condition…

【密码学】消息认证

你发送给朋友一条消息&#xff08;内容&#xff1a;明天下午来我家吃饭&#xff09;&#xff0c;这一过程中你不想让除你朋友以外的人看到消息的内容&#xff0c;这就叫做消息的机密性&#xff0c;用来保护消息机密性的方式被叫做加密机制。 现在站在朋友的视角&#xff0c;某一…

使用PyQt5实现添加工具栏、增加SwitchButton控件

前言&#xff1a;通过在网上找到的“电池电压监控界面”&#xff0c;学习PyQt5中添加工具栏、增加SwitchButton控件&#xff0c;在滑块控件右侧增加文本显示、设置界面背景颜色、修改文本控件字体颜色等。 1. 上位机界面效果展示 网络上原图如下&#xff1a; 自己使用PyQt5做…

【Linux】多线程_3

文章目录 九、多线程3. C11中的多线程4. 线程的简单封装 未完待续 九、多线程 3. C11中的多线程 Linux中是根据多线程库来实现多线程的&#xff0c;C11也有自己的多线程&#xff0c;那它的多线程又是怎样的&#xff1f;我们来使用一些C11的多线程。 Makefile&#xff1a; te…

Linux - 探索命令行

探索命令行 Linux命令行中的命令使用格式都是相同的: 命令名称 参数1 参数2 参数3 ...参数之间用任意数量的空白字符分开. 关于命令行, 可以先阅读一些基本常识. 然后我们介绍最常用的一些命令: ls用于列出当前目录(即"文件夹")下的所有文件(或目录). 目录会用蓝色…

CSIP-FTE考试专业题

靶场下载链接&#xff1a; https://pan.baidu.com/s/1ce1Kk0hSYlxrUoRTnNsiKA?pwdha1x pte-2003密码&#xff1a;admin123 centos:root admin123 解压密码&#xff1a; PTE考试专用 下载好后直接用vmware打开&#xff0c;有两个靶机&#xff0c;一个是基础题&#x…

【CTF-Crypto】数论基础-02

【CTF-Crypto】数论基础-02 文章目录 【CTF-Crypto】数论基础-021-16 二次剩余1-20 模p下-1的平方根*1-21 Legendre符号*1-22 Jacobi符号*2-1 群*2-2 群的性质2-3 阿贝尔群*2-4 子群2-11 群同态2-18 原根2-21 什么是环2-23 什么是域2-25 子环2-26 理想2-32 多项式环 1-16 二次剩…

打造智慧校园德育管理,提升学生操行基础分

智慧校园的德育管理系统内嵌的操行基础分功能&#xff0c;是对学生日常行为规范和道德素养进行量化评估的一个创新实践。该功能通过将抽象的道德品质转化为具体可量化的指标&#xff0c;如遵守纪律、尊师重道、团结协作、爱护环境及参与集体活动的积极性等&#xff0c;为每个学…

医疗器械FDA |FDA网络安全测试具体内容

医疗器械FDA网络安全测试的具体内容涵盖了多个方面&#xff0c;以确保医疗器械在网络环境中的安全性和合规性。以下是根据权威来源归纳的FDA网络安全测试的具体内容&#xff1a; 一、技术文件审查 网络安全计划&#xff1a;制造商需要提交网络安全计划&#xff0c;详细描述产…

Spring Boot集成easyposter快速入门Demo

1.什么是easyposter&#xff1f; easyposter是一个简单的,便于扩展的绘制海报工具包 使用场景 在日常工作过程中&#xff0c;通常一些C端平台会伴随着海报生成与分享业务。因为随着移动互联网的迅猛发展&#xff0c;社交分享已成为我们日常生活的重要组成部分。海报分享作为…

visual studio 2019版下载以及与UE4虚幻引擎配置(过程记录)(官网无法下载visual studio 2019安装包)

一、概述 由于需要使用到UE4虚幻引擎&#xff0c;我使用的版本是4.27版本的&#xff0c;其官方默认的visual studio版本是2019版本的&#xff0c;相应的版本对应关系可以通过下面的官方网站对应关系查询。https://docs.unrealengine.com/4.27/zh-CN/ProductionPipelines/Develo…

MMSegmentation笔记

如何训练自制数据集&#xff1f; 首先需要在 mmsegmentation/mmseg/datasets 目录下创建一个自制数据集的配置文件&#xff0c;以我的苹果叶片病害分割数据集为例&#xff0c;创建了mmsegmentation/mmseg/datasets/appleleafseg.py 可以看到&#xff0c;这个配置文件主要定义…

python:使用matplotlib库绘制图像(四)

作者是跟着http://t.csdnimg.cn/4fVW0学习的&#xff0c;matplotlib系列文章是http://t.csdnimg.cn/4fVW0的自己学习过程中整理的详细说明版本&#xff0c;对小白更友好哦&#xff01; 四、条形图 1. 一个数据样本的条形图 条形图&#xff1a;常用于比较不同类别的数量或值&…

3dmax-vray5大常用材质设置方法

3dmax云渲染平台——渲染100 以高性价比著称&#xff0c;是预算有限的小伙伴首选。 15分钟0.2,60分钟内0.8;注册填邀请码【7788】可领30元礼包和免费渲染券 提供了多种机器配置选择(可以自行匹配环境)最高256G大内存机器&#xff0c;满足不同用户需求。 木纹材质 肌理调整&…

练习9.5 彩票分析

练习 9.14&#xff1a;彩票 创建⼀个列表或元素&#xff0c;其中包含 10 个数和 5 个字 ⺟。从这个列表或元组中随机选择 4 个数或字⺟&#xff0c;并打印⼀条消息&#xff0c; 指出只要彩票上是这 4 个数或字⺟&#xff0c;就中⼤奖了。 练习 9.15&#xff1a;彩票分析 可以使…

2.5 OJ 网站的使用与作业全解

目录 1 OJ 网站如何使用 1.1 注册账户 1.2 登录账户 1.3 做题步骤 2 本节课的 OJ 作业说明 3 章节综合判断题 4 课时2作业1 5 课时2作业2 6 课时2作业3 1 OJ 网站如何使用 〇J 是英文 Online Judge 的缩写&#xff0c;中文翻译过来是在线判题。当用户将自己编写的代码…

基于XC7VX690T FPGA+ZU15EG SOC的6U VPX总线实时信号处理平台(支持4路光纤)

6U VPX架构&#xff0c;符合VITA46规范板载高性能FPGA处理器&#xff1a;XC7VX690T-2FFG1927I板载1片高性能MPSOC&#xff1a;XCZU15EG-2FFVB1156I板载1片MCU&#xff0c;进行健康管理、时钟配置等V7 FPGA外挂2个FMC接口两片FPGA之间通过高速GTH进行互联 基于6U VPX总线架构的通…