5 Commits
second ... qhy3

16 changed files with 590 additions and 209 deletions

View File

@@ -0,0 +1,10 @@
{
"permissions": {
"allow": [
"Bash(conda env list:*)",
"Bash(mamba env:*)",
"Bash(micromamba env list:*)",
"Bash(echo:*)"
]
}
}

7
.gitignore vendored
View File

@@ -55,7 +55,6 @@ coverage.xml
*.pot *.pot
# Django stuff: # Django stuff:
*.log
local_settings.py local_settings.py
db.sqlite3 db.sqlite3
@@ -121,7 +120,7 @@ localTest/
fig/ fig/
figure/ figure/
*.mp4 *.mp4
*.json
Data/ControlVAE.yml Data/ControlVAE.yml
Data/Misc Data/Misc
Data/Pretrained Data/Pretrained
@@ -129,4 +128,6 @@ Data/utils.py
Experiment/checkpoint Experiment/checkpoint
Experiment/log Experiment/log
*.ckpt *.ckpt
*.0

View File

@@ -222,7 +222,7 @@ data:
test: test:
target: unifolm_wma.data.wma_data.WMAData target: unifolm_wma.data.wma_data.WMAData
params: params:
data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts' data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
video_length: ${model.params.wma_config.params.temporal_length} video_length: ${model.params.wma_config.params.temporal_length}
frame_stride: 2 frame_stride: 2
load_raw_resolution: True load_raw_resolution: True

View File

@@ -16,6 +16,9 @@ from collections import OrderedDict
from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config from unifolm_wma.utils.utils import instantiate_from_config
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]: def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
""" """

View File

@@ -19,6 +19,9 @@ from fastapi.responses import JSONResponse
from typing import Any, Dict, Optional, Tuple, List from typing import Any, Dict, Optional, Tuple, List
from datetime import datetime from datetime import datetime
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from unifolm_wma.utils.utils import instantiate_from_config from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.models.samplers.ddim import DDIMSampler

View File

