OmegaConf
OmegaConf
是一个用于处理配置文件和命令行参数的Python库,它支持YAML和JSON格式的配置文件。OmegaConf
提供了一些高级功能,如配置合并、类型安全的配置访问、环境变量插值等。OmegaConf.load()
是这个库中的一个函数,用于加载和解析配置文件或字典对象。
当您使用 OmegaConf.load()
函数时,可以提供一个表示配置文件路径的字符串或一个Python字典。如果提供的是文件路径,OmegaConf.load()
会读取并解析该文件,将其内容转换为一个 OmegaConf
配置对象。如果提供的是字典对象,OmegaConf.load()
会将该字典转换为一个 OmegaConf
配置对象。
以下是一个简单的 OmegaConf.load()
示例:
from omegaconf import OmegaConf# 从 YAML 文件加载配置
config = OmegaConf.load("config.yaml")# 从 Python 字典加载配置
config_dict = {"param1": "value1", "param2": "value2"}
config = OmegaConf.load(config_dict)# 访问配置参数
param1 = config.param1
param2 = config.param2
在这个示例中,我们首先从一个名为 config.yaml
的YAML文件加载配置,然后从一个Python字典加载配置。加载后,我们可以使用点符号(.
)轻松访问配置中的参数。
通过使用 OmegaConf
,您可以更方便地管理和访问配置文件中的信息,从而简化应用程序的配置管理。
当使用 OmegaConf.load()
从 YAML 文件加载配置时,得到的 config
对象是一个 OmegaConf 容器(如 DictConfig
或 ListConfig
),而不是普通的 Python 字典。这些容器是 OmegaConf 库的特殊数据结构,它们为访问和操作配置数据提供了更高级的功能和类型安全性。
尽管 DictConfig
和 ListConfig
不是普通的 Python 字典或列表,但它们的使用方式与字典和列表非常类似。例如,您可以使用点符号(.
)或方括号([]
)来访问 DictConfig
中的元素。以下是一个从 YAML 文件加载配置并访问参数的示例:
# config.yaml
param1: value1
param2:- item1- item2
from omegaconf import OmegaConf# 从 YAML 文件加载配置
config = OmegaConf.load("config.yaml")# 访问配置参数
param1 = config.param1 # 或者 config["param1"]
param2_item1 = config.param2[0] # 或者 config["param2"][0]
尽管 OmegaConf 容器的使用方式类似于字典和列表,但它们提供了一些额外的功能,如类型安全访问、默认值、环境变量插值等。如果需要将 OmegaConf 容器转换为普通的 Python 字典,可以使用 OmegaConf.to_container()
函数:
config_dict = OmegaConf.to_container(config)
这样,config_dict
就是一个普通的 Python 字典,可以按照正常的字典操作进行访问和处理。
Instantiate_from_config
def instantiate_from_config(config):if not "target" in config:if config == '__is_first_stage__':return Noneelif config == "__is_unconditional__":return Noneraise KeyError("Expected key `target` to instantiate.")return get_obj_from_str(config["target"])(**config.get("params", dict()))
这段代码定义了一个名为 instantiate_from_config
的函数,它接受一个 config
参数。这个函数的主要目的是根据提供的配置信息,动态地实例化一个对象。具体来说,代码执行以下操作:
- 检查
config
中是否包含 “target” 键。如果不包含,执行以下操作:- 如果
config
等于 ‘is_first_stage’,则返回None
。 - 如果
config
等于 ‘is_unconditional’,则返回None
。 - 如果
config
中没有 “target” 键,抛出一个KeyError
异常,提示需要 “target” 键来实例化对象。
- 如果
- 如果配置包含 “target” 键,则使用
get_obj_from_str
函数根据 “target” 键的值获取一个类或函数对象。get_obj_from_str
函数的实现没有在这段代码中给出,但它通常会根据提供的全限定名(包括模块名和类/函数名)导入相应的对象。 - 使用
config.get("params", dict())
获取 “params” 键的值作为参数,如果 “params” 键不存在,则使用一个空字典。随后,通过解包参数字典(使用**
运算符)并调用从 “target” 键获取的对象,实例化该对象。
下面是一个简单的示例来说明一下instantiate_from_config:
# 假设我们有一个名为 my_class.py 的文件,其中包含一个名为 MyClass 的类
# my_class.py
class MyClass:def __init__(self, param1, param2):self.param1 = param1self.param2 = param2# 假设我们有一个名为 main.py 的文件,其中调用 instantiate_from_config 函数
# main.py
config = {"target": "my_class.MyClass","params": {"param1": "value1","param2": "value2"}
}obj = instantiate_from_config(config)
# 此时,obj 是一个 MyClass 的实例,使用 "value1" 和 "value2" 作为其构造函数参数
总之,instantiate_from_config
函数根据提供的配置信息(包括目标类/函数的全限定名和参数字典)动态地实例化一个对象。
return get_obj_from_str(config["target"])(**config.get("params", dict()))
中:
get_obj_from_str(config["target"])
用于根据传入的全限定名(config["target"]
)获取一个类或函数对象。然后,在获取到的类或函数对象后面加上括号 ()
,表示我们要实例化这个类(调用构造函数)或调用这个函数。这里的 **config.get("params", dict())
是将配置中的 “params” 键的值作为参数传递给类构造函数或函数。
ControlLDM ( LatentDiffusion ( DDPM))
DDPM: first_stage_key 最基础的扩散模型,一张原图,加噪,然后去噪,预测噪声,学习生成的过程
LatentDiffusion: cond_stage_key 加入了条件,比如ldm中的最右侧的各种prompt:txt, voice, img…
ControlLDM: control_key 加入了hint, 也就是controlNet中的control ,加入了control stage config
-
ControlUnetModel
-
ControlNet : 把原来的LDM的unet的encoder和mid_block部分加入zero_convolution
-
ControlLDM:
model = ControlLDM
装饰器
@torch.no_grad()
是一个 PyTorch 装饰器,用于指定在一段代码中关闭梯度计算。在 PyTorch 中,张量(Tensor)的计算通常会自动跟踪和记录计算图(computational graph),以便在反向传播(backpropagation)过程中计算梯度。然而,在某些情况下,我们不需要计算梯度,例如在模型评估和推理阶段。
使用 @torch.no_grad()
装饰器可以在特定的函数或代码块中关闭梯度计算,这有助于节省内存并提高计算效率。当你不需要更新模型参数(如在验证和测试阶段)时,这是一个非常有用的功能。下面是一个简单的例子:
@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs):x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)control = batch[self.control_key]if bs is not None:control = control[:bs]control = control.to(self.device)control = einops.rearrange(control, 'b h w c -> b c h w')control = control.to(memory_format=torch.contiguous_format).float()return x, dict(c_crossattn=[c], c_concat=[control])def apply_model(self, x_noisy, t, cond, *args, **kwargs):assert isinstance(cond, dict)diffusion_model = self.model.diffusion_modelcond_txt = torch.cat(cond['c_crossattn'], 1)if cond['c_concat'] is None:eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)else:control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)control = [c * scale for c, scale in zip(control, self.control_scales)]eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)return eps
在上面提供的代码片段中,apply_model
函数没有使用 @torch.no_grad()
装饰器。因此,在调用 apply_model
函数时,PyTorch 将正常跟踪和计算梯度。只有使用了 @torch.no_grad()
装饰器的 get_input
函数才会关闭梯度计算。
如果你希望在 apply_model
函数中也不计算梯度,可以在函数定义前添加 @torch.no_grad()
装饰器
@torch.no_grad()
装饰器仅直接应用于它所装饰的函数。在这个函数内部,所有涉及梯度计算的操作都将被禁用。然而,如果这个函数调用了其他函数,@torch.no_grad()
会影响到调用的函数。也就是说,被调用的函数中的梯度计算也会被禁用。让我们看一个例子来说明这一点:
import torch# 定义一个简单的模型
class MyModel(torch.nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear = torch.nn.Linear(10, 5)def forward(self, x):return self.linear(x)def another_function(model, input_tensor):return model(input_tensor)@torch.no_grad()
def inference(model, input_tensor):model.eval()output = another_function(model, input_tensor)return outputmodel = MyModel()
input_tensor = torch.randn(1, 10)# 使用推理函数进行推理
output = inference(model, input_tensor)
print(output)
在这个例子中,我们定义了一个名为 another_function
的额外函数。inference
函数(带有 @torch.no_grad()
装饰器)调用了 another_function
。虽然 another_function
没有直接使用 @torch.no_grad()
装饰器,但在 inference
函数的上下文中,梯度计算仍然被禁用。因此,在这种情况下,被调用的 another_function
中的梯度计算也被禁用了。
self.control_model = instantiate_from_config(control_stage_config)
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
hint.shape=[B,C,H,W], context: condition feature map
control_stage_config—》ControlNet
ControlNet就是把LDM的Unet的encoder和mid_block加上zero_conv的结构
timestep_embedding是把timesteps编码为维度为dim的张量