保存结果一次
This commit is contained in:
@@ -719,6 +719,18 @@ class WMAModel(nn.Module):
|
||||
run_head = kwargs.pop("run_head", True)
|
||||
backbone_block_profiler = kwargs.pop("backbone_block_profiler", None)
|
||||
backbone_step_index = kwargs.pop("backbone_step_index", None)
|
||||
backbone_reuse_blocks = kwargs.pop("backbone_reuse_blocks", None)
|
||||
backbone_reuse_start_step = kwargs.pop("backbone_reuse_start_step",
|
||||
None)
|
||||
backbone_reuse_schedule_steps = kwargs.pop(
|
||||
"backbone_reuse_schedule_steps", None)
|
||||
backbone_reuse_force_compute_steps = kwargs.pop(
|
||||
"backbone_reuse_force_compute_steps", None)
|
||||
backbone_reuse_mode = kwargs.pop("backbone_reuse_mode", "disabled")
|
||||
backbone_reuse_cache = kwargs.pop("backbone_reuse_cache", None)
|
||||
backbone_reuse_step_stats = kwargs.pop("backbone_reuse_step_stats",
|
||||
None)
|
||||
backbone_reuse_branch = kwargs.pop("backbone_reuse_branch", "single")
|
||||
t_emb = timestep_embedding(timesteps,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
@@ -816,6 +828,32 @@ class WMAModel(nn.Module):
|
||||
)
|
||||
return out
|
||||
|
||||
reuse_cache_branch: Dict[str, Tensor] | None = None
|
||||
if backbone_reuse_cache is not None:
|
||||
reuse_cache_branch = backbone_reuse_cache.setdefault(
|
||||
backbone_reuse_branch, {})
|
||||
|
||||
def should_reuse_output_block(block_name: str) -> bool:
|
||||
if backbone_reuse_mode != "reuse_output":
|
||||
return False
|
||||
if backbone_step_index is None or backbone_reuse_start_step is None:
|
||||
return False
|
||||
if backbone_reuse_blocks is None or block_name not in backbone_reuse_blocks:
|
||||
return False
|
||||
if int(backbone_step_index) < int(backbone_reuse_start_step):
|
||||
return False
|
||||
if (backbone_reuse_force_compute_steps is not None
|
||||
and int(backbone_step_index)
|
||||
in backbone_reuse_force_compute_steps):
|
||||
return False
|
||||
if (backbone_reuse_schedule_steps is not None
|
||||
and int(backbone_step_index)
|
||||
in backbone_reuse_schedule_steps):
|
||||
return False
|
||||
if reuse_cache_branch is None:
|
||||
return False
|
||||
return block_name in reuse_cache_branch
|
||||
|
||||
h = x.type(self.dtype)
|
||||
adapter_idx = 0
|
||||
hs = []
|
||||
@@ -861,6 +899,7 @@ class WMAModel(nn.Module):
|
||||
hs_out = []
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
skip_h = hs.pop()
|
||||
block_name = f"output_{id}"
|
||||
|
||||
def run_output_block() -> Tensor:
|
||||
return module(torch.cat([h, skip_h], dim=1),
|
||||
@@ -868,12 +907,24 @@ class WMAModel(nn.Module):
|
||||
context=context,
|
||||
batch_size=b)
|
||||
|
||||
h = run_block_with_profile(
|
||||
block_name=f"output_{id}",
|
||||
block_stage="output_blocks",
|
||||
block_index=id,
|
||||
fn=run_output_block,
|
||||
)
|
||||
if should_reuse_output_block(block_name):
|
||||
h = reuse_cache_branch[block_name].to(device=h.device,
|
||||
dtype=h.dtype)
|
||||
if backbone_reuse_step_stats is not None:
|
||||
backbone_reuse_step_stats.setdefault(
|
||||
backbone_reuse_branch, set()).add(block_name)
|
||||
else:
|
||||
h = run_block_with_profile(
|
||||
block_name=block_name,
|
||||
block_stage="output_blocks",
|
||||
block_index=id,
|
||||
fn=run_output_block,
|
||||
)
|
||||
if (reuse_cache_branch is not None and backbone_reuse_mode ==
|
||||
"reuse_output"
|
||||
and backbone_reuse_blocks is not None
|
||||
and block_name in backbone_reuse_blocks):
|
||||
reuse_cache_branch[block_name] = h.detach().clone()
|
||||
if isinstance(module[-1], Upsample):
|
||||
hs_a.append(
|
||||
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
Reference in New Issue
Block a user