Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9a08e27a19 | |||
| b558856e1e | |||
| dcbcb2c377 | |||
| ff43432ef9 | |||
| afa12ba031 | |||
| bf4d66c874 | |||
| 9347a4ebe5 | |||
| 223a50f9e0 | |||
| 2a6068f9e4 | |||
| 91a9b0febc | |||
| ed637c972b |
15
.claude/settings.local.json
Normal file
15
.claude/settings.local.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(conda env list:*)",
|
||||
"Bash(mamba env:*)",
|
||||
"Bash(micromamba env list:*)",
|
||||
"Bash(echo:*)",
|
||||
"Bash(git show:*)",
|
||||
"Bash(nvidia-smi:*)",
|
||||
"Bash(conda activate unifolm-wma)",
|
||||
"Bash(conda info:*)",
|
||||
"Bash(direnv allow:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
2
.envrc
Normal file
2
.envrc
Normal file
@@ -0,0 +1,2 @@
|
||||
eval "$(conda shell.bash hook 2>/dev/null)"
|
||||
conda activate unifolm-wma
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -55,7 +55,6 @@ coverage.xml
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
@@ -121,7 +120,7 @@ localTest/
|
||||
fig/
|
||||
figure/
|
||||
*.mp4
|
||||
*.json
|
||||
|
||||
Data/ControlVAE.yml
|
||||
Data/Misc
|
||||
Data/Pretrained
|
||||
@@ -129,4 +128,7 @@ Data/utils.py
|
||||
Experiment/checkpoint
|
||||
Experiment/log
|
||||
|
||||
*.ckpt
|
||||
*.ckpt
|
||||
|
||||
*.0
|
||||
ckpts/unifolm_wma_dual.ckpt.prepared.pt
|
||||
|
||||
@@ -222,7 +222,7 @@ data:
|
||||
test:
|
||||
target: unifolm_wma.data.wma_data.WMAData
|
||||
params:
|
||||
data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
frame_stride: 2
|
||||
load_raw_resolution: True
|
||||
|
||||
@@ -16,6 +16,9 @@ from collections import OrderedDict
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
|
||||
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
||||
"""
|
||||
|
||||
@@ -19,6 +19,9 @@ from fastapi.responses import JSONResponse
|
||||
from typing import Any, Dict, Optional, Tuple, List
|
||||
from datetime import datetime
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ import logging
|
||||
import einops
|
||||
import warnings
|
||||
import imageio
|
||||
import atexit
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from omegaconf import OmegaConf
|
||||
@@ -16,8 +18,12 @@ from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from collections import OrderedDict
|
||||
from torch import nn
|
||||
from eval_utils import populate_queues, log_to_tensorboard
|
||||
from eval_utils import populate_queues
|
||||
from collections import deque
|
||||
from typing import Optional, List, Any
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from PIL import Image
|
||||
@@ -150,6 +156,81 @@ def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
# ========== Async I/O ==========
|
||||
_io_executor: Optional[ThreadPoolExecutor] = None
|
||||
_io_futures: List[Any] = []
|
||||
|
||||
|
||||
def _get_io_executor() -> ThreadPoolExecutor:
|
||||
global _io_executor
|
||||
if _io_executor is None:
|
||||
_io_executor = ThreadPoolExecutor(max_workers=2)
|
||||
return _io_executor
|
||||
|
||||
|
||||
def _flush_io():
|
||||
"""Wait for all pending async I/O to finish."""
|
||||
global _io_futures
|
||||
for fut in _io_futures:
|
||||
try:
|
||||
fut.result()
|
||||
except Exception as e:
|
||||
print(f">>> [async I/O] error: {e}")
|
||||
_io_futures.clear()
|
||||
|
||||
|
||||
atexit.register(_flush_io)
|
||||
|
||||
|
||||
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
|
||||
"""Synchronous save on CPU tensor (runs in background thread)."""
|
||||
video = torch.clamp(video_cpu.float(), -1., 1.)
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
torchvision.io.write_video(filename,
|
||||
grid,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
||||
"""Submit video saving to background thread pool."""
|
||||
video_cpu = video.detach().cpu()
|
||||
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
|
||||
_io_futures.append(fut)
|
||||
|
||||
|
||||
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
|
||||
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
|
||||
if video_cpu.dim() == 5:
|
||||
n = video_cpu.shape[0]
|
||||
video = video_cpu.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = grid.unsqueeze(dim=0)
|
||||
writer.add_video(tag, grid, fps=fps)
|
||||
|
||||
|
||||
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
|
||||
"""Submit TensorBoard logging to background thread pool."""
|
||||
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
||||
data_cpu = data.detach().cpu()
|
||||
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
|
||||
_io_futures.append(fut)
|
||||
|
||||
|
||||
def get_init_frame_path(data_dir: str, sample: dict) -> str:
|
||||
"""Construct the init_frame path from directory and sample metadata.
|
||||
|
||||
@@ -327,7 +408,8 @@ def image_guided_synthesis_sim_mode(
|
||||
timestep_spacing: str = 'uniform',
|
||||
guidance_rescale: float = 0.0,
|
||||
sim_mode: bool = True,
|
||||
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
decode_video: bool = True,
|
||||
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
|
||||
|
||||
@@ -350,10 +432,13 @@ def image_guided_synthesis_sim_mode(
|
||||
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
|
||||
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
||||
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
||||
decode_video (bool): Whether to decode latent samples to pixel-space video.
|
||||
Set to False to skip VAE decode for speed when only actions/states are needed.
|
||||
**kwargs: Additional arguments passed to the DDIM sampler.
|
||||
|
||||
Returns:
|
||||
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
|
||||
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W],
|
||||
or None when decode_video=False.
|
||||
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
||||
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
||||
"""
|
||||
@@ -406,6 +491,7 @@ def image_guided_synthesis_sim_mode(
|
||||
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
||||
cond_mask = None
|
||||
cond_z0 = None
|
||||
batch_variants = None
|
||||
if ddim_sampler is not None:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
@@ -424,9 +510,10 @@ def image_guided_synthesis_sim_mode(
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
if decode_video:
|
||||
# Reconstruct from latent to pixel space
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
|
||||
return batch_variants, actions, states
|
||||
|
||||
@@ -453,26 +540,51 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
# Load config
|
||||
# Load config (always needed for data setup)
|
||||
config = OmegaConf.load(args.config)
|
||||
config['model']['params']['wma_config']['params'][
|
||||
'use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
model = load_model_checkpoint(model, args.ckpt_path)
|
||||
model.eval()
|
||||
print(f'>>> Load pre-trained model ...')
|
||||
|
||||
# Build unnomalizer
|
||||
prepared_path = args.ckpt_path + ".prepared.pt"
|
||||
if os.path.exists(prepared_path):
|
||||
# ---- Fast path: load the fully-prepared model ----
|
||||
print(f">>> Loading prepared model from {prepared_path} ...")
|
||||
model = torch.load(prepared_path,
|
||||
map_location=f"cuda:{gpu_no}",
|
||||
weights_only=False,
|
||||
mmap=True)
|
||||
model.eval()
|
||||
print(f">>> Prepared model loaded.")
|
||||
else:
|
||||
# ---- Normal path: construct + load checkpoint ----
|
||||
config['model']['params']['wma_config']['params'][
|
||||
'use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
model = load_model_checkpoint(model, args.ckpt_path)
|
||||
model.eval()
|
||||
model = model.cuda(gpu_no)
|
||||
print(f'>>> Load pre-trained model ...')
|
||||
|
||||
# Save prepared model for fast loading next time
|
||||
print(f">>> Saving prepared model to {prepared_path} ...")
|
||||
torch.save(model, prepared_path)
|
||||
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
||||
|
||||
# Build normalizer (always needed, independent of model loading path)
|
||||
logging.info("***** Configing Data *****")
|
||||
data = instantiate_from_config(config.data)
|
||||
data.setup()
|
||||
print(">>> Dataset is successfully loaded ...")
|
||||
|
||||
model = model.cuda(gpu_no)
|
||||
device = get_device_from_parameters(model)
|
||||
|
||||
# Fuse KV projections in attention layers (to_k + to_v → to_kv)
|
||||
from unifolm_wma.modules.attention import CrossAttention
|
||||
kv_count = sum(1 for m in model.modules()
|
||||
if isinstance(m, CrossAttention) and m.fuse_kv())
|
||||
print(f" ✓ KV fused: {kv_count} attention layers")
|
||||
|
||||
# Run over data
|
||||
assert (args.height % 16 == 0) and (
|
||||
args.width % 16
|
||||
@@ -587,7 +699,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
fs=model_input_fs,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False)
|
||||
sim_mode=False,
|
||||
decode_video=not args.fast_policy_no_decode)
|
||||
|
||||
# Update future actions in the observation queues
|
||||
for idx in range(len(pred_actions[0])):
|
||||
@@ -644,29 +757,31 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
cond_obs_queues = populate_queues(cond_obs_queues,
|
||||
observation)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
# Save the imagen videos for decision-making (async)
|
||||
if pred_videos_0 is not None:
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard_async(writer,
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
pred_videos_1,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
log_to_tensorboard_async(writer,
|
||||
pred_videos_1,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_0.cpu(),
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
if pred_videos_0 is not None:
|
||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||
save_results_async(pred_videos_0,
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_1.cpu(),
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
save_results_async(pred_videos_1,
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
|
||||
print('>' * 24)
|
||||
# Collect the result of world-model interactions
|
||||
@@ -674,12 +789,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
full_video = torch.cat(wm_video, dim=2)
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||
log_to_tensorboard(writer,
|
||||
full_video,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
log_to_tensorboard_async(writer,
|
||||
full_video,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||
save_results(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
|
||||
# Wait for all async I/O to complete
|
||||
_flush_io()
|
||||
|
||||
|
||||
def get_parser():
|
||||
@@ -794,6 +912,11 @@ def get_parser():
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="not using the predicted states as comparison")
|
||||
parser.add_argument(
|
||||
"--fast_policy_no_decode",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.")
|
||||
parser.add_argument("--save_fps",
|
||||
type=int,
|
||||
default=8,
|
||||
|
||||
@@ -11,6 +11,9 @@ from unifolm_wma.utils.utils import instantiate_from_config
|
||||
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
|
||||
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
|
||||
def get_parser(**parser_kwargs):
|
||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
||||
|
||||
@@ -501,6 +501,10 @@ class ConditionalUnet1D(nn.Module):
|
||||
self.last_frame_only = last_frame_only
|
||||
self.horizon = horizon
|
||||
|
||||
# Context precomputation cache
|
||||
self._global_cond_cache_enabled = False
|
||||
self._global_cond_cache = {}
|
||||
|
||||
def forward(self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
@@ -530,14 +534,20 @@ class ConditionalUnet1D(nn.Module):
|
||||
B, T, D = sample.shape
|
||||
if self.use_linear_act_proj:
|
||||
sample = self.proj_in_action(sample.unsqueeze(-1))
|
||||
global_cond = self.obs_encoder(cond)
|
||||
global_cond = rearrange(global_cond,
|
||||
'(b t) d -> b 1 (t d)',
|
||||
b=B,
|
||||
t=self.n_obs_steps)
|
||||
global_cond = repeat(global_cond,
|
||||
'b c d -> b (repeat c) d',
|
||||
repeat=T)
|
||||
_gc_key = (cond['image'].data_ptr(), cond['agent_pos'].data_ptr())
|
||||
if self._global_cond_cache_enabled and _gc_key in self._global_cond_cache:
|
||||
global_cond = self._global_cond_cache[_gc_key]
|
||||
else:
|
||||
global_cond = self.obs_encoder(cond)
|
||||
global_cond = rearrange(global_cond,
|
||||
'(b t) d -> b 1 (t d)',
|
||||
b=B,
|
||||
t=self.n_obs_steps)
|
||||
global_cond = repeat(global_cond,
|
||||
'b c d -> b (repeat c) d',
|
||||
repeat=T)
|
||||
if self._global_cond_cache_enabled:
|
||||
self._global_cond_cache[_gc_key] = global_cond
|
||||
else:
|
||||
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||
sample = self.proj_in_horizon(sample)
|
||||
|
||||
@@ -6,6 +6,8 @@ from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim
|
||||
from unifolm_wma.utils.common import noise_like
|
||||
from unifolm_wma.utils.common import extract_into_tensor
|
||||
from tqdm import tqdm
|
||||
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
|
||||
from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
@@ -67,11 +69,12 @@ class DDIMSampler(object):
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
# Ensure tensors are on correct device for efficient indexing
|
||||
self.register_buffer('ddim_sigmas', to_torch(torch.as_tensor(ddim_sigmas)))
|
||||
self.register_buffer('ddim_alphas', to_torch(torch.as_tensor(ddim_alphas)))
|
||||
self.register_buffer('ddim_alphas_prev', to_torch(torch.as_tensor(ddim_alphas_prev)))
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
np.sqrt(1. - ddim_alphas))
|
||||
to_torch(torch.as_tensor(np.sqrt(1. - ddim_alphas))))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
@@ -241,63 +244,70 @@ class DDIMSampler(object):
|
||||
|
||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
||||
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
||||
enable_cross_attn_kv_cache(self.model)
|
||||
enable_ctx_cache(self.model)
|
||||
try:
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts.fill_(step)
|
||||
|
||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
if clean_cond:
|
||||
img_orig = x0
|
||||
else:
|
||||
img_orig = self.model.q_sample(x0, ts)
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
if clean_cond:
|
||||
img_orig = x0
|
||||
else:
|
||||
img_orig = self.model.q_sample(x0, ts)
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
action,
|
||||
state,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
action,
|
||||
state,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
img, pred_x0, model_output_action, model_output_state = outs
|
||||
img, pred_x0, model_output_action, model_output_state = outs
|
||||
|
||||
action = dp_ddim_scheduler_action.step(
|
||||
model_output_action,
|
||||
step,
|
||||
action,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
state = dp_ddim_scheduler_state.step(
|
||||
model_output_state,
|
||||
step,
|
||||
state,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
action = dp_ddim_scheduler_action.step(
|
||||
model_output_action,
|
||||
step,
|
||||
action,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
state = dp_ddim_scheduler_state.step(
|
||||
model_output_state,
|
||||
step,
|
||||
state,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates['x_inter_action'].append(action)
|
||||
intermediates['x_inter_state'].append(state)
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates['x_inter_action'].append(action)
|
||||
intermediates['x_inter_state'].append(state)
|
||||
finally:
|
||||
disable_cross_attn_kv_cache(self.model)
|
||||
disable_ctx_cache(self.model)
|
||||
|
||||
return img, action, state, intermediates
|
||||
|
||||
@@ -325,10 +335,6 @@ class DDIMSampler(object):
|
||||
guidance_rescale=0.0,
|
||||
**kwargs):
|
||||
b, *_, device = *x.shape, x.device
|
||||
if x.dim() == 5:
|
||||
is_video = True
|
||||
else:
|
||||
is_video = False
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output, model_output_action, model_output_state = self.model.apply_model(
|
||||
@@ -377,17 +383,11 @@ class DDIMSampler(object):
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
if is_video:
|
||||
size = (b, 1, 1, 1, 1)
|
||||
else:
|
||||
size = (b, 1, 1, 1)
|
||||
|
||||
a_t = torch.full(size, alphas[index], device=device)
|
||||
a_prev = torch.full(size, alphas_prev[index], device=device)
|
||||
sigma_t = torch.full(size, sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(size,
|
||||
sqrt_one_minus_alphas[index],
|
||||
device=device)
|
||||
# Use 0-d tensors directly (already on device); broadcasting handles shape
|
||||
a_t = alphas[index]
|
||||
a_prev = alphas_prev[index]
|
||||
sigma_t = sigmas[index]
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
@@ -395,12 +395,8 @@ class DDIMSampler(object):
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
scale_t = torch.full(size,
|
||||
self.ddim_scale_arr[index],
|
||||
device=device)
|
||||
prev_scale_t = torch.full(size,
|
||||
self.ddim_scale_arr_prev[index],
|
||||
device=device)
|
||||
scale_t = self.ddim_scale_arr[index]
|
||||
prev_scale_t = self.ddim_scale_arr_prev[index]
|
||||
rescale = (prev_scale_t / scale_t)
|
||||
pred_x0 *= rescale
|
||||
|
||||
|
||||
@@ -98,6 +98,10 @@ class CrossAttention(nn.Module):
|
||||
self.text_context_len = text_context_len
|
||||
self.agent_state_context_len = agent_state_context_len
|
||||
self.agent_action_context_len = agent_action_context_len
|
||||
self._kv_cache = {}
|
||||
self._kv_cache_enabled = False
|
||||
self._kv_fused = False
|
||||
|
||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||
if self.image_cross_attention:
|
||||
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
@@ -114,6 +118,27 @@ class CrossAttention(nn.Module):
|
||||
self.register_parameter('alpha_caa',
|
||||
nn.Parameter(torch.tensor(0.)))
|
||||
|
||||
def fuse_kv(self):
|
||||
"""Fuse to_k/to_v into to_kv (2 Linear → 1). Works for all layers."""
|
||||
k_w = self.to_k.weight # (inner_dim, context_dim)
|
||||
v_w = self.to_v.weight
|
||||
self.to_kv = nn.Linear(k_w.shape[1], k_w.shape[0] * 2, bias=False)
|
||||
self.to_kv.weight = nn.Parameter(torch.cat([k_w, v_w], dim=0))
|
||||
del self.to_k, self.to_v
|
||||
if self.image_cross_attention:
|
||||
for suffix in ('_ip', '_as', '_aa'):
|
||||
k_attr = f'to_k{suffix}'
|
||||
v_attr = f'to_v{suffix}'
|
||||
kw = getattr(self, k_attr).weight
|
||||
vw = getattr(self, v_attr).weight
|
||||
fused = nn.Linear(kw.shape[1], kw.shape[0] * 2, bias=False)
|
||||
fused.weight = nn.Parameter(torch.cat([kw, vw], dim=0))
|
||||
setattr(self, f'to_kv{suffix}', fused)
|
||||
delattr(self, k_attr)
|
||||
delattr(self, v_attr)
|
||||
self._kv_fused = True
|
||||
return True
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
@@ -140,19 +165,28 @@ class CrossAttention(nn.Module):
|
||||
self.agent_action_context_len +
|
||||
self.text_context_len:, :]
|
||||
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
if self._kv_fused:
|
||||
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
||||
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
||||
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1)
|
||||
else:
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
context = context[:, :self.text_context_len, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
if self._kv_fused:
|
||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||
else:
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
@@ -236,134 +270,162 @@ class CrossAttention(nn.Module):
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
k_as, v_as, out_as = None, None, None
|
||||
k_aa, v_aa, out_aa = None, None, None
|
||||
attn_mask_aa = None
|
||||
|
||||
h = self.heads
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
if self.image_cross_attention and not spatial_self_attn:
|
||||
b, _, _ = q.shape
|
||||
q = q.unsqueeze(3).reshape(b, q.shape[1], h, self.dim_head).permute(0, 2, 1, 3).reshape(b * h, q.shape[1], self.dim_head).contiguous()
|
||||
|
||||
def _reshape_kv(t):
|
||||
return t.unsqueeze(3).reshape(b, t.shape[1], h, self.dim_head).permute(0, 2, 1, 3).reshape(b * h, t.shape[1], self.dim_head).contiguous()
|
||||
|
||||
use_cache = self._kv_cache_enabled and not spatial_self_attn
|
||||
cache_hit = use_cache and len(self._kv_cache) > 0
|
||||
|
||||
if cache_hit:
|
||||
k = self._kv_cache['k']
|
||||
v = self._kv_cache['v']
|
||||
k_ip = self._kv_cache.get('k_ip')
|
||||
v_ip = self._kv_cache.get('v_ip')
|
||||
k_as = self._kv_cache.get('k_as')
|
||||
v_as = self._kv_cache.get('v_as')
|
||||
k_aa = self._kv_cache.get('k_aa')
|
||||
v_aa = self._kv_cache.get('v_aa')
|
||||
attn_mask_aa = self._kv_cache.get('attn_mask_aa')
|
||||
elif self.image_cross_attention and not spatial_self_attn:
|
||||
if context.shape[1] == self.text_context_len + self.video_length:
|
||||
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
if self._kv_fused:
|
||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||
else:
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k, v = map(_reshape_kv, (k, v))
|
||||
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
||||
if use_cache:
|
||||
self._kv_cache = {'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip}
|
||||
elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
|
||||
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
if self._kv_fused:
|
||||
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
||||
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
||||
else:
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k, v = map(_reshape_kv, (k, v))
|
||||
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
||||
k_as, v_as = map(_reshape_kv, (k_as, v_as))
|
||||
if use_cache:
|
||||
self._kv_cache = {'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip, 'k_as': k_as, 'v_as': v_as}
|
||||
else:
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
|
||||
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
|
||||
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
|
||||
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
if self._kv_fused:
|
||||
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
||||
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
||||
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1)
|
||||
else:
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
|
||||
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
|
||||
q.shape[1],
|
||||
k_aa.shape[1],
|
||||
block_size=16).to(k_aa.device)
|
||||
k, v = map(_reshape_kv, (k, v))
|
||||
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
||||
k_as, v_as = map(_reshape_kv, (k_as, v_as))
|
||||
k_aa, v_aa = map(_reshape_kv, (k_aa, v_aa))
|
||||
|
||||
attn_mask_aa_raw = self._get_attn_mask_aa(x.shape[0],
|
||||
q.shape[1],
|
||||
k_aa.shape[1],
|
||||
block_size=16,
|
||||
device=k_aa.device)
|
||||
attn_mask_aa = attn_mask_aa_raw.unsqueeze(1).repeat(1, h, 1, 1).reshape(
|
||||
b * h, attn_mask_aa_raw.shape[1], attn_mask_aa_raw.shape[2]).to(q.dtype)
|
||||
|
||||
if use_cache:
|
||||
self._kv_cache = {
|
||||
'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip,
|
||||
'k_as': k_as, 'v_as': v_as, 'k_aa': k_aa, 'v_aa': v_aa,
|
||||
'attn_mask_aa': attn_mask_aa,
|
||||
}
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
||||
context = context[:, :self.text_context_len, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous()
|
||||
if self._kv_fused:
|
||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||
else:
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
k, v = map(_reshape_kv, (k, v))
|
||||
if use_cache:
|
||||
self._kv_cache = {'k': k, 'v': v}
|
||||
if k is not None:
|
||||
k, v = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
||||
(k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q,
|
||||
k,
|
||||
v,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out = (out.unsqueeze(0).reshape(
|
||||
b, self.heads, out.shape[1],
|
||||
b, h, out.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
h * self.dim_head))
|
||||
|
||||
if k_ip is not None:
|
||||
# For image cross-attention
|
||||
k_ip, v_ip = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||
),
|
||||
(k_ip, v_ip),
|
||||
)
|
||||
out_ip = xformers.ops.memory_efficient_attention(q,
|
||||
k_ip,
|
||||
v_ip,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out_ip = (out_ip.unsqueeze(0).reshape(
|
||||
b, self.heads, out_ip.shape[1],
|
||||
b, h, out_ip.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out_ip.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
h * self.dim_head))
|
||||
|
||||
if k_as is not None:
|
||||
# For agent state cross-attention
|
||||
k_as, v_as = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||
),
|
||||
(k_as, v_as),
|
||||
)
|
||||
out_as = xformers.ops.memory_efficient_attention(q,
|
||||
k_as,
|
||||
v_as,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out_as = (out_as.unsqueeze(0).reshape(
|
||||
b, self.heads, out_as.shape[1],
|
||||
b, h, out_as.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out_as.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
h * self.dim_head))
|
||||
|
||||
if k_aa is not None:
|
||||
# For agent action cross-attention
|
||||
k_aa, v_aa = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||
),
|
||||
(k_aa, v_aa),
|
||||
)
|
||||
|
||||
attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape(
|
||||
b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2])
|
||||
attn_mask_aa = attn_mask_aa.to(q.dtype)
|
||||
|
||||
out_aa = xformers.ops.memory_efficient_attention(
|
||||
q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
|
||||
|
||||
out_aa = (out_aa.unsqueeze(0).reshape(
|
||||
b, self.heads, out_aa.shape[1],
|
||||
b, h, out_aa.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out_aa.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
h * self.dim_head))
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -386,17 +448,43 @@ class CrossAttention(nn.Module):
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
|
||||
def _get_attn_mask_aa(self, b, l1, l2, block_size=16, device=None):
|
||||
cache_key = (b, l1, l2, block_size)
|
||||
if hasattr(self, '_attn_mask_aa_cache_key') and self._attn_mask_aa_cache_key == cache_key:
|
||||
cached = self._attn_mask_aa_cache
|
||||
if device is not None and cached.device != torch.device(device):
|
||||
cached = cached.to(device)
|
||||
self._attn_mask_aa_cache = cached
|
||||
return cached
|
||||
|
||||
target_device = device if device is not None else 'cpu'
|
||||
num_token = l2 // block_size
|
||||
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
|
||||
col_indices = torch.arange(l2)
|
||||
start_positions = ((torch.arange(b, device=target_device) % block_size) + 1) * num_token
|
||||
col_indices = torch.arange(l2, device=target_device)
|
||||
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
||||
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
||||
attn_mask = torch.zeros_like(mask, dtype=torch.float)
|
||||
attn_mask = torch.zeros(b, l1, l2, dtype=torch.float, device=target_device)
|
||||
attn_mask[mask] = float('-inf')
|
||||
|
||||
self._attn_mask_aa_cache_key = cache_key
|
||||
self._attn_mask_aa_cache = attn_mask
|
||||
return attn_mask
|
||||
|
||||
|
||||
def enable_cross_attn_kv_cache(module):
|
||||
for m in module.modules():
|
||||
if isinstance(m, CrossAttention):
|
||||
m._kv_cache_enabled = True
|
||||
m._kv_cache = {}
|
||||
|
||||
|
||||
def disable_cross_attn_kv_cache(module):
|
||||
for m in module.modules():
|
||||
if isinstance(m, CrossAttention):
|
||||
m._kv_cache_enabled = False
|
||||
m._kv_cache = {}
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -685,6 +685,21 @@ class WMAModel(nn.Module):
|
||||
self.action_token_projector = instantiate_from_config(
|
||||
stem_process_config)
|
||||
|
||||
# Context precomputation cache
|
||||
self._ctx_cache_enabled = False
|
||||
self._ctx_cache = {}
|
||||
# Reusable CUDA stream for parallel state_unet / action_unet
|
||||
self._state_stream = torch.cuda.Stream()
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state.pop('_state_stream', None)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self._state_stream = torch.cuda.Stream()
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
x_action: Tensor,
|
||||
@@ -720,58 +735,64 @@ class WMAModel(nn.Module):
|
||||
repeat_only=False).type(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
bt, l_context, _ = context.shape
|
||||
if self.base_model_gen_only:
|
||||
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
||||
_ctx_key = context.data_ptr()
|
||||
if self._ctx_cache_enabled and _ctx_key in self._ctx_cache:
|
||||
context = self._ctx_cache[_ctx_key]
|
||||
else:
|
||||
if l_context == self.n_obs_steps + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
|
||||
77, :]
|
||||
context_img = context[:, self.n_obs_steps + 77:, :]
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context = torch.cat(
|
||||
[context_agent_state, context_text, context_img], dim=1)
|
||||
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_agent_action = context[:, self.
|
||||
n_obs_steps:self.n_obs_steps +
|
||||
16, :]
|
||||
context_agent_action = rearrange(
|
||||
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
|
||||
context_agent_action = self.action_token_projector(
|
||||
context_agent_action)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'(b o) l d -> b o l d',
|
||||
o=t)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b o (t l) d -> b o t l d',
|
||||
t=t)
|
||||
context_agent_action = context_agent_action.permute(
|
||||
0, 2, 1, 3, 4)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b t o l d -> (b t) (o l) d')
|
||||
bt, l_context, _ = context.shape
|
||||
if self.base_model_gen_only:
|
||||
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
||||
else:
|
||||
if l_context == self.n_obs_steps + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
|
||||
77, :]
|
||||
context_img = context[:, self.n_obs_steps + 77:, :]
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context = torch.cat(
|
||||
[context_agent_state, context_text, context_img], dim=1)
|
||||
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_agent_action = context[:, self.
|
||||
n_obs_steps:self.n_obs_steps +
|
||||
16, :]
|
||||
context_agent_action = rearrange(
|
||||
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
|
||||
context_agent_action = self.action_token_projector(
|
||||
context_agent_action)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'(b o) l d -> b o l d',
|
||||
o=t)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b o (t l) d -> b o t l d',
|
||||
t=t)
|
||||
context_agent_action = context_agent_action.permute(
|
||||
0, 2, 1, 3, 4)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b t o l d -> (b t) (o l) d')
|
||||
|
||||
context_text = context[:, self.n_obs_steps +
|
||||
16:self.n_obs_steps + 16 + 77, :]
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_text = context[:, self.n_obs_steps +
|
||||
16:self.n_obs_steps + 16 + 77, :]
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context = torch.cat([
|
||||
context_agent_state, context_agent_action, context_text,
|
||||
context_img
|
||||
],
|
||||
dim=1)
|
||||
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context = torch.cat([
|
||||
context_agent_state, context_agent_action, context_text,
|
||||
context_img
|
||||
],
|
||||
dim=1)
|
||||
if self._ctx_cache_enabled:
|
||||
self._ctx_cache[_ctx_key] = context
|
||||
|
||||
emb = emb.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
@@ -832,17 +853,45 @@ class WMAModel(nn.Module):
|
||||
|
||||
if not self.base_model_gen_only:
|
||||
ba, _, _ = x_action.shape
|
||||
ts_state = timesteps[:ba] if b > 1 else timesteps
|
||||
# Run action_unet and state_unet in parallel via CUDA streams
|
||||
s_stream = self._state_stream
|
||||
s_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s_stream):
|
||||
s_y = self.state_unet(x_state, ts_state, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
# Predict state
|
||||
if b > 1:
|
||||
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
else:
|
||||
s_y = self.state_unet(x_state, timesteps, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
torch.cuda.current_stream().wait_stream(s_stream)
|
||||
else:
|
||||
a_y = torch.zeros_like(x_action)
|
||||
s_y = torch.zeros_like(x_state)
|
||||
|
||||
return y, a_y, s_y
|
||||
|
||||
|
||||
def enable_ctx_cache(model):
|
||||
"""Enable context precomputation cache on WMAModel and its action/state UNets."""
|
||||
for m in model.modules():
|
||||
if isinstance(m, WMAModel):
|
||||
m._ctx_cache_enabled = True
|
||||
m._ctx_cache = {}
|
||||
# conditional_unet1d cache
|
||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||
for m in model.modules():
|
||||
if isinstance(m, ConditionalUnet1D):
|
||||
m._global_cond_cache_enabled = True
|
||||
m._global_cond_cache = {}
|
||||
|
||||
|
||||
def disable_ctx_cache(model):
|
||||
"""Disable and clear context precomputation cache."""
|
||||
for m in model.modules():
|
||||
if isinstance(m, WMAModel):
|
||||
m._ctx_cache_enabled = False
|
||||
m._ctx_cache = {}
|
||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||
for m in model.modules():
|
||||
if isinstance(m, ConditionalUnet1D):
|
||||
m._global_cond_cache_enabled = False
|
||||
m._global_cond_cache = {}
|
||||
|
||||
121
unitree_z1_dual_arm_cleanup_pencils/case1/output.log
Normal file
121
unitree_z1_dual_arm_cleanup_pencils/case1/output.log
Normal file
@@ -0,0 +1,121 @@
|
||||
2026-02-10 15:38:28.973314: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-10 15:38:29.023024: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-10 15:38:29.023070: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-10 15:38:29.024393: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-10 15:38:29.031901: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-10 15:38:29.955454: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/8 [00:00<?, ?it/s]>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
12%|█▎ | 1/8 [01:14<08:41, 74.51s/it]
|
||||
25%|██▌ | 2/8 [02:29<07:28, 74.79s/it]
|
||||
38%|███▊ | 3/8 [03:44<06:14, 74.81s/it]
|
||||
50%|█████ | 4/8 [04:59<04:59, 74.78s/it]
|
||||
62%|██████▎ | 5/8 [06:13<03:44, 74.73s/it]
|
||||
75%|███████▌ | 6/8 [07:28<02:29, 74.66s/it]
|
||||
88%|████████▊ | 7/8 [08:42<01:14, 74.56s/it]
|
||||
100%|██████████| 8/8 [09:56<00:00, 74.51s/it]
|
||||
100%|██████████| 8/8 [09:56<00:00, 74.62s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case1/unitree_z1_dual_arm_cleanup_pencils_case1.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
||||
"psnr": 47.911564449209735
|
||||
}
|
||||
120
unitree_z1_dual_arm_stackbox_v2/case1/output.log
Normal file
120
unitree_z1_dual_arm_stackbox_v2/case1/output.log
Normal file
@@ -0,0 +1,120 @@
|
||||
2026-02-11 11:59:27.241485: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-11 11:59:27.291755: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-11 11:59:27.291807: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-11 11:59:27.293169: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-11 11:59:27.300838: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-11 11:59:28.228009: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||
>>> Prepared model loaded.
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
✓ KV fused: 66 attention layers
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/11 [00:00<?, ?it/s]
|
||||
9%|▉ | 1/11 [00:34<05:40, 34.05s/it]>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
18%|█▊ | 2/11 [01:08<05:07, 34.17s/it]
|
||||
27%|██▋ | 3/11 [01:42<04:33, 34.16s/it]
|
||||
36%|███▋ | 4/11 [02:16<03:59, 34.18s/it]
|
||||
45%|████▌ | 5/11 [02:50<03:24, 34.14s/it]
|
||||
55%|█████▍ | 6/11 [03:24<02:50, 34.10s/it]
|
||||
64%|██████▎ | 7/11 [03:58<02:16, 34.07s/it]
|
||||
73%|███████▎ | 8/11 [04:32<01:42, 34.03s/it]
|
||||
82%|████████▏ | 9/11 [05:06<01:08, 34.02s/it]
|
||||
91%|█████████ | 10/11 [05:40<00:34, 34.04s/it]
|
||||
100%|██████████| 11/11 [06:14<00:00, 34.03s/it]
|
||||
100%|██████████| 11/11 [06:14<00:00, 34.07s/it]
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
5
unitree_z1_dual_arm_stackbox_v2/case1/psnr_result.json
Normal file
5
unitree_z1_dual_arm_stackbox_v2/case1/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
|
||||
"pred_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
|
||||
"psnr": 28.167025381705358
|
||||
}
|
||||
@@ -20,5 +20,6 @@ dataset="unitree_z1_dual_arm_stackbox_v2"
|
||||
--n_iter 11 \
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae
|
||||
--perframe_ae \
|
||||
--fast_policy_no_decode
|
||||
} 2>&1 | tee "${res_dir}/output.log"
|
||||
|
||||
Reference in New Issue
Block a user