轻量投影/MLP BF16

psnr指标反而比只量化扩散主干模型要低,原因不明
This commit is contained in:
2026-01-18 18:26:37 +08:00
parent 2b634cde90
commit 3c0f409fcf
6 changed files with 96 additions and 48 deletions

View File

@@ -770,28 +770,36 @@ def image_guided_synthesis_sim_mode(
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
with profiler.profile_section("synthesis/conditioning_prep"):
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
cond_img_emb = model.embedder(cond_img)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=noise_shape[2])
cond = {"c_concat": [img_cat_cond]}
if not text_input:
prompts = [""] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts)
cond_state_emb = model.state_projector(observation['observation.state'])
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(observation['action'])
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
if not text_input:
prompts = [""] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts)
target_dtype = cond_ins_emb.dtype
cond_img_emb = model._projector_forward(model.image_proj_model,
cond_img_emb, target_dtype)
cond_state_emb = model._projector_forward(
model.state_projector, observation['observation.state'],
target_dtype)
cond_state_emb = cond_state_emb + model.agent_state_pos_emb.to(
dtype=target_dtype)
cond_action_emb = model._projector_forward(
model.action_projector, observation['action'], target_dtype)
cond_action_emb = cond_action_emb + model.agent_action_pos_emb.to(
dtype=target_dtype)
if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb)