轻量投影/MLP BF16
psnr指标反而比只量化扩散主干模型要低,原因不明
This commit is contained in:
@@ -289,13 +289,15 @@ def image_guided_synthesis(model: torch.nn.Module,
|
||||
if not text_input:
|
||||
prompts = [""] * batch_size
|
||||
|
||||
b, c, t, h, w = videos.shape
|
||||
img = videos[:, :, 0]
|
||||
img_emb = model.embedder(img)
|
||||
img_emb = model.image_proj_model(img_emb)
|
||||
img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t)
|
||||
cond_emb = model.get_learned_conditioning(prompts)
|
||||
cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0)
|
||||
b, c, t, h, w = videos.shape
|
||||
img = videos[:, :, 0]
|
||||
img_emb = model.embedder(img)
|
||||
cond_emb = model.get_learned_conditioning(prompts)
|
||||
target_dtype = cond_emb.dtype
|
||||
img_emb = model._projector_forward(model.image_proj_model, img_emb,
|
||||
target_dtype)
|
||||
img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t)
|
||||
cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]}
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
|
||||
@@ -168,26 +168,33 @@ def image_guided_synthesis(
|
||||
batch_size = noise_shape[0]
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
|
||||
img = observation['observation.images.top']
|
||||
cond_img = img[:, -1, ...]
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||
img_cat_cond = z[:, :, -1:, :, :]
|
||||
img = observation['observation.images.top']
|
||||
cond_img = img[:, -1, ...]
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||
img_cat_cond = z[:, :, -1:, :, :]
|
||||
img_cat_cond = repeat(img_cat_cond,
|
||||
'b c t h w -> b c (repeat t) h w',
|
||||
repeat=noise_shape[2])
|
||||
cond = {"c_concat": [img_cat_cond]}
|
||||
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||
cond_state = model.state_projector(observation['observation.state'])
|
||||
cond_state_emb = model.agent_state_pos_emb + cond_state
|
||||
|
||||
cond_action = model.action_projector(observation['action'])
|
||||
cond_action_emb = model.agent_action_pos_emb + cond_action
|
||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||
cond = {"c_concat": [img_cat_cond]}
|
||||
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||
target_dtype = cond_ins_emb.dtype
|
||||
cond_img_emb = model._projector_forward(model.image_proj_model,
|
||||
cond_img_emb, target_dtype)
|
||||
cond_state = model._projector_forward(model.state_projector,
|
||||
observation['observation.state'],
|
||||
target_dtype)
|
||||
cond_state_emb = model.agent_state_pos_emb.to(dtype=target_dtype) + cond_state
|
||||
|
||||
cond_action = model._projector_forward(model.action_projector,
|
||||
observation['action'],
|
||||
target_dtype)
|
||||
cond_action_emb = model.agent_action_pos_emb.to(
|
||||
dtype=target_dtype) + cond_action
|
||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||
|
||||
cond["c_crossattn"] = [
|
||||
torch.cat([cond_state_emb, cond_ins_emb, cond_img_emb], dim=1)
|
||||
|
||||
@@ -770,28 +770,36 @@ def image_guided_synthesis_sim_mode(
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
|
||||
with profiler.profile_section("synthesis/conditioning_prep"):
|
||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||
img_cat_cond = z[:, :, -1:, :, :]
|
||||
img_cat_cond = repeat(img_cat_cond,
|
||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||
img_cat_cond = z[:, :, -1:, :, :]
|
||||
img_cat_cond = repeat(img_cat_cond,
|
||||
'b c t h w -> b c (repeat t) h w',
|
||||
repeat=noise_shape[2])
|
||||
cond = {"c_concat": [img_cat_cond]}
|
||||
|
||||
if not text_input:
|
||||
prompts = [""] * batch_size
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||
|
||||
cond_state_emb = model.state_projector(observation['observation.state'])
|
||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||
|
||||
cond_action_emb = model.action_projector(observation['action'])
|
||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||
|
||||
if not text_input:
|
||||
prompts = [""] * batch_size
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||
target_dtype = cond_ins_emb.dtype
|
||||
|
||||
cond_img_emb = model._projector_forward(model.image_proj_model,
|
||||
cond_img_emb, target_dtype)
|
||||
|
||||
cond_state_emb = model._projector_forward(
|
||||
model.state_projector, observation['observation.state'],
|
||||
target_dtype)
|
||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb.to(
|
||||
dtype=target_dtype)
|
||||
|
||||
cond_action_emb = model._projector_forward(
|
||||
model.action_projector, observation['action'], target_dtype)
|
||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb.to(
|
||||
dtype=target_dtype)
|
||||
|
||||
if not sim_mode:
|
||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||
|
||||
Reference in New Issue
Block a user