速度变化不明显psnr显著提升

This commit is contained in:
qhy
2026-02-11 16:38:21 +08:00
parent f386a5810b
commit 3101252c25
4 changed files with 58 additions and 37 deletions

View File

@@ -450,8 +450,9 @@ def image_guided_synthesis_sim_mode(
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
with torch.cuda.amp.autocast(dtype=torch.float16):
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
@@ -465,11 +466,12 @@ def image_guided_synthesis_sim_mode(
prompts = [""] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts)
cond_state_emb = model.state_projector(observation['observation.state'])
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
with torch.cuda.amp.autocast(dtype=torch.float16):
cond_state_emb = model.state_projector(observation['observation.state'])
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(observation['action'])
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
cond_action_emb = model.action_projector(observation['action'])
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb)
@@ -571,11 +573,22 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
torch.save(model, prepared_path)
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
# ---- BF16: only convert the diffusion backbone, keep VAE/CLIP/embedder in FP32 ----
# ---- FP16: convert diffusion backbone + conditioning modules ----
model.model.to(torch.float16)
model.model.diffusion_model.dtype = torch.float16
print(">>> Diffusion backbone (model.model) converted to FP16.")
# Projectors / MLP → FP16
model.image_proj_model.half()
model.state_projector.half()
model.action_projector.half()
print(">>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.")
# Text/image encoders → FP16
model.cond_stage_model.half()
model.embedder.half()
print(">>> Encoders (cond_stage_model, embedder) converted to FP16.")
# Build normalizer (always needed, independent of model loading path)
logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data)

View File

@@ -988,7 +988,7 @@ class LatentDiffusion(DDPM):
def instantiate_cond_stage(self, config: OmegaConf) -> None:
"""
Build the conditioning stage model.
Build the conditioning stage model. Frozen models are converted to FP16.
Args:
config: OmegaConf config describing the conditioning model to instantiate.
@@ -1000,6 +1000,7 @@ class LatentDiffusion(DDPM):
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
self.cond_stage_model.half()
else:
model = instantiate_from_config(config)
self.cond_stage_model = model
@@ -1014,17 +1015,18 @@ class LatentDiffusion(DDPM):
Returns:
Conditioning embedding as a tensor (shape depends on cond model).
"""
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(
self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
with torch.cuda.amp.autocast(dtype=torch.float16):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(
self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
def get_first_stage_encoding(
@@ -1957,6 +1959,7 @@ class LatentVisualDiffusion(LatentDiffusion):
self.image_proj_model.train = disabled_train
for param in self.image_proj_model.parameters():
param.requires_grad = False
self.image_proj_model.half()
def _init_embedder(self, config: OmegaConf, freeze: bool = True) -> None:
"""
@@ -1972,6 +1975,7 @@ class LatentVisualDiffusion(LatentDiffusion):
self.embedder.train = disabled_train
for param in self.embedder.parameters():
param.requires_grad = False
self.embedder.half()
def init_normalizers(self, normalize_config: OmegaConf,
dataset_stats: Mapping[str, Any]) -> None:
@@ -2175,8 +2179,9 @@ class LatentVisualDiffusion(LatentDiffusion):
(random_num < 3 * self.uncond_prob).float(), "n -> n 1 1 1")
cond_img = input_mask * img
cond_img_emb = self.embedder(cond_img)
cond_img_emb = self.image_proj_model(cond_img_emb)
with torch.cuda.amp.autocast(dtype=torch.float16):
cond_img_emb = self.embedder(cond_img)
cond_img_emb = self.image_proj_model(cond_img_emb)
if self.model.conditioning_key == 'hybrid':
if self.interp_mode:
@@ -2191,11 +2196,12 @@ class LatentVisualDiffusion(LatentDiffusion):
repeat=z.shape[2])
cond["c_concat"] = [img_cat_cond]
cond_action = self.action_projector(action)
cond_action_emb = self.agent_action_pos_emb + cond_action
# Get conditioning states
cond_state = self.state_projector(obs_state)
cond_state_emb = self.agent_state_pos_emb + cond_state
with torch.cuda.amp.autocast(dtype=torch.float16):
cond_action = self.action_projector(action)
cond_action_emb = self.agent_action_pos_emb + cond_action
# Get conditioning states
cond_state = self.state_projector(obs_state)
cond_state_emb = self.agent_state_pos_emb + cond_state
if self.decision_making_only:
is_sim_mode = False

View File

@@ -1,14 +1,16 @@
2026-02-11 16:14:08.942290: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-11 16:14:08.992267: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-11 16:14:08.992319: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-11 16:14:08.993621: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-11 16:14:09.001096: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
2026-02-11 16:32:03.555597: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-11 16:32:03.605506: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-11 16:32:03.605550: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-11 16:32:03.606879: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-11 16:32:03.614434: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-11 16:14:09.927986: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2026-02-11 16:32:04.545234: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
>>> Prepared model loaded.
>>> Diffusion backbone (model.model) converted to FP16.
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
>>> Encoders (cond_stage_model, embedder) converted to FP16.
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
@@ -32,7 +34,7 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]
9%|▉ | 1/11 [00:23<03:56, 23.68s/it]>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
@@ -85,7 +87,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
18%|█▊ | 2/11 [00:47<03:31, 23.51s/it]
@@ -116,6 +118,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -1,5 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
"psnr": 25.21894470816415
"psnr": 27.185465604200047
}