修复混合精度vae相关的配置错误,确保在推理阶段正确使用了混合精度模型,并且导出了正确精度的检查点文件。

This commit is contained in:
2026-02-08 12:35:59 +00:00
parent e6c55a648c
commit e588182642
5 changed files with 178 additions and 25 deletions

View File

@@ -1105,6 +1105,10 @@ class LatentDiffusion(DDPM):
else:
reshape_back = False
# Align input dtype with VAE weights (e.g. fp32 samples → bf16 VAE)
vae_dtype = next(self.first_stage_model.parameters()).dtype
z = z.to(dtype=vae_dtype)
if not self.perframe_ae:
z = 1. / self.scale_factor * z
results = self.first_stage_model.decode(z, **kwargs)
@@ -2457,7 +2461,6 @@ class DiffusionWrapper(pl.LightningModule):
Returns:
Output from the inner diffusion model (tensor or tuple, depending on the model).
"""
if self.conditioning_key is None:
out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat':