从头开始构建一个小规模的文生视频模型

OpenAI 的 Sora、Stability AI 的 Stable Video Diffusion 以及许多其他已经发布或未来将出现的文本生成视频模型,是继大语言模型 (LLM) 之后 2024 年最流行的 AI 趋势之一。

在这篇博客中,作者将展示如何将从头开始构建一个小规模的文本生成视频模型,涵盖了从理解理论概念、到编写整个架构再到生成最终结果的所有内容。

由于作者没有大算力的 GPU,所以仅编写了小规模架构。以下是在不同处理器上训练模型所需时间的比较。

图片

作者表示,在 CPU 上运行显然需要更长的时间来训练模型。如果你需要快速测试代码中的更改并查看结果,CPU 不是最佳选择。因此建议使用 Colab 或 Kaggle 的 T4 GPU 进行更高效、更快速的训练。

技术交流&资料

技术要学会分享、交流,不建议闭门造车。一个人可以走的很快、一堆人可以走的更远。

成立了算法面试和技术交流群,相关资料、技术交流&答疑,均可加我们的交流群获取,群友已超过2000人,添加时最好的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2040,备注:来自CSDN + 技术交流

构建目标

我们采用了与传统机器学习或深度学习模型类似的方法,即在数据集上进行训练,然后在未见过数据上进行测试。在文本转视频的背景下,假设有一个包含 10 万个狗捡球和猫追老鼠视频的训练数据集,然后训练模型来生成猫捡球或狗追老鼠的视频。

图片

图源:iStock, GettyImages

虽然此类训练数据集在互联网上很容易获得,但所需的算力极高。因此,我们将使用由 Python 代码生成的移动对象视频数据集。同时使用 GAN(生成对抗网络)架构来创建模型,而不是 OpenAI Sora 使用的扩散模型。

我们也尝试使用扩散模型,但内存要求超出了自己的能力。另一方面,GAN 可以更容易、更快地进行训练和测试。

准备条件

我们将使用 OOP(面向对象编程),因此必须对它以及神经网络有基本的了解。此外 GAN(生成对抗网络)的知识不是必需的,因为这里简单介绍它们的架构。

  • OOP:https://www.youtube.com/watch?v=q2SGW2VgwAM

  • 神经网络理论:https://www.youtube.com/watch?v=Jy4wM2X21u0

  • GAN 架构:https://www.youtube.com/watch?v=TpMIssRdhco

  • Python 基础:https://www.youtube.com/watch?v=eWRfhZUzrAc

了解 GAN 架构

什么是 GAN?

生成对抗网络是一种深度学习模型,其中两个神经网络相互竞争:一个从给定的数据集创建新数据(如图像或音乐),另一个则判断数据是真实的还是虚假的。这个过程一直持续到生成的数据与原始数据无法区分。

真实世界应用

  • 生成图像:GAN 根据文本 prompt 创建逼真的图像或修改现有图像,例如增强分辨率或为黑白照片添加颜色。

  • 数据增强:GAN 生成合成数据来训练其他机器学习模型,例如为欺诈检测系统创建欺诈交易数据。

  • 补充缺失信息:GAN 可以填充缺失数据,例如根据地形图生成地下图像以用于能源应用。

  • 生成 3D 模型:GAN 将 2D 图像转换为 3D 模型,在医疗保健等领域非常有用,可用于为手术规划创建逼真的器官图像。

GAN 工作原理

GAN 由两个深度神经网络组成:生成器和判别器。这两个网络在对抗设置中一起训练,其中一个网络生成新数据,另一个网络评估数据是真是假。

图片

GAN 训练示例

让我们以图像到图像的转换为例,解释一下 GAN 模型,重点是修改人脸。

1. 输入图像:输入图像是一张真实的人脸图像。

2. 属性修改:生成器会修改人脸的属性,比如给眼睛加上墨镜。

3. 生成图像:生成器会创建一组添加了太阳镜的图像。

4. 判别器的任务:判别器接收到混合的真实图像(带有太阳镜的人)和生成的图像(添加了太阳镜的人脸)。

5. 评估:判别器尝试区分真实图像和生成图像。

