速度变化不明显psnr显著提升

This commit is contained in:
qhy
2026-02-11 16:38:21 +08:00
parent f386a5810b
commit 3101252c25
4 changed files with 58 additions and 37 deletions

View File

@@ -450,8 +450,9 @@ def image_guided_synthesis_sim_mode(
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
with torch.cuda.amp.autocast(dtype=torch.float16):
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
@@ -465,11 +466,12 @@ def image_guided_synthesis_sim_mode(
prompts = [""] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts)
cond_state_emb = model.state_projector(observation['observation.state'])
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
with torch.cuda.amp.autocast(dtype=torch.float16):
cond_state_emb = model.state_projector(observation['observation.state'])
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(observation['action'])
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
cond_action_emb = model.action_projector(observation['action'])
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb)
@@ -571,11 +573,22 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
torch.save(model, prepared_path)
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
# ---- BF16: only convert the diffusion backbone, keep VAE/CLIP/embedder in FP32 ----
# ---- FP16: convert diffusion backbone + conditioning modules ----
model.model.to(torch.float16)
model.model.diffusion_model.dtype = torch.float16
print(">>> Diffusion backbone (model.model) converted to FP16.")
# Projectors / MLP → FP16
model.image_proj_model.half()
model.state_projector.half()
model.action_projector.half()
print(">>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.")
# Text/image encoders → FP16
model.cond_stage_model.half()
model.embedder.half()
print(">>> Encoders (cond_stage_model, embedder) converted to FP16.")
# Build normalizer (always needed, independent of model loading path)
logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data)