复制模型对象,跳过加载模型
This commit is contained in:
@@ -1026,9 +1026,26 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||||
df = pd.read_csv(csv_path)
|
df = pd.read_csv(csv_path)
|
||||||
|
|
||||||
# Load config
|
# Load config (always needed for data setup)
|
||||||
with profiler.profile_section("model_loading/config"):
|
|
||||||
config = OmegaConf.load(args.config)
|
config = OmegaConf.load(args.config)
|
||||||
|
|
||||||
|
prepared_path = args.ckpt_path + ".prepared.pt"
|
||||||
|
if os.path.exists(prepared_path):
|
||||||
|
# ---- Fast path: load the fully-prepared model ----
|
||||||
|
print(f">>> Loading prepared model from {prepared_path} ...")
|
||||||
|
with profiler.profile_section("model_loading/prepared"):
|
||||||
|
model = torch.load(prepared_path,
|
||||||
|
map_location=f"cuda:{gpu_no}",
|
||||||
|
weights_only=False,
|
||||||
|
mmap=True)
|
||||||
|
model.eval()
|
||||||
|
diffusion_autocast_dtype = (torch.bfloat16
|
||||||
|
if args.diffusion_dtype == "bf16"
|
||||||
|
else None)
|
||||||
|
print(f">>> Prepared model loaded.")
|
||||||
|
else:
|
||||||
|
# ---- Normal path: construct + checkpoint + casting ----
|
||||||
|
with profiler.profile_section("model_loading/config"):
|
||||||
config['model']['params']['wma_config']['params'][
|
config['model']['params']['wma_config']['params'][
|
||||||
'use_checkpoint'] = False
|
'use_checkpoint'] = False
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
@@ -1043,14 +1060,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
model = model.cuda(gpu_no) # move residual buffers not in state_dict
|
model = model.cuda(gpu_no) # move residual buffers not in state_dict
|
||||||
print(f'>>> Load pre-trained model ...')
|
print(f'>>> Load pre-trained model ...')
|
||||||
|
|
||||||
# Build unnomalizer
|
|
||||||
logging.info("***** Configing Data *****")
|
|
||||||
with profiler.profile_section("data_loading"):
|
|
||||||
data = instantiate_from_config(config.data)
|
|
||||||
data.setup()
|
|
||||||
print(">>> Dataset is successfully loaded ...")
|
|
||||||
device = get_device_from_parameters(model)
|
|
||||||
|
|
||||||
diffusion_autocast_dtype = None
|
diffusion_autocast_dtype = None
|
||||||
if args.diffusion_dtype == "bf16":
|
if args.diffusion_dtype == "bf16":
|
||||||
maybe_cast_module(
|
maybe_cast_module(
|
||||||
@@ -1175,6 +1184,20 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
print(">>> export_only set; skipping inference.")
|
print(">>> export_only set; skipping inference.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Save prepared model for fast loading next time
|
||||||
|
if prepared_path:
|
||||||
|
print(f">>> Saving prepared model to {prepared_path} ...")
|
||||||
|
torch.save(model, prepared_path)
|
||||||
|
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
||||||
|
|
||||||
|
# Build normalizer (always needed, independent of model loading path)
|
||||||
|
logging.info("***** Configing Data *****")
|
||||||
|
with profiler.profile_section("data_loading"):
|
||||||
|
data = instantiate_from_config(config.data)
|
||||||
|
data.setup()
|
||||||
|
print(">>> Dataset is successfully loaded ...")
|
||||||
|
device = get_device_from_parameters(model)
|
||||||
|
|
||||||
profiler.record_memory("after_model_load")
|
profiler.record_memory("after_model_load")
|
||||||
|
|
||||||
# Run over data
|
# Run over data
|
||||||
|
|||||||
Reference in New Issue
Block a user