6. 反馈回路:如果判别器正确识别出假图像,生成器会调整其参数以生成更逼真的图像。如果生成器成功欺骗了判别器,判别器会更新其参数以提高检测能力。

通过这一对抗过程,两个网络都在不断改进。生成器越来越善于生成逼真的图像,而判别器则越来越善于识别假图像,直到达到平衡,判别器再也无法区分真实图像和生成的图像。此时,GAN 已成功学会生成逼真的修改图像。

设置背景

我们将使用一系列 Python 库,让我们导入它们。

# Operating System module for interacting with the operating system
import os# Module for generating random numbers
import random# Module for numerical operations
import numpy as np# OpenCV library for image processing
import cv2# Python Imaging Library for image processing
from PIL import Image, ImageDraw, ImageFont# PyTorch library for deep learning
import torch# Dataset class for creating custom datasets in PyTorch
from torch.utils.data import Dataset# Module for image transformations
import torchvision.transforms as transforms# Neural network module in PyTorch
import torch.nn as nn# Optimization algorithms in PyTorch
import torch.optim as optim# Function for padding sequences in PyTorch
from torch.nn.utils.rnn import pad_sequence# Function for saving images in PyTorch
from torchvision.utils import save_image# Module for plotting graphs and images
import matplotlib.pyplot as plt# Module for displaying rich content in IPython environments
from IPython.display import clear_output, display, HTML# Module for encoding and decoding binary data to text
import base64

现在我们已经导入了所有的库,下一步就是定义我们的训练数据,用于训练 GAN 架构。

对训练数据进行编码

我们需要至少 10000 个视频作为训练数据。为什么呢?因为我测试了较小数量的视频,结果非常糟糕,几乎没有任何效果。下一个重要问题是:这些视频内容是什么? 我们的训练视频数据集包括一个圆圈以不同方向和不同运动方式移动的视频。让我们来编写代码并生成 10,000 个视频,看看它的效果如何。

# Create a directory named 'training_dataset'
os.makedirs('training_dataset', exist_ok=True)# Define the number of videos to generate for the dataset
num_videos = 10000# Define the number of frames per video (1 Second Video)
frames_per_video = 10# Define the size of each image in the dataset
img_size = (64, 64)# Define the size of the shapes (Circle)
shape_size = 10 

设置一些基本参数后,接下来我们需要定义训练数据集的文本 prompt,并据此生成训练视频。

# Define text prompts and corresponding movements for circles
prompts_and_movements = [("circle moving down", "circle", "down"), # Move circle downward("circle moving left", "circle", "left"), # Move circle leftward("circle moving right", "circle", "right"), # Move circle rightward("circle moving diagonally up-right", "circle", "diagonal_up_right"), # Move circle diagonally up-right("circle moving diagonally down-left", "circle", "diagonal_down_left"), # Move circle diagonally down-left("circle moving diagonally up-left", "circle", "diagonal_up_left"), # Move circle diagonally up-left("circle moving diagonally down-right", "circle", "diagonal_down_right"), # Move circle diagonally down-right("circle rotating clockwise", "circle", "rotate_clockwise"), # Rotate circle clockwise("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"), # Rotate circle counter-clockwise("circle shrinking", "circle", "shrink"), # Shrink circle("circle expanding", "circle", "expand"), # Expand circle("circle bouncing vertically", "circle", "bounce_vertical"), # Bounce circle vertically("circle bouncing horizontally", "circle", "bounce_horizontal"), # Bounce circle horizontally("circle zigzagging vertically", "circle", "zigzag_vertical"), # Zigzag circle vertically("circle zigzagging horizontally", "circle", "zigzag_horizontal"), # Zigzag circle horizontally("circle moving up-left", "circle", "up_left"), # Move circle up-left("circle moving down-right", "circle", "down_right"), # Move circle down-right("circle moving down-left", "circle", "down_left"), # Move circle down-left
]

我们已经利用这些 prompt 定义了圆的几个运动轨迹。现在,我们需要编写一些数学公式,以便根据 prompt 移动圆。

