6 Commits

Author SHA1 Message Date
qhy
9a08e27a19 KV 融合实现完成。改动总结: 速度微弱提升psnr略微上升
attention.py — 3处改动:
  1. __init__ 添加 _kv_fused = False 标志
  2.新增 fuse_kv() 方法:将 to_k + to_v → to_kv,同时处理 _ip/_as/_aa 辅助 KV 对
  2. bmm_forward 两个分支加_kv_fused 判断,用to_kv().chunk(2, dim=-1) 替代分别调用
2026-02-11 12:36:38 +08:00
qhy
b558856e1e fix bugs 2026-02-10 22:35:45 +08:00
qhy
dcbcb2c377 - state_unet 放到一个独立的 CUDA stream 上执行
- action_unet 在默认 stream 上同时执行
  - 用 wait_stream 确保两者都完成后再返回
两个 1D UNet 输入完全独立,共享的 hs_a 和 context_action 都是只读的。GPU 利用率只有 ~31%,小张量 kernel 不会打满 GPU,两个 stream 可以真正并行。
2026-02-10 21:41:48 +08:00
qhy
ff43432ef9 结果 2026-02-10 20:01:25 +08:00
qhy
afa12ba031 每步迭代保存异步 2026-02-10 19:54:53 +08:00
qhy
bf4d66c874 跳过模型加载 2026-02-10 19:36:17 +08:00
8 changed files with 263 additions and 98 deletions

View File

@@ -4,7 +4,12 @@
"Bash(conda env list:*)", "Bash(conda env list:*)",
"Bash(mamba env:*)", "Bash(mamba env:*)",
"Bash(micromamba env list:*)", "Bash(micromamba env list:*)",
"Bash(echo:*)" "Bash(echo:*)",
"Bash(git show:*)",
"Bash(nvidia-smi:*)",
"Bash(conda activate unifolm-wma)",
"Bash(conda info:*)",
"Bash(direnv allow:*)"
] ]
} }
} }

2
.envrc Normal file
View File

@@ -0,0 +1,2 @@
eval "$(conda shell.bash hook 2>/dev/null)"
conda activate unifolm-wma

1
.gitignore vendored
View File

@@ -131,3 +131,4 @@ Experiment/log
*.ckpt *.ckpt
*.0 *.0
ckpts/unifolm_wma_dual.ckpt.prepared.pt

View File

