video_backbone剖析
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import time
|
||||
|
||||
from torch import Tensor
|
||||
from functools import partial
|
||||
@@ -715,6 +716,9 @@ class WMAModel(nn.Module):
|
||||
"""
|
||||
|
||||
b, _, t, _, _ = x.shape
|
||||
run_head = kwargs.pop("run_head", True)
|
||||
backbone_block_profiler = kwargs.pop("backbone_block_profiler", None)
|
||||
backbone_step_index = kwargs.pop("backbone_step_index", None)
|
||||
t_emb = timestep_embedding(timesteps,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
@@ -791,14 +795,47 @@ class WMAModel(nn.Module):
|
||||
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
||||
emb = emb + fs_embed
|
||||
|
||||
def run_block_with_profile(block_name: str, block_stage: str,
|
||||
block_index: int | None,
|
||||
fn: Callable[[], Tensor]) -> Tensor:
|
||||
if backbone_block_profiler is None or backbone_step_index is None:
|
||||
return fn()
|
||||
if x.device.type == "cuda":
|
||||
torch.cuda.synchronize(x.device)
|
||||
start_time = time.perf_counter()
|
||||
out = fn()
|
||||
if x.device.type == "cuda":
|
||||
torch.cuda.synchronize(x.device)
|
||||
backbone_block_profiler.record_block(
|
||||
step=int(backbone_step_index),
|
||||
block_name=block_name,
|
||||
block_stage=block_stage,
|
||||
block_index=block_index,
|
||||
output=out,
|
||||
forward_time_ms=(time.perf_counter() - start_time) * 1000.0,
|
||||
)
|
||||
return out
|
||||
|
||||
h = x.type(self.dtype)
|
||||
adapter_idx = 0
|
||||
hs = []
|
||||
hs_a = []
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
if id == 0 and self.addition_attention:
|
||||
h = self.init_attn(h, emb, context=context, batch_size=b)
|
||||
def run_input_block() -> Tensor:
|
||||
block_out = module(h, emb, context=context, batch_size=b)
|
||||
if id == 0 and self.addition_attention:
|
||||
block_out = self.init_attn(block_out,
|
||||
emb,
|
||||
context=context,
|
||||
batch_size=b)
|
||||
return block_out
|
||||
|
||||
h = run_block_with_profile(
|
||||
block_name=f"input_{id}",
|
||||
block_stage="input_blocks",
|
||||
block_index=id,
|
||||
fn=run_input_block,
|
||||
)
|
||||
# plug-in adapter features
|
||||
if ((id + 1) % 3 == 0) and features_adapter is not None:
|
||||
h = h + features_adapter[adapter_idx]
|
||||
@@ -813,13 +850,30 @@ class WMAModel(nn.Module):
|
||||
if features_adapter is not None:
|
||||
assert len(
|
||||
features_adapter) == adapter_idx, 'Wrong features_adapter'
|
||||
h = self.middle_block(h, emb, context=context, batch_size=b)
|
||||
h = run_block_with_profile(
|
||||
block_name="middle",
|
||||
block_stage="middle_block",
|
||||
block_index=0,
|
||||
fn=lambda: self.middle_block(h, emb, context=context, batch_size=b),
|
||||
)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
hs_out = []
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
skip_h = hs.pop()
|
||||
|
||||
def run_output_block() -> Tensor:
|
||||
return module(torch.cat([h, skip_h], dim=1),
|
||||
emb,
|
||||
context=context,
|
||||
batch_size=b)
|
||||
|
||||
h = run_block_with_profile(
|
||||
block_name=f"output_{id}",
|
||||
block_stage="output_blocks",
|
||||
block_index=id,
|
||||
fn=run_output_block,
|
||||
)
|
||||
if isinstance(module[-1], Upsample):
|
||||
hs_a.append(
|
||||
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
@@ -830,7 +884,7 @@ class WMAModel(nn.Module):
|
||||
y = self.out(h)
|
||||
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
||||
|
||||
if not self.base_model_gen_only:
|
||||
if not self.base_model_gen_only and run_head:
|
||||
ba, _, _ = x_action.shape
|
||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
@@ -841,6 +895,9 @@ class WMAModel(nn.Module):
|
||||
else:
|
||||
s_y = self.state_unet(x_state, timesteps, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
elif not self.base_model_gen_only:
|
||||
a_y = None
|
||||
s_y = None
|
||||
else:
|
||||
a_y = torch.zeros_like(x_action)
|
||||
s_y = torch.zeros_like(x_state)
|
||||
|
||||
Reference in New Issue
Block a user