主干部分fp16,最敏感psnr=25.21,可以考虑对主干部分太敏感的部分回退fp32
This commit is contained in:
@@ -571,6 +571,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
torch.save(model, prepared_path)
|
||||
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
||||
|
||||
# ---- BF16: only convert the diffusion backbone, keep VAE/CLIP/embedder in FP32 ----
|
||||
model.model.to(torch.bfloat16)
|
||||
model.model.diffusion_model.dtype = torch.bfloat16
|
||||
print(">>> Diffusion backbone (model.model) converted to BF16.")
|
||||
|
||||
# Build normalizer (always needed, independent of model loading path)
|
||||
logging.info("***** Configing Data *****")
|
||||
data = instantiate_from_config(config.data)
|
||||
|
||||
Reference in New Issue
Block a user