优化写入后新的所有结果
This commit is contained in:
@@ -10,6 +10,7 @@ import einops
|
||||
import warnings
|
||||
import imageio
|
||||
import atexit
|
||||
import multiprocessing as mp
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
@@ -231,6 +232,32 @@ def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> N
|
||||
_io_futures.append(fut)
|
||||
|
||||
|
||||
def _video_tensor_to_frames(video: Tensor) -> np.ndarray:
|
||||
video = torch.clamp(video.float(), -1., 1.)
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(f, nrow=int(n), padding=0) for f in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = ((grid + 1.0) / 2.0 * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
return grid.numpy()[:, :, :, ::-1]
|
||||
|
||||
|
||||
def _video_writer_process(q: mp.Queue, filename: str, fps: int):
|
||||
frames = []
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is None:
|
||||
break
|
||||
frames.append(_video_tensor_to_frames(item))
|
||||
if frames:
|
||||
grid = np.concatenate(frames, axis=0)
|
||||
grid = torch.from_numpy(grid[:, :, :, ::-1].copy()) # BGR → RGB
|
||||
torchvision.io.write_video(filename, grid, fps=fps,
|
||||
video_codec='h264', options={'crf': '10'})
|
||||
|
||||
|
||||
def get_init_frame_path(data_dir: str, sample: dict) -> str:
|
||||
"""Construct the init_frame path from directory and sample metadata.
|
||||
|
||||
@@ -648,8 +675,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# For saving environmental changes in world-model
|
||||
sample_save_dir = f'{video_save_dir}/wm/{fs}'
|
||||
os.makedirs(sample_save_dir, exist_ok=True)
|
||||
# For collecting interaction videos
|
||||
wm_latent = []
|
||||
# Writer process for incremental video saving
|
||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||
write_q = mp.Queue()
|
||||
writer_proc = mp.Process(
|
||||
target=_video_writer_process,
|
||||
args=(write_q, sample_full_video_file, args.save_fps))
|
||||
writer_proc.start()
|
||||
# Initialize observation queues
|
||||
cond_obs_queues = {
|
||||
"observation.images.top":
|
||||
@@ -789,19 +821,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
fps=args.save_fps)
|
||||
|
||||
print('>' * 24)
|
||||
# Store raw latent for deferred decode
|
||||
wm_latent.append(wm_samples[:, :, :args.exe_steps].cpu())
|
||||
# Decode segment and send to writer process
|
||||
seg_video = model.decode_first_stage(
|
||||
wm_samples[:, :, :args.exe_steps]).detach().cpu()
|
||||
write_q.put(seg_video)
|
||||
|
||||
# Deferred decode: batch decode all stored latents
|
||||
full_latent = torch.cat(wm_latent, dim=2).to(device)
|
||||
full_video = model.decode_first_stage(full_latent).cpu()
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||
log_to_tensorboard_async(writer,
|
||||
full_video,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
# Stop writer process
|
||||
write_q.put(None)
|
||||
writer_proc.join()
|
||||
|
||||
# Wait for all async I/O to complete
|
||||
_flush_io()
|
||||
|
||||
Reference in New Issue
Block a user