Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9a08e27a19 | |||
| b558856e1e | |||
| dcbcb2c377 | |||
| ff43432ef9 | |||
| afa12ba031 | |||
| bf4d66c874 |
@@ -4,7 +4,12 @@
|
|||||||
"Bash(conda env list:*)",
|
"Bash(conda env list:*)",
|
||||||
"Bash(mamba env:*)",
|
"Bash(mamba env:*)",
|
||||||
"Bash(micromamba env list:*)",
|
"Bash(micromamba env list:*)",
|
||||||
"Bash(echo:*)"
|
"Bash(echo:*)",
|
||||||
|
"Bash(git show:*)",
|
||||||
|
"Bash(nvidia-smi:*)",
|
||||||
|
"Bash(conda activate unifolm-wma)",
|
||||||
|
"Bash(conda info:*)",
|
||||||
|
"Bash(direnv allow:*)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
2
.envrc
Normal file
2
.envrc
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
eval "$(conda shell.bash hook 2>/dev/null)"
|
||||||
|
conda activate unifolm-wma
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -131,3 +131,4 @@ Experiment/log
|
|||||||
*.ckpt
|
*.ckpt
|
||||||
|
|
||||||
*.0
|
*.0
|
||||||
|
ckpts/unifolm_wma_dual.ckpt.prepared.pt
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import logging
|
|||||||
import einops
|
import einops
|
||||||
import warnings
|
import warnings
|
||||||
import imageio
|
import imageio
|
||||||
|
import atexit
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@@ -16,8 +18,9 @@ from tqdm import tqdm
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from eval_utils import populate_queues, log_to_tensorboard
|
from eval_utils import populate_queues
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from typing import Optional, List, Any
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
@@ -153,6 +156,81 @@ def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
|
|||||||
options={'crf': '10'})
|
options={'crf': '10'})
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Async I/O ==========
|
||||||
|
_io_executor: Optional[ThreadPoolExecutor] = None
|
||||||
|
_io_futures: List[Any] = []
|
||||||
|
|
||||||
|
|
||||||
|
def _get_io_executor() -> ThreadPoolExecutor:
|
||||||
|
global _io_executor
|
||||||
|
if _io_executor is None:
|
||||||
|
_io_executor = ThreadPoolExecutor(max_workers=2)
|
||||||
|
return _io_executor
|
||||||
|
|
||||||
|
|
||||||
|
def _flush_io():
|
||||||
|
"""Wait for all pending async I/O to finish."""
|
||||||
|
global _io_futures
|
||||||
|
for fut in _io_futures:
|
||||||
|
try:
|
||||||
|
fut.result()
|
||||||
|
except Exception as e:
|
||||||
|
print(f">>> [async I/O] error: {e}")
|
||||||
|
_io_futures.clear()
|
||||||
|
|
||||||
|
|
||||||
|
atexit.register(_flush_io)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
|
||||||
|
"""Synchronous save on CPU tensor (runs in background thread)."""
|
||||||
|
video = torch.clamp(video_cpu.float(), -1., 1.)
|
||||||
|
n = video.shape[0]
|
||||||
|
video = video.permute(2, 0, 1, 3, 4)
|
||||||
|
frame_grids = [
|
||||||
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||||
|
for framesheet in video
|
||||||
|
]
|
||||||
|
grid = torch.stack(frame_grids, dim=0)
|
||||||
|
grid = (grid + 1.0) / 2.0
|
||||||
|
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||||
|
torchvision.io.write_video(filename,
|
||||||
|
grid,
|
||||||
|
fps=fps,
|
||||||
|
video_codec='h264',
|
||||||
|
options={'crf': '10'})
|
||||||
|
|
||||||
|
|
||||||
|
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
||||||
|
"""Submit video saving to background thread pool."""
|
||||||
|
video_cpu = video.detach().cpu()
|
||||||
|
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
|
||||||
|
_io_futures.append(fut)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
|
||||||
|
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
|
||||||
|
if video_cpu.dim() == 5:
|
||||||
|
n = video_cpu.shape[0]
|
||||||
|
video = video_cpu.permute(2, 0, 1, 3, 4)
|
||||||
|
frame_grids = [
|
||||||
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||||
|
for framesheet in video
|
||||||
|
]
|
||||||
|
grid = torch.stack(frame_grids, dim=0)
|
||||||
|
grid = (grid + 1.0) / 2.0
|
||||||
|
grid = grid.unsqueeze(dim=0)
|
||||||
|
writer.add_video(tag, grid, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
|
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
|
||||||
|
"""Submit TensorBoard logging to background thread pool."""
|
||||||
|
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
||||||
|
data_cpu = data.detach().cpu()
|
||||||
|
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
|
||||||
|
_io_futures.append(fut)
|
||||||
|
|
||||||
|
|
||||||
def get_init_frame_path(data_dir: str, sample: dict) -> str:
|
def get_init_frame_path(data_dir: str, sample: dict) -> str:
|
||||||
"""Construct the init_frame path from directory and sample metadata.
|
"""Construct the init_frame path from directory and sample metadata.
|
||||||
|
|
||||||
@@ -462,26 +540,51 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||||
df = pd.read_csv(csv_path)
|
df = pd.read_csv(csv_path)
|
||||||
|
|
||||||
# Load config
|
# Load config (always needed for data setup)
|
||||||
config = OmegaConf.load(args.config)
|
config = OmegaConf.load(args.config)
|
||||||
config['model']['params']['wma_config']['params'][
|
|
||||||
'use_checkpoint'] = False
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
model.perframe_ae = args.perframe_ae
|
|
||||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
|
||||||
model = load_model_checkpoint(model, args.ckpt_path)
|
|
||||||
model.eval()
|
|
||||||
print(f'>>> Load pre-trained model ...')
|
|
||||||
|
|
||||||
# Build unnomalizer
|
prepared_path = args.ckpt_path + ".prepared.pt"
|
||||||
|
if os.path.exists(prepared_path):
|
||||||
|
# ---- Fast path: load the fully-prepared model ----
|
||||||
|
print(f">>> Loading prepared model from {prepared_path} ...")
|
||||||
|
model = torch.load(prepared_path,
|
||||||
|
map_location=f"cuda:{gpu_no}",
|
||||||
|
weights_only=False,
|
||||||
|
mmap=True)
|
||||||
|
model.eval()
|
||||||
|
print(f">>> Prepared model loaded.")
|
||||||
|
else:
|
||||||
|
# ---- Normal path: construct + load checkpoint ----
|
||||||
|
config['model']['params']['wma_config']['params'][
|
||||||
|
'use_checkpoint'] = False
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
model.perframe_ae = args.perframe_ae
|
||||||
|
|
||||||
|
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||||
|
model = load_model_checkpoint(model, args.ckpt_path)
|
||||||
|
model.eval()
|
||||||
|
model = model.cuda(gpu_no)
|
||||||
|
print(f'>>> Load pre-trained model ...')
|
||||||
|
|
||||||
|
# Save prepared model for fast loading next time
|
||||||
|
print(f">>> Saving prepared model to {prepared_path} ...")
|
||||||
|
torch.save(model, prepared_path)
|
||||||
|
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
||||||
|
|
||||||
|
# Build normalizer (always needed, independent of model loading path)
|
||||||
logging.info("***** Configing Data *****")
|
logging.info("***** Configing Data *****")
|
||||||
data = instantiate_from_config(config.data)
|
data = instantiate_from_config(config.data)
|
||||||
data.setup()
|
data.setup()
|
||||||
print(">>> Dataset is successfully loaded ...")
|
print(">>> Dataset is successfully loaded ...")
|
||||||
|
|
||||||
model = model.cuda(gpu_no)
|
|
||||||
device = get_device_from_parameters(model)
|
device = get_device_from_parameters(model)
|
||||||
|
|
||||||
|
# Fuse KV projections in attention layers (to_k + to_v → to_kv)
|
||||||
|
from unifolm_wma.modules.attention import CrossAttention
|
||||||
|
kv_count = sum(1 for m in model.modules()
|
||||||
|
if isinstance(m, CrossAttention) and m.fuse_kv())
|
||||||
|
print(f" ✓ KV fused: {kv_count} attention layers")
|
||||||
|
|
||||||
# Run over data
|
# Run over data
|
||||||
assert (args.height % 16 == 0) and (
|
assert (args.height % 16 == 0) and (
|
||||||
args.width % 16
|
args.width % 16
|
||||||
@@ -654,31 +757,31 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
cond_obs_queues = populate_queues(cond_obs_queues,
|
cond_obs_queues = populate_queues(cond_obs_queues,
|
||||||
observation)
|
observation)
|
||||||
|
|
||||||
# Save the imagen videos for decision-making
|
# Save the imagen videos for decision-making (async)
|
||||||
if pred_videos_0 is not None:
|
if pred_videos_0 is not None:
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||||
log_to_tensorboard(writer,
|
log_to_tensorboard_async(writer,
|
||||||
pred_videos_0,
|
pred_videos_0,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
# Save videos environment changes via world-model interaction
|
# Save videos environment changes via world-model interaction
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
||||||
log_to_tensorboard(writer,
|
log_to_tensorboard_async(writer,
|
||||||
pred_videos_1,
|
pred_videos_1,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
|
|
||||||
# Save the imagen videos for decision-making
|
# Save the imagen videos for decision-making
|
||||||
if pred_videos_0 is not None:
|
if pred_videos_0 is not None:
|
||||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||||
save_results(pred_videos_0.cpu(),
|
save_results_async(pred_videos_0,
|
||||||
sample_video_file,
|
sample_video_file,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
# Save videos environment changes via world-model interaction
|
# Save videos environment changes via world-model interaction
|
||||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
||||||
save_results(pred_videos_1.cpu(),
|
save_results_async(pred_videos_1,
|
||||||
sample_video_file,
|
sample_video_file,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
|
|
||||||
print('>' * 24)
|
print('>' * 24)
|
||||||
# Collect the result of world-model interactions
|
# Collect the result of world-model interactions
|
||||||
@@ -686,12 +789,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
full_video = torch.cat(wm_video, dim=2)
|
full_video = torch.cat(wm_video, dim=2)
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||||
log_to_tensorboard(writer,
|
log_to_tensorboard_async(writer,
|
||||||
full_video,
|
full_video,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||||
save_results(full_video, sample_full_video_file, fps=args.save_fps)
|
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
||||||
|
|
||||||
|
# Wait for all async I/O to complete
|
||||||
|
_flush_io()
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
|||||||
@@ -100,6 +100,7 @@ class CrossAttention(nn.Module):
|
|||||||
self.agent_action_context_len = agent_action_context_len
|
self.agent_action_context_len = agent_action_context_len
|
||||||
self._kv_cache = {}
|
self._kv_cache = {}
|
||||||
self._kv_cache_enabled = False
|
self._kv_cache_enabled = False
|
||||||
|
self._kv_fused = False
|
||||||
|
|
||||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||||
if self.image_cross_attention:
|
if self.image_cross_attention:
|
||||||
@@ -117,6 +118,27 @@ class CrossAttention(nn.Module):
|
|||||||
self.register_parameter('alpha_caa',
|
self.register_parameter('alpha_caa',
|
||||||
nn.Parameter(torch.tensor(0.)))
|
nn.Parameter(torch.tensor(0.)))
|
||||||
|
|
||||||
|
def fuse_kv(self):
|
||||||
|
"""Fuse to_k/to_v into to_kv (2 Linear → 1). Works for all layers."""
|
||||||
|
k_w = self.to_k.weight # (inner_dim, context_dim)
|
||||||
|
v_w = self.to_v.weight
|
||||||
|
self.to_kv = nn.Linear(k_w.shape[1], k_w.shape[0] * 2, bias=False)
|
||||||
|
self.to_kv.weight = nn.Parameter(torch.cat([k_w, v_w], dim=0))
|
||||||
|
del self.to_k, self.to_v
|
||||||
|
if self.image_cross_attention:
|
||||||
|
for suffix in ('_ip', '_as', '_aa'):
|
||||||
|
k_attr = f'to_k{suffix}'
|
||||||
|
v_attr = f'to_v{suffix}'
|
||||||
|
kw = getattr(self, k_attr).weight
|
||||||
|
vw = getattr(self, v_attr).weight
|
||||||
|
fused = nn.Linear(kw.shape[1], kw.shape[0] * 2, bias=False)
|
||||||
|
fused.weight = nn.Parameter(torch.cat([kw, vw], dim=0))
|
||||||
|
setattr(self, f'to_kv{suffix}', fused)
|
||||||
|
delattr(self, k_attr)
|
||||||
|
delattr(self, v_attr)
|
||||||
|
self._kv_fused = True
|
||||||
|
return True
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, mask=None):
|
||||||
spatial_self_attn = (context is None)
|
spatial_self_attn = (context is None)
|
||||||
k_ip, v_ip, out_ip = None, None, None
|
k_ip, v_ip, out_ip = None, None, None
|
||||||
@@ -143,19 +165,28 @@ class CrossAttention(nn.Module):
|
|||||||
self.agent_action_context_len +
|
self.agent_action_context_len +
|
||||||
self.text_context_len:, :]
|
self.text_context_len:, :]
|
||||||
|
|
||||||
k = self.to_k(context_ins)
|
if self._kv_fused:
|
||||||
v = self.to_v(context_ins)
|
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
||||||
k_ip = self.to_k_ip(context_image)
|
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||||
v_ip = self.to_v_ip(context_image)
|
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
||||||
k_as = self.to_k_as(context_agent_state)
|
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1)
|
||||||
v_as = self.to_v_as(context_agent_state)
|
else:
|
||||||
k_aa = self.to_k_aa(context_agent_action)
|
k = self.to_k(context_ins)
|
||||||
v_aa = self.to_v_aa(context_agent_action)
|
v = self.to_v(context_ins)
|
||||||
|
k_ip = self.to_k_ip(context_image)
|
||||||
|
v_ip = self.to_v_ip(context_image)
|
||||||
|
k_as = self.to_k_as(context_agent_state)
|
||||||
|
v_as = self.to_v_as(context_agent_state)
|
||||||
|
k_aa = self.to_k_aa(context_agent_action)
|
||||||
|
v_aa = self.to_v_aa(context_agent_action)
|
||||||
else:
|
else:
|
||||||
if not spatial_self_attn:
|
if not spatial_self_attn:
|
||||||
context = context[:, :self.text_context_len, :]
|
context = context[:, :self.text_context_len, :]
|
||||||
k = self.to_k(context)
|
if self._kv_fused:
|
||||||
v = self.to_v(context)
|
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||||
|
else:
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||||
(q, k, v))
|
(q, k, v))
|
||||||
@@ -267,10 +298,14 @@ class CrossAttention(nn.Module):
|
|||||||
elif self.image_cross_attention and not spatial_self_attn:
|
elif self.image_cross_attention and not spatial_self_attn:
|
||||||
if context.shape[1] == self.text_context_len + self.video_length:
|
if context.shape[1] == self.text_context_len + self.video_length:
|
||||||
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
|
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
|
||||||
k = self.to_k(context)
|
if self._kv_fused:
|
||||||
v = self.to_v(context)
|
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||||
k_ip = self.to_k_ip(context_image)
|
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||||
v_ip = self.to_v_ip(context_image)
|
else:
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
k_ip = self.to_k_ip(context_image)
|
||||||
|
v_ip = self.to_v_ip(context_image)
|
||||||
k, v = map(_reshape_kv, (k, v))
|
k, v = map(_reshape_kv, (k, v))
|
||||||
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
||||||
if use_cache:
|
if use_cache:
|
||||||
@@ -279,12 +314,17 @@ class CrossAttention(nn.Module):
|
|||||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||||
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
|
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
|
||||||
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
|
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
|
||||||
k = self.to_k(context_ins)
|
if self._kv_fused:
|
||||||
v = self.to_v(context_ins)
|
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
||||||
k_ip = self.to_k_ip(context_image)
|
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||||
v_ip = self.to_v_ip(context_image)
|
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
||||||
k_as = self.to_k_as(context_agent_state)
|
else:
|
||||||
v_as = self.to_v_as(context_agent_state)
|
k = self.to_k(context_ins)
|
||||||
|
v = self.to_v(context_ins)
|
||||||
|
k_ip = self.to_k_ip(context_image)
|
||||||
|
v_ip = self.to_v_ip(context_image)
|
||||||
|
k_as = self.to_k_as(context_agent_state)
|
||||||
|
v_as = self.to_v_as(context_agent_state)
|
||||||
k, v = map(_reshape_kv, (k, v))
|
k, v = map(_reshape_kv, (k, v))
|
||||||
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
||||||
k_as, v_as = map(_reshape_kv, (k_as, v_as))
|
k_as, v_as = map(_reshape_kv, (k_as, v_as))
|
||||||
@@ -296,14 +336,20 @@ class CrossAttention(nn.Module):
|
|||||||
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
|
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
|
||||||
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
|
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
|
||||||
|
|
||||||
k = self.to_k(context_ins)
|
if self._kv_fused:
|
||||||
v = self.to_v(context_ins)
|
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
||||||
k_ip = self.to_k_ip(context_image)
|
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
||||||
v_ip = self.to_v_ip(context_image)
|
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
||||||
k_as = self.to_k_as(context_agent_state)
|
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1)
|
||||||
v_as = self.to_v_as(context_agent_state)
|
else:
|
||||||
k_aa = self.to_k_aa(context_agent_action)
|
k = self.to_k(context_ins)
|
||||||
v_aa = self.to_v_aa(context_agent_action)
|
v = self.to_v(context_ins)
|
||||||
|
k_ip = self.to_k_ip(context_image)
|
||||||
|
v_ip = self.to_v_ip(context_image)
|
||||||
|
k_as = self.to_k_as(context_agent_state)
|
||||||
|
v_as = self.to_v_as(context_agent_state)
|
||||||
|
k_aa = self.to_k_aa(context_agent_action)
|
||||||
|
v_aa = self.to_v_aa(context_agent_action)
|
||||||
|
|
||||||
k, v = map(_reshape_kv, (k, v))
|
k, v = map(_reshape_kv, (k, v))
|
||||||
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
||||||
@@ -328,8 +374,11 @@ class CrossAttention(nn.Module):
|
|||||||
if not spatial_self_attn:
|
if not spatial_self_attn:
|
||||||
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
||||||
context = context[:, :self.text_context_len, :]
|
context = context[:, :self.text_context_len, :]
|
||||||
k = self.to_k(context)
|
if self._kv_fused:
|
||||||
v = self.to_v(context)
|
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||||
|
else:
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
k, v = map(_reshape_kv, (k, v))
|
k, v = map(_reshape_kv, (k, v))
|
||||||
if use_cache:
|
if use_cache:
|
||||||
self._kv_cache = {'k': k, 'v': v}
|
self._kv_cache = {'k': k, 'v': v}
|
||||||
|
|||||||
@@ -688,6 +688,17 @@ class WMAModel(nn.Module):
|
|||||||
# Context precomputation cache
|
# Context precomputation cache
|
||||||
self._ctx_cache_enabled = False
|
self._ctx_cache_enabled = False
|
||||||
self._ctx_cache = {}
|
self._ctx_cache = {}
|
||||||
|
# Reusable CUDA stream for parallel state_unet / action_unet
|
||||||
|
self._state_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state.pop('_state_stream', None)
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self._state_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@@ -842,15 +853,16 @@ class WMAModel(nn.Module):
|
|||||||
|
|
||||||
if not self.base_model_gen_only:
|
if not self.base_model_gen_only:
|
||||||
ba, _, _ = x_action.shape
|
ba, _, _ = x_action.shape
|
||||||
|
ts_state = timesteps[:ba] if b > 1 else timesteps
|
||||||
|
# Run action_unet and state_unet in parallel via CUDA streams
|
||||||
|
s_stream = self._state_stream
|
||||||
|
s_stream.wait_stream(torch.cuda.current_stream())
|
||||||
|
with torch.cuda.stream(s_stream):
|
||||||
|
s_y = self.state_unet(x_state, ts_state, hs_a,
|
||||||
|
context_action[:2], **kwargs)
|
||||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||||
context_action[:2], **kwargs)
|
context_action[:2], **kwargs)
|
||||||
# Predict state
|
torch.cuda.current_stream().wait_stream(s_stream)
|
||||||
if b > 1:
|
|
||||||
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
|
|
||||||
context_action[:2], **kwargs)
|
|
||||||
else:
|
|
||||||
s_y = self.state_unet(x_state, timesteps, hs_a,
|
|
||||||
context_action[:2], **kwargs)
|
|
||||||
else:
|
else:
|
||||||
a_y = torch.zeros_like(x_action)
|
a_y = torch.zeros_like(x_action)
|
||||||
s_y = torch.zeros_like(x_state)
|
s_y = torch.zeros_like(x_state)
|
||||||
|
|||||||
@@ -1,24 +1,13 @@
|
|||||||
2026-02-10 17:39:22.590654: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-11 11:59:27.241485: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-10 17:39:22.640645: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-11 11:59:27.291755: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-10 17:39:22.640689: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-11 11:59:27.291807: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-10 17:39:22.642010: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-11 11:59:27.293169: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-10 17:39:22.649530: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-11 11:59:27.300838: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-10 17:39:23.575804: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-11 11:59:28.228009: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
>>> Prepared model loaded.
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
|
||||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
|
||||||
>>> model checkpoint loaded.
|
|
||||||
>>> Load pre-trained model ...
|
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -36,13 +25,16 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]
|
0%| | 0/11 [00:00<?, ?it/s]
|
||||||
|
9%|▉ | 1/11 [00:34<05:40, 34.05s/it]>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 1: generating actions ...
|
>>> Step 1: generating actions ...
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
@@ -92,9 +84,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
|||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
|
||||||
9%|▉ | 1/11 [00:35<05:55, 35.52s/it]
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
18%|█▊ | 2/11 [01:08<05:07, 34.17s/it]
|
18%|█▊ | 2/11 [01:08<05:07, 34.17s/it]
|
||||||
@@ -125,6 +115,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
|
"gt_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
|
||||||
"pred_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
|
"pred_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
|
||||||
"psnr": 25.12008483689618
|
"psnr": 28.167025381705358
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user