打印推理权重精度信息
This commit is contained in:
@@ -441,7 +441,7 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
||||
return file_list
|
||||
|
||||
|
||||
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
|
||||
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
|
||||
"""Load model weights from checkpoint file.
|
||||
|
||||
Args:
|
||||
@@ -472,11 +472,43 @@ def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
|
||||
for key in state_dict['module'].keys():
|
||||
new_pl_sd[key[16:]] = state_dict['module'][key]
|
||||
model.load_state_dict(new_pl_sd)
|
||||
print('>>> model checkpoint loaded.')
|
||||
return model
|
||||
|
||||
|
||||
def is_inferenced(save_dir: str, filename: str) -> bool:
|
||||
print('>>> model checkpoint loaded.')
|
||||
return model
|
||||
|
||||
|
||||
def _module_param_dtype(module: nn.Module | None) -> str:
|
||||
if module is None:
|
||||
return "None"
|
||||
for param in module.parameters():
|
||||
return str(param.dtype)
|
||||
return "no_params"
|
||||
|
||||
|
||||
def log_inference_precision(model: nn.Module) -> None:
|
||||
try:
|
||||
param = next(model.parameters())
|
||||
device = str(param.device)
|
||||
model_dtype = str(param.dtype)
|
||||
except StopIteration:
|
||||
device = "unknown"
|
||||
model_dtype = "no_params"
|
||||
|
||||
print(f">>> inference precision: model={model_dtype}, device={device}")
|
||||
for attr in [
|
||||
"model", "first_stage_model", "cond_stage_model", "embedder",
|
||||
"image_proj_model"
|
||||
]:
|
||||
if hasattr(model, attr):
|
||||
submodule = getattr(model, attr)
|
||||
print(f">>> {attr} param dtype: {_module_param_dtype(submodule)}")
|
||||
|
||||
print(
|
||||
">>> autocast gpu dtype default: "
|
||||
f"{torch.get_autocast_gpu_dtype()} "
|
||||
f"(enabled={torch.is_autocast_enabled()})")
|
||||
|
||||
|
||||
def is_inferenced(save_dir: str, filename: str) -> bool:
|
||||
"""Check if a given filename has already been processed and saved.
|
||||
|
||||
Args:
|
||||
@@ -853,11 +885,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
data.setup()
|
||||
print(">>> Dataset is successfully loaded ...")
|
||||
|
||||
with profiler.profile_section("model_to_cuda"):
|
||||
model = model.cuda(gpu_no)
|
||||
device = get_device_from_parameters(model)
|
||||
|
||||
profiler.record_memory("after_model_load")
|
||||
with profiler.profile_section("model_to_cuda"):
|
||||
model = model.cuda(gpu_no)
|
||||
device = get_device_from_parameters(model)
|
||||
|
||||
log_inference_precision(model)
|
||||
|
||||
profiler.record_memory("after_model_load")
|
||||
|
||||
# Run over data
|
||||
assert (args.height % 16 == 0) and (
|
||||
|
||||
Reference in New Issue
Block a user