Layer 3: 延迟 decode,只解码 CLIP 需要的 1 帧
- world model 调用 decode_video=False,跳过 16 帧全量 decode - 只 decode 最后 1 帧给 CLIP embedding / observation queue - 存 raw latent,循环结束后统一 batch decode 生成最终视频 - 每轮省 15 次 VAE decode,8 轮共省 120 次 - 跳过中间迭代的 wm tensorboard/mp4 保存 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -559,6 +559,7 @@ def image_guided_synthesis_sim_mode(
|
|||||||
autocast_ctx = nullcontext()
|
autocast_ctx = nullcontext()
|
||||||
|
|
||||||
batch_variants = None
|
batch_variants = None
|
||||||
|
samples = None
|
||||||
if ddim_sampler is not None:
|
if ddim_sampler is not None:
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||||
@@ -583,7 +584,7 @@ def image_guided_synthesis_sim_mode(
|
|||||||
batch_images = model.decode_first_stage(samples)
|
batch_images = model.decode_first_stage(samples)
|
||||||
batch_variants = batch_images
|
batch_variants = batch_images
|
||||||
|
|
||||||
return batch_variants, actions, states
|
return batch_variants, actions, states, samples
|
||||||
|
|
||||||
|
|
||||||
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||||
@@ -693,7 +694,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
sample_save_dir = f'{video_save_dir}/wm/{fs}'
|
sample_save_dir = f'{video_save_dir}/wm/{fs}'
|
||||||
os.makedirs(sample_save_dir, exist_ok=True)
|
os.makedirs(sample_save_dir, exist_ok=True)
|
||||||
# For collecting interaction videos
|
# For collecting interaction videos
|
||||||
wm_video = []
|
wm_latent = []
|
||||||
# Initialize observation queues
|
# Initialize observation queues
|
||||||
cond_obs_queues = {
|
cond_obs_queues = {
|
||||||
"observation.images.top":
|
"observation.images.top":
|
||||||
@@ -749,7 +750,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
# Use world-model in policy to generate action
|
# Use world-model in policy to generate action
|
||||||
print(f'>>> Step {itr}: generating actions ...')
|
print(f'>>> Step {itr}: generating actions ...')
|
||||||
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
pred_videos_0, pred_actions, _, _ = image_guided_synthesis_sim_mode(
|
||||||
model,
|
model,
|
||||||
sample['instruction'],
|
sample['instruction'],
|
||||||
observation,
|
observation,
|
||||||
@@ -791,7 +792,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
# Interaction with the world-model
|
# Interaction with the world-model
|
||||||
print(f'>>> Step {itr}: interacting with world model ...')
|
print(f'>>> Step {itr}: interacting with world model ...')
|
||||||
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
pred_videos_1, _, pred_states, wm_samples = image_guided_synthesis_sim_mode(
|
||||||
model,
|
model,
|
||||||
"",
|
"",
|
||||||
observation,
|
observation,
|
||||||
@@ -804,12 +805,16 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
fs=model_input_fs,
|
fs=model_input_fs,
|
||||||
text_input=False,
|
text_input=False,
|
||||||
timestep_spacing=args.timestep_spacing,
|
timestep_spacing=args.timestep_spacing,
|
||||||
guidance_rescale=args.guidance_rescale)
|
guidance_rescale=args.guidance_rescale,
|
||||||
|
decode_video=False)
|
||||||
|
|
||||||
|
# Decode only the last frame for CLIP embedding in next iteration
|
||||||
|
last_frame_pixel = model.decode_first_stage(wm_samples[:, :, -1:, :, :])
|
||||||
|
|
||||||
for idx in range(args.exe_steps):
|
for idx in range(args.exe_steps):
|
||||||
observation = {
|
observation = {
|
||||||
'observation.images.top':
|
'observation.images.top':
|
||||||
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
last_frame_pixel[0, :, 0:1].permute(1, 0, 2, 3),
|
||||||
'observation.state':
|
'observation.state':
|
||||||
torch.zeros_like(pred_states[0][idx:idx + 1]) if
|
torch.zeros_like(pred_states[0][idx:idx + 1]) if
|
||||||
args.zero_pred_state else pred_states[0][idx:idx + 1],
|
args.zero_pred_state else pred_states[0][idx:idx + 1],
|
||||||
@@ -827,30 +832,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
pred_videos_0,
|
pred_videos_0,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
# Save videos environment changes via world-model interaction
|
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
|
||||||
log_to_tensorboard(writer,
|
|
||||||
pred_videos_1,
|
|
||||||
sample_tag,
|
|
||||||
fps=args.save_fps)
|
|
||||||
|
|
||||||
# Save the imagen videos for decision-making
|
|
||||||
if pred_videos_0 is not None:
|
|
||||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
|
||||||
save_results(pred_videos_0.cpu(),
|
|
||||||
sample_video_file,
|
|
||||||
fps=args.save_fps)
|
|
||||||
# Save videos environment changes via world-model interaction
|
|
||||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
|
||||||
save_results(pred_videos_1.cpu(),
|
|
||||||
sample_video_file,
|
|
||||||
fps=args.save_fps)
|
|
||||||
|
|
||||||
print('>' * 24)
|
print('>' * 24)
|
||||||
# Collect the result of world-model interactions
|
# Store raw latent for deferred decode
|
||||||
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
|
wm_latent.append(wm_samples[:, :, :args.exe_steps].cpu())
|
||||||
|
|
||||||
full_video = torch.cat(wm_video, dim=2)
|
# Deferred decode: batch decode all stored latents
|
||||||
|
full_latent = torch.cat(wm_latent, dim=2).to(device)
|
||||||
|
full_video = model.decode_first_stage(full_latent).cpu()
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||||
log_to_tensorboard(writer,
|
log_to_tensorboard(writer,
|
||||||
full_video,
|
full_video,
|
||||||
|
|||||||
Reference in New Issue
Block a user