video_backbone剖析

This commit is contained in:
qhy
2026-03-16 10:30:54 +08:00
parent 7e45eba18b
commit 8ca159d375
282 changed files with 174952 additions and 1350 deletions

View File

@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from torch import Tensor
from functools import partial
@@ -715,6 +716,9 @@ class WMAModel(nn.Module):
"""
b, _, t, _, _ = x.shape
run_head = kwargs.pop("run_head", True)
backbone_block_profiler = kwargs.pop("backbone_block_profiler", None)
backbone_step_index = kwargs.pop("backbone_step_index", None)
t_emb = timestep_embedding(timesteps,
self.model_channels,
repeat_only=False).type(x.dtype)
@@ -791,14 +795,47 @@ class WMAModel(nn.Module):
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
emb = emb + fs_embed
def run_block_with_profile(block_name: str, block_stage: str,
block_index: int | None,
fn: Callable[[], Tensor]) -> Tensor:
if backbone_block_profiler is None or backbone_step_index is None:
return fn()
if x.device.type == "cuda":
torch.cuda.synchronize(x.device)
start_time = time.perf_counter()
out = fn()
if x.device.type == "cuda":
torch.cuda.synchronize(x.device)
backbone_block_profiler.record_block(
step=int(backbone_step_index),
block_name=block_name,
block_stage=block_stage,
block_index=block_index,
output=out,
forward_time_ms=(time.perf_counter() - start_time) * 1000.0,
)
return out
h = x.type(self.dtype)
adapter_idx = 0
hs = []
hs_a = []
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context=context, batch_size=b)
if id == 0 and self.addition_attention:
h = self.init_attn(h, emb, context=context, batch_size=b)
def run_input_block() -> Tensor:
block_out = module(h, emb, context=context, batch_size=b)
if id == 0 and self.addition_attention:
block_out = self.init_attn(block_out,
emb,
context=context,
batch_size=b)
return block_out
h = run_block_with_profile(
block_name=f"input_{id}",
block_stage="input_blocks",
block_index=id,
fn=run_input_block,
)
# plug-in adapter features
if ((id + 1) % 3 == 0) and features_adapter is not None:
h = h + features_adapter[adapter_idx]
@@ -813,13 +850,30 @@ class WMAModel(nn.Module):
if features_adapter is not None:
assert len(
features_adapter) == adapter_idx, 'Wrong features_adapter'
h = self.middle_block(h, emb, context=context, batch_size=b)
h = run_block_with_profile(
block_name="middle",
block_stage="middle_block",
block_index=0,
fn=lambda: self.middle_block(h, emb, context=context, batch_size=b),
)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
hs_out = []
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context=context, batch_size=b)
for id, module in enumerate(self.output_blocks):
skip_h = hs.pop()
def run_output_block() -> Tensor:
return module(torch.cat([h, skip_h], dim=1),
emb,
context=context,
batch_size=b)
h = run_block_with_profile(
block_name=f"output_{id}",
block_stage="output_blocks",
block_index=id,
fn=run_output_block,
)
if isinstance(module[-1], Upsample):
hs_a.append(
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
@@ -830,7 +884,7 @@ class WMAModel(nn.Module):
y = self.out(h)
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
if not self.base_model_gen_only:
if not self.base_model_gen_only and run_head:
ba, _, _ = x_action.shape
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
context_action[:2], **kwargs)
@@ -841,6 +895,9 @@ class WMAModel(nn.Module):
else:
s_y = self.state_unet(x_state, timesteps, hs_a,
context_action[:2], **kwargs)
elif not self.base_model_gen_only:
a_y = None
s_y = None
else:
a_y = torch.zeros_like(x_action)
s_y = torch.zeros_like(x_state)