多卡流水导出
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import time
|
||||
|
||||
from torch import Tensor
|
||||
from functools import partial
|
||||
@@ -686,6 +685,37 @@ class WMAModel(nn.Module):
|
||||
self.action_token_projector = instantiate_from_config(
|
||||
stem_process_config)
|
||||
|
||||
# Context precomputation cache
|
||||
self._ctx_cache_enabled = False
|
||||
self._ctx_cache = {}
|
||||
self._trt_backbone = None # TRT engine for video UNet backbone
|
||||
# Reusable CUDA stream for parallel state_unet / action_unet
|
||||
self._state_stream = torch.cuda.Stream()
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state.pop('_state_stream', None)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
if not hasattr(self, '_ctx_cache_enabled'):
|
||||
self._ctx_cache_enabled = False
|
||||
if not hasattr(self, '_ctx_cache'):
|
||||
self._ctx_cache = {}
|
||||
if not hasattr(self, '_trt_backbone'):
|
||||
self._trt_backbone = None
|
||||
self._state_stream = torch.cuda.Stream()
|
||||
|
||||
def load_trt_backbone(self, engine_path, n_hs_a=9):
|
||||
"""Load a TensorRT engine for the video UNet backbone."""
|
||||
from unifolm_wma.trt_utils import TRTBackbone
|
||||
device = next(self.parameters()).device
|
||||
self._trt_backbone = TRTBackbone(engine_path,
|
||||
n_hs_a=n_hs_a,
|
||||
device=device)
|
||||
print(f">>> TRT backbone loaded from {engine_path} on {device}")
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
x_action: Tensor,
|
||||
@@ -714,80 +744,70 @@ class WMAModel(nn.Module):
|
||||
Tuple of Tensors for predictions:
|
||||
|
||||
"""
|
||||
|
||||
b, _, t, _, _ = x.shape
|
||||
run_head = kwargs.pop("run_head", True)
|
||||
backbone_block_profiler = kwargs.pop("backbone_block_profiler", None)
|
||||
backbone_step_index = kwargs.pop("backbone_step_index", None)
|
||||
backbone_reuse_blocks = kwargs.pop("backbone_reuse_blocks", None)
|
||||
backbone_reuse_start_step = kwargs.pop("backbone_reuse_start_step",
|
||||
None)
|
||||
backbone_reuse_schedule_steps = kwargs.pop(
|
||||
"backbone_reuse_schedule_steps", None)
|
||||
backbone_reuse_force_compute_steps = kwargs.pop(
|
||||
"backbone_reuse_force_compute_steps", None)
|
||||
backbone_reuse_mode = kwargs.pop("backbone_reuse_mode", "disabled")
|
||||
backbone_reuse_cache = kwargs.pop("backbone_reuse_cache", None)
|
||||
backbone_reuse_step_stats = kwargs.pop("backbone_reuse_step_stats",
|
||||
None)
|
||||
backbone_reuse_branch = kwargs.pop("backbone_reuse_branch", "single")
|
||||
t_emb = timestep_embedding(timesteps,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
bt, l_context, _ = context.shape
|
||||
if self.base_model_gen_only:
|
||||
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
||||
_ctx_key = context.data_ptr()
|
||||
if self._ctx_cache_enabled and _ctx_key in self._ctx_cache:
|
||||
context = self._ctx_cache[_ctx_key]
|
||||
else:
|
||||
if l_context == self.n_obs_steps + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
|
||||
77, :]
|
||||
context_img = context[:, self.n_obs_steps + 77:, :]
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context = torch.cat(
|
||||
[context_agent_state, context_text, context_img], dim=1)
|
||||
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_agent_action = context[:, self.
|
||||
n_obs_steps:self.n_obs_steps +
|
||||
16, :]
|
||||
context_agent_action = rearrange(
|
||||
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
|
||||
context_agent_action = self.action_token_projector(
|
||||
context_agent_action)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'(b o) l d -> b o l d',
|
||||
o=t)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b o (t l) d -> b o t l d',
|
||||
t=t)
|
||||
context_agent_action = context_agent_action.permute(
|
||||
0, 2, 1, 3, 4)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b t o l d -> (b t) (o l) d')
|
||||
bt, l_context, _ = context.shape
|
||||
if self.base_model_gen_only:
|
||||
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
||||
else:
|
||||
if l_context == self.n_obs_steps + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
|
||||
77, :]
|
||||
context_img = context[:, self.n_obs_steps + 77:, :]
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context = torch.cat(
|
||||
[context_agent_state, context_text, context_img], dim=1)
|
||||
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_agent_action = context[:, self.
|
||||
n_obs_steps:self.n_obs_steps +
|
||||
16, :]
|
||||
context_agent_action = rearrange(
|
||||
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
|
||||
context_agent_action = self.action_token_projector(
|
||||
context_agent_action)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'(b o) l d -> b o l d',
|
||||
o=t)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b o (t l) d -> b o t l d',
|
||||
t=t)
|
||||
context_agent_action = context_agent_action.permute(
|
||||
0, 2, 1, 3, 4)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b t o l d -> (b t) (o l) d')
|
||||
|
||||
context_text = context[:, self.n_obs_steps +
|
||||
16:self.n_obs_steps + 16 + 77, :]
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_text = context[:, self.n_obs_steps +
|
||||
16:self.n_obs_steps + 16 + 77, :]
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context = torch.cat([
|
||||
context_agent_state, context_agent_action, context_text,
|
||||
context_img
|
||||
],
|
||||
dim=1)
|
||||
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context = torch.cat([
|
||||
context_agent_state, context_agent_action, context_text,
|
||||
context_img
|
||||
],
|
||||
dim=1)
|
||||
if self._ctx_cache_enabled:
|
||||
self._ctx_cache[_ctx_key] = context
|
||||
|
||||
emb = emb.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
@@ -807,150 +827,95 @@ class WMAModel(nn.Module):
|
||||
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
||||
emb = emb + fs_embed
|
||||
|
||||
def run_block_with_profile(block_name: str, block_stage: str,
|
||||
block_index: int | None,
|
||||
fn: Callable[[], Tensor]) -> Tensor:
|
||||
if backbone_block_profiler is None or backbone_step_index is None:
|
||||
return fn()
|
||||
if x.device.type == "cuda":
|
||||
torch.cuda.synchronize(x.device)
|
||||
start_time = time.perf_counter()
|
||||
out = fn()
|
||||
if x.device.type == "cuda":
|
||||
torch.cuda.synchronize(x.device)
|
||||
backbone_block_profiler.record_block(
|
||||
step=int(backbone_step_index),
|
||||
block_name=block_name,
|
||||
block_stage=block_stage,
|
||||
block_index=block_index,
|
||||
output=out,
|
||||
forward_time_ms=(time.perf_counter() - start_time) * 1000.0,
|
||||
)
|
||||
return out
|
||||
|
||||
reuse_cache_branch: Dict[str, Tensor] | None = None
|
||||
if backbone_reuse_cache is not None:
|
||||
reuse_cache_branch = backbone_reuse_cache.setdefault(
|
||||
backbone_reuse_branch, {})
|
||||
|
||||
def should_reuse_output_block(block_name: str) -> bool:
|
||||
if backbone_reuse_mode != "reuse_output":
|
||||
return False
|
||||
if backbone_step_index is None or backbone_reuse_start_step is None:
|
||||
return False
|
||||
if backbone_reuse_blocks is None or block_name not in backbone_reuse_blocks:
|
||||
return False
|
||||
if int(backbone_step_index) < int(backbone_reuse_start_step):
|
||||
return False
|
||||
if (backbone_reuse_force_compute_steps is not None
|
||||
and int(backbone_step_index)
|
||||
in backbone_reuse_force_compute_steps):
|
||||
return False
|
||||
if (backbone_reuse_schedule_steps is not None
|
||||
and int(backbone_step_index)
|
||||
in backbone_reuse_schedule_steps):
|
||||
return False
|
||||
if reuse_cache_branch is None:
|
||||
return False
|
||||
return block_name in reuse_cache_branch
|
||||
|
||||
h = x.type(self.dtype)
|
||||
adapter_idx = 0
|
||||
hs = []
|
||||
hs_a = []
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
def run_input_block() -> Tensor:
|
||||
block_out = module(h, emb, context=context, batch_size=b)
|
||||
if self._trt_backbone is not None:
|
||||
# TRT path: run backbone via TensorRT engine
|
||||
h_in = x.type(self.dtype).contiguous()
|
||||
y, hs_a = self._trt_backbone(h_in, emb.contiguous(), context.contiguous())
|
||||
else:
|
||||
# PyTorch path: original backbone
|
||||
h = x.type(self.dtype)
|
||||
adapter_idx = 0
|
||||
hs = []
|
||||
hs_a = []
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
if id == 0 and self.addition_attention:
|
||||
block_out = self.init_attn(block_out,
|
||||
emb,
|
||||
context=context,
|
||||
batch_size=b)
|
||||
return block_out
|
||||
h = self.init_attn(h, emb, context=context, batch_size=b)
|
||||
# plug-in adapter features
|
||||
if ((id + 1) % 3 == 0) and features_adapter is not None:
|
||||
h = h + features_adapter[adapter_idx]
|
||||
adapter_idx += 1
|
||||
if id != 0:
|
||||
if isinstance(module[0], Downsample):
|
||||
hs_a.append(
|
||||
rearrange(hs[-1], '(b t) c h w -> b t c h w', b=b))
|
||||
hs.append(h)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', b=b))
|
||||
|
||||
h = run_block_with_profile(
|
||||
block_name=f"input_{id}",
|
||||
block_stage="input_blocks",
|
||||
block_index=id,
|
||||
fn=run_input_block,
|
||||
)
|
||||
# plug-in adapter features
|
||||
if ((id + 1) % 3 == 0) and features_adapter is not None:
|
||||
h = h + features_adapter[adapter_idx]
|
||||
adapter_idx += 1
|
||||
if id != 0:
|
||||
if isinstance(module[0], Downsample):
|
||||
if features_adapter is not None:
|
||||
assert len(
|
||||
features_adapter) == adapter_idx, 'Wrong features_adapter'
|
||||
h = self.middle_block(h, emb, context=context, batch_size=b)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', b=b))
|
||||
|
||||
hs_out = []
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
if isinstance(module[-1], Upsample):
|
||||
hs_a.append(
|
||||
rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
hs.append(h)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', b=b))
|
||||
hs_out.append(h)
|
||||
h = h.type(x.dtype)
|
||||
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', b=b))
|
||||
|
||||
if features_adapter is not None:
|
||||
assert len(
|
||||
features_adapter) == adapter_idx, 'Wrong features_adapter'
|
||||
h = run_block_with_profile(
|
||||
block_name="middle",
|
||||
block_stage="middle_block",
|
||||
block_index=0,
|
||||
fn=lambda: self.middle_block(h, emb, context=context, batch_size=b),
|
||||
)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||
y = self.out(h)
|
||||
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
||||
|
||||
hs_out = []
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
skip_h = hs.pop()
|
||||
block_name = f"output_{id}"
|
||||
|
||||
def run_output_block() -> Tensor:
|
||||
return module(torch.cat([h, skip_h], dim=1),
|
||||
emb,
|
||||
context=context,
|
||||
batch_size=b)
|
||||
|
||||
if should_reuse_output_block(block_name):
|
||||
h = reuse_cache_branch[block_name].to(device=h.device,
|
||||
dtype=h.dtype)
|
||||
if backbone_reuse_step_stats is not None:
|
||||
backbone_reuse_step_stats.setdefault(
|
||||
backbone_reuse_branch, set()).add(block_name)
|
||||
else:
|
||||
h = run_block_with_profile(
|
||||
block_name=block_name,
|
||||
block_stage="output_blocks",
|
||||
block_index=id,
|
||||
fn=run_output_block,
|
||||
)
|
||||
if (reuse_cache_branch is not None and backbone_reuse_mode ==
|
||||
"reuse_output"
|
||||
and backbone_reuse_blocks is not None
|
||||
and block_name in backbone_reuse_blocks):
|
||||
reuse_cache_branch[block_name] = h.detach().clone()
|
||||
if isinstance(module[-1], Upsample):
|
||||
hs_a.append(
|
||||
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
hs_out.append(h)
|
||||
h = h.type(x.dtype)
|
||||
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
y = self.out(h)
|
||||
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
||||
|
||||
if not self.base_model_gen_only and run_head:
|
||||
if not self.base_model_gen_only:
|
||||
ba, _, _ = x_action.shape
|
||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
# Predict state
|
||||
if b > 1:
|
||||
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
|
||||
ts_state = timesteps[:ba] if b > 1 else timesteps
|
||||
is_sim_mode = context_action[2] if len(context_action) > 2 else False
|
||||
|
||||
if is_sim_mode:
|
||||
# WM mode: only need state_unet, skip action_unet
|
||||
s_y = self.state_unet(x_state, ts_state, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
a_y = torch.zeros_like(x_action)
|
||||
else:
|
||||
s_y = self.state_unet(x_state, timesteps, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
elif not self.base_model_gen_only:
|
||||
a_y = None
|
||||
s_y = None
|
||||
# DM mode: only need action_unet, skip state_unet
|
||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
s_y = torch.zeros_like(x_state)
|
||||
else:
|
||||
a_y = torch.zeros_like(x_action)
|
||||
s_y = torch.zeros_like(x_state)
|
||||
|
||||
return y, a_y, s_y
|
||||
|
||||
|
||||
def enable_ctx_cache(model):
|
||||
"""Enable context precomputation cache on WMAModel and its action/state UNets."""
|
||||
for m in model.modules():
|
||||
if isinstance(m, WMAModel):
|
||||
m._ctx_cache_enabled = True
|
||||
m._ctx_cache = {}
|
||||
# conditional_unet1d cache
|
||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||
for m in model.modules():
|
||||
if isinstance(m, ConditionalUnet1D):
|
||||
m._global_cond_cache_enabled = True
|
||||
m._global_cond_cache = {}
|
||||
|
||||
|
||||
def disable_ctx_cache(model):
|
||||
"""Disable and clear context precomputation cache."""
|
||||
for m in model.modules():
|
||||
if isinstance(m, WMAModel):
|
||||
m._ctx_cache_enabled = False
|
||||
m._ctx_cache = {}
|
||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||
for m in model.modules():
|
||||
if isinstance(m, ConditionalUnet1D):
|
||||
m._global_cond_cache_enabled = False
|
||||
m._global_cond_cache = {}
|
||||
|
||||
Reference in New Issue
Block a user