@@ -18,6 +18,9 @@ from collections import OrderedDict
from torch import nn from torch import nn
from eval_utils import populate_queues, log_to_tensorboard from eval_utils import populate_queues, log_to_tensorboard
from collections import deque from collections import deque
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from torch import Tensor from torch import Tensor
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from PIL import Image from PIL import Image
@@ -327,7 +330,8 @@ def image_guided_synthesis_sim_mode(
timestep_spacing: str = 'uniform', timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0, guidance_rescale: float = 0.0,
sim_mode: bool = True, sim_mode: bool = True,
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: decode_video: bool = True,
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor]:
""" """
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text). Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
@@ -350,10 +354,13 @@ def image_guided_synthesis_sim_mode(
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace". timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance. guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model. sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
decode_video (bool): Whether to decode latent samples to pixel-space video.
Set to False to skip VAE decode for speed when only actions/states are needed.
**kwargs: Additional arguments passed to the DDIM sampler. **kwargs: Additional arguments passed to the DDIM sampler.
Returns: Returns:
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W]. batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W],
or None when decode_video=False.
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding. actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding. states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
""" """
@@ -406,6 +413,7 @@ def image_guided_synthesis_sim_mode(
kwargs.update({"unconditional_conditioning_img_nonetext": None}) kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None cond_mask = None
cond_z0 = None cond_z0 = None
batch_variants = None
if ddim_sampler is not None: if ddim_sampler is not None:
samples, actions, states, intermedia = ddim_sampler.sample( samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps, S=ddim_steps,
@@ -424,9 +432,10 @@ def image_guided_synthesis_sim_mode(
guidance_rescale=guidance_rescale, guidance_rescale=guidance_rescale,
**kwargs) **kwargs)
# Reconstruct from latent to pixel space if decode_video:
batch_images = model.decode_first_stage(samples) # Reconstruct from latent to pixel space
batch_variants = batch_images batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states return batch_variants, actions, states
@@ -587,7 +596,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
fs=model_input_fs, fs=model_input_fs,
timestep_spacing=args.timestep_spacing, timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale, guidance_rescale=args.guidance_rescale,
sim_mode=False) sim_mode=False,
decode_video=not args.fast_policy_no_decode)
# Update future actions in the observation queues # Update future actions in the observation queues
for idx in range(len(pred_actions[0])): for idx in range(len(pred_actions[0])):
@@ -645,11 +655,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
observation) observation)
# Save the imagen videos for decision-making # Save the imagen videos for decision-making
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}" if pred_videos_0 is not None:
log_to_tensorboard(writer, sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
pred_videos_0, log_to_tensorboard(writer,
sample_tag, pred_videos_0,
fps=args.save_fps) sample_tag,
fps=args.save_fps)
# Save videos environment changes via world-model interaction # Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}" sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer, log_to_tensorboard(writer,
@@ -658,10 +669,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
fps=args.save_fps) fps=args.save_fps)
# Save the imagen videos for decision-making # Save the imagen videos for decision-making
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' if pred_videos_0 is not None:
save_results(pred_videos_0.cpu(), sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
sample_video_file, save_results(pred_videos_0.cpu(),
fps=args.save_fps) sample_video_file,
fps=args.save_fps)
# Save videos environment changes via world-model interaction # Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_1.cpu(), save_results(pred_videos_1.cpu(),
@@ -794,6 +806,11 @@ def get_parser():
action='store_true', action='store_true',
default=False, default=False,
help="not using the predicted states as comparison") help="not using the predicted states as comparison")
parser.add_argument(
"--fast_policy_no_decode",
action='store_true',
default=False,
help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.")
parser.add_argument("--save_fps", parser.add_argument("--save_fps",
type=int, type=int,
default=8, default=8,

View File

@@ -11,6 +11,9 @@ from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def get_parser(**parser_kwargs): def get_parser(**parser_kwargs):
parser = argparse.ArgumentParser(**parser_kwargs) parser = argparse.ArgumentParser(**parser_kwargs)

View File

@@ -501,6 +501,10 @@ class ConditionalUnet1D(nn.Module):
self.last_frame_only = last_frame_only self.last_frame_only = last_frame_only
self.horizon = horizon self.horizon = horizon
# Context precomputation cache
self._global_cond_cache_enabled = False
self._global_cond_cache = {}
def forward(self, def forward(self,
sample: torch.Tensor, sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
@@ -530,14 +534,20 @@ class ConditionalUnet1D(nn.Module):
B, T, D = sample.shape B, T, D = sample.shape
if self.use_linear_act_proj: if self.use_linear_act_proj:
sample = self.proj_in_action(sample.unsqueeze(-1)) sample = self.proj_in_action(sample.unsqueeze(-1))
global_cond = self.obs_encoder(cond) _gc_key = (cond['image'].data_ptr(), cond['agent_pos'].data_ptr())
global_cond = rearrange(global_cond, if self._global_cond_cache_enabled and _gc_key in self._global_cond_cache:
'(b t) d -> b 1 (t d)', global_cond = self._global_cond_cache[_gc_key]
b=B, else:
t=self.n_obs_steps) global_cond = self.obs_encoder(cond)
global_cond = repeat(global_cond, global_cond = rearrange(global_cond,
'b c d -> b (repeat c) d', '(b t) d -> b 1 (t d)',
repeat=T) b=B,
t=self.n_obs_steps)
global_cond = repeat(global_cond,
'b c d -> b (repeat c) d',
repeat=T)
if self._global_cond_cache_enabled:
self._global_cond_cache[_gc_key] = global_cond
else: else:
sample = einops.rearrange(sample, 'b h t -> b t h') sample = einops.rearrange(sample, 'b h t -> b t h')
sample = self.proj_in_horizon(sample) sample = self.proj_in_horizon(sample)

View File

@@ -6,6 +6,8 @@ from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim
from unifolm_wma.utils.common import noise_like from unifolm_wma.utils.common import noise_like
from unifolm_wma.utils.common import extract_into_tensor from unifolm_wma.utils.common import extract_into_tensor
from tqdm import tqdm from tqdm import tqdm
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache
class DDIMSampler(object): class DDIMSampler(object):
@@ -67,11 +69,12 @@ class DDIMSampler(object):
ddim_timesteps=self.ddim_timesteps, ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta, eta=ddim_eta,
verbose=verbose) verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas) # Ensure tensors are on correct device for efficient indexing
self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_sigmas', to_torch(torch.as_tensor(ddim_sigmas)))
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_alphas', to_torch(torch.as_tensor(ddim_alphas)))
self.register_buffer('ddim_alphas_prev', to_torch(torch.as_tensor(ddim_alphas_prev)))
self.register_buffer('ddim_sqrt_one_minus_alphas', self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas)) to_torch(torch.as_tensor(np.sqrt(1. - ddim_alphas))))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev)) (1 - self.alphas_cumprod / self.alphas_cumprod_prev))
@@ -241,63 +244,70 @@ class DDIMSampler(object):
dp_ddim_scheduler_action.set_timesteps(len(timesteps)) dp_ddim_scheduler_action.set_timesteps(len(timesteps))
dp_ddim_scheduler_state.set_timesteps(len(timesteps)) dp_ddim_scheduler_state.set_timesteps(len(timesteps))
for i, step in enumerate(iterator): ts = torch.empty((b, ), device=device, dtype=torch.long)
index = total_steps - i - 1 enable_cross_attn_kv_cache(self.model)
ts = torch.full((b, ), step, device=device, dtype=torch.long) enable_ctx_cache(self.model)
try:
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts.fill_(step)
# Use mask to blend noised original latent (img_orig) & new sampled latent (img) # Use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None: if mask is not None:
assert x0 is not None assert x0 is not None
if clean_cond: if clean_cond:
img_orig = x0 img_orig = x0
else: else:
img_orig = self.model.q_sample(x0, ts) img_orig = self.model.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim( outs = self.p_sample_ddim(
img, img,
action, action,
state, state,
cond, cond,
ts, ts,
index=index, index=index,
use_original_steps=ddim_use_original_steps, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, quantize_denoised=quantize_denoised,
temperature=temperature, temperature=temperature,
noise_dropout=noise_dropout, noise_dropout=noise_dropout,
score_corrector=score_corrector, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
mask=mask, mask=mask,
x0=x0, x0=x0,
fs=fs, fs=fs,
guidance_rescale=guidance_rescale, guidance_rescale=guidance_rescale,
**kwargs) **kwargs)
img, pred_x0, model_output_action, model_output_state = outs img, pred_x0, model_output_action, model_output_state = outs
action = dp_ddim_scheduler_action.step( action = dp_ddim_scheduler_action.step(
model_output_action, model_output_action,
step, step,
action, action,
generator=None, generator=None,
).prev_sample ).prev_sample
state = dp_ddim_scheduler_state.step( state = dp_ddim_scheduler_state.step(
model_output_state, model_output_state,
step, step,
state, state,
generator=None, generator=None,
).prev_sample ).prev_sample
if callback: callback(i) if callback: callback(i)
if img_callback: img_callback(pred_x0, i) if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1: if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img) intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0) intermediates['pred_x0'].append(pred_x0)
intermediates['x_inter_action'].append(action) intermediates['x_inter_action'].append(action)
intermediates['x_inter_state'].append(state) intermediates['x_inter_state'].append(state)
finally:
disable_cross_attn_kv_cache(self.model)
disable_ctx_cache(self.model)
return img, action, state, intermediates return img, action, state, intermediates
@@ -325,10 +335,6 @@ class DDIMSampler(object):
guidance_rescale=0.0, guidance_rescale=0.0,
**kwargs): **kwargs):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
if x.dim() == 5:
is_video = True
else:
is_video = False
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output, model_output_action, model_output_state = self.model.apply_model( model_output, model_output_action, model_output_state = self.model.apply_model(
@@ -377,17 +383,11 @@ class DDIMSampler(object):
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if is_video: # Use 0-d tensors directly (already on device); broadcasting handles shape
size = (b, 1, 1, 1, 1) a_t = alphas[index]
else: a_prev = alphas_prev[index]
size = (b, 1, 1, 1) sigma_t = sigmas[index]
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
a_t = torch.full(size, alphas[index], device=device)
a_prev = torch.full(size, alphas_prev[index], device=device)
sigma_t = torch.full(size, sigmas[index], device=device)
sqrt_one_minus_at = torch.full(size,
sqrt_one_minus_alphas[index],
device=device)
if self.model.parameterization != "v": if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
@@ -395,12 +395,8 @@ class DDIMSampler(object):
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if self.model.use_dynamic_rescale: if self.model.use_dynamic_rescale:
scale_t = torch.full(size, scale_t = self.ddim_scale_arr[index]
self.ddim_scale_arr[index], prev_scale_t = self.ddim_scale_arr_prev[index]
device=device)
prev_scale_t = torch.full(size,
self.ddim_scale_arr_prev[index],
device=device)
rescale = (prev_scale_t / scale_t) rescale = (prev_scale_t / scale_t)
pred_x0 *= rescale pred_x0 *= rescale

View File

@@ -98,6 +98,9 @@ class CrossAttention(nn.Module):
self.text_context_len = text_context_len self.text_context_len = text_context_len
self.agent_state_context_len = agent_state_context_len self.agent_state_context_len = agent_state_context_len
self.agent_action_context_len = agent_action_context_len self.agent_action_context_len = agent_action_context_len
self._kv_cache = {}
self._kv_cache_enabled = False
self.cross_attention_scale_learnable = cross_attention_scale_learnable self.cross_attention_scale_learnable = cross_attention_scale_learnable
if self.image_cross_attention: if self.image_cross_attention:
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
@@ -236,17 +239,42 @@ class CrossAttention(nn.Module):
k_ip, v_ip, out_ip = None, None, None k_ip, v_ip, out_ip = None, None, None
k_as, v_as, out_as = None, None, None k_as, v_as, out_as = None, None, None
k_aa, v_aa, out_aa = None, None, None k_aa, v_aa, out_aa = None, None, None
attn_mask_aa = None
h = self.heads
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
if self.image_cross_attention and not spatial_self_attn: b, _, _ = q.shape
q = q.unsqueeze(3).reshape(b, q.shape[1], h, self.dim_head).permute(0, 2, 1, 3).reshape(b * h, q.shape[1], self.dim_head).contiguous()
def _reshape_kv(t):
return t.unsqueeze(3).reshape(b, t.shape[1], h, self.dim_head).permute(0, 2, 1, 3).reshape(b * h, t.shape[1], self.dim_head).contiguous()
use_cache = self._kv_cache_enabled and not spatial_self_attn
cache_hit = use_cache and len(self._kv_cache) > 0
if cache_hit:
k = self._kv_cache['k']
v = self._kv_cache['v']
k_ip = self._kv_cache.get('k_ip')
v_ip = self._kv_cache.get('v_ip')
k_as = self._kv_cache.get('k_as')
v_as = self._kv_cache.get('v_as')
k_aa = self._kv_cache.get('k_aa')
v_aa = self._kv_cache.get('v_aa')
attn_mask_aa = self._kv_cache.get('attn_mask_aa')
elif self.image_cross_attention and not spatial_self_attn:
if context.shape[1] == self.text_context_len + self.video_length: if context.shape[1] == self.text_context_len + self.video_length:
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :] context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
k_ip = self.to_k_ip(context_image) k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image) v_ip = self.to_v_ip(context_image)
k, v = map(_reshape_kv, (k, v))
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
if use_cache:
self._kv_cache = {'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip}
elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length: elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
context_agent_state = context[:, :self.agent_state_context_len, :] context_agent_state = context[:, :self.agent_state_context_len, :]
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :] context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
@@ -257,6 +285,11 @@ class CrossAttention(nn.Module):
v_ip = self.to_v_ip(context_image) v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state) k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state) v_as = self.to_v_as(context_agent_state)
k, v = map(_reshape_kv, (k, v))
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
k_as, v_as = map(_reshape_kv, (k_as, v_as))
if use_cache:
self._kv_cache = {'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip, 'k_as': k_as, 'v_as': v_as}
else: else:
context_agent_state = context[:, :self.agent_state_context_len, :] context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :] context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
@@ -272,98 +305,78 @@ class CrossAttention(nn.Module):
k_aa = self.to_k_aa(context_agent_action) k_aa = self.to_k_aa(context_agent_action)
v_aa = self.to_v_aa(context_agent_action) v_aa = self.to_v_aa(context_agent_action)
attn_mask_aa = self._get_attn_mask_aa(x.shape[0], k, v = map(_reshape_kv, (k, v))
q.shape[1], k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
k_aa.shape[1], k_as, v_as = map(_reshape_kv, (k_as, v_as))
block_size=16).to(k_aa.device) k_aa, v_aa = map(_reshape_kv, (k_aa, v_aa))
attn_mask_aa_raw = self._get_attn_mask_aa(x.shape[0],
q.shape[1],
k_aa.shape[1],
block_size=16,
device=k_aa.device)
attn_mask_aa = attn_mask_aa_raw.unsqueeze(1).repeat(1, h, 1, 1).reshape(
b * h, attn_mask_aa_raw.shape[1], attn_mask_aa_raw.shape[2]).to(q.dtype)
if use_cache:
self._kv_cache = {
'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip,
'k_as': k_as, 'v_as': v_as, 'k_aa': k_aa, 'v_aa': v_aa,
'attn_mask_aa': attn_mask_aa,
}
else: else:
if not spatial_self_attn: if not spatial_self_attn:
assert 1 > 2, ">>> ERROR: you should never go into here ..." assert 1 > 2, ">>> ERROR: you should never go into here ..."
context = context[:, :self.text_context_len, :] context = context[:, :self.text_context_len, :]
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
k, v = map(_reshape_kv, (k, v))
b, _, _ = q.shape if use_cache:
q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous() self._kv_cache = {'k': k, 'v': v}
if k is not None: if k is not None:
k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(k, v),
)
out = xformers.ops.memory_efficient_attention(q, out = xformers.ops.memory_efficient_attention(q,
k, k,
v, v,
attn_bias=None, attn_bias=None,
op=None) op=None)
out = (out.unsqueeze(0).reshape( out = (out.unsqueeze(0).reshape(
b, self.heads, out.shape[1], b, h, out.shape[1],
self.dim_head).permute(0, 2, 1, self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1], 3).reshape(b, out.shape[1],
self.heads * self.dim_head)) h * self.dim_head))
if k_ip is not None: if k_ip is not None:
# For image cross-attention
k_ip, v_ip = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_ip, v_ip),
)
out_ip = xformers.ops.memory_efficient_attention(q, out_ip = xformers.ops.memory_efficient_attention(q,
k_ip, k_ip,
v_ip, v_ip,
attn_bias=None, attn_bias=None,
op=None) op=None)
out_ip = (out_ip.unsqueeze(0).reshape( out_ip = (out_ip.unsqueeze(0).reshape(
b, self.heads, out_ip.shape[1], b, h, out_ip.shape[1],
self.dim_head).permute(0, 2, 1, self.dim_head).permute(0, 2, 1,
3).reshape(b, out_ip.shape[1], 3).reshape(b, out_ip.shape[1],
self.heads * self.dim_head)) h * self.dim_head))
if k_as is not None: if k_as is not None:
# For agent state cross-attention
k_as, v_as = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_as, v_as),
)
out_as = xformers.ops.memory_efficient_attention(q, out_as = xformers.ops.memory_efficient_attention(q,
k_as, k_as,
v_as, v_as,
attn_bias=None, attn_bias=None,
op=None) op=None)
out_as = (out_as.unsqueeze(0).reshape( out_as = (out_as.unsqueeze(0).reshape(
b, self.heads, out_as.shape[1], b, h, out_as.shape[1],
self.dim_head).permute(0, 2, 1, self.dim_head).permute(0, 2, 1,
3).reshape(b, out_as.shape[1], 3).reshape(b, out_as.shape[1],
self.heads * self.dim_head)) h * self.dim_head))
if k_aa is not None: if k_aa is not None:
# For agent action cross-attention
k_aa, v_aa = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_aa, v_aa),
)
attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape(
b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2])
attn_mask_aa = attn_mask_aa.to(q.dtype)
out_aa = xformers.ops.memory_efficient_attention( out_aa = xformers.ops.memory_efficient_attention(
q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None) q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
out_aa = (out_aa.unsqueeze(0).reshape( out_aa = (out_aa.unsqueeze(0).reshape(
b, self.heads, out_aa.shape[1], b, h, out_aa.shape[1],
self.dim_head).permute(0, 2, 1, self.dim_head).permute(0, 2, 1,
3).reshape(b, out_aa.shape[1], 3).reshape(b, out_aa.shape[1],
self.heads * self.dim_head)) h * self.dim_head))
if exists(mask): if exists(mask):
raise NotImplementedError raise NotImplementedError
@@ -386,17 +399,43 @@ class CrossAttention(nn.Module):
return self.to_out(out) return self.to_out(out)
def _get_attn_mask_aa(self, b, l1, l2, block_size=16): def _get_attn_mask_aa(self, b, l1, l2, block_size=16, device=None):
cache_key = (b, l1, l2, block_size)
if hasattr(self, '_attn_mask_aa_cache_key') and self._attn_mask_aa_cache_key == cache_key:
cached = self._attn_mask_aa_cache
if device is not None and cached.device != torch.device(device):
cached = cached.to(device)
self._attn_mask_aa_cache = cached
return cached
target_device = device if device is not None else 'cpu'
num_token = l2 // block_size num_token = l2 // block_size
start_positions = ((torch.arange(b) % block_size) + 1) * num_token start_positions = ((torch.arange(b, device=target_device) % block_size) + 1) * num_token
col_indices = torch.arange(l2) col_indices = torch.arange(l2, device=target_device)
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1) mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
mask = mask_2d.unsqueeze(1).expand(b, l1, l2) mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
attn_mask = torch.zeros_like(mask, dtype=torch.float) attn_mask = torch.zeros(b, l1, l2, dtype=torch.float, device=target_device)
attn_mask[mask] = float('-inf') attn_mask[mask] = float('-inf')
self._attn_mask_aa_cache_key = cache_key
self._attn_mask_aa_cache = attn_mask
return attn_mask return attn_mask
def enable_cross_attn_kv_cache(module):
for m in module.modules():
if isinstance(m, CrossAttention):
m._kv_cache_enabled = True
m._kv_cache = {}
def disable_cross_attn_kv_cache(module):
for m in module.modules():
if isinstance(m, CrossAttention):
m._kv_cache_enabled = False
m._kv_cache = {}
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, def __init__(self,

View File

@@ -685,6 +685,10 @@ class WMAModel(nn.Module):
self.action_token_projector = instantiate_from_config( self.action_token_projector = instantiate_from_config(
stem_process_config) stem_process_config)
# Context precomputation cache
self._ctx_cache_enabled = False
self._ctx_cache = {}
def forward(self, def forward(self,
x: Tensor, x: Tensor,
x_action: Tensor, x_action: Tensor,
@@ -720,58 +724,64 @@ class WMAModel(nn.Module):
repeat_only=False).type(x.dtype) repeat_only=False).type(x.dtype)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
bt, l_context, _ = context.shape _ctx_key = context.data_ptr()
if self.base_model_gen_only: if self._ctx_cache_enabled and _ctx_key in self._ctx_cache:
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE context = self._ctx_cache[_ctx_key]
else: else:
if l_context == self.n_obs_steps + 77 + t * 16: bt, l_context, _ = context.shape
context_agent_state = context[:, :self.n_obs_steps] if self.base_model_gen_only:
context_text = context[:, self.n_obs_steps:self.n_obs_steps + assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
77, :] else:
context_img = context[:, self.n_obs_steps + 77:, :] if l_context == self.n_obs_steps + 77 + t * 16:
context_agent_state = context_agent_state.repeat_interleave( context_agent_state = context[:, :self.n_obs_steps]
repeats=t, dim=0) context_text = context[:, self.n_obs_steps:self.n_obs_steps +
context_text = context_text.repeat_interleave(repeats=t, dim=0) 77, :]
context_img = rearrange(context_img, context_img = context[:, self.n_obs_steps + 77:, :]
'b (t l) c -> (b t) l c', context_agent_state = context_agent_state.repeat_interleave(
t=t) repeats=t, dim=0)
context = torch.cat( context_text = context_text.repeat_interleave(repeats=t, dim=0)
[context_agent_state, context_text, context_img], dim=1) context_img = rearrange(context_img,
elif l_context == self.n_obs_steps + 16 + 77 + t * 16: 'b (t l) c -> (b t) l c',
context_agent_state = context[:, :self.n_obs_steps] t=t)
context_agent_action = context[:, self. context = torch.cat(
n_obs_steps:self.n_obs_steps + [context_agent_state, context_text, context_img], dim=1)
16, :] elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
context_agent_action = rearrange( context_agent_state = context[:, :self.n_obs_steps]
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d') context_agent_action = context[:, self.
context_agent_action = self.action_token_projector( n_obs_steps:self.n_obs_steps +
context_agent_action) 16, :]
context_agent_action = rearrange(context_agent_action, context_agent_action = rearrange(
'(b o) l d -> b o l d', context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
o=t) context_agent_action = self.action_token_projector(
context_agent_action = rearrange(context_agent_action, context_agent_action)
'b o (t l) d -> b o t l d', context_agent_action = rearrange(context_agent_action,
t=t) '(b o) l d -> b o l d',
context_agent_action = context_agent_action.permute( o=t)
0, 2, 1, 3, 4) context_agent_action = rearrange(context_agent_action,
context_agent_action = rearrange(context_agent_action, 'b o (t l) d -> b o t l d',
'b t o l d -> (b t) (o l) d') t=t)
context_agent_action = context_agent_action.permute(
0, 2, 1, 3, 4)
context_agent_action = rearrange(context_agent_action,
'b t o l d -> (b t) (o l) d')
context_text = context[:, self.n_obs_steps + context_text = context[:, self.n_obs_steps +
16:self.n_obs_steps + 16 + 77, :] 16:self.n_obs_steps + 16 + 77, :]
context_text = context_text.repeat_interleave(repeats=t, dim=0) context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = context[:, self.n_obs_steps + 16 + 77:, :] context_img = context[:, self.n_obs_steps + 16 + 77:, :]
context_img = rearrange(context_img, context_img = rearrange(context_img,
'b (t l) c -> (b t) l c', 'b (t l) c -> (b t) l c',
t=t) t=t)
context_agent_state = context_agent_state.repeat_interleave( context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0) repeats=t, dim=0)
context = torch.cat([ context = torch.cat([
context_agent_state, context_agent_action, context_text, context_agent_state, context_agent_action, context_text,
context_img context_img
], ],
dim=1) dim=1)
if self._ctx_cache_enabled:
self._ctx_cache[_ctx_key] = context
emb = emb.repeat_interleave(repeats=t, dim=0) emb = emb.repeat_interleave(repeats=t, dim=0)
@@ -846,3 +856,30 @@ class WMAModel(nn.Module):
s_y = torch.zeros_like(x_state) s_y = torch.zeros_like(x_state)
return y, a_y, s_y return y, a_y, s_y
def enable_ctx_cache(model):
"""Enable context precomputation cache on WMAModel and its action/state UNets."""
for m in model.modules():
if isinstance(m, WMAModel):
m._ctx_cache_enabled = True
m._ctx_cache = {}
# conditional_unet1d cache
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
if isinstance(m, ConditionalUnet1D):
m._global_cond_cache_enabled = True
m._global_cond_cache = {}
def disable_ctx_cache(model):
"""Disable and clear context precomputation cache."""
for m in model.modules():
if isinstance(m, WMAModel):
m._ctx_cache_enabled = False
m._ctx_cache = {}
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
if isinstance(m, ConditionalUnet1D):
m._global_cond_cache_enabled = False
m._global_cond_cache = {}

View File

@@ -0,0 +1,121 @@
2026-02-10 15:38:28.973314: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-10 15:38:29.023024: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-10 15:38:29.023070: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-10 15:38:29.024393: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-10 15:38:29.031901: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-10 15:38:29.955454: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/8 [00:00<?, ?it/s]>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:14<08:41, 74.51s/it]
25%|██▌ | 2/8 [02:29<07:28, 74.79s/it]
38%|███▊ | 3/8 [03:44<06:14, 74.81s/it]
50%|█████ | 4/8 [04:59<04:59, 74.78s/it]
62%|██████▎ | 5/8 [06:13<03:44, 74.73s/it]
75%|███████▌ | 6/8 [07:28<02:29, 74.66s/it]
88%|████████▊ | 7/8 [08:42<01:14, 74.56s/it]
100%|██████████| 8/8 [09:56<00:00, 74.51s/it]
100%|██████████| 8/8 [09:56<00:00, 74.62s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case1/unitree_z1_dual_arm_cleanup_pencils_case1.mp4",
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
"psnr": 47.911564449209735
}

