YAML 文件提供了一种清晰、简洁且易于理解的方式来描述配置信息,特别适用于机器学习模型的超参数调优和实验管理。
以 Latent Diffusion 官方代码仓库中的 https://github.com/CompVis/latent-diffusion/blob/main/configs/autoencoder/autoencoder_kl_32x32x4.yaml 为例(如下),该 YAML 配置文件,用于定义训练一个自编码器模型的设置,其中包含 3 个部分:
- model (AutoencoderKL的模型结构)
- data(DataModuleFromConfig中如何读入数据)
- lightning(设置回调函数和训练器)
model:
base_learning_rate: 4.5e-6
target: ldm.models.autoencoder.AutoencoderKL
params:
monitor: "val/rec_loss"
embed_dim: 4
lossconfig:
target: ldm.modules.losses.LPIPSWithDiscriminator
params:
disc_start: 50001
kl_weight: 0.000001
disc_weight: 0.5
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
data:
target: main.DataModuleFromConfig
params:
batch_size: 12
wrap: True
train:
target: ldm.data.imagenet.ImageNetSRTrain
params:
size: 256
degradation: pil_nearest
validation:
target: ldm.data.imagenet.ImageNetSRValidation
params:
size: 256
degradation: pil_nearest
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 1000
max_images: 8
increase_log_steps: True
trainer:
benchmark: True
accumulate_grad_batches: 2
Model
base_learning_rate: 4.5e-6
: 这是基础学习率,用于优化器的初始化。学习率表示在每次参数更新时,参数被调整的程度。target: ldm.models.autoencoder.AutoencoderKL
: 这是要训练的模型的类路径,即模型定义代码所在的位置。params
: 这里是模型的参数设置。monitor: "val/rec_loss"
: 监控的指标,通常是验证集上的重构损失。embed_dim: 4
: 嵌入维度,可能是自编码器中隐藏层的维度。lossconfig
: 损失函数的配置。-
target: ldm.modules.losses.LPIPSWithDiscriminator
: LPIPS损失所在位置。 -
params
: 参数设置。disc_start: 50001
: 鉴别器开始的步数。kl_weight: 0.000001
: KL散度的权重。disc_weight: 0.5
: 鉴别器权重。
-
ddconfig
: 双向变换的配置。double_z: True
: 是否使用双向Z变换。- 其他参数是有关双向变换网络结构的设置,包括通道数量、分辨率、残差块数量等。
Data
target: main.DataModuleFromConfig
: 数据模块的类路径。params
: 数据加载器的参数设置。batch_size: 12
: 批量大小,即每次迭代训练时传递给模型的样本数量。wrap: True
: 是否循环迭代数据。train
: 训练数据的设置。target: ldm.data.imagenet.ImageNetSRTrain
: 训练集加载器的类路径。params
: 参数设置。size: 256
: 数据的大小。degradation: pil_nearest
: 图像降质方法。
validation
: 验证集的设置。target: ldm.data.imagenet.ImageNetSRValidation
: 验证数据加载器的类路径。params
: 参数设置,与训练数据类似。
Lightning
callbacks
: 回调函数的设置。image_logger
: 图像记录器的设置。target: main.ImageLogger
: 图像记录器的类路径。params
: 参数设置。batch_frequency: 1000
: 记录图像的频率。max_images: 8
: 最大图像数量。increase_log_steps: True
: 是否逐步增加日志步骤。
trainer
: 训练器设置。benchmark: True
: 是否启用性能测试。accumulate_grad_batches: 2
: 梯度累积的步骤数量,用于处理较大的批次大小。