VAE 也做 BF16
这个权重不做修改更好精度
This commit is contained in:
@@ -2032,6 +2032,13 @@ class LatentVisualDiffusion(LatentDiffusion):
|
||||
target_dtype: torch.dtype | None) -> Tensor:
|
||||
use_bf16 = (self.projector_bf16 and x.device.type == "cuda"
|
||||
and torch.cuda.is_bf16_supported())
|
||||
if not use_bf16:
|
||||
weight_dtype = None
|
||||
for param in projector.parameters():
|
||||
weight_dtype = param.dtype
|
||||
break
|
||||
if weight_dtype is not None and x.dtype != weight_dtype:
|
||||
x = x.to(dtype=weight_dtype)
|
||||
if use_bf16:
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
out = projector(x)
|
||||
|
||||
Reference in New Issue
Block a user