保存结果一次

This commit is contained in:
qhy
2026-03-18 20:52:13 +08:00
parent 8ca159d375
commit 9d2d57d96b
15 changed files with 2312 additions and 15 deletions

View File

@@ -719,6 +719,18 @@ class WMAModel(nn.Module):
run_head = kwargs.pop("run_head", True)
backbone_block_profiler = kwargs.pop("backbone_block_profiler", None)
backbone_step_index = kwargs.pop("backbone_step_index", None)
backbone_reuse_blocks = kwargs.pop("backbone_reuse_blocks", None)
backbone_reuse_start_step = kwargs.pop("backbone_reuse_start_step",
None)
backbone_reuse_schedule_steps = kwargs.pop(
"backbone_reuse_schedule_steps", None)
backbone_reuse_force_compute_steps = kwargs.pop(
"backbone_reuse_force_compute_steps", None)
backbone_reuse_mode = kwargs.pop("backbone_reuse_mode", "disabled")
backbone_reuse_cache = kwargs.pop("backbone_reuse_cache", None)
backbone_reuse_step_stats = kwargs.pop("backbone_reuse_step_stats",
None)
backbone_reuse_branch = kwargs.pop("backbone_reuse_branch", "single")
t_emb = timestep_embedding(timesteps,
self.model_channels,
repeat_only=False).type(x.dtype)
@@ -816,6 +828,32 @@ class WMAModel(nn.Module):
)
return out
reuse_cache_branch: Dict[str, Tensor] | None = None
if backbone_reuse_cache is not None:
reuse_cache_branch = backbone_reuse_cache.setdefault(
backbone_reuse_branch, {})
def should_reuse_output_block(block_name: str) -> bool:
if backbone_reuse_mode != "reuse_output":
return False
if backbone_step_index is None or backbone_reuse_start_step is None:
return False
if backbone_reuse_blocks is None or block_name not in backbone_reuse_blocks:
return False
if int(backbone_step_index) < int(backbone_reuse_start_step):
return False
if (backbone_reuse_force_compute_steps is not None
and int(backbone_step_index)
in backbone_reuse_force_compute_steps):
return False
if (backbone_reuse_schedule_steps is not None
and int(backbone_step_index)
in backbone_reuse_schedule_steps):
return False
if reuse_cache_branch is None:
return False
return block_name in reuse_cache_branch
h = x.type(self.dtype)
adapter_idx = 0
hs = []
@@ -861,6 +899,7 @@ class WMAModel(nn.Module):
hs_out = []
for id, module in enumerate(self.output_blocks):
skip_h = hs.pop()
block_name = f"output_{id}"
def run_output_block() -> Tensor:
return module(torch.cat([h, skip_h], dim=1),
@@ -868,12 +907,24 @@ class WMAModel(nn.Module):
context=context,
batch_size=b)
h = run_block_with_profile(
block_name=f"output_{id}",
block_stage="output_blocks",
block_index=id,
fn=run_output_block,
)
if should_reuse_output_block(block_name):
h = reuse_cache_branch[block_name].to(device=h.device,
dtype=h.dtype)
if backbone_reuse_step_stats is not None:
backbone_reuse_step_stats.setdefault(
backbone_reuse_branch, set()).add(block_name)
else:
h = run_block_with_profile(
block_name=block_name,
block_stage="output_blocks",
block_index=id,
fn=run_output_block,
)
if (reuse_cache_branch is not None and backbone_reuse_mode ==
"reuse_output"
and backbone_reuse_blocks is not None
and block_name in backbone_reuse_blocks):
reuse_cache_branch[block_name] = h.detach().clone()
if isinstance(module[-1], Upsample):
hs_a.append(
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))