多卡流水导出

This commit is contained in:
qhy
2026-05-17 15:05:30 +08:00
parent 9d2d57d96b
commit afd90e59fe
6 changed files with 1787 additions and 1611 deletions

View File

@@ -1,7 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from torch import Tensor
from functools import partial
@@ -686,6 +685,37 @@ class WMAModel(nn.Module):
self.action_token_projector = instantiate_from_config(
stem_process_config)
# Context precomputation cache
self._ctx_cache_enabled = False
self._ctx_cache = {}
self._trt_backbone = None # TRT engine for video UNet backbone
# Reusable CUDA stream for parallel state_unet / action_unet
self._state_stream = torch.cuda.Stream()
def __getstate__(self):
state = self.__dict__.copy()
state.pop('_state_stream', None)
return state
def __setstate__(self, state):
self.__dict__.update(state)
if not hasattr(self, '_ctx_cache_enabled'):
self._ctx_cache_enabled = False
if not hasattr(self, '_ctx_cache'):
self._ctx_cache = {}
if not hasattr(self, '_trt_backbone'):
self._trt_backbone = None
self._state_stream = torch.cuda.Stream()
def load_trt_backbone(self, engine_path, n_hs_a=9):
"""Load a TensorRT engine for the video UNet backbone."""
from unifolm_wma.trt_utils import TRTBackbone
device = next(self.parameters()).device
self._trt_backbone = TRTBackbone(engine_path,
n_hs_a=n_hs_a,
device=device)
print(f">>> TRT backbone loaded from {engine_path} on {device}")
def forward(self,
x: Tensor,
x_action: Tensor,
@@ -714,80 +744,70 @@ class WMAModel(nn.Module):
Tuple of Tensors for predictions:
"""
b, _, t, _, _ = x.shape
run_head = kwargs.pop("run_head", True)
backbone_block_profiler = kwargs.pop("backbone_block_profiler", None)
backbone_step_index = kwargs.pop("backbone_step_index", None)
backbone_reuse_blocks = kwargs.pop("backbone_reuse_blocks", None)
backbone_reuse_start_step = kwargs.pop("backbone_reuse_start_step",
None)
backbone_reuse_schedule_steps = kwargs.pop(
"backbone_reuse_schedule_steps", None)
backbone_reuse_force_compute_steps = kwargs.pop(
"backbone_reuse_force_compute_steps", None)
backbone_reuse_mode = kwargs.pop("backbone_reuse_mode", "disabled")
backbone_reuse_cache = kwargs.pop("backbone_reuse_cache", None)
backbone_reuse_step_stats = kwargs.pop("backbone_reuse_step_stats",
None)
backbone_reuse_branch = kwargs.pop("backbone_reuse_branch", "single")
t_emb = timestep_embedding(timesteps,
self.model_channels,
repeat_only=False).type(x.dtype)
emb = self.time_embed(t_emb)
bt, l_context, _ = context.shape
if self.base_model_gen_only:
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
_ctx_key = context.data_ptr()
if self._ctx_cache_enabled and _ctx_key in self._ctx_cache:
context = self._ctx_cache[_ctx_key]
else:
if l_context == self.n_obs_steps + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
77, :]
context_img = context[:, self.n_obs_steps + 77:, :]
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context = torch.cat(
[context_agent_state, context_text, context_img], dim=1)
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_agent_action = context[:, self.
n_obs_steps:self.n_obs_steps +
16, :]
context_agent_action = rearrange(
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
context_agent_action = self.action_token_projector(
context_agent_action)
context_agent_action = rearrange(context_agent_action,
'(b o) l d -> b o l d',
o=t)
context_agent_action = rearrange(context_agent_action,
'b o (t l) d -> b o t l d',
t=t)
context_agent_action = context_agent_action.permute(
0, 2, 1, 3, 4)
context_agent_action = rearrange(context_agent_action,
'b t o l d -> (b t) (o l) d')
bt, l_context, _ = context.shape
if self.base_model_gen_only:
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
else:
if l_context == self.n_obs_steps + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
77, :]
context_img = context[:, self.n_obs_steps + 77:, :]
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context = torch.cat(
[context_agent_state, context_text, context_img], dim=1)
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_agent_action = context[:, self.
n_obs_steps:self.n_obs_steps +
16, :]
context_agent_action = rearrange(
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
context_agent_action = self.action_token_projector(
context_agent_action)
context_agent_action = rearrange(context_agent_action,
'(b o) l d -> b o l d',
o=t)
context_agent_action = rearrange(context_agent_action,
'b o (t l) d -> b o t l d',
t=t)
context_agent_action = context_agent_action.permute(
0, 2, 1, 3, 4)
context_agent_action = rearrange(context_agent_action,
'b t o l d -> (b t) (o l) d')
context_text = context[:, self.n_obs_steps +
16:self.n_obs_steps + 16 + 77, :]
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_text = context[:, self.n_obs_steps +
16:self.n_obs_steps + 16 + 77, :]
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context = torch.cat([
context_agent_state, context_agent_action, context_text,
context_img
],
dim=1)
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context = torch.cat([
context_agent_state, context_agent_action, context_text,
context_img
],
dim=1)
if self._ctx_cache_enabled:
self._ctx_cache[_ctx_key] = context
emb = emb.repeat_interleave(repeats=t, dim=0)
@@ -807,150 +827,95 @@ class WMAModel(nn.Module):
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
emb = emb + fs_embed
def run_block_with_profile(block_name: str, block_stage: str,
block_index: int | None,
fn: Callable[[], Tensor]) -> Tensor:
if backbone_block_profiler is None or backbone_step_index is None:
return fn()
if x.device.type == "cuda":
torch.cuda.synchronize(x.device)
start_time = time.perf_counter()
out = fn()
if x.device.type == "cuda":
torch.cuda.synchronize(x.device)
backbone_block_profiler.record_block(
step=int(backbone_step_index),
block_name=block_name,
block_stage=block_stage,
block_index=block_index,
output=out,
forward_time_ms=(time.perf_counter() - start_time) * 1000.0,
)
return out
reuse_cache_branch: Dict[str, Tensor] | None = None
if backbone_reuse_cache is not None:
reuse_cache_branch = backbone_reuse_cache.setdefault(
backbone_reuse_branch, {})
def should_reuse_output_block(block_name: str) -> bool:
if backbone_reuse_mode != "reuse_output":
return False
if backbone_step_index is None or backbone_reuse_start_step is None:
return False
if backbone_reuse_blocks is None or block_name not in backbone_reuse_blocks:
return False
if int(backbone_step_index) < int(backbone_reuse_start_step):
return False
if (backbone_reuse_force_compute_steps is not None
and int(backbone_step_index)
in backbone_reuse_force_compute_steps):
return False
if (backbone_reuse_schedule_steps is not None
and int(backbone_step_index)
in backbone_reuse_schedule_steps):
return False
if reuse_cache_branch is None:
return False
return block_name in reuse_cache_branch
h = x.type(self.dtype)
adapter_idx = 0
hs = []
hs_a = []
for id, module in enumerate(self.input_blocks):
def run_input_block() -> Tensor:
block_out = module(h, emb, context=context, batch_size=b)
if self._trt_backbone is not None:
# TRT path: run backbone via TensorRT engine
h_in = x.type(self.dtype).contiguous()
y, hs_a = self._trt_backbone(h_in, emb.contiguous(), context.contiguous())
else:
# PyTorch path: original backbone
h = x.type(self.dtype)
adapter_idx = 0
hs = []
hs_a = []
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context=context, batch_size=b)
if id == 0 and self.addition_attention:
block_out = self.init_attn(block_out,
emb,
context=context,
batch_size=b)
return block_out
h = self.init_attn(h, emb, context=context, batch_size=b)
# plug-in adapter features
if ((id + 1) % 3 == 0) and features_adapter is not None:
h = h + features_adapter[adapter_idx]
adapter_idx += 1
if id != 0:
if isinstance(module[0], Downsample):
hs_a.append(
rearrange(hs[-1], '(b t) c h w -> b t c h w', b=b))
hs.append(h)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', b=b))
h = run_block_with_profile(
block_name=f"input_{id}",
block_stage="input_blocks",
block_index=id,
fn=run_input_block,
)
# plug-in adapter features
if ((id + 1) % 3 == 0) and features_adapter is not None:
h = h + features_adapter[adapter_idx]
adapter_idx += 1
if id != 0:
if isinstance(module[0], Downsample):
if features_adapter is not None:
assert len(
features_adapter) == adapter_idx, 'Wrong features_adapter'
h = self.middle_block(h, emb, context=context, batch_size=b)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', b=b))
hs_out = []
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context=context, batch_size=b)
if isinstance(module[-1], Upsample):
hs_a.append(
rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
hs.append(h)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', b=b))
hs_out.append(h)
h = h.type(x.dtype)
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', b=b))
if features_adapter is not None:
assert len(
features_adapter) == adapter_idx, 'Wrong features_adapter'
h = run_block_with_profile(
block_name="middle",
block_stage="middle_block",
block_index=0,
fn=lambda: self.middle_block(h, emb, context=context, batch_size=b),
)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
y = self.out(h)
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
hs_out = []
for id, module in enumerate(self.output_blocks):
skip_h = hs.pop()
block_name = f"output_{id}"
def run_output_block() -> Tensor:
return module(torch.cat([h, skip_h], dim=1),
emb,
context=context,
batch_size=b)
if should_reuse_output_block(block_name):
h = reuse_cache_branch[block_name].to(device=h.device,
dtype=h.dtype)
if backbone_reuse_step_stats is not None:
backbone_reuse_step_stats.setdefault(
backbone_reuse_branch, set()).add(block_name)
else:
h = run_block_with_profile(
block_name=block_name,
block_stage="output_blocks",
block_index=id,
fn=run_output_block,
)
if (reuse_cache_branch is not None and backbone_reuse_mode ==
"reuse_output"
and backbone_reuse_blocks is not None
and block_name in backbone_reuse_blocks):
reuse_cache_branch[block_name] = h.detach().clone()
if isinstance(module[-1], Upsample):
hs_a.append(
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
hs_out.append(h)
h = h.type(x.dtype)
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
y = self.out(h)
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
if not self.base_model_gen_only and run_head:
if not self.base_model_gen_only:
ba, _, _ = x_action.shape
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
context_action[:2], **kwargs)
# Predict state
if b > 1:
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
ts_state = timesteps[:ba] if b > 1 else timesteps
is_sim_mode = context_action[2] if len(context_action) > 2 else False
if is_sim_mode:
# WM mode: only need state_unet, skip action_unet
s_y = self.state_unet(x_state, ts_state, hs_a,
context_action[:2], **kwargs)
a_y = torch.zeros_like(x_action)
else:
s_y = self.state_unet(x_state, timesteps, hs_a,
context_action[:2], **kwargs)
elif not self.base_model_gen_only:
a_y = None
s_y = None
# DM mode: only need action_unet, skip state_unet
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
context_action[:2], **kwargs)
s_y = torch.zeros_like(x_state)
else:
a_y = torch.zeros_like(x_action)
s_y = torch.zeros_like(x_state)
return y, a_y, s_y
def enable_ctx_cache(model):
"""Enable context precomputation cache on WMAModel and its action/state UNets."""
for m in model.modules():
if isinstance(m, WMAModel):
m._ctx_cache_enabled = True
m._ctx_cache = {}
# conditional_unet1d cache
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
if isinstance(m, ConditionalUnet1D):
m._global_cond_cache_enabled = True
m._global_cond_cache = {}
def disable_ctx_cache(model):
"""Disable and clear context precomputation cache."""
for m in model.modules():
if isinstance(m, WMAModel):
m._ctx_cache_enabled = False
m._ctx_cache = {}
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
if isinstance(m, ConditionalUnet1D):
m._global_cond_cache_enabled = False
m._global_cond_cache = {}