修复混合精度vae相关的配置错误,确保在推理阶段正确使用了混合精度模型,并且导出了正确精度的检查点文件。
This commit is contained in:
@@ -1105,6 +1105,10 @@ class LatentDiffusion(DDPM):
|
||||
else:
|
||||
reshape_back = False
|
||||
|
||||
# Align input dtype with VAE weights (e.g. fp32 samples → bf16 VAE)
|
||||
vae_dtype = next(self.first_stage_model.parameters()).dtype
|
||||
z = z.to(dtype=vae_dtype)
|
||||
|
||||
if not self.perframe_ae:
|
||||
z = 1. / self.scale_factor * z
|
||||
results = self.first_stage_model.decode(z, **kwargs)
|
||||
@@ -2457,7 +2461,6 @@ class DiffusionWrapper(pl.LightningModule):
|
||||
Returns:
|
||||
Output from the inner diffusion model (tensor or tuple, depending on the model).
|
||||
"""
|
||||
|
||||
if self.conditioning_key is None:
|
||||
out = self.diffusion_model(x, t)
|
||||
elif self.conditioning_key == 'concat':
|
||||
|
||||
Reference in New Issue
Block a user