# Define function with parameters
def create_image_with_moving_shape(size, frame_num, shape, direction):# Create a new RGB image with specified size and white backgroundimg = Image.new('RGB', size, color=(255, 255, 255)) # Create a drawing context for the imagedraw = ImageDraw.Draw(img) # Calculate the center coordinates of the imagecenter_x, center_y = size[0] // 2, size[1] // 2 # Initialize position with center for all movementsposition = (center_x, center_y) # Define a dictionary mapping directions to their respective position adjustments or image transformationsdirection_map = { # Adjust position downwards based on frame number"down": (0, frame_num * 5 % size[1]), # Adjust position to the left based on frame number"left": (-frame_num * 5 % size[0], 0), # Adjust position to the right based on frame number"right": (frame_num * 5 % size[0], 0), # Adjust position diagonally up and to the right"diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]), # Adjust position diagonally down and to the left"diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]), # Adjust position diagonally up and to the left"diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]), # Adjust position diagonally down and to the right"diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]), # Rotate the image clockwise based on frame number"rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)), # Rotate the image counter-clockwise based on frame number"rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)), # Adjust position for a bouncing effect vertically"bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)), # Adjust position for a bouncing effect horizontally"bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0), # Adjust position for a zigzag effect vertically"zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]), # Adjust position for a zigzag effect horizontally"zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y), # Adjust position upwards and to the right based on frame number"up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]), # Adjust position upwards and to the left based on frame number"up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]), # Adjust position downwards and to the right based on frame number"down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]), # Adjust position downwards and to the left based on frame number"down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]) }# Check if direction is in the direction mapif direction in direction_map: # Check if the direction maps to a position adjustmentif isinstance(direction_map[direction], tuple): # Update position based on the adjustmentposition = tuple(np.add(position, direction_map[direction])) else: # If the direction maps to an image transformation# Update the image based on the transformationimg = direction_map[direction] # Return the image as a numpy arrayreturn np.array(img)

上述函数用于根据所选方向在每一帧中移动我们的圆。我们只需在其上运行一个循环,直至生成所有视频的次数。

# Iterate over the number of videos to generate
for i in range(num_videos):# Randomly choose a prompt and movement from the predefined listprompt, shape, direction = random.choice(prompts_and_movements)# Create a directory for the current videovideo_dir = f'training_dataset/video_{i}'os.makedirs(video_dir, exist_ok=True)# Write the chosen prompt to a text file in the video directorywith open(f'{video_dir}/prompt.txt', 'w') as f:f.write(prompt)# Generate frames for the current videofor frame_num in range(frames_per_video):# Create an image with a moving shape based on the current frame number, shape, and directionimg = create_image_with_moving_shape(img_size, frame_num, shape, direction)# Save the generated image as a PNG file in the video directorycv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)

运行上述代码后,就会生成整个训练数据集。以下是训练数据集文件的结构。

图片

每个训练视频文件夹包含其帧以及对应的文本 prompt。让我们看一下我们的训练数据集样本。

在我们的训练数据集中,我们没有包含圆圈先向上移动然后向右移动的运动。我们将使用这个作为测试 prompt,来评估我们训练的模型在未见过的数据上的表现。

图片

还有一个重要的要点需要注意,我们的训练数据包含许多物体从场景中移出或部分出现在摄像机前方的样本,类似于我们在 OpenAI Sora 演示视频中观察到的情况。

图片

在我们的训练数据中包含此类样本的原因是为了测试当圆圈从角落进入场景时,模型是否能够保持一致性而不会破坏其形状。

现在我们的训练数据已经生成,需要将训练视频转换为张量,这是 PyTorch 等深度学习框架中使用的主要数据类型。此外,通过将数据缩放到较小的范围,执行归一化等转换有助于提高训练架构的收敛性和稳定性。

预处理训练数据

我们必须为文本转视频任务编写一个数据集类,它可以从训练数据集目录中读取视频帧及其相应的文本 prompt,使其可以在 PyTorch 中使用。

# Define a dataset class inheriting from torch.utils.data.Dataset
class TextToVideoDataset(Dataset):def __init__(self, root_dir, transform=None):# Initialize the dataset with root directory and optional transformself.root_dir = root_dirself.transform = transform# List all subdirectories in the root directoryself.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]# Initialize lists to store frame paths and corresponding promptsself.frame_paths = []self.prompts = []# Loop through each video directoryfor video_dir in self.video_dirs:# List all PNG files in the video directory and store their pathsframes = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]self.frame_paths.extend(frames)# Read the prompt text file in the video directory and store its contentwith open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:prompt = f.read().strip()# Repeat the prompt for each frame in the video and store in prompts listself.prompts.extend([prompt] * len(frames))# Return the total number of samples in the datasetdef __len__(self):return len(self.frame_paths)# Retrieve a sample from the dataset given an indexdef __getitem__(self, idx):# Get the path of the frame corresponding to the given indexframe_path = self.frame_paths[idx]# Open the image using PIL (Python Imaging Library)image = Image.open(frame_path)# Get the prompt corresponding to the given indexprompt = self.prompts[idx]# Apply transformation if specifiedif self.transform:image = self.transform(image)# Return the transformed image and the promptreturn image, prompt

在继续编写架构代码之前,我们需要对训练数据进行归一化处理。我们使用 16 的 batch 大小并对数据进行混洗以引入更多随机性。

实现文本嵌入层

你可能已经看到,在 Transformer 架构中,起点是将文本输入转换为嵌入,从而在多头注意力中进行进一步处理。类似地,我们在这里必须编写一个文本嵌入层。基于该层,GAN 架构训练在我们的嵌入数据和图像张量上进行。

# Define a class for text embedding
class TextEmbedding(nn.Module):# Constructor method with vocab_size and embed_size parametersdef __init__(self, vocab_size, embed_size):# Call the superclass constructorsuper(TextEmbedding, self).__init__()# Initialize embedding layerself.embedding = nn.Embedding(vocab_size, embed_size)# Define the forward pass methoddef forward(self, x):# Return embedded representation of inputreturn self.embedding(x) 

词汇量将基于我们的训练数据,在稍后进行计算。嵌入大小将为 10。如果使用更大的数据集,你还可以使用 Hugging Face 上已有的嵌入模型。

实现生成器层

现在我们已经知道生成器在 GAN 中的作用,接下来让我们对这一层进行编码,然后了解其内容。

class Generator(nn.Module):def __init__(self, text_embed_size):super(Generator, self).__init__()# Fully connected layer that takes noise and text embedding as inputself.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)# Transposed convolutional layers to upsample the inputself.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1) # Output has 3 channels for RGB images# Activation functionsself.relu = nn.ReLU(True) # ReLU activation functionself.tanh = nn.Tanh() # Tanh activation function for final outputdef forward(self, noise, text_embed):# Concatenate noise and text embedding along the channel dimensionx = torch.cat((noise, text_embed), dim=1)# Fully connected layer followed by reshaping to 4D tensorx = self.fc1(x).view(-1, 256, 8, 8)# Upsampling through transposed convolution layers with ReLU activationx = self.relu(self.deconv1(x))x = self.relu(self.deconv2(x))# Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)x = self.tanh(self.deconv3(x))return x

该 Generator 类负责根据随机噪声和文本嵌入的组合创建视频帧,旨在根据给定的文本描述生成逼真的视频帧。该网络从完全连接层 (nn.Linear) 开始,将噪声向量和文本嵌入组合成单个特征向量。然后,该向量被重新整形并经过一系列的转置卷积层 (nn.ConvTranspose2d),这些层将特征图逐步上采样到所需的视频帧大小。

这些层使用 ReLU 激活 (nn.ReLU) 实现非线性,最后一层使用 Tanh 激活 (nn.Tanh) 将输出缩放到 [-1, 1] 的范围。因此,生成器将抽象的高维输入转换为以视觉方式表示输入文本的连贯视频帧。

实现判别器层