@@ -9,6 +9,8 @@ import logging
import einops import einops
import warnings import warnings
import imageio import imageio
import atexit
from concurrent.futures import ThreadPoolExecutor
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from omegaconf import OmegaConf from omegaconf import OmegaConf
@@ -16,8 +18,9 @@ from tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from collections import OrderedDict from collections import OrderedDict
from torch import nn from torch import nn
from eval_utils import populate_queues, log_to_tensorboard from eval_utils import populate_queues
from collections import deque from collections import deque
from typing import Optional, List, Any
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
@@ -153,6 +156,81 @@ def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
options={'crf': '10'}) options={'crf': '10'})
# ========== Async I/O ==========
_io_executor: Optional[ThreadPoolExecutor] = None
_io_futures: List[Any] = []
def _get_io_executor() -> ThreadPoolExecutor:
global _io_executor
if _io_executor is None:
_io_executor = ThreadPoolExecutor(max_workers=2)
return _io_executor
def _flush_io():
"""Wait for all pending async I/O to finish."""
global _io_futures
for fut in _io_futures:
try:
fut.result()
except Exception as e:
print(f">>> [async I/O] error: {e}")
_io_futures.clear()
atexit.register(_flush_io)
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
"""Synchronous save on CPU tensor (runs in background thread)."""
video = torch.clamp(video_cpu.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
"""Submit video saving to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
_io_futures.append(fut)
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
if video_cpu.dim() == 5:
n = video_cpu.shape[0]
video = video_cpu.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
writer.add_video(tag, grid, fps=fps)
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
"""Submit TensorBoard logging to background thread pool."""
if isinstance(data, torch.Tensor) and data.dim() == 5:
data_cpu = data.detach().cpu()
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
_io_futures.append(fut)
def get_init_frame_path(data_dir: str, sample: dict) -> str: def get_init_frame_path(data_dir: str, sample: dict) -> str:
"""Construct the init_frame path from directory and sample metadata. """Construct the init_frame path from directory and sample metadata.
@@ -462,26 +540,51 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path) df = pd.read_csv(csv_path)
# Load config # Load config (always needed for data setup)
config = OmegaConf.load(args.config) config = OmegaConf.load(args.config)
prepared_path = args.ckpt_path + ".prepared.pt"
if os.path.exists(prepared_path):
# ---- Fast path: load the fully-prepared model ----
print(f">>> Loading prepared model from {prepared_path} ...")
model = torch.load(prepared_path,
map_location=f"cuda:{gpu_no}",
weights_only=False,
mmap=True)
model.eval()
print(f">>> Prepared model loaded.")
else:
# ---- Normal path: construct + load checkpoint ----
config['model']['params']['wma_config']['params'][ config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False 'use_checkpoint'] = False
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path) model = load_model_checkpoint(model, args.ckpt_path)
model.eval() model.eval()
model = model.cuda(gpu_no)
print(f'>>> Load pre-trained model ...') print(f'>>> Load pre-trained model ...')
# Build unnomalizer # Save prepared model for fast loading next time
print(f">>> Saving prepared model to {prepared_path} ...")
torch.save(model, prepared_path)
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
# Build normalizer (always needed, independent of model loading path)
logging.info("***** Configing Data *****") logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data) data = instantiate_from_config(config.data)
data.setup() data.setup()
print(">>> Dataset is successfully loaded ...") print(">>> Dataset is successfully loaded ...")
model = model.cuda(gpu_no)
device = get_device_from_parameters(model) device = get_device_from_parameters(model)
# Fuse KV projections in attention layers (to_k + to_v → to_kv)
from unifolm_wma.modules.attention import CrossAttention
kv_count = sum(1 for m in model.modules()
if isinstance(m, CrossAttention) and m.fuse_kv())
print(f" ✓ KV fused: {kv_count} attention layers")
# Run over data # Run over data
assert (args.height % 16 == 0) and ( assert (args.height % 16 == 0) and (
args.width % 16 args.width % 16
@@ -654,16 +757,16 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
cond_obs_queues = populate_queues(cond_obs_queues, cond_obs_queues = populate_queues(cond_obs_queues,
observation) observation)
# Save the imagen videos for decision-making # Save the imagen videos for decision-making (async)
if pred_videos_0 is not None: if pred_videos_0 is not None:
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}" sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer, log_to_tensorboard_async(writer,
pred_videos_0, pred_videos_0,
sample_tag, sample_tag,
fps=args.save_fps) fps=args.save_fps)
# Save videos environment changes via world-model interaction # Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}" sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer, log_to_tensorboard_async(writer,
pred_videos_1, pred_videos_1,
sample_tag, sample_tag,
fps=args.save_fps) fps=args.save_fps)
@@ -671,12 +774,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Save the imagen videos for decision-making # Save the imagen videos for decision-making
if pred_videos_0 is not None: if pred_videos_0 is not None:
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(), save_results_async(pred_videos_0,
sample_video_file, sample_video_file,
fps=args.save_fps) fps=args.save_fps)
# Save videos environment changes via world-model interaction # Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_1.cpu(), save_results_async(pred_videos_1,
sample_video_file, sample_video_file,
fps=args.save_fps) fps=args.save_fps)
@@ -686,12 +789,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
full_video = torch.cat(wm_video, dim=2) full_video = torch.cat(wm_video, dim=2)
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full" sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard(writer, log_to_tensorboard_async(writer,
full_video, full_video,
sample_tag, sample_tag,
fps=args.save_fps) fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results(full_video, sample_full_video_file, fps=args.save_fps) save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
# Wait for all async I/O to complete
_flush_io()
def get_parser(): def get_parser():

View File

@@ -100,6 +100,7 @@ class CrossAttention(nn.Module):
self.agent_action_context_len = agent_action_context_len self.agent_action_context_len = agent_action_context_len
self._kv_cache = {} self._kv_cache = {}
self._kv_cache_enabled = False self._kv_cache_enabled = False
self._kv_fused = False
self.cross_attention_scale_learnable = cross_attention_scale_learnable self.cross_attention_scale_learnable = cross_attention_scale_learnable
if self.image_cross_attention: if self.image_cross_attention:
@@ -117,6 +118,27 @@ class CrossAttention(nn.Module):
self.register_parameter('alpha_caa', self.register_parameter('alpha_caa',
nn.Parameter(torch.tensor(0.))) nn.Parameter(torch.tensor(0.)))
def fuse_kv(self):
"""Fuse to_k/to_v into to_kv (2 Linear → 1). Works for all layers."""
k_w = self.to_k.weight # (inner_dim, context_dim)
v_w = self.to_v.weight
self.to_kv = nn.Linear(k_w.shape[1], k_w.shape[0] * 2, bias=False)
self.to_kv.weight = nn.Parameter(torch.cat([k_w, v_w], dim=0))
del self.to_k, self.to_v
if self.image_cross_attention:
for suffix in ('_ip', '_as', '_aa'):
k_attr = f'to_k{suffix}'
v_attr = f'to_v{suffix}'
kw = getattr(self, k_attr).weight
vw = getattr(self, v_attr).weight
fused = nn.Linear(kw.shape[1], kw.shape[0] * 2, bias=False)
fused.weight = nn.Parameter(torch.cat([kw, vw], dim=0))
setattr(self, f'to_kv{suffix}', fused)
delattr(self, k_attr)
delattr(self, v_attr)
self._kv_fused = True
return True
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None) spatial_self_attn = (context is None)
k_ip, v_ip, out_ip = None, None, None k_ip, v_ip, out_ip = None, None, None
@@ -143,6 +165,12 @@ class CrossAttention(nn.Module):
self.agent_action_context_len + self.agent_action_context_len +
self.text_context_len:, :] self.text_context_len:, :]
if self._kv_fused:
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1)
else:
k = self.to_k(context_ins) k = self.to_k(context_ins)
v = self.to_v(context_ins) v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image) k_ip = self.to_k_ip(context_image)
@@ -154,6 +182,9 @@ class CrossAttention(nn.Module):
else: else:
if not spatial_self_attn: if not spatial_self_attn:
context = context[:, :self.text_context_len, :] context = context[:, :self.text_context_len, :]
if self._kv_fused:
k, v = self.to_kv(context).chunk(2, dim=-1)
else:
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
@@ -267,6 +298,10 @@ class CrossAttention(nn.Module):
elif self.image_cross_attention and not spatial_self_attn: elif self.image_cross_attention and not spatial_self_attn:
if context.shape[1] == self.text_context_len + self.video_length: if context.shape[1] == self.text_context_len + self.video_length:
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :] context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
if self._kv_fused:
k, v = self.to_kv(context).chunk(2, dim=-1)
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
else:
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
k_ip = self.to_k_ip(context_image) k_ip = self.to_k_ip(context_image)
@@ -279,6 +314,11 @@ class CrossAttention(nn.Module):
context_agent_state = context[:, :self.agent_state_context_len, :] context_agent_state = context[:, :self.agent_state_context_len, :]
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :] context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :] context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
if self._kv_fused:
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
else:
k = self.to_k(context_ins) k = self.to_k(context_ins)
v = self.to_v(context_ins) v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image) k_ip = self.to_k_ip(context_image)
@@ -296,6 +336,12 @@ class CrossAttention(nn.Module):
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :] context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :] context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
if self._kv_fused:
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1)
else:
k = self.to_k(context_ins) k = self.to_k(context_ins)
v = self.to_v(context_ins) v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image) k_ip = self.to_k_ip(context_image)
@@ -328,6 +374,9 @@ class CrossAttention(nn.Module):
if not spatial_self_attn: if not spatial_self_attn:
assert 1 > 2, ">>> ERROR: you should never go into here ..." assert 1 > 2, ">>> ERROR: you should never go into here ..."
context = context[:, :self.text_context_len, :] context = context[:, :self.text_context_len, :]
if self._kv_fused:
k, v = self.to_kv(context).chunk(2, dim=-1)
else:
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
k, v = map(_reshape_kv, (k, v)) k, v = map(_reshape_kv, (k, v))

View File

@@ -688,6 +688,17 @@ class WMAModel(nn.Module):
# Context precomputation cache # Context precomputation cache
self._ctx_cache_enabled = False self._ctx_cache_enabled = False
self._ctx_cache = {} self._ctx_cache = {}
# Reusable CUDA stream for parallel state_unet / action_unet
self._state_stream = torch.cuda.Stream()
def __getstate__(self):
state = self.__dict__.copy()
state.pop('_state_stream', None)
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._state_stream = torch.cuda.Stream()
def forward(self, def forward(self,
x: Tensor, x: Tensor,
@@ -842,15 +853,16 @@ class WMAModel(nn.Module):
if not self.base_model_gen_only: if not self.base_model_gen_only:
ba, _, _ = x_action.shape ba, _, _ = x_action.shape
ts_state = timesteps[:ba] if b > 1 else timesteps
# Run action_unet and state_unet in parallel via CUDA streams
s_stream = self._state_stream
s_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s_stream):
s_y = self.state_unet(x_state, ts_state, hs_a,
context_action[:2], **kwargs)
a_y = self.action_unet(x_action, timesteps[:ba], hs_a, a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
context_action[:2], **kwargs) context_action[:2], **kwargs)
# Predict state torch.cuda.current_stream().wait_stream(s_stream)
if b > 1:
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
context_action[:2], **kwargs)
else:
s_y = self.state_unet(x_state, timesteps, hs_a,
context_action[:2], **kwargs)
else: else:
a_y = torch.zeros_like(x_action) a_y = torch.zeros_like(x_action)
s_y = torch.zeros_like(x_state) s_y = torch.zeros_like(x_state)

View File

@@ -1,24 +1,13 @@
2026-02-10 17:39:22.590654: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2026-02-11 11:59:27.241485: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-10 17:39:22.640645: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2026-02-11 11:59:27.291755: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-10 17:39:22.640689: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2026-02-11 11:59:27.291807: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-10 17:39:22.642010: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2026-02-11 11:59:27.293169: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-10 17:39:22.649530: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. 2026-02-11 11:59:27.300838: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-10 17:39:23.575804: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT 2026-02-11 11:59:28.228009: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123 Global seed set to 123
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode >>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08 >>> Prepared model loaded.
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data ***** INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded. >>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded. >>> unitree_z1_stackbox: data stats loaded.
@@ -36,13 +25,16 @@ INFO:root:***** Configing Data *****
>>> unitree_g1_pack_camera: data stats loaded. >>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated. >>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ... >>> Dataset is successfully loaded ...
✓ KV fused: 66 attention layers
>>> Generate 16 frames under each generation ... >>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5 DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13 DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9 DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096 DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s] 0%| | 0/11 [00:00<?, ?it/s]
9%|▉ | 1/11 [00:34<05:40, 34.05s/it]>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ... >>> Step 1: generating actions ...
DEBUG:PIL.Image:Importing BlpImagePlugin DEBUG:PIL.Image:Importing BlpImagePlugin
@@ -92,9 +84,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin
9%|▉ | 1/11 [00:35<05:55, 35.52s/it]
DEBUG:PIL.Image:Importing XVThumbImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin
18%|█▊ | 2/11 [01:08<05:07, 34.17s/it] 18%|█▊ | 2/11 [01:08<05:07, 34.17s/it]
@@ -125,6 +115,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 6: generating actions ... >>> Step 6: generating actions ...
>>> Step 6: interacting with world model ... >>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ... >>> Step 7: generating actions ...
>>> Step 7: interacting with world model ... >>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -1,5 +1,5 @@
{ {
"gt_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4", "gt_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
"pred_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4", "pred_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
"psnr": 25.12008483689618 "psnr": 28.167025381705358
} }