轻量投影/MLP BF16
psnr指标反而比只量化扩散主干模型要低,原因不明
This commit is contained in:
@@ -289,13 +289,15 @@ def image_guided_synthesis(model: torch.nn.Module,
|
|||||||
if not text_input:
|
if not text_input:
|
||||||
prompts = [""] * batch_size
|
prompts = [""] * batch_size
|
||||||
|
|
||||||
b, c, t, h, w = videos.shape
|
b, c, t, h, w = videos.shape
|
||||||
img = videos[:, :, 0]
|
img = videos[:, :, 0]
|
||||||
img_emb = model.embedder(img)
|
img_emb = model.embedder(img)
|
||||||
img_emb = model.image_proj_model(img_emb)
|
cond_emb = model.get_learned_conditioning(prompts)
|
||||||
img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t)
|
target_dtype = cond_emb.dtype
|
||||||
cond_emb = model.get_learned_conditioning(prompts)
|
img_emb = model._projector_forward(model.image_proj_model, img_emb,
|
||||||
cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0)
|
target_dtype)
|
||||||
|
img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t)
|
||||||
|
cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0)
|
||||||
|
|
||||||
cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]}
|
cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]}
|
||||||
if model.model.conditioning_key == 'hybrid':
|
if model.model.conditioning_key == 'hybrid':
|
||||||
|
|||||||
@@ -168,26 +168,33 @@ def image_guided_synthesis(
|
|||||||
batch_size = noise_shape[0]
|
batch_size = noise_shape[0]
|
||||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||||
|
|
||||||
img = observation['observation.images.top']
|
img = observation['observation.images.top']
|
||||||
cond_img = img[:, -1, ...]
|
cond_img = img[:, -1, ...]
|
||||||
cond_img_emb = model.embedder(cond_img)
|
cond_img_emb = model.embedder(cond_img)
|
||||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
|
||||||
|
if model.model.conditioning_key == 'hybrid':
|
||||||
if model.model.conditioning_key == 'hybrid':
|
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
img_cat_cond = z[:, :, -1:, :, :]
|
||||||
img_cat_cond = z[:, :, -1:, :, :]
|
|
||||||
img_cat_cond = repeat(img_cat_cond,
|
img_cat_cond = repeat(img_cat_cond,
|
||||||
'b c t h w -> b c (repeat t) h w',
|
'b c t h w -> b c (repeat t) h w',
|
||||||
repeat=noise_shape[2])
|
repeat=noise_shape[2])
|
||||||
cond = {"c_concat": [img_cat_cond]}
|
cond = {"c_concat": [img_cat_cond]}
|
||||||
|
|
||||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||||
cond_state = model.state_projector(observation['observation.state'])
|
target_dtype = cond_ins_emb.dtype
|
||||||
cond_state_emb = model.agent_state_pos_emb + cond_state
|
cond_img_emb = model._projector_forward(model.image_proj_model,
|
||||||
|
cond_img_emb, target_dtype)
|
||||||
cond_action = model.action_projector(observation['action'])
|
cond_state = model._projector_forward(model.state_projector,
|
||||||
cond_action_emb = model.agent_action_pos_emb + cond_action
|
observation['observation.state'],
|
||||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
target_dtype)
|
||||||
|
cond_state_emb = model.agent_state_pos_emb.to(dtype=target_dtype) + cond_state
|
||||||
|
|
||||||
|
cond_action = model._projector_forward(model.action_projector,
|
||||||
|
observation['action'],
|
||||||
|
target_dtype)
|
||||||
|
cond_action_emb = model.agent_action_pos_emb.to(
|
||||||
|
dtype=target_dtype) + cond_action
|
||||||
|
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||||
|
|
||||||
cond["c_crossattn"] = [
|
cond["c_crossattn"] = [
|
||||||
torch.cat([cond_state_emb, cond_ins_emb, cond_img_emb], dim=1)
|
torch.cat([cond_state_emb, cond_ins_emb, cond_img_emb], dim=1)
|
||||||
|
|||||||
@@ -770,28 +770,36 @@ def image_guided_synthesis_sim_mode(
|
|||||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||||
|
|
||||||
with profiler.profile_section("synthesis/conditioning_prep"):
|
with profiler.profile_section("synthesis/conditioning_prep"):
|
||||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||||
cond_img_emb = model.embedder(cond_img)
|
cond_img_emb = model.embedder(cond_img)
|
||||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
|
||||||
|
if model.model.conditioning_key == 'hybrid':
|
||||||
if model.model.conditioning_key == 'hybrid':
|
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
img_cat_cond = z[:, :, -1:, :, :]
|
||||||
img_cat_cond = z[:, :, -1:, :, :]
|
img_cat_cond = repeat(img_cat_cond,
|
||||||
img_cat_cond = repeat(img_cat_cond,
|
|
||||||
'b c t h w -> b c (repeat t) h w',
|
'b c t h w -> b c (repeat t) h w',
|
||||||
repeat=noise_shape[2])
|
repeat=noise_shape[2])
|
||||||
cond = {"c_concat": [img_cat_cond]}
|
cond = {"c_concat": [img_cat_cond]}
|
||||||
|
|
||||||
if not text_input:
|
if not text_input:
|
||||||
prompts = [""] * batch_size
|
prompts = [""] * batch_size
|
||||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||||
|
target_dtype = cond_ins_emb.dtype
|
||||||
cond_state_emb = model.state_projector(observation['observation.state'])
|
|
||||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
cond_img_emb = model._projector_forward(model.image_proj_model,
|
||||||
|
cond_img_emb, target_dtype)
|
||||||
cond_action_emb = model.action_projector(observation['action'])
|
|
||||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
cond_state_emb = model._projector_forward(
|
||||||
|
model.state_projector, observation['observation.state'],
|
||||||
|
target_dtype)
|
||||||
|
cond_state_emb = cond_state_emb + model.agent_state_pos_emb.to(
|
||||||
|
dtype=target_dtype)
|
||||||
|
|
||||||
|
cond_action_emb = model._projector_forward(
|
||||||
|
model.action_projector, observation['action'], target_dtype)
|
||||||
|
cond_action_emb = cond_action_emb + model.agent_action_pos_emb.to(
|
||||||
|
dtype=target_dtype)
|
||||||
|
|
||||||
if not sim_mode:
|
if not sim_mode:
|
||||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||||
|
|||||||
@@ -1882,6 +1882,7 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
dp_use_ema: bool = False,
|
dp_use_ema: bool = False,
|
||||||
pretrained_checkpoint: str | None = None,
|
pretrained_checkpoint: str | None = None,
|
||||||
decision_making_only: bool = True,
|
decision_making_only: bool = True,
|
||||||
|
projector_bf16: bool = True,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -1907,6 +1908,7 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
dp_use_ema: If True, maintain EMA for action UNet head.
|
dp_use_ema: If True, maintain EMA for action UNet head.
|
||||||
pretrained_checkpoint: Optional path to a pretrained checkpoint.
|
pretrained_checkpoint: Optional path to a pretrained checkpoint.
|
||||||
decision_making_only: If True, use decision-only augmentation path.
|
decision_making_only: If True, use decision-only augmentation path.
|
||||||
|
projector_bf16: If True, run image/state/action projectors under BF16 autocast.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@@ -1917,6 +1919,7 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
self.n_obs_steps_imagen = n_obs_steps_imagen
|
self.n_obs_steps_imagen = n_obs_steps_imagen
|
||||||
self.n_obs_steps_acting = n_obs_steps_acting
|
self.n_obs_steps_acting = n_obs_steps_acting
|
||||||
self.decision_making_only = decision_making_only
|
self.decision_making_only = decision_making_only
|
||||||
|
self.projector_bf16 = projector_bf16
|
||||||
|
|
||||||
self._init_embedder(img_cond_stage_config, freeze_embedder)
|
self._init_embedder(img_cond_stage_config, freeze_embedder)
|
||||||
self._init_img_ctx_projector(image_proj_stage_config,
|
self._init_img_ctx_projector(image_proj_stage_config,
|
||||||
@@ -2025,6 +2028,28 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
self.agent_state_pos_emb = nn.Parameter(
|
self.agent_state_pos_emb = nn.Parameter(
|
||||||
torch.randn(1, self.n_obs_steps_imagen, self.global_emb_dim))
|
torch.randn(1, self.n_obs_steps_imagen, self.global_emb_dim))
|
||||||
|
|
||||||
|
def _projector_forward(self, projector: nn.Module, x: Tensor,
|
||||||
|
target_dtype: torch.dtype | None) -> Tensor:
|
||||||
|
use_bf16 = (self.projector_bf16 and x.device.type == "cuda"
|
||||||
|
and torch.cuda.is_bf16_supported())
|
||||||
|
if use_bf16:
|
||||||
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||||
|
out = projector(x)
|
||||||
|
else:
|
||||||
|
out = projector(x)
|
||||||
|
if not hasattr(self, "_printed_projector_bf16"):
|
||||||
|
print(
|
||||||
|
">>> projector bf16 autocast: "
|
||||||
|
f"enabled={self.projector_bf16} "
|
||||||
|
f"use_bf16={use_bf16} "
|
||||||
|
f"input={x.dtype} "
|
||||||
|
f"output={out.dtype} "
|
||||||
|
f"target={target_dtype}")
|
||||||
|
self._printed_projector_bf16 = True
|
||||||
|
if target_dtype is not None and out.dtype != target_dtype:
|
||||||
|
out = out.to(dtype=target_dtype)
|
||||||
|
return out
|
||||||
|
|
||||||
def _get_augmented_batch(
|
def _get_augmented_batch(
|
||||||
self,
|
self,
|
||||||
z: Tensor,
|
z: Tensor,
|
||||||
@@ -2166,6 +2191,7 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
null_prompt = self.get_learned_conditioning([""])
|
null_prompt = self.get_learned_conditioning([""])
|
||||||
cond_ins_emb = torch.where(prompt_mask, null_prompt,
|
cond_ins_emb = torch.where(prompt_mask, null_prompt,
|
||||||
cond_ins_emb.detach())
|
cond_ins_emb.detach())
|
||||||
|
target_dtype = cond_ins_emb.dtype
|
||||||
|
|
||||||
# Get conditioning frames
|
# Get conditioning frames
|
||||||
cond_frame_index = 0
|
cond_frame_index = 0
|
||||||
@@ -2176,7 +2202,8 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
|
|
||||||
cond_img = input_mask * img
|
cond_img = input_mask * img
|
||||||
cond_img_emb = self.embedder(cond_img)
|
cond_img_emb = self.embedder(cond_img)
|
||||||
cond_img_emb = self.image_proj_model(cond_img_emb)
|
cond_img_emb = self._projector_forward(self.image_proj_model,
|
||||||
|
cond_img_emb, target_dtype)
|
||||||
|
|
||||||
if self.model.conditioning_key == 'hybrid':
|
if self.model.conditioning_key == 'hybrid':
|
||||||
if self.interp_mode:
|
if self.interp_mode:
|
||||||
@@ -2191,11 +2218,15 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
repeat=z.shape[2])
|
repeat=z.shape[2])
|
||||||
cond["c_concat"] = [img_cat_cond]
|
cond["c_concat"] = [img_cat_cond]
|
||||||
|
|
||||||
cond_action = self.action_projector(action)
|
cond_action = self._projector_forward(self.action_projector, action,
|
||||||
cond_action_emb = self.agent_action_pos_emb + cond_action
|
target_dtype)
|
||||||
|
cond_action_emb = self.agent_action_pos_emb.to(
|
||||||
|
dtype=target_dtype) + cond_action
|
||||||
# Get conditioning states
|
# Get conditioning states
|
||||||
cond_state = self.state_projector(obs_state)
|
cond_state = self._projector_forward(self.state_projector, obs_state,
|
||||||
cond_state_emb = self.agent_state_pos_emb + cond_state
|
target_dtype)
|
||||||
|
cond_state_emb = self.agent_state_pos_emb.to(
|
||||||
|
dtype=target_dtype) + cond_state
|
||||||
|
|
||||||
if self.decision_making_only:
|
if self.decision_making_only:
|
||||||
is_sim_mode = False
|
is_sim_mode = False
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user