在编写完生成器层之后,我们需要实现另一半,即判别器部分。

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# Convolutional layers to process input imagesself.conv1 = nn.Conv2d(3, 64, 4, 2, 1)   # 3 input channels (RGB), 64 output channels, kernel size 4x4, stride 2, padding 1self.conv2 = nn.Conv2d(64, 128, 4, 2, 1) # 64 input channels, 128 output channels, kernel size 4x4, stride 2, padding 1self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # 128 input channels, 256 output channels, kernel size 4x4, stride 2, padding 1# Fully connected layer for classificationself.fc1 = nn.Linear(256 * 8 * 8, 1)  # Input size 256x8x8 (output size of last convolution), output size 1 (binary classification)# Activation functionsself.leaky_relu = nn.LeakyReLU(0.2, inplace=True)  # Leaky ReLU activation with negative slope 0.2self.sigmoid = nn.Sigmoid()  # Sigmoid activation for final output (probability)def forward(self, input):# Pass input through convolutional layers with LeakyReLU activationx = self.leaky_relu(self.conv1(input))x = self.leaky_relu(self.conv2(x))x = self.leaky_relu(self.conv3(x))# Flatten the output of convolutional layersx = x.view(-1, 256 * 8 * 8)# Pass through fully connected layer with Sigmoid activation for binary classificationx = self.sigmoid(self.fc1(x))return x

判别器类用作二元分类器,区分真实视频帧和生成的视频帧。目的是评估视频帧的真实性,从而指导生成器产生更真实的输出。该网络由卷积层 (nn.Conv2d) 组成,这些卷积层从输入视频帧中提取分层特征, Leaky ReLU 激活 (nn.LeakyReLU) 增加非线性,同时允许负值的小梯度。

然后,特征图被展平并通过完全连接层 (nn.Linear),最终以 S 形激活 (nn.Sigmoid) 输出指示帧是真实还是假的概率分数。

通过训练判别器准确地对帧进行分类,生成器同时接受训练以创建更令人信服的视频帧,从而骗过判别器。

编写训练参数

我们必须设置用于训练 GAN 的基础组件,例如损失函数、优化器等。

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Create a simple vocabulary for text prompts
all_prompts = [prompt for prompt, _, _ in prompts_and_movements]  # Extract all prompts from prompts_and_movements list
vocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))}  # Create a vocabulary dictionary where each unique word is assigned an index
vocab_size = len(vocab)  # Size of the vocabulary
embed_size = 10  # Size of the text embedding vectordef encode_text(prompt):# Encode a given prompt into a tensor of indices using the vocabularyreturn torch.tensor([vocab[word] for word in prompt.split()])# Initialize models, loss function, and optimizers
text_embedding = TextEmbedding(vocab_size, embed_size).to(device)  # Initialize TextEmbedding model with vocab_size and embed_size
netG = Generator(embed_size).to(device)  # Initialize Generator model with embed_size
netD = Discriminator().to(device)  # Initialize Discriminator model
criterion = nn.BCELoss().to(device)  # Binary Cross Entropy loss function
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Discriminator
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Generator

这是我们必须转换代码以在 GPU 上运行的部分(如果可用)。我们已经编写了代码来查找 vocab_size,并且我们正在为生成器和判别器使用 ADAM 优化器。你可以选择自己的优化器。在这里,我们将学习率设置为较小的值 0.0002,嵌入大小为 10,这比其他可供公众使用的 Hugging Face 模型要小得多。

编写训练 loop

就像其他神经网络一样,我们将以类似的方式对 GAN 架构训练进行编码。

# Number of epochs
num_epochs = 13# Iterate over each epoch
for epoch in range(num_epochs):# Iterate over each batch of datafor i, (data, prompts) in enumerate(dataloader):# Move real data to devicereal_data = data.to(device)# Convert prompts to listprompts = [prompt for prompt in prompts]# Update DiscriminatornetD.zero_grad()  # Zero the gradients of the Discriminatorbatch_size = real_data.size(0)  # Get the batch sizelabels = torch.ones(batch_size, 1).to(device)  # Create labels for real data (ones)output = netD(real_data)  # Forward pass real data through DiscriminatorlossD_real = criterion(output, labels)  # Calculate loss on real datalossD_real.backward()  # Backward pass to calculate gradients# Generate fake datanoise = torch.randn(batch_size, 100).to(device)  # Generate random noisetext_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts])  # Encode prompts into text embeddingsfake_data = netG(noise, text_embeds)  # Generate fake data from noise and text embeddingslabels = torch.zeros(batch_size, 1).to(device)  # Create labels for fake data (zeros)output = netD(fake_data.detach())  # Forward pass fake data through Discriminator (detach to avoid gradients flowing back to Generator)lossD_fake = criterion(output, labels)  # Calculate loss on fake datalossD_fake.backward()  # Backward pass to calculate gradientsoptimizerD.step()  # Update Discriminator parameters# Update GeneratornetG.zero_grad()  # Zero the gradients of the Generatorlabels = torch.ones(batch_size, 1).to(device)  # Create labels for fake data (ones) to fool Discriminatoroutput = netD(fake_data)  # Forward pass fake data (now updated) through DiscriminatorlossG = criterion(output, labels)  # Calculate loss for Generator based on Discriminator's responselossG.backward()  # Backward pass to calculate gradientsoptimizerG.step()  # Update Generator parameters# Print epoch informationprint(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}")

