Compare commits
4 Commits
a09d35ae5b
...
third
| Author | SHA1 | Date | |
|---|---|---|---|
| 25c6a328ef | |||
| 1d23e5d36d | |||
| 57ba85d147 | |||
| 2cef3e9e45 |
@@ -559,6 +559,7 @@ def image_guided_synthesis_sim_mode(
|
|||||||
autocast_ctx = nullcontext()
|
autocast_ctx = nullcontext()
|
||||||
|
|
||||||
batch_variants = None
|
batch_variants = None
|
||||||
|
samples = None
|
||||||
if ddim_sampler is not None:
|
if ddim_sampler is not None:
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||||
@@ -583,7 +584,7 @@ def image_guided_synthesis_sim_mode(
|
|||||||
batch_images = model.decode_first_stage(samples)
|
batch_images = model.decode_first_stage(samples)
|
||||||
batch_variants = batch_images
|
batch_variants = batch_images
|
||||||
|
|
||||||
return batch_variants, actions, states
|
return batch_variants, actions, states, samples
|
||||||
|
|
||||||
|
|
||||||
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||||
@@ -625,6 +626,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
# Compile hot ResBlocks for operator fusion
|
# Compile hot ResBlocks for operator fusion
|
||||||
apply_torch_compile(model)
|
apply_torch_compile(model)
|
||||||
|
|
||||||
|
# Fuse KV projections in attention layers (to_k + to_v → to_kv)
|
||||||
|
from unifolm_wma.modules.attention import CrossAttention
|
||||||
|
kv_count = sum(1 for m in model.modules()
|
||||||
|
if isinstance(m, CrossAttention) and m.fuse_kv())
|
||||||
|
print(f" ✓ KV fused: {kv_count} attention layers")
|
||||||
|
|
||||||
# Export precision-converted checkpoint if requested
|
# Export precision-converted checkpoint if requested
|
||||||
if args.export_precision_ckpt:
|
if args.export_precision_ckpt:
|
||||||
export_path = args.export_precision_ckpt
|
export_path = args.export_precision_ckpt
|
||||||
@@ -687,7 +694,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
sample_save_dir = f'{video_save_dir}/wm/{fs}'
|
sample_save_dir = f'{video_save_dir}/wm/{fs}'
|
||||||
os.makedirs(sample_save_dir, exist_ok=True)
|
os.makedirs(sample_save_dir, exist_ok=True)
|
||||||
# For collecting interaction videos
|
# For collecting interaction videos
|
||||||
wm_video = []
|
wm_latent = []
|
||||||
# Initialize observation queues
|
# Initialize observation queues
|
||||||
cond_obs_queues = {
|
cond_obs_queues = {
|
||||||
"observation.images.top":
|
"observation.images.top":
|
||||||
@@ -743,7 +750,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
# Use world-model in policy to generate action
|
# Use world-model in policy to generate action
|
||||||
print(f'>>> Step {itr}: generating actions ...')
|
print(f'>>> Step {itr}: generating actions ...')
|
||||||
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
pred_videos_0, pred_actions, _, _ = image_guided_synthesis_sim_mode(
|
||||||
model,
|
model,
|
||||||
sample['instruction'],
|
sample['instruction'],
|
||||||
observation,
|
observation,
|
||||||
@@ -785,7 +792,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
# Interaction with the world-model
|
# Interaction with the world-model
|
||||||
print(f'>>> Step {itr}: interacting with world model ...')
|
print(f'>>> Step {itr}: interacting with world model ...')
|
||||||
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
pred_videos_1, _, pred_states, wm_samples = image_guided_synthesis_sim_mode(
|
||||||
model,
|
model,
|
||||||
"",
|
"",
|
||||||
observation,
|
observation,
|
||||||
@@ -798,12 +805,16 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
fs=model_input_fs,
|
fs=model_input_fs,
|
||||||
text_input=False,
|
text_input=False,
|
||||||
timestep_spacing=args.timestep_spacing,
|
timestep_spacing=args.timestep_spacing,
|
||||||
guidance_rescale=args.guidance_rescale)
|
guidance_rescale=args.guidance_rescale,
|
||||||
|
decode_video=False)
|
||||||
|
|
||||||
|
# Decode only the last frame for CLIP embedding in next iteration
|
||||||
|
last_frame_pixel = model.decode_first_stage(wm_samples[:, :, -1:, :, :])
|
||||||
|
|
||||||
for idx in range(args.exe_steps):
|
for idx in range(args.exe_steps):
|
||||||
observation = {
|
observation = {
|
||||||
'observation.images.top':
|
'observation.images.top':
|
||||||
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
last_frame_pixel[0, :, 0:1].permute(1, 0, 2, 3),
|
||||||
'observation.state':
|
'observation.state':
|
||||||
torch.zeros_like(pred_states[0][idx:idx + 1]) if
|
torch.zeros_like(pred_states[0][idx:idx + 1]) if
|
||||||
args.zero_pred_state else pred_states[0][idx:idx + 1],
|
args.zero_pred_state else pred_states[0][idx:idx + 1],
|
||||||
@@ -821,30 +832,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
pred_videos_0,
|
pred_videos_0,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
# Save videos environment changes via world-model interaction
|
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
|
||||||
log_to_tensorboard(writer,
|
|
||||||
pred_videos_1,
|
|
||||||
sample_tag,
|
|
||||||
fps=args.save_fps)
|
|
||||||
|
|
||||||
# Save the imagen videos for decision-making
|
|
||||||
if pred_videos_0 is not None:
|
|
||||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
|
||||||
save_results(pred_videos_0.cpu(),
|
|
||||||
sample_video_file,
|
|
||||||
fps=args.save_fps)
|
|
||||||
# Save videos environment changes via world-model interaction
|
|
||||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
|
||||||
save_results(pred_videos_1.cpu(),
|
|
||||||
sample_video_file,
|
|
||||||
fps=args.save_fps)
|
|
||||||
|
|
||||||
print('>' * 24)
|
print('>' * 24)
|
||||||
# Collect the result of world-model interactions
|
# Store raw latent for deferred decode
|
||||||
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
|
wm_latent.append(wm_samples[:, :, :args.exe_steps].cpu())
|
||||||
|
|
||||||
full_video = torch.cat(wm_video, dim=2)
|
# Deferred decode: batch decode all stored latents
|
||||||
|
full_latent = torch.cat(wm_latent, dim=2).to(device)
|
||||||
|
full_video = model.decode_first_stage(full_latent).cpu()
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||||
log_to_tensorboard(writer,
|
log_to_tensorboard(writer,
|
||||||
full_video,
|
full_video,
|
||||||
|
|||||||
@@ -567,6 +567,11 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
timesteps = timesteps.expand(sample.shape[0])
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
global_feature = self.diffusion_step_encoder(timesteps)
|
global_feature = self.diffusion_step_encoder(timesteps)
|
||||||
|
# Pre-expand global_feature once (reused in every down/mid/up block)
|
||||||
|
if self.use_linear_act_proj:
|
||||||
|
global_feature_expanded = global_feature.unsqueeze(1).expand(-1, T, -1)
|
||||||
|
else:
|
||||||
|
global_feature_expanded = global_feature.unsqueeze(1).expand(-1, 2, -1)
|
||||||
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
|
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
|
||||||
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
|
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
|
||||||
|
|
||||||
@@ -603,15 +608,11 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
|
|
||||||
if self.use_linear_act_proj:
|
if self.use_linear_act_proj:
|
||||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(
|
|
||||||
1).repeat_interleave(repeats=T, dim=1)
|
|
||||||
else:
|
else:
|
||||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(
|
|
||||||
1).repeat_interleave(repeats=2, dim=1)
|
|
||||||
cur_global_feature = torch.cat(
|
cur_global_feature = torch.cat(
|
||||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
|
||||||
x = resnet(x, cur_global_feature)
|
x = resnet(x, cur_global_feature)
|
||||||
x = resnet2(x, cur_global_feature)
|
x = resnet2(x, cur_global_feature)
|
||||||
h.append(x)
|
h.append(x)
|
||||||
@@ -638,15 +639,11 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
||||||
if self.use_linear_act_proj:
|
if self.use_linear_act_proj:
|
||||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
|
||||||
repeats=T, dim=1)
|
|
||||||
else:
|
else:
|
||||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
|
||||||
repeats=2, dim=1)
|
|
||||||
cur_global_feature = torch.cat(
|
cur_global_feature = torch.cat(
|
||||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
|
||||||
x = resnet(x, cur_global_feature)
|
x = resnet(x, cur_global_feature)
|
||||||
x = resnet2(x, cur_global_feature)
|
x = resnet2(x, cur_global_feature)
|
||||||
|
|
||||||
@@ -683,16 +680,12 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
|
|
||||||
if self.use_linear_act_proj:
|
if self.use_linear_act_proj:
|
||||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(
|
|
||||||
1).repeat_interleave(repeats=T, dim=1)
|
|
||||||
else:
|
else:
|
||||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(
|
|
||||||
1).repeat_interleave(repeats=2, dim=1)
|
|
||||||
|
|
||||||
cur_global_feature = torch.cat(
|
cur_global_feature = torch.cat(
|
||||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
|
||||||
|
|
||||||
x = torch.cat((x, h.pop()), dim=1)
|
x = torch.cat((x, h.pop()), dim=1)
|
||||||
x = resnet(x, cur_global_feature)
|
x = resnet(x, cur_global_feature)
|
||||||
|
|||||||
@@ -251,6 +251,13 @@ class DDIMSampler(object):
|
|||||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||||
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
||||||
|
noise_buf = torch.empty_like(img)
|
||||||
|
# Pre-convert schedule arrays to inference dtype (avoid per-step .to())
|
||||||
|
_dtype = img.dtype
|
||||||
|
_alphas = (self.model.alphas_cumprod if ddim_use_original_steps else self.ddim_alphas).to(_dtype)
|
||||||
|
_alphas_prev = (self.model.alphas_cumprod_prev if ddim_use_original_steps else self.ddim_alphas_prev).to(_dtype)
|
||||||
|
_sqrt_one_minus = (self.model.sqrt_one_minus_alphas_cumprod if ddim_use_original_steps else self.ddim_sqrt_one_minus_alphas).to(_dtype)
|
||||||
|
_sigmas = (self.ddim_sigmas_for_original_num_steps if ddim_use_original_steps else self.ddim_sigmas).to(_dtype)
|
||||||
enable_cross_attn_kv_cache(self.model)
|
enable_cross_attn_kv_cache(self.model)
|
||||||
enable_ctx_cache(self.model)
|
enable_ctx_cache(self.model)
|
||||||
try:
|
try:
|
||||||
@@ -286,6 +293,8 @@ class DDIMSampler(object):
|
|||||||
x0=x0,
|
x0=x0,
|
||||||
fs=fs,
|
fs=fs,
|
||||||
guidance_rescale=guidance_rescale,
|
guidance_rescale=guidance_rescale,
|
||||||
|
noise_buf=noise_buf,
|
||||||
|
schedule_arrays=(_alphas, _alphas_prev, _sqrt_one_minus, _sigmas),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
img, pred_x0, model_output_action, model_output_state = outs
|
img, pred_x0, model_output_action, model_output_state = outs
|
||||||
@@ -339,6 +348,8 @@ class DDIMSampler(object):
|
|||||||
mask=None,
|
mask=None,
|
||||||
x0=None,
|
x0=None,
|
||||||
guidance_rescale=0.0,
|
guidance_rescale=0.0,
|
||||||
|
noise_buf=None,
|
||||||
|
schedule_arrays=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
@@ -384,16 +395,18 @@ class DDIMSampler(object):
|
|||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
||||||
**corrector_kwargs)
|
**corrector_kwargs)
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
if schedule_arrays is not None:
|
||||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
alphas, alphas_prev, sqrt_one_minus_alphas, sigmas = schedule_arrays
|
||||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
else:
|
||||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
alphas = (self.model.alphas_cumprod if use_original_steps else self.ddim_alphas).to(x.dtype)
|
||||||
|
alphas_prev = (self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev).to(x.dtype)
|
||||||
|
sqrt_one_minus_alphas = (self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas).to(x.dtype)
|
||||||
|
sigmas = (self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas).to(x.dtype)
|
||||||
|
|
||||||
# Use 0-d tensors directly (already on device); broadcasting handles shape
|
a_t = alphas[index]
|
||||||
a_t = alphas[index].to(x.dtype)
|
a_prev = alphas_prev[index]
|
||||||
a_prev = alphas_prev[index].to(x.dtype)
|
sigma_t = sigmas[index]
|
||||||
sigma_t = sigmas[index].to(x.dtype)
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
|
||||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].to(x.dtype)
|
|
||||||
|
|
||||||
if self.model.parameterization != "v":
|
if self.model.parameterization != "v":
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
@@ -411,6 +424,10 @@ class DDIMSampler(object):
|
|||||||
|
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
|
||||||
|
if noise_buf is not None:
|
||||||
|
noise_buf.normal_()
|
||||||
|
noise = sigma_t * noise_buf * temperature
|
||||||
|
else:
|
||||||
noise = sigma_t * noise_like(x.shape, device,
|
noise = sigma_t * noise_like(x.shape, device,
|
||||||
repeat_noise) * temperature
|
repeat_noise) * temperature
|
||||||
if noise_dropout > 0.:
|
if noise_dropout > 0.:
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ class CrossAttention(nn.Module):
|
|||||||
self.agent_action_context_len = agent_action_context_len
|
self.agent_action_context_len = agent_action_context_len
|
||||||
self._kv_cache = {}
|
self._kv_cache = {}
|
||||||
self._kv_cache_enabled = False
|
self._kv_cache_enabled = False
|
||||||
|
self._kv_fused = False
|
||||||
|
|
||||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||||
if self.image_cross_attention:
|
if self.image_cross_attention:
|
||||||
@@ -116,6 +117,27 @@ class CrossAttention(nn.Module):
|
|||||||
self.register_parameter('alpha_caa',
|
self.register_parameter('alpha_caa',
|
||||||
nn.Parameter(torch.tensor(0.)))
|
nn.Parameter(torch.tensor(0.)))
|
||||||
|
|
||||||
|
def fuse_kv(self):
|
||||||
|
"""Fuse to_k/to_v into to_kv (2 Linear → 1). Works for all layers."""
|
||||||
|
k_w = self.to_k.weight # (inner_dim, context_dim)
|
||||||
|
v_w = self.to_v.weight
|
||||||
|
self.to_kv = nn.Linear(k_w.shape[1], k_w.shape[0] * 2, bias=False)
|
||||||
|
self.to_kv.weight = nn.Parameter(torch.cat([k_w, v_w], dim=0))
|
||||||
|
del self.to_k, self.to_v
|
||||||
|
if self.image_cross_attention:
|
||||||
|
for suffix in ('_ip', '_as', '_aa'):
|
||||||
|
k_attr = f'to_k{suffix}'
|
||||||
|
v_attr = f'to_v{suffix}'
|
||||||
|
kw = getattr(self, k_attr).weight
|
||||||
|
vw = getattr(self, v_attr).weight
|
||||||
|
fused = nn.Linear(kw.shape[1], kw.shape[0] * 2, bias=False)
|
||||||
|
fused.weight = nn.Parameter(torch.cat([kw, vw], dim=0))
|
||||||
|
setattr(self, f'to_kv{suffix}', fused)
|
||||||
|
delattr(self, k_attr)
|
||||||
|
delattr(self, v_attr)
|
||||||
|
self._kv_fused = True
|
||||||
|
return True
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, mask=None):
|
||||||
spatial_self_attn = (context is None)
|
spatial_self_attn = (context is None)
|
||||||
k_ip, v_ip, out_ip = None, None, None
|
k_ip, v_ip, out_ip = None, None, None
|
||||||
@@ -276,6 +298,12 @@ class CrossAttention(nn.Module):
|
|||||||
self.agent_action_context_len +
|
self.agent_action_context_len +
|
||||||
self.text_context_len:, :]
|
self.text_context_len:, :]
|
||||||
|
|
||||||
|
if self._kv_fused:
|
||||||
|
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
||||||
|
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||||
|
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
||||||
|
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1)
|
||||||
|
else:
|
||||||
k = self.to_k(context_ins)
|
k = self.to_k(context_ins)
|
||||||
v = self.to_v(context_ins)
|
v = self.to_v(context_ins)
|
||||||
k_ip = self.to_k_ip(context_image)
|
k_ip = self.to_k_ip(context_image)
|
||||||
@@ -304,6 +332,9 @@ class CrossAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
if not spatial_self_attn:
|
if not spatial_self_attn:
|
||||||
context = context[:, :self.text_context_len, :]
|
context = context[:, :self.text_context_len, :]
|
||||||
|
if self._kv_fused:
|
||||||
|
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||||
|
else:
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
|
|
||||||
|
|||||||
@@ -690,6 +690,8 @@ class WMAModel(nn.Module):
|
|||||||
self._ctx_cache = {}
|
self._ctx_cache = {}
|
||||||
# fs_embed cache
|
# fs_embed cache
|
||||||
self._fs_embed_cache = None
|
self._fs_embed_cache = None
|
||||||
|
# Pre-created CUDA stream for parallel action/state UNet
|
||||||
|
self._side_stream = torch.cuda.Stream() if not self.base_model_gen_only else None
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@@ -849,8 +851,8 @@ class WMAModel(nn.Module):
|
|||||||
if not self.base_model_gen_only:
|
if not self.base_model_gen_only:
|
||||||
ba, _, _ = x_action.shape
|
ba, _, _ = x_action.shape
|
||||||
ts_state = timesteps[:ba] if b > 1 else timesteps
|
ts_state = timesteps[:ba] if b > 1 else timesteps
|
||||||
# Run action_unet and state_unet in parallel via CUDA streams
|
# Run action_unet and state_unet in parallel via pre-created CUDA stream
|
||||||
s_stream = torch.cuda.Stream()
|
s_stream = self._side_stream
|
||||||
s_stream.wait_stream(torch.cuda.current_stream())
|
s_stream.wait_stream(torch.cuda.current_stream())
|
||||||
with torch.cuda.stream(s_stream):
|
with torch.cuda.stream(s_stream):
|
||||||
s_y = self.state_unet(x_state, ts_state, hs_a,
|
s_y = self.state_unet(x_state, ts_state, hs_a,
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
2026-02-10 10:36:44.797852: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-11 06:58:19.745318: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-10 10:36:44.801300: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-11 06:58:19.748691: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-10 10:36:44.837891: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-11 06:58:19.782405: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-10 10:36:44.837946: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-11 06:58:19.782465: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-10 10:36:44.839880: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-11 06:58:19.784464: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-10 10:36:44.849073: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-11 06:58:19.793381: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-10 10:36:44.849365: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-11 06:58:19.794103: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-10 10:36:45.644793: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-11 06:58:20.607029: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
[rank: 0] Global seed set to 123
|
[rank: 0] Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
@@ -41,6 +41,7 @@ INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|||||||
⚠ Found 601 fp32 params, converting to bf16
|
⚠ Found 601 fp32 params, converting to bf16
|
||||||
✓ All parameters converted to bfloat16
|
✓ All parameters converted to bfloat16
|
||||||
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -65,8 +66,31 @@ DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
|||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
|
||||||
12%|█▎ | 1/8 [01:01<07:08, 61.25s/it]
|
12%|█▎ | 1/8 [01:01<07:08, 61.25s/it]
|
||||||
|
25%|██▌ | 2/8 [01:58<05:53, 58.90s/it]
|
||||||
|
38%|███▊ | 3/8 [02:55<04:50, 58.14s/it]
|
||||||
|
50%|█████ | 4/8 [03:52<03:51, 57.79s/it]
|
||||||
|
62%|██████▎ | 5/8 [04:50<02:52, 57.60s/it]
|
||||||
|
75%|███████▌ | 6/8 [05:47<01:54, 57.48s/it]
|
||||||
|
88%|████████▊ | 7/8 [06:44<00:57, 57.41s/it]
|
||||||
|
100%|██████████| 8/8 [07:42<00:00, 57.36s/it]
|
||||||
|
100%|██████████| 8/8 [07:42<00:00, 57.75s/it]
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
>>> Step 5: interacting with world model ...
|
>>> Step 5: interacting with world model ...
|
||||||
@@ -116,30 +140,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
|||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
|
|
||||||
12%|█▎ | 1/8 [01:06<07:46, 66.62s/it]
|
|
||||||
25%|██▌ | 2/8 [02:07<06:17, 62.97s/it]
|
|
||||||
38%|███▊ | 3/8 [03:07<05:08, 61.80s/it]
|
|
||||||
50%|█████ | 4/8 [04:07<04:05, 61.30s/it]
|
|
||||||
62%|██████▎ | 5/8 [05:08<03:03, 61.02s/it]
|
|
||||||
75%|███████▌ | 6/8 [06:08<02:01, 60.84s/it]
|
|
||||||
88%|████████▊ | 7/8 [07:09<01:00, 60.68s/it]
|
|
||||||
100%|██████████| 8/8 [08:09<00:00, 60.66s/it]
|
|
||||||
100%|██████████| 8/8 [08:09<00:00, 61.25s/it]
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 1: generating actions ...
|
|
||||||
>>> Step 1: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 2: generating actions ...
|
|
||||||
>>> Step 2: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 3: generating actions ...
|
|
||||||
>>> Step 3: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 4: generating actions ...
|
|
||||||
>>> Step 4: interacting with world model ...
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
>>> Step 5: generating actions ...
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
>>> Step 5: interacting with world model ...
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
||||||
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
||||||
"psnr": 31.802224855380352
|
"psnr": 30.34518638635329
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user