VAE 也做 BF16

这个权重不做修改更好精度
This commit is contained in:
2026-01-18 21:14:55 +08:00
parent e1b029201e
commit a90efc6718
6 changed files with 67 additions and 16 deletions

View File

@@ -2032,6 +2032,13 @@ class LatentVisualDiffusion(LatentDiffusion):
target_dtype: torch.dtype | None) -> Tensor:
use_bf16 = (self.projector_bf16 and x.device.type == "cuda"
and torch.cuda.is_bf16_supported())
if not use_bf16:
weight_dtype = None
for param in projector.parameters():
weight_dtype = param.dtype
break
if weight_dtype is not None and x.dtype != weight_dtype:
x = x.to(dtype=weight_dtype)
if use_bf16:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
out = projector(x)