通过反向传播,我们的损失将针对生成器和判别器进行调整。我们在训练 loop 中使用了 13 个 epoch。我们测试了不同的值,但如果 epoch 高于这个值,结果并没有太大差异。此外,过度拟合的风险很高。如果我们的数据集更加多样化,包含更多动作和形状,则可以考虑使用更高的 epoch,但在这里没有这样做。

当我们运行此代码时,它会开始训练,并在每个 epoch 之后 print 生成器和判别器的损失。

## OUTPUT ##Epoch [1/13] Loss D: 0.8798642754554749, Loss G: 1.300612449645996
Epoch [2/13] Loss D: 0.8235711455345154, Loss G: 1.3729925155639648
Epoch [3/13] Loss D: 0.6098687052726746, Loss G: 1.3266581296920776...

保存训练的模型

训练完成后,我们需要保存训练好的 GAN 架构的判别器和生成器,这只需两行代码即可实现。

# Save the Generator model's state dictionary to a file named 'generator.pth'
torch.save(netG.state_dict(), 'generator.pth')# Save the Discriminator model's state dictionary to a file named 'discriminator.pth'
torch.save(netD.state_dict(), 'discriminator.pth')

生成 AI 视频

正如我们所讨论的,我们在未见过的数据上测试模型的方法与我们训练数据中涉及狗取球和猫追老鼠的示例类似。因此,我们的测试 prompt 可能涉及猫取球或狗追老鼠等场景。

在我们的特定情况下,圆圈向上移动然后向右移动的运动在训练数据中不存在,因此模型不熟悉这种特定运动。但是,模型已经在其他动作上进行了训练。我们可以使用此动作作为 prompt 来测试我们训练过的模型并观察其性能。

# Inference function to generate a video based on a given text promptdef generate_video(text_prompt, num_frames=10):    # Create a directory for the generated video frames based on the text prompt    os.makedirs(f'generated_video_{text_prompt.replace(" ", "_")}', exist_ok=True)        # Encode the text prompt into a text embedding tensor    text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0).unsqueeze(0)        # Generate frames for the video    for frame_num in range(num_frames):        # Generate random noise        noise = torch.randn(1, 100).to(device)                # Generate a fake frame using the Generator network        with torch.no_grad():            fake_frame = netG(noise, text_embed)                # Save the generated fake frame as an image file        save_image(fake_frame, f'generated_video_{text_prompt.replace(" ", "_")}/frame_{frame_num}.png')# usage of the generate_video function with a specific text promptgenerate_video('circle moving up-right')

当我们运行上述代码时,它将生成一个目录,其中包含我们生成视频的所有帧。我们需要使用一些代码将所有这些帧合并为一个短视频。

# Define the path to your folder containing the PNG frames
folder_path = 'generated_video_circle_moving_up-right'# Get the list of all PNG files in the folder
image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]# Sort the images by name (assuming they are numbered sequentially)
image_files.sort()# Create a list to store the frames
frames = []# Read each image and append it to the frames list
for image_file in image_files:image_path = os.path.join(folder_path, image_file)frame = cv2.imread(image_path)frames.append(frame)# Convert the frames list to a numpy array for easier processing
frames = np.array(frames)# Define the frame rate (frames per second)
fps = 10# Create a video writer object
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('generated_video.avi', fourcc, fps, (frames[0].shape[1], frames[0].shape[0]))# Write each frame to the video
for frame in frames:out.write(frame)# Release the video writer
out.release()

