早停特征验证,早停不通
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import copy
|
||||
import time
|
||||
|
||||
from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
|
||||
from unifolm_wma.utils.common import noise_like
|
||||
@@ -106,6 +107,9 @@ class DDIMSampler(object):
|
||||
fs=None,
|
||||
timestep_spacing='uniform', #uniform_trailing for starting from last timestep
|
||||
guidance_rescale=0.0,
|
||||
action_T=None,
|
||||
state_T=None,
|
||||
record_step_outputs=False,
|
||||
**kwargs):
|
||||
|
||||
# Check condition bs
|
||||
@@ -161,6 +165,9 @@ class DDIMSampler(object):
|
||||
precision=precision,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
action_T=action_T,
|
||||
state_T=state_T,
|
||||
record_step_outputs=record_step_outputs,
|
||||
**kwargs)
|
||||
return samples, actions, states, intermediates
|
||||
|
||||
@@ -187,24 +194,30 @@ class DDIMSampler(object):
|
||||
precision=None,
|
||||
fs=None,
|
||||
guidance_rescale=0.0,
|
||||
action_T=None,
|
||||
state_T=None,
|
||||
record_step_outputs=False,
|
||||
**kwargs):
|
||||
device = self.model.betas.device
|
||||
dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action
|
||||
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
|
||||
|
||||
b = shape[0]
|
||||
horizon = shape[2] if len(shape) >= 3 else 16
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
else:
|
||||
img = x_T
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
if action_T is None:
|
||||
action = torch.randn((b, horizon, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
else:
|
||||
action = action_T
|
||||
if state_T is None:
|
||||
state = torch.randn((b, horizon, self.model.agent_state_dim),
|
||||
device=device)
|
||||
else:
|
||||
state = state_T
|
||||
|
||||
if precision is not None:
|
||||
if precision == 16:
|
||||
@@ -228,6 +241,13 @@ class DDIMSampler(object):
|
||||
'x_inter_state': [state],
|
||||
'pred_x0_state': [state],
|
||||
}
|
||||
if record_step_outputs:
|
||||
intermediates['analysis_init'] = {
|
||||
'img': img.detach().cpu(),
|
||||
'action': action.detach().cpu(),
|
||||
'state': state.detach().cpu(),
|
||||
}
|
||||
intermediates['step_records'] = []
|
||||
time_range = reversed(range(
|
||||
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
||||
@@ -238,6 +258,9 @@ class DDIMSampler(object):
|
||||
iterator = time_range
|
||||
|
||||
clean_cond = kwargs.pop("clean_cond", False)
|
||||
sync_device = device if isinstance(device, torch.device) else torch.device(
|
||||
device)
|
||||
should_sync = record_step_outputs and sync_device.type == "cuda"
|
||||
|
||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||
@@ -254,6 +277,10 @@ class DDIMSampler(object):
|
||||
img_orig = self.model.q_sample(x0, ts)
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
if should_sync:
|
||||
torch.cuda.synchronize(sync_device)
|
||||
step_start_time = time.time()
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
action,
|
||||
@@ -290,6 +317,10 @@ class DDIMSampler(object):
|
||||
generator=None,
|
||||
).prev_sample
|
||||
|
||||
if should_sync:
|
||||
torch.cuda.synchronize(sync_device)
|
||||
step_time_s = time.time() - step_start_time
|
||||
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
@@ -298,6 +329,16 @@ class DDIMSampler(object):
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates['x_inter_action'].append(action)
|
||||
intermediates['x_inter_state'].append(state)
|
||||
if record_step_outputs:
|
||||
intermediates['step_records'].append({
|
||||
'step_index': i + 1,
|
||||
'ddim_timestep': int(step),
|
||||
'img': img.detach().cpu(),
|
||||
'pred_x0': pred_x0.detach().cpu(),
|
||||
'action': action.detach().cpu(),
|
||||
'state': state.detach().cpu(),
|
||||
'step_time_s': step_time_s,
|
||||
})
|
||||
|
||||
return img, action, state, intermediates
|
||||
|
||||
|
||||
Reference in New Issue
Block a user