速度变化不明显psnr显著提升
This commit is contained in:
@@ -450,8 +450,9 @@ def image_guided_synthesis_sim_mode(
|
|||||||
|
|
||||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||||
cond_img_emb = model.embedder(cond_img)
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
cond_img_emb = model.embedder(cond_img)
|
||||||
|
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||||
|
|
||||||
if model.model.conditioning_key == 'hybrid':
|
if model.model.conditioning_key == 'hybrid':
|
||||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||||
@@ -465,11 +466,12 @@ def image_guided_synthesis_sim_mode(
|
|||||||
prompts = [""] * batch_size
|
prompts = [""] * batch_size
|
||||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
cond_state_emb = model.state_projector(observation['observation.state'])
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
cond_state_emb = model.state_projector(observation['observation.state'])
|
||||||
|
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||||
|
|
||||||
cond_action_emb = model.action_projector(observation['action'])
|
cond_action_emb = model.action_projector(observation['action'])
|
||||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||||
|
|
||||||
if not sim_mode:
|
if not sim_mode:
|
||||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||||
@@ -571,11 +573,22 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
torch.save(model, prepared_path)
|
torch.save(model, prepared_path)
|
||||||
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
||||||
|
|
||||||
# ---- BF16: only convert the diffusion backbone, keep VAE/CLIP/embedder in FP32 ----
|
# ---- FP16: convert diffusion backbone + conditioning modules ----
|
||||||
model.model.to(torch.float16)
|
model.model.to(torch.float16)
|
||||||
model.model.diffusion_model.dtype = torch.float16
|
model.model.diffusion_model.dtype = torch.float16
|
||||||
print(">>> Diffusion backbone (model.model) converted to FP16.")
|
print(">>> Diffusion backbone (model.model) converted to FP16.")
|
||||||
|
|
||||||
|
# Projectors / MLP → FP16
|
||||||
|
model.image_proj_model.half()
|
||||||
|
model.state_projector.half()
|
||||||
|
model.action_projector.half()
|
||||||
|
print(">>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.")
|
||||||
|
|
||||||
|
# Text/image encoders → FP16
|
||||||
|
model.cond_stage_model.half()
|
||||||
|
model.embedder.half()
|
||||||
|
print(">>> Encoders (cond_stage_model, embedder) converted to FP16.")
|
||||||
|
|
||||||
# Build normalizer (always needed, independent of model loading path)
|
# Build normalizer (always needed, independent of model loading path)
|
||||||
logging.info("***** Configing Data *****")
|
logging.info("***** Configing Data *****")
|
||||||
data = instantiate_from_config(config.data)
|
data = instantiate_from_config(config.data)
|
||||||
|
|||||||
@@ -988,7 +988,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
def instantiate_cond_stage(self, config: OmegaConf) -> None:
|
def instantiate_cond_stage(self, config: OmegaConf) -> None:
|
||||||
"""
|
"""
|
||||||
Build the conditioning stage model.
|
Build the conditioning stage model. Frozen models are converted to FP16.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: OmegaConf config describing the conditioning model to instantiate.
|
config: OmegaConf config describing the conditioning model to instantiate.
|
||||||
@@ -1000,6 +1000,7 @@ class LatentDiffusion(DDPM):
|
|||||||
self.cond_stage_model.train = disabled_train
|
self.cond_stage_model.train = disabled_train
|
||||||
for param in self.cond_stage_model.parameters():
|
for param in self.cond_stage_model.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
self.cond_stage_model.half()
|
||||||
else:
|
else:
|
||||||
model = instantiate_from_config(config)
|
model = instantiate_from_config(config)
|
||||||
self.cond_stage_model = model
|
self.cond_stage_model = model
|
||||||
@@ -1014,17 +1015,18 @@ class LatentDiffusion(DDPM):
|
|||||||
Returns:
|
Returns:
|
||||||
Conditioning embedding as a tensor (shape depends on cond model).
|
Conditioning embedding as a tensor (shape depends on cond model).
|
||||||
"""
|
"""
|
||||||
if self.cond_stage_forward is None:
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
if hasattr(self.cond_stage_model, 'encode') and callable(
|
if self.cond_stage_forward is None:
|
||||||
self.cond_stage_model.encode):
|
if hasattr(self.cond_stage_model, 'encode') and callable(
|
||||||
c = self.cond_stage_model.encode(c)
|
self.cond_stage_model.encode):
|
||||||
if isinstance(c, DiagonalGaussianDistribution):
|
c = self.cond_stage_model.encode(c)
|
||||||
c = c.mode()
|
if isinstance(c, DiagonalGaussianDistribution):
|
||||||
|
c = c.mode()
|
||||||
|
else:
|
||||||
|
c = self.cond_stage_model(c)
|
||||||
else:
|
else:
|
||||||
c = self.cond_stage_model(c)
|
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
||||||
else:
|
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
||||||
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
|
||||||
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def get_first_stage_encoding(
|
def get_first_stage_encoding(
|
||||||
@@ -1957,6 +1959,7 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
self.image_proj_model.train = disabled_train
|
self.image_proj_model.train = disabled_train
|
||||||
for param in self.image_proj_model.parameters():
|
for param in self.image_proj_model.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
self.image_proj_model.half()
|
||||||
|
|
||||||
def _init_embedder(self, config: OmegaConf, freeze: bool = True) -> None:
|
def _init_embedder(self, config: OmegaConf, freeze: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -1972,6 +1975,7 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
self.embedder.train = disabled_train
|
self.embedder.train = disabled_train
|
||||||
for param in self.embedder.parameters():
|
for param in self.embedder.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
self.embedder.half()
|
||||||
|
|
||||||
def init_normalizers(self, normalize_config: OmegaConf,
|
def init_normalizers(self, normalize_config: OmegaConf,
|
||||||
dataset_stats: Mapping[str, Any]) -> None:
|
dataset_stats: Mapping[str, Any]) -> None:
|
||||||
@@ -2175,8 +2179,9 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
(random_num < 3 * self.uncond_prob).float(), "n -> n 1 1 1")
|
(random_num < 3 * self.uncond_prob).float(), "n -> n 1 1 1")
|
||||||
|
|
||||||
cond_img = input_mask * img
|
cond_img = input_mask * img
|
||||||
cond_img_emb = self.embedder(cond_img)
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
cond_img_emb = self.image_proj_model(cond_img_emb)
|
cond_img_emb = self.embedder(cond_img)
|
||||||
|
cond_img_emb = self.image_proj_model(cond_img_emb)
|
||||||
|
|
||||||
if self.model.conditioning_key == 'hybrid':
|
if self.model.conditioning_key == 'hybrid':
|
||||||
if self.interp_mode:
|
if self.interp_mode:
|
||||||
@@ -2191,11 +2196,12 @@ class LatentVisualDiffusion(LatentDiffusion):
|
|||||||
repeat=z.shape[2])
|
repeat=z.shape[2])
|
||||||
cond["c_concat"] = [img_cat_cond]
|
cond["c_concat"] = [img_cat_cond]
|
||||||
|
|
||||||
cond_action = self.action_projector(action)
|
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||||
cond_action_emb = self.agent_action_pos_emb + cond_action
|
cond_action = self.action_projector(action)
|
||||||
# Get conditioning states
|
cond_action_emb = self.agent_action_pos_emb + cond_action
|
||||||
cond_state = self.state_projector(obs_state)
|
# Get conditioning states
|
||||||
cond_state_emb = self.agent_state_pos_emb + cond_state
|
cond_state = self.state_projector(obs_state)
|
||||||
|
cond_state_emb = self.agent_state_pos_emb + cond_state
|
||||||
|
|
||||||
if self.decision_making_only:
|
if self.decision_making_only:
|
||||||
is_sim_mode = False
|
is_sim_mode = False
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
2026-02-11 16:14:08.942290: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-11 16:32:03.555597: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-11 16:14:08.992267: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-11 16:32:03.605506: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-11 16:14:08.992319: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-11 16:32:03.605550: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-11 16:14:08.993621: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-11 16:32:03.606879: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-11 16:14:09.001096: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-11 16:32:03.614434: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-11 16:14:09.927986: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-11 16:32:04.545234: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||||
>>> Prepared model loaded.
|
>>> Prepared model loaded.
|
||||||
>>> Diffusion backbone (model.model) converted to FP16.
|
>>> Diffusion backbone (model.model) converted to FP16.
|
||||||
|
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||||
|
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -32,7 +34,7 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
|
|||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s]
|
0%| | 0/11 [00:00<?, ?it/s]
|
||||||
9%|▉ | 1/11 [00:23<03:56, 23.68s/it]>>> Step 0: generating actions ...
|
9%|▉ | 1/11 [00:23<03:56, 23.68s/it]>>> Step 0: generating actions ...
|
||||||
>>> Step 0: interacting with world model ...
|
>>> Step 0: interacting with world model ...
|
||||||
@@ -85,7 +87,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
|||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
18%|█▊ | 2/11 [00:47<03:31, 23.51s/it]
|
18%|█▊ | 2/11 [00:47<03:31, 23.51s/it]
|
||||||
@@ -116,6 +118,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
|
||||||
"psnr": 25.21894470816415
|
"psnr": 27.185465604200047
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user