速度变化不明显psnr显著提升

This commit is contained in:
qhy
2026-02-11 16:38:21 +08:00
parent f386a5810b
commit 3101252c25
4 changed files with 58 additions and 37 deletions

View File

@@ -450,8 +450,9 @@ def image_guided_synthesis_sim_mode(
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
cond_img_emb = model.embedder(cond_img) with torch.cuda.amp.autocast(dtype=torch.float16):
cond_img_emb = model.image_proj_model(cond_img_emb) cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
if model.model.conditioning_key == 'hybrid': if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
@@ -465,11 +466,12 @@ def image_guided_synthesis_sim_mode(
prompts = [""] * batch_size prompts = [""] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts) cond_ins_emb = model.get_learned_conditioning(prompts)
cond_state_emb = model.state_projector(observation['observation.state']) with torch.cuda.amp.autocast(dtype=torch.float16):
cond_state_emb = cond_state_emb + model.agent_state_pos_emb cond_state_emb = model.state_projector(observation['observation.state'])
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(observation['action']) cond_action_emb = model.action_projector(observation['action'])
cond_action_emb = cond_action_emb + model.agent_action_pos_emb cond_action_emb = cond_action_emb + model.agent_action_pos_emb
if not sim_mode: if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb) cond_action_emb = torch.zeros_like(cond_action_emb)
@@ -571,11 +573,22 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
torch.save(model, prepared_path) torch.save(model, prepared_path)
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).") print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
# ---- BF16: only convert the diffusion backbone, keep VAE/CLIP/embedder in FP32 ---- # ---- FP16: convert diffusion backbone + conditioning modules ----
model.model.to(torch.float16) model.model.to(torch.float16)
model.model.diffusion_model.dtype = torch.float16 model.model.diffusion_model.dtype = torch.float16
print(">>> Diffusion backbone (model.model) converted to FP16.") print(">>> Diffusion backbone (model.model) converted to FP16.")
# Projectors / MLP → FP16
model.image_proj_model.half()
model.state_projector.half()
model.action_projector.half()
print(">>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.")
# Text/image encoders → FP16
model.cond_stage_model.half()
model.embedder.half()
print(">>> Encoders (cond_stage_model, embedder) converted to FP16.")
# Build normalizer (always needed, independent of model loading path) # Build normalizer (always needed, independent of model loading path)
logging.info("***** Configing Data *****") logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data) data = instantiate_from_config(config.data)

View File