确保文件夹路径指向你新生成的视频所在的位置。运行此代码后,你将成功创建 AI 视频。让我们看看它是什么样子。

图片

我们进行了多次训练,训练次数相同。在两种情况下,圆圈都是从底部开始,出现一半。好消息是,我们的模型在两种情况下都尝试执行直立运动。

例如,在尝试 1 中,圆圈沿对角线向上移动,然后执行向上运动,而在尝试 2 中,圆圈沿对角线移动,同时尺寸缩小。在两种情况下,圆圈都没有向左移动或完全消失,这是一个好兆头。

最后,作者表示已经测试了该架构的各个方面,发现训练数据是关键。通过在数据集中包含更多动作和形状,你可以增加可变性并提高模型的性能。由于数据是通过代码生成的,因此生成更多样的数据不会花费太多时间;相反,你可以专注于完善逻辑。

此外,文章中讨论的 GAN 架构相对简单。你可以通过集成高级技术或使用语言模型嵌入 (LLM) 而不是基本神经网络嵌入来使其更复杂。此外,调整嵌入大小等参数会显著影响模型的有效性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/38114.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言力扣刷题1——最长回文字串[双指针]

力扣算题1——最长回文字串[双指针] 一、博客声明二、题目描述三、解题思路1、思路说明2、知识补充a、malloc动态内存分配b、free释放内存c、strlen求字符数组长度d、strncpy函数 四、解题代码(附注释) 一、博客声明 找工作逃不过刷题,为了更…

Docker配置远程连接

前置条件:docker所在的服务器开放2375端口 文件:/usr/lib/systemd/system/docker.service 节点ExecStart 追加 -H tcp://0.0.0.0:2375

智慧校园变革之路:全平台综合概述与最佳实践

在当今信息化浪潮的推动下,"智慧校园"作为教育创新的前沿阵地,正逐步揭开其神秘面纱,引领一场前所未有的教育转型革命。它远超过单纯技术叠加的传统框架,而是深度融合云计算、大数据、物联网等前沿科技,精心…

【计算机毕业设计】基于Springboot的智能物流管理系统【源码+lw+部署文档】

包含论文源码的压缩包较大,请私信或者加我的绿色小软件获取 免责声明:资料部分来源于合法的互联网渠道收集和整理,部分自己学习积累成果,供大家学习参考与交流。收取的费用仅用于收集和整理资料耗费时间的酬劳。 本人尊重原创作者…

【Mac】Auto Mouse Click for Mac(高效、稳定的鼠标连点器软件)软件介绍

软件介绍 Auto Mouse Click for Mac 是一款专为 macOS 平台设计的自动鼠标点击软件,它可以帮助用户自动化重复的鼠标点击操作,从而提高工作效率。以下是这款软件的主要特点和功能: 1.自动化点击操作:Auto Mouse Click 允许用户录…

等保测评过程中会用到哪些工具或服务

等保测评工具和服务概述 等保测评,即信息安全等级保护测评,是一项重要的信息安全工作,它涉及信息系统的技术和管理两个方面,包括物理安全、网络安全、主机安全、应用安全、数据安全等多个维度。等保测评工具和服务是企业进行信息…

神经网络实战2-损失函数和反向传播

其实就是通过求偏导的方式,求出各个权重大小 loss函数是找最小值的,要求导,在计算机里面计算导数是倒着来的,所以叫反向传播。 import torch from torch.nn import L1Lossinputstorch.tensor([1,2,3],dtypetorch.float32) targe…

使用Llama3/Qwen2等开源大模型,部署团队私有化Code Copilot和使用教程

目前市面上有不少基于大模型的 Code Copilot 产品,部分产品对于个人开发者来说可免费使用,比如阿里的通义灵码、百度的文心快码等。这些免费的产品均通过 API 的方式提供服务,因此调用时均必须联网、同时需要把代码、提示词等内容作为 API 的…

面了英伟达算法岗,被疯狂拷打。。。

