速度变化不明显psnr显著提升
This commit is contained in:
@@ -450,6 +450,7 @@ def image_guided_synthesis_sim_mode(
|
||||
|
||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
|
||||
@@ -465,6 +466,7 @@ def image_guided_synthesis_sim_mode(
|
||||
prompts = [""] * batch_size
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
cond_state_emb = model.state_projector(observation['observation.state'])
|
||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||
|
||||
@@ -571,11 +573,22 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
torch.save(model, prepared_path)
|
||||
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
||||
|
||||
# ---- BF16: only convert the diffusion backbone, keep VAE/CLIP/embedder in FP32 ----
|
||||
# ---- FP16: convert diffusion backbone + conditioning modules ----
|
||||
model.model.to(torch.float16)
|
||||
model.model.diffusion_model.dtype = torch.float16
|
||||
print(">>> Diffusion backbone (model.model) converted to FP16.")
|
||||
|
||||
# Projectors / MLP → FP16
|
||||
model.image_proj_model.half()
|
||||
model.state_projector.half()
|
||||
model.action_projector.half()
|
||||
print(">>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.")
|
||||
|
||||
# Text/image encoders → FP16
|
||||
model.cond_stage_model.half()
|
||||
model.embedder.half()
|
||||
print(">>> Encoders (cond_stage_model, embedder) converted to FP16.")
|
||||
|
||||
# Build normalizer (always needed, independent of model loading path)
|
||||
logging.info("***** Configing Data *****")
|
||||
data = instantiate_from_config(config.data)
|
||||
|
||||
@@ -988,7 +988,7 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
def instantiate_cond_stage(self, config: OmegaConf) -> None:
|
||||
"""
|
||||
Build the conditioning stage model.
|
||||
Build the conditioning stage model. Frozen models are converted to FP16.
|
||||
|
||||
Args:
|
||||
config: OmegaConf config describing the conditioning model to instantiate.
|
||||
@@ -1000,6 +1000,7 @@ class LatentDiffusion(DDPM):
|
||||
self.cond_stage_model.train = disabled_train
|
||||
for param in self.cond_stage_model.parameters():
|
||||
param.requires_grad = False
|
||||
self.cond_stage_model.half()
|
||||
else:
|
||||
model = instantiate_from_config(config)
|
||||
self.cond_stage_model = model
|
||||
@@ -1014,6 +1015,7 @@ class LatentDiffusion(DDPM):
|
||||
Returns:
|
||||
Conditioning embedding as a tensor (shape depends on cond model).
|
||||
"""
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
if self.cond_stage_forward is None:
|
||||
if hasattr(self.cond_stage_model, 'encode') and callable(
|
||||
self.cond_stage_model.encode):
|
||||
@@ -1957,6 +1959,7 @@ class LatentVisualDiffusion(LatentDiffusion):
|
||||
self.image_proj_model.train = disabled_train
|
||||
for param in self.image_proj_model.parameters():
|
||||
param.requires_grad = False
|
||||
self.image_proj_model.half()
|
||||
|
||||
def _init_embedder(self, config: OmegaConf, freeze: bool = True) -> None:
|
||||
"""
|
||||
@@ -1972,6 +1975,7 @@ class LatentVisualDiffusion(LatentDiffusion):
|
||||
self.embedder.train = disabled_train
|
||||
for param in self.embedder.parameters():
|
||||
param.requires_grad = False
|
||||
self.embedder.half()
|
||||
|
||||
def init_normalizers(self, normalize_config: OmegaConf,
|
||||
dataset_stats: Mapping[str, Any]) -> None:
|
||||
@@ -2175,6 +2179,7 @@ class LatentVisualDiffusion(LatentDiffusion):
|
||||
(random_num < 3 * self.uncond_prob).float(), "n -> n 1 1 1")
|
||||
|
||||
cond_img = input_mask * img
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
cond_img_emb = self.embedder(cond_img)
|
||||
cond_img_emb = self.image_proj_model(cond_img_emb)
|
||||
|
||||
@@ -2191,6 +2196,7 @@ class LatentVisualDiffusion(LatentDiffusion):
|
||||
repeat=z.shape[2])
|
||||
cond["c_concat"] = [img_cat_cond]
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
cond_action = self.action_projector(action)
|
||||
cond_action_emb = self.agent_action_pos_emb + cond_action
|
||||
# Get conditioning states
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
2026-02-11 16:14:08.942290: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-11 16:14:08.992267: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-11 16:14:08.992319: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-11 16:14:08.993621: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-11 16:14:09.001096: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
2026-02-11 16:32:03.555597: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-11 16:32:03.605506: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-11 16:32:03.605550: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-11 16:32:03.606879: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-11 16:32:03.614434: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-11 16:14:09.927986: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
2026-02-11 16:32:04.545234: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||
>>> Prepared model loaded.
|
||||
>>> Diffusion backbone (model.model) converted to FP16.
|
||||
>>> Projectors (image_proj_model, state_projector, action_projector) converted to FP16.
|
||||
>>> Encoders (cond_stage_model, embedder) converted to FP16.
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
@@ -32,7 +34,7 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
|
||||
0%| | 0/11 [00:00<?, ?it/s]
|
||||
9%|▉ | 1/11 [00:23<03:56, 23.68s/it]>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
@@ -85,7 +87,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
18%|█▊ | 2/11 [00:47<03:31, 23.51s/it]
|
||||
@@ -116,6 +118,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
|
||||
"psnr": 25.21894470816415
|
||||
"psnr": 27.185465604200047
|
||||
}
|
||||
Reference in New Issue
Block a user