算子融合
This commit is contained in:
@@ -55,16 +55,13 @@ class DDIMSampler(object):
|
|||||||
to_torch(self.model.alphas_cumprod_prev))
|
to_torch(self.model.alphas_cumprod_prev))
|
||||||
|
|
||||||
# Calculations for diffusion q(x_t | x_{t-1}) and others
|
# Calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
self.register_buffer('sqrt_alphas_cumprod',
|
# Computed directly on GPU to avoid CPU↔GPU transfers
|
||||||
to_torch(np.sqrt(alphas_cumprod.cpu())))
|
ac = to_torch(alphas_cumprod)
|
||||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
self.register_buffer('sqrt_alphas_cumprod', ac.sqrt())
|
||||||
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', (1. - ac).sqrt())
|
||||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
self.register_buffer('log_one_minus_alphas_cumprod', (1. - ac).log())
|
||||||
to_torch(np.log(1. - alphas_cumprod.cpu())))
|
self.register_buffer('sqrt_recip_alphas_cumprod', ac.rsqrt())
|
||||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', (1. / ac - 1).sqrt())
|
||||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
|
||||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
|
||||||
|
|
||||||
# DDIM sampling parameters
|
# DDIM sampling parameters
|
||||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||||
@@ -86,6 +83,11 @@ class DDIMSampler(object):
|
|||||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||||
torch.sqrt(1. - ddim_alphas))
|
torch.sqrt(1. - ddim_alphas))
|
||||||
|
# Precomputed coefficients for DDIM update formula
|
||||||
|
self.register_buffer('ddim_sqrt_alphas', ddim_alphas.sqrt())
|
||||||
|
self.register_buffer('ddim_sqrt_alphas_prev', ddim_alphas_prev.sqrt())
|
||||||
|
self.register_buffer('ddim_dir_coeff',
|
||||||
|
(1. - ddim_alphas_prev - ddim_sigmas**2).sqrt())
|
||||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||||
@@ -208,18 +210,11 @@ class DDIMSampler(object):
|
|||||||
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
|
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
if x_T is None:
|
|
||||||
img = torch.randn(shape, device=device)
|
|
||||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
|
||||||
device=device)
|
|
||||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
|
||||||
device=device)
|
|
||||||
else:
|
|
||||||
img = x_T
|
|
||||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||||
device=device)
|
device=device)
|
||||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||||
device=device)
|
device=device)
|
||||||
|
img = torch.randn(shape, device=device) if x_T is None else x_T
|
||||||
|
|
||||||
if precision is not None:
|
if precision is not None:
|
||||||
if precision == 16:
|
if precision == 16:
|
||||||
@@ -362,12 +357,13 @@ class DDIMSampler(object):
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
model_output = e_t_uncond + unconditional_guidance_scale * (
|
model_output = torch.lerp(e_t_uncond, e_t_cond,
|
||||||
e_t_cond - e_t_uncond)
|
unconditional_guidance_scale)
|
||||||
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
|
model_output_action = torch.lerp(e_t_uncond_action,
|
||||||
e_t_cond_action - e_t_uncond_action)
|
e_t_cond_action,
|
||||||
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
|
unconditional_guidance_scale)
|
||||||
e_t_cond_state - e_t_uncond_state)
|
model_output_state = torch.lerp(e_t_uncond_state, e_t_cond_state,
|
||||||
|
unconditional_guidance_scale)
|
||||||
|
|
||||||
if guidance_rescale > 0.0:
|
if guidance_rescale > 0.0:
|
||||||
model_output = rescale_noise_cfg(
|
model_output = rescale_noise_cfg(
|
||||||
@@ -396,18 +392,28 @@ class DDIMSampler(object):
|
|||||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
|
if use_original_steps:
|
||||||
|
sqrt_alphas = alphas.sqrt()
|
||||||
|
sqrt_alphas_prev = alphas_prev.sqrt()
|
||||||
|
dir_coeffs = (1. - alphas_prev - sigmas**2).sqrt()
|
||||||
|
else:
|
||||||
|
sqrt_alphas = self.ddim_sqrt_alphas
|
||||||
|
sqrt_alphas_prev = self.ddim_sqrt_alphas_prev
|
||||||
|
dir_coeffs = self.ddim_dir_coeff
|
||||||
|
|
||||||
if is_video:
|
if is_video:
|
||||||
size = (1, 1, 1, 1, 1)
|
size = (1, 1, 1, 1, 1)
|
||||||
else:
|
else:
|
||||||
size = (1, 1, 1, 1)
|
size = (1, 1, 1, 1)
|
||||||
|
|
||||||
a_t = alphas[index].view(size)
|
sqrt_at = sqrt_alphas[index].view(size)
|
||||||
a_prev = alphas_prev[index].view(size)
|
sqrt_a_prev = sqrt_alphas_prev[index].view(size)
|
||||||
sigma_t = sigmas[index].view(size)
|
sigma_t = sigmas[index].view(size)
|
||||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
|
||||||
|
dir_coeff = dir_coeffs[index].view(size)
|
||||||
|
|
||||||
if self.model.parameterization != "v":
|
if self.model.parameterization != "v":
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / sqrt_at
|
||||||
else:
|
else:
|
||||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||||
|
|
||||||
@@ -420,14 +426,11 @@ class DDIMSampler(object):
|
|||||||
if quantize_denoised:
|
if quantize_denoised:
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
|
||||||
|
|
||||||
noise = sigma_t * noise_like(x.shape, device,
|
noise = sigma_t * noise_like(x.shape, device,
|
||||||
repeat_noise) * temperature
|
repeat_noise) * temperature
|
||||||
if noise_dropout > 0.:
|
if noise_dropout > 0.:
|
||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = sqrt_a_prev * pred_x0 + dir_coeff * e_t + noise
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
|
||||||
|
|
||||||
return x_prev, pred_x0, model_output_action, model_output_state
|
return x_prev, pred_x0, model_output_action, model_output_state
|
||||||
|
|
||||||
@@ -475,7 +478,7 @@ class DDIMSampler(object):
|
|||||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||||
else:
|
else:
|
||||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
sqrt_alphas_cumprod = self.ddim_sqrt_alphas
|
||||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||||
|
|
||||||
if noise is None:
|
if noise is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user