优化写入后新的所有结果

This commit is contained in:
qhy
2026-02-19 20:18:31 +08:00
parent 5e0e21d91b
commit 43ab0f71b0
28 changed files with 1776 additions and 1199 deletions

View File

@@ -10,6 +10,7 @@ import einops
import warnings
import imageio
import atexit
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor
from pytorch_lightning import seed_everything
@@ -231,6 +232,32 @@ def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> N
_io_futures.append(fut)
def _video_tensor_to_frames(video: Tensor) -> np.ndarray:
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(f, nrow=int(n), padding=0) for f in video
]
grid = torch.stack(frame_grids, dim=0)
grid = ((grid + 1.0) / 2.0 * 255).to(torch.uint8).permute(0, 2, 3, 1)
return grid.numpy()[:, :, :, ::-1]
def _video_writer_process(q: mp.Queue, filename: str, fps: int):
frames = []
while True:
item = q.get()
if item is None:
break
frames.append(_video_tensor_to_frames(item))
if frames:
grid = np.concatenate(frames, axis=0)
grid = torch.from_numpy(grid[:, :, :, ::-1].copy()) # BGR → RGB
torchvision.io.write_video(filename, grid, fps=fps,
video_codec='h264', options={'crf': '10'})
def get_init_frame_path(data_dir: str, sample: dict) -> str:
"""Construct the init_frame path from directory and sample metadata.
@@ -648,8 +675,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# For saving environmental changes in world-model
sample_save_dir = f'{video_save_dir}/wm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# For collecting interaction videos
wm_latent = []
# Writer process for incremental video saving
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
write_q = mp.Queue()
writer_proc = mp.Process(
target=_video_writer_process,
args=(write_q, sample_full_video_file, args.save_fps))
writer_proc.start()
# Initialize observation queues
cond_obs_queues = {
"observation.images.top":
@@ -789,19 +821,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
fps=args.save_fps)
print('>' * 24)
# Store raw latent for deferred decode
wm_latent.append(wm_samples[:, :, :args.exe_steps].cpu())
# Decode segment and send to writer process
seg_video = model.decode_first_stage(
wm_samples[:, :, :args.exe_steps]).detach().cpu()
write_q.put(seg_video)
# Deferred decode: batch decode all stored latents
full_latent = torch.cat(wm_latent, dim=2).to(device)
full_video = model.decode_first_stage(full_latent).cpu()
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard_async(writer,
full_video,
sample_tag,
fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
# Stop writer process
write_q.put(None)
writer_proc.join()
# Wait for all async I/O to complete
_flush_io()