修复混合精度vae相关的配置错误,确保在推理阶段正确使用了混合精度模型,并且导出了正确精度的检查点文件。
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import argparse, os, glob
|
||||
from contextlib import nullcontext
|
||||
import pandas as pd
|
||||
import random
|
||||
import torch
|
||||
@@ -38,6 +39,68 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
return next(iter(module.parameters())).device
|
||||
|
||||
|
||||
def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.Module:
|
||||
"""Apply precision settings to model components based on command-line arguments.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to apply precision settings to.
|
||||
args (argparse.Namespace): Parsed command-line arguments containing precision settings.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model with precision settings applied.
|
||||
"""
|
||||
print(f">>> Applying precision settings:")
|
||||
print(f" - Diffusion dtype: {args.diffusion_dtype}")
|
||||
print(f" - Projector mode: {args.projector_mode}")
|
||||
print(f" - Encoder mode: {args.encoder_mode}")
|
||||
print(f" - VAE dtype: {args.vae_dtype}")
|
||||
|
||||
# 1. Set Diffusion backbone precision
|
||||
if args.diffusion_dtype == "bf16":
|
||||
# Convert diffusion model weights to bf16
|
||||
model.model.to(torch.bfloat16)
|
||||
model.diffusion_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Diffusion model weights converted to bfloat16")
|
||||
else:
|
||||
model.diffusion_autocast_dtype = None
|
||||
print(" ✓ Diffusion model using fp32")
|
||||
|
||||
# 2. Set Projector precision
|
||||
if args.projector_mode == "bf16_full":
|
||||
model.state_projector.to(torch.bfloat16)
|
||||
model.action_projector.to(torch.bfloat16)
|
||||
model.projector_autocast_dtype = None
|
||||
print(" ✓ Projectors converted to bfloat16")
|
||||
elif args.projector_mode == "autocast":
|
||||
model.projector_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Projectors will use autocast (weights fp32, compute bf16)")
|
||||
else:
|
||||
model.projector_autocast_dtype = None
|
||||
# fp32 mode: do nothing, keep original precision
|
||||
|
||||
# 3. Set Encoder precision
|
||||
if args.encoder_mode == "bf16_full":
|
||||
model.embedder.to(torch.bfloat16)
|
||||
model.image_proj_model.to(torch.bfloat16)
|
||||
model.encoder_autocast_dtype = None
|
||||
print(" ✓ Encoders converted to bfloat16")
|
||||
elif args.encoder_mode == "autocast":
|
||||
model.encoder_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Encoders will use autocast (weights fp32, compute bf16)")
|
||||
else:
|
||||
model.encoder_autocast_dtype = None
|
||||
# fp32 mode: do nothing, keep original precision
|
||||
|
||||
# 4. Set VAE precision
|
||||
if args.vae_dtype == "bf16":
|
||||
model.first_stage_model.to(torch.bfloat16)
|
||||
print(" ✓ VAE converted to bfloat16")
|
||||
else:
|
||||
print(" ✓ VAE kept in fp32 for best quality")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
||||
"""Save a list of frames to a video file.
|
||||
|
||||
@@ -262,6 +325,11 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
|
||||
"""
|
||||
b, c, t, h, w = videos.shape
|
||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||
|
||||
# Auto-detect VAE dtype and convert input
|
||||
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
||||
x = x.to(dtype=vae_dtype)
|
||||
|
||||
z = model.encode_first_stage(x)
|
||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
return z
|
||||
@@ -363,10 +431,22 @@ def image_guided_synthesis_sim_mode(
|
||||
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
|
||||
# Auto-detect model dtype and convert inputs accordingly
|
||||
model_dtype = next(model.embedder.parameters()).dtype
|
||||
|
||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
|
||||
|
||||
# Encoder autocast: weights stay fp32, compute in bf16
|
||||
enc_ac_dtype = getattr(model, 'encoder_autocast_dtype', None)
|
||||
if enc_ac_dtype is not None and model.device.type == 'cuda':
|
||||
enc_ctx = torch.autocast('cuda', dtype=enc_ac_dtype)
|
||||
else:
|
||||
enc_ctx = nullcontext()
|
||||
|
||||
with enc_ctx:
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||
@@ -380,11 +460,22 @@ def image_guided_synthesis_sim_mode(
|
||||
prompts = [""] * batch_size
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||
|
||||
cond_state_emb = model.state_projector(observation['observation.state'])
|
||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||
# Auto-detect projector dtype and convert inputs
|
||||
projector_dtype = next(model.state_projector.parameters()).dtype
|
||||
|
||||
cond_action_emb = model.action_projector(observation['action'])
|
||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||
# Projector autocast: weights stay fp32, compute in bf16
|
||||
proj_ac_dtype = getattr(model, 'projector_autocast_dtype', None)
|
||||
if proj_ac_dtype is not None and model.device.type == 'cuda':
|
||||
proj_ctx = torch.autocast('cuda', dtype=proj_ac_dtype)
|
||||
else:
|
||||
proj_ctx = nullcontext()
|
||||
|
||||
with proj_ctx:
|
||||
cond_state_emb = model.state_projector(observation['observation.state'].to(dtype=projector_dtype))
|
||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||
|
||||
cond_action_emb = model.action_projector(observation['action'].to(dtype=projector_dtype))
|
||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||
|
||||
if not sim_mode:
|
||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||
@@ -406,8 +497,17 @@ def image_guided_synthesis_sim_mode(
|
||||
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
||||
cond_mask = None
|
||||
cond_z0 = None
|
||||
|
||||
# Setup autocast context for diffusion sampling
|
||||
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
|
||||
if autocast_dtype is not None and model.device.type == 'cuda':
|
||||
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
|
||||
else:
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
if ddim_sampler is not None:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
with autocast_ctx:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=batch_size,
|
||||
@@ -464,6 +564,17 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
model.eval()
|
||||
print(f'>>> Load pre-trained model ...')
|
||||
|
||||
# Apply precision settings before moving to GPU
|
||||
model = apply_precision_settings(model, args)
|
||||
|
||||
# Export precision-converted checkpoint if requested
|
||||
if args.export_precision_ckpt:
|
||||
export_path = args.export_precision_ckpt
|
||||
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
|
||||
torch.save({"state_dict": model.state_dict()}, export_path)
|
||||
print(f">>> Precision-converted checkpoint saved to: {export_path}")
|
||||
return
|
||||
|
||||
# Build unnomalizer
|
||||
logging.info("***** Configing Data *****")
|
||||
data = instantiate_from_config(config.data)
|
||||
@@ -798,6 +909,35 @@ def get_parser():
|
||||
type=int,
|
||||
default=8,
|
||||
help="fps for the saving video")
|
||||
parser.add_argument(
|
||||
"--diffusion_dtype",
|
||||
type=str,
|
||||
choices=["fp32", "bf16"],
|
||||
default="bf16",
|
||||
help="Diffusion backbone precision (fp32/bf16)")
|
||||
parser.add_argument(
|
||||
"--projector_mode",
|
||||
type=str,
|
||||
choices=["fp32", "autocast", "bf16_full"],
|
||||
default="bf16_full",
|
||||
help="Projector precision mode (fp32/autocast/bf16_full)")
|
||||
parser.add_argument(
|
||||
"--encoder_mode",
|
||||
type=str,
|
||||
choices=["fp32", "autocast", "bf16_full"],
|
||||
default="bf16_full",
|
||||
help="Encoder precision mode (fp32/autocast/bf16_full)")
|
||||
parser.add_argument(
|
||||
"--vae_dtype",
|
||||
type=str,
|
||||
choices=["fp32", "bf16"],
|
||||
default="fp32",
|
||||
help="VAE precision (fp32/bf16, most affects image quality)")
|
||||
parser.add_argument(
|
||||
"--export_precision_ckpt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Export precision-converted checkpoint to this path, then exit.")
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user