打印推理权重精度信息

This commit is contained in:
2026-01-18 11:19:10 +08:00
parent c86c2be5ff
commit 7b499284bf
9 changed files with 256 additions and 143 deletions

View File

@@ -441,7 +441,7 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
return file_list
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
"""Load model weights from checkpoint file.
Args:
@@ -472,11 +472,43 @@ def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
for key in state_dict['module'].keys():
new_pl_sd[key[16:]] = state_dict['module'][key]
model.load_state_dict(new_pl_sd)
print('>>> model checkpoint loaded.')
return model
def is_inferenced(save_dir: str, filename: str) -> bool:
print('>>> model checkpoint loaded.')
return model
def _module_param_dtype(module: nn.Module | None) -> str:
if module is None:
return "None"
for param in module.parameters():
return str(param.dtype)
return "no_params"
def log_inference_precision(model: nn.Module) -> None:
try:
param = next(model.parameters())
device = str(param.device)
model_dtype = str(param.dtype)
except StopIteration:
device = "unknown"
model_dtype = "no_params"
print(f">>> inference precision: model={model_dtype}, device={device}")
for attr in [
"model", "first_stage_model", "cond_stage_model", "embedder",
"image_proj_model"
]:
if hasattr(model, attr):
submodule = getattr(model, attr)
print(f">>> {attr} param dtype: {_module_param_dtype(submodule)}")
print(
">>> autocast gpu dtype default: "
f"{torch.get_autocast_gpu_dtype()} "
f"(enabled={torch.is_autocast_enabled()})")
def is_inferenced(save_dir: str, filename: str) -> bool:
"""Check if a given filename has already been processed and saved.
Args:
@@ -853,11 +885,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
data.setup()
print(">>> Dataset is successfully loaded ...")
with profiler.profile_section("model_to_cuda"):
model = model.cuda(gpu_no)
device = get_device_from_parameters(model)
profiler.record_memory("after_model_load")
with profiler.profile_section("model_to_cuda"):
model = model.cuda(gpu_no)
device = get_device_from_parameters(model)
log_inference_precision(model)
profiler.record_memory("after_model_load")
# Run over data
assert (args.height % 16 == 0) and (