添加CrossAttention kv缓存,减少重复计算,提升性能,psnr=31.8022 dB
This commit is contained in:
@@ -6,6 +6,7 @@ from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim
|
||||
from unifolm_wma.utils.common import noise_like
|
||||
from unifolm_wma.utils.common import extract_into_tensor
|
||||
from tqdm import tqdm
|
||||
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
@@ -243,63 +244,67 @@ class DDIMSampler(object):
|
||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts.fill_(step)
|
||||
enable_cross_attn_kv_cache(self.model)
|
||||
try:
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts.fill_(step)
|
||||
|
||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
if clean_cond:
|
||||
img_orig = x0
|
||||
else:
|
||||
img_orig = self.model.q_sample(x0, ts)
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
if clean_cond:
|
||||
img_orig = x0
|
||||
else:
|
||||
img_orig = self.model.q_sample(x0, ts)
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
action,
|
||||
state,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
action,
|
||||
state,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
img, pred_x0, model_output_action, model_output_state = outs
|
||||
img, pred_x0, model_output_action, model_output_state = outs
|
||||
|
||||
action = dp_ddim_scheduler_action.step(
|
||||
model_output_action,
|
||||
step,
|
||||
action,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
state = dp_ddim_scheduler_state.step(
|
||||
model_output_state,
|
||||
step,
|
||||
state,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
action = dp_ddim_scheduler_action.step(
|
||||
model_output_action,
|
||||
step,
|
||||
action,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
state = dp_ddim_scheduler_state.step(
|
||||
model_output_state,
|
||||
step,
|
||||
state,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates['x_inter_action'].append(action)
|
||||
intermediates['x_inter_state'].append(state)
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates['x_inter_action'].append(action)
|
||||
intermediates['x_inter_state'].append(state)
|
||||
finally:
|
||||
disable_cross_attn_kv_cache(self.model)
|
||||
|
||||
return img, action, state, intermediates
|
||||
|
||||
|
||||
Reference in New Issue
Block a user