早停特征验证,早停不通

This commit is contained in:
qhy
2026-03-15 12:41:53 +08:00
parent db9cc5766d
commit 7e45eba18b
227 changed files with 24579 additions and 163 deletions

View File

@@ -1,6 +1,7 @@
import numpy as np
import torch
import copy
import time
from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
from unifolm_wma.utils.common import noise_like
@@ -106,6 +107,9 @@ class DDIMSampler(object):
fs=None,
timestep_spacing='uniform', #uniform_trailing for starting from last timestep
guidance_rescale=0.0,
action_T=None,
state_T=None,
record_step_outputs=False,
**kwargs):
# Check condition bs
@@ -161,6 +165,9 @@ class DDIMSampler(object):
precision=precision,
fs=fs,
guidance_rescale=guidance_rescale,
action_T=action_T,
state_T=state_T,
record_step_outputs=record_step_outputs,
**kwargs)
return samples, actions, states, intermediates
@@ -187,24 +194,30 @@ class DDIMSampler(object):
precision=None,
fs=None,
guidance_rescale=0.0,
action_T=None,
state_T=None,
record_step_outputs=False,
**kwargs):
device = self.model.betas.device
dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
b = shape[0]
horizon = shape[2] if len(shape) >= 3 else 16
if x_T is None:
img = torch.randn(shape, device=device)
action = torch.randn((b, 16, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
device=device)
else:
img = x_T
action = torch.randn((b, 16, self.model.agent_action_dim),
if action_T is None:
action = torch.randn((b, horizon, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
else:
action = action_T
if state_T is None:
state = torch.randn((b, horizon, self.model.agent_state_dim),
device=device)
else:
state = state_T
if precision is not None:
if precision == 16:
@@ -228,6 +241,13 @@ class DDIMSampler(object):
'x_inter_state': [state],
'pred_x0_state': [state],
}
if record_step_outputs:
intermediates['analysis_init'] = {
'img': img.detach().cpu(),
'action': action.detach().cpu(),
'state': state.detach().cpu(),
}
intermediates['step_records'] = []
time_range = reversed(range(
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
@@ -238,6 +258,9 @@ class DDIMSampler(object):
iterator = time_range
clean_cond = kwargs.pop("clean_cond", False)
sync_device = device if isinstance(device, torch.device) else torch.device(
device)
should_sync = record_step_outputs and sync_device.type == "cuda"
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
@@ -254,6 +277,10 @@ class DDIMSampler(object):
img_orig = self.model.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if should_sync:
torch.cuda.synchronize(sync_device)
step_start_time = time.time()
outs = self.p_sample_ddim(
img,
action,
@@ -290,6 +317,10 @@ class DDIMSampler(object):
generator=None,
).prev_sample
if should_sync:
torch.cuda.synchronize(sync_device)
step_time_s = time.time() - step_start_time
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
@@ -298,6 +329,16 @@ class DDIMSampler(object):
intermediates['pred_x0'].append(pred_x0)
intermediates['x_inter_action'].append(action)
intermediates['x_inter_state'].append(state)
if record_step_outputs:
intermediates['step_records'].append({
'step_index': i + 1,
'ddim_timestep': int(step),
'img': img.detach().cpu(),
'pred_x0': pred_x0.detach().cpu(),
'action': action.detach().cpu(),
'state': state.detach().cpu(),
'step_time_s': step_time_s,
})
return img, action, state, intermediates