节前,我们组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、今年参加社招和校招面试的同学。 针对大模型技术趋势、算法项目落地经验分享、新手如何入门算法岗、该如何准备面试攻略、面试常考点等热门话题进行了深入的讨论。 总结链接如…

Python逻辑控制语句 之 循环语句--for循环

1.for 的介绍 for 循环 也称为是 for 遍历, 也可以做指定次数的循环遍历: 是从容器中将数据逐个取出的过程.容器: 字符串/列表/元组/字典 2.for 的语法 (1)for 循环遍历字符串 for 变量 in 字符串: 重复执⾏的代码 字符串中存在多少个字符, 代码就执行…

YOLOv8 的简介 及C#中如何简单应用YOLOv8

YOLOv8 是 YOLO(You Only Look Once)系列中的最新版本,是一种用于目标检测和图像分割的深度学习模型。YOLO模型以其快速和准确的目标检测性能而著称,广泛应用于实时应用程序中。 主要特点 高效性:YOLOv8 在保持高检测…

【HBZ分享】如何实现系统的高可用

如何实现系统的高可用? 高可用架构设计 采用集群,多机房,多副本,负载均衡,热备份等手段,确保系统冗余 与 可恢复能力,避免单点故障。服务容错 与 降级 面对故障时,通常采用合适的…

解决java中时间参数的问题

在java的日常开发中,我们经常需要去接收前端传递过来的时间字符串,同时给前端返回数据时,也会涉及到时间字段的数据传递,那么我们要如何处理这些字段呢? 知识铺垫:java最后返回的时间是时间世界&#xff0…

鲁工小装载机-前后桥传动轴油封更换记录

鲁工装载机 因前后桥大量漏齿轮油,故拆开查看、更换油封 一: 如图圈起来的地方是螺丝和钢板相别,用200的焊接电流用电焊机点开一个豁口后拆除螺丝。 转轴是拆除传动轴后的样子。 这就是拆下来的样子,这玩意插上边那图&…

Python 3 字符串

Python 3 字符串 字符串在Python中是一种基本的数据类型,用于存储文本数据。Python 3中的字符串是由Unicode字符组成的序列,这使得它可以轻松地处理多种语言的文本。在本文中,我们将深入探讨Python 3中字符串的各个方面,包括创建字符串、字符串操作、格式化和常见的方法。…

F12开发者工具怎么用(小白版)

F12开发者工具(也称为浏览器开发者工具)是现代浏览器(如Chrome、Firefox、Edge等)内置的工具集,主要用于网页开发和调试。以下是使用这些工具的一些基本指南: 打开开发者工具 按下 F12 键,或者…

探索Scala在大数据开发中的高级功能

目录 2. Scala的语言特性 2.1 静态类型和类型推断 2.2 面向对象与函数式编程 3. 高级集合操作 3.1 不可变集合 3.2 高阶函数 4. 并发与并行处理 4.1 Future与Promise 4.2 Akka Actor模型 5. Spark与Scala的结合 5.1 RDD和DataFrame 5.2 Spark SQL与数据处理 6. 高…

八爪鱼现金流-033,升级日志,里程碑4

2024年6月30日15:48:46 v-4.0.0 定时任务发送邮件提醒功能开发: 发送邮箱定时任务。提醒月报记账. 工资日 5号 15号 25号 晚上17:30发送 里程碑版本4完成。 八爪鱼现金流 八爪鱼

【论文阅读】A Survey on Large Language Model based Autonomous Agents

文章目录 1 大语言模型的构建1.1分析模块 profiling module1.2 记忆模块 memory module1.2.1 记忆结构1.2.2 记忆形式1.2.3 记忆运行 1.3 规划模块 planning module1.3.1 无反馈规划1.3.2 有反馈计划 1.4 执行模块 action module1.4.1 执行目标1.4.2 执行空间 2 Agent能力提升2…

深度剖析:前端如何驾驭海量数据,实现流畅渲染的多种途径

文章目录 一、分批渲染1、setTimeout定时器分批渲染2、使用requestAnimationFrame()改进渲染2.1、什么是requestAnimationFrame2.2、为什么使用requestAnimationFrame而不是setTimeout或setInterval2.3、requestAnimationFrame的优势和适用场景 二、滚动触底加载数据三、Elemen…