View File

@@ -0,0 +1,130 @@
2026-02-10 17:39:22.590654: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-10 17:39:22.640645: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-10 17:39:22.640689: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-10 17:39:22.642010: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-10 17:39:22.649530: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-10 17:39:23.575804: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [00:35<05:55, 35.52s/it]
18%|█▊ | 2/11 [01:11<05:21, 35.73s/it]
27%|██▋ | 3/11 [01:47<04:48, 36.04s/it]
36%|███▋ | 4/11 [02:24<04:13, 36.19s/it]
45%|████▌ | 5/11 [03:00<03:37, 36.25s/it]
55%|█████▍ | 6/11 [03:36<03:00, 36.16s/it]
64%|██████▎ | 7/11 [04:12<02:24, 36.09s/it]
73%|███████▎ | 8/11 [04:48<01:48, 36.08s/it]
82%|████████▏ | 9/11 [05:24<01:12, 36.06s/it]
91%|█████████ | 10/11 [06:00<00:36, 36.07s/it]
100%|██████████| 11/11 [06:36<00:00, 36.07s/it]
100%|██████████| 11/11 [06:36<00:00, 36.07s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
"pred_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
"psnr": 25.12008483689618
}

View File

@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_stackbox_v2"
--n_iter 11 \ --n_iter 11 \
--timestep_spacing 'uniform_trailing' \ --timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \ --guidance_rescale 0.7 \
--perframe_ae --perframe_ae \
--fast_policy_no_decode
} 2>&1 | tee "${res_dir}/output.log" } 2>&1 | tee "${res_dir}/output.log"