对扩散主干做 BF16
量化对象:model.model(扩散 UNet/WMAModel 主体)
This commit is contained in:
@@ -714,22 +714,23 @@ def preprocess_observation(
|
||||
return return_observations
|
||||
|
||||
|
||||
def image_guided_synthesis_sim_mode(
|
||||
model: torch.nn.Module,
|
||||
prompts: list[str],
|
||||
observation: dict,
|
||||
noise_shape: tuple[int, int, int, int, int],
|
||||
action_cond_step: int = 16,
|
||||
n_samples: int = 1,
|
||||
ddim_steps: int = 50,
|
||||
ddim_eta: float = 1.0,
|
||||
unconditional_guidance_scale: float = 1.0,
|
||||
fs: int | None = None,
|
||||
text_input: bool = True,
|
||||
timestep_spacing: str = 'uniform',
|
||||
guidance_rescale: float = 0.0,
|
||||
sim_mode: bool = True,
|
||||
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def image_guided_synthesis_sim_mode(
|
||||
model: torch.nn.Module,
|
||||
prompts: list[str],
|
||||
observation: dict,
|
||||
noise_shape: tuple[int, int, int, int, int],
|
||||
action_cond_step: int = 16,
|
||||
n_samples: int = 1,
|
||||
ddim_steps: int = 50,
|
||||
ddim_eta: float = 1.0,
|
||||
unconditional_guidance_scale: float = 1.0,
|
||||
fs: int | None = None,
|
||||
text_input: bool = True,
|
||||
timestep_spacing: str = 'uniform',
|
||||
guidance_rescale: float = 0.0,
|
||||
sim_mode: bool = True,
|
||||
diffusion_autocast_dtype: Optional[torch.dtype] = None,
|
||||
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
|
||||
|
||||
@@ -750,9 +751,10 @@ def image_guided_synthesis_sim_mode(
|
||||
fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None.
|
||||
text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True.
|
||||
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
|
||||
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
||||
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
||||
**kwargs: Additional arguments passed to the DDIM sampler.
|
||||
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
||||
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
||||
diffusion_autocast_dtype (Optional[torch.dtype]): Autocast dtype for diffusion sampling (e.g., torch.bfloat16).
|
||||
**kwargs: Additional arguments passed to the DDIM sampler.
|
||||
|
||||
Returns:
|
||||
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
|
||||
@@ -810,31 +812,37 @@ def image_guided_synthesis_sim_mode(
|
||||
uc = None
|
||||
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
||||
cond_mask = None
|
||||
cond_z0 = None
|
||||
|
||||
if ddim_sampler is not None:
|
||||
with profiler.profile_section("synthesis/ddim_sampling"):
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=batch_size,
|
||||
shape=noise_shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
cfg_img=None,
|
||||
mask=cond_mask,
|
||||
x0=cond_z0,
|
||||
fs=fs,
|
||||
timestep_spacing=timestep_spacing,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
with profiler.profile_section("synthesis/decode_first_stage"):
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
cond_z0 = None
|
||||
|
||||
if ddim_sampler is not None:
|
||||
with profiler.profile_section("synthesis/ddim_sampling"):
|
||||
autocast_ctx = nullcontext()
|
||||
if diffusion_autocast_dtype is not None and model.device.type == "cuda":
|
||||
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
|
||||
with autocast_ctx:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=batch_size,
|
||||
shape=noise_shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
cfg_img=None,
|
||||
mask=cond_mask,
|
||||
x0=cond_z0,
|
||||
fs=fs,
|
||||
timestep_spacing=timestep_spacing,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
with profiler.profile_section("synthesis/decode_first_stage"):
|
||||
if samples.dtype != torch.float32:
|
||||
samples = samples.float()
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
|
||||
return batch_variants, actions, states
|
||||
|
||||
@@ -889,6 +897,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
model = model.cuda(gpu_no)
|
||||
device = get_device_from_parameters(model)
|
||||
|
||||
diffusion_autocast_dtype = None
|
||||
if args.diffusion_dtype == "bf16":
|
||||
with profiler.profile_section("model_loading/diffusion_bf16"):
|
||||
model.model.to(dtype=torch.bfloat16)
|
||||
diffusion_autocast_dtype = torch.bfloat16
|
||||
print(">>> diffusion backbone set to bfloat16")
|
||||
|
||||
log_inference_precision(model)
|
||||
|
||||
profiler.record_memory("after_model_load")
|
||||
@@ -1014,20 +1029,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# Use world-model in policy to generate action
|
||||
print(f'>>> Step {itr}: generating actions ...')
|
||||
with profiler.profile_section("action_generation"):
|
||||
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
sample['instruction'],
|
||||
observation,
|
||||
noise_shape,
|
||||
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
sample['instruction'],
|
||||
observation,
|
||||
noise_shape,
|
||||
action_cond_step=args.exe_steps,
|
||||
ddim_steps=args.ddim_steps,
|
||||
ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.
|
||||
unconditional_guidance_scale,
|
||||
fs=model_input_fs,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False)
|
||||
unconditional_guidance_scale=args.
|
||||
unconditional_guidance_scale,
|
||||
fs=model_input_fs,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False,
|
||||
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
||||
|
||||
# Update future actions in the observation queues
|
||||
with profiler.profile_section("update_action_queues"):
|
||||
@@ -1058,20 +1074,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# Interaction with the world-model
|
||||
print(f'>>> Step {itr}: interacting with world model ...')
|
||||
with profiler.profile_section("world_model_interaction"):
|
||||
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
"",
|
||||
observation,
|
||||
noise_shape,
|
||||
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
"",
|
||||
observation,
|
||||
noise_shape,
|
||||
action_cond_step=args.exe_steps,
|
||||
ddim_steps=args.ddim_steps,
|
||||
ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.
|
||||
unconditional_guidance_scale,
|
||||
fs=model_input_fs,
|
||||
text_input=False,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale)
|
||||
unconditional_guidance_scale,
|
||||
fs=model_input_fs,
|
||||
text_input=False,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
||||
|
||||
with profiler.profile_section("update_state_queues"):
|
||||
for step_idx in range(args.exe_steps):
|
||||
@@ -1216,13 +1233,20 @@ def get_parser():
|
||||
help=
|
||||
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--perframe_ae",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=
|
||||
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--perframe_ae",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=
|
||||
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--diffusion_dtype",
|
||||
type=str,
|
||||
choices=["fp32", "bf16"],
|
||||
default="fp32",
|
||||
help="Dtype for diffusion backbone weights and sampling autocast."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_action_steps",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user