@@ -988,7 +988,7 @@ class LatentDiffusion(DDPM):
def instantiate_cond_stage(self, config: OmegaConf) -> None: def instantiate_cond_stage(self, config: OmegaConf) -> None:
""" """
Build the conditioning stage model. Build the conditioning stage model. Frozen models are converted to FP16.
Args: Args:
config: OmegaConf config describing the conditioning model to instantiate. config: OmegaConf config describing the conditioning model to instantiate.
@@ -1000,6 +1000,7 @@ class LatentDiffusion(DDPM):
self.cond_stage_model.train = disabled_train self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters(): for param in self.cond_stage_model.parameters():
param.requires_grad = False param.requires_grad = False
self.cond_stage_model.half()
else: else:
model = instantiate_from_config(config) model = instantiate_from_config(config)
self.cond_stage_model = model self.cond_stage_model = model
@@ -1014,17 +1015,18 @@ class LatentDiffusion(DDPM):
Returns: Returns:
Conditioning embedding as a tensor (shape depends on cond model). Conditioning embedding as a tensor (shape depends on cond model).
""" """
if self.cond_stage_forward is None: with torch.cuda.amp.autocast(dtype=torch.float16):
if hasattr(self.cond_stage_model, 'encode') and callable( if self.cond_stage_forward is None:
self.cond_stage_model.encode): if hasattr(self.cond_stage_model, 'encode') and callable(
c = self.cond_stage_model.encode(c) self.cond_stage_model.encode):
if isinstance(c, DiagonalGaussianDistribution): c = self.cond_stage_model.encode(c)
c = c.mode() if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else: else:
c = self.cond_stage_model(c) assert hasattr(self.cond_stage_model, self.cond_stage_forward)
else: c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c return c
def get_first_stage_encoding( def get_first_stage_encoding(
@@ -1957,6 +1959,7 @@ class LatentVisualDiffusion(LatentDiffusion):
self.image_proj_model.train = disabled_train self.image_proj_model.train = disabled_train
for param in self.image_proj_model.parameters(): for param in self.image_proj_model.parameters():
param.requires_grad = False param.requires_grad = False
self.image_proj_model.half()
def _init_embedder(self, config: OmegaConf, freeze: bool = True) -> None: def _init_embedder(self, config: OmegaConf, freeze: bool = True) -> None:
""" """
@@ -1972,6 +1975,7 @@ class LatentVisualDiffusion(LatentDiffusion):
self.embedder.train = disabled_train self.embedder.train = disabled_train
for param in self.embedder.parameters(): for param in self.embedder.parameters():
param.requires_grad = False param.requires_grad = False
self.embedder.half()
def init_normalizers(self, normalize_config: OmegaConf, def init_normalizers(self, normalize_config: OmegaConf,
dataset_stats: Mapping[str, Any]) -> None: dataset_stats: Mapping[str, Any]) -> None:
@@ -2175,8 +2179,9 @@ class LatentVisualDiffusion(LatentDiffusion):
(random_num < 3 * self.uncond_prob).float(), "n -> n 1 1 1") (random_num < 3 * self.uncond_prob).float(), "n -> n 1 1 1")
cond_img = input_mask * img cond_img = input_mask * img
cond_img_emb = self.embedder(cond_img) with torch.cuda.amp.autocast(dtype=torch.float16):
cond_img_emb = self.image_proj_model(cond_img_emb) cond_img_emb = self.embedder(cond_img)
cond_img_emb = self.image_proj_model(cond_img_emb)
if self.model.conditioning_key == 'hybrid': if self.model.conditioning_key == 'hybrid':
if self.interp_mode: if self.interp_mode:
@@ -2191,11 +2196,12 @@ class LatentVisualDiffusion(LatentDiffusion):
repeat=z.shape[2]) repeat=z.shape[2])
cond["c_concat"] = [img_cat_cond] cond["c_concat"] = [img_cat_cond]
cond_action = self.action_projector(action) with torch.cuda.amp.autocast(dtype=torch.float16):
cond_action_emb = self.agent_action_pos_emb + cond_action cond_action = self.action_projector(action)
# Get conditioning states cond_action_emb = self.agent_action_pos_emb + cond_action
cond_state = self.state_projector(obs_state) # Get conditioning states
cond_state_emb = self.agent_state_pos_emb + cond_state cond_state = self.state_projector(obs_state)
cond_state_emb = self.agent_state_pos_emb + cond_state
if self.decision_making_only: if self.decision_making_only:
is_sim_mode = False is_sim_mode = False

View File

@@ -1,14 +1,16 @@
2026-02-11 16:14:08.942290: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2026-02-11 16:32:03.555597: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-11 16:14:08.992267: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2026-02-11 16:32:03.605506: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-11 16:14:08.992319: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2026-02-11 16:32:03.605550: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-11 16:14:08.993621: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2026-02-11 16:32:03.606879: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-11 16:14:09.001096: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. 2026-02-11 16:32:03.614434: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-11 16:14:09.927986: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT 2026-02-11 16:32:04.545234: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123 Global seed set to 123
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ... >>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
>>> Prepared model loaded. >>> Prepared model loaded.
>>> Diffusion backbone (model.model) converted to FP16. >>> Diffusion backbone (model.model) converted to FP16.
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
>>> Encoders (cond_stage_model, embedder) converted to FP16.
INFO:root:***** Configing Data ***** INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded. >>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded. >>> unitree_z1_stackbox: data stats loaded.
@@ -32,7 +34,7 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13 DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9 DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096 DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s] 0%| | 0/11 [00:00<?, ?it/s]
9%|▉ | 1/11 [00:23<03:56, 23.68s/it]>>> Step 0: generating actions ... 9%|▉ | 1/11 [00:23<03:56, 23.68s/it]>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ... >>> Step 0: interacting with world model ...
@@ -85,7 +87,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin
18%|█▊ | 2/11 [00:47<03:31, 23.51s/it] 18%|█▊ | 2/11 [00:47<03:31, 23.51s/it]
@@ -116,6 +118,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 6: generating actions ... >>> Step 6: generating actions ...
>>> Step 6: interacting with world model ... >>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ... >>> Step 7: generating actions ...
>>> Step 7: interacting with world model ... >>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -1,5 +1,5 @@
{ {
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4", "gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4", "pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
"psnr": 25.21894470816415 "psnr": 27.185465604200047
} }