更改求解器 step=150时 成功率更高 step=125时 速度更快 成功率持平

This commit is contained in:
qihuanye
2026-05-04 07:55:13 +00:00
parent 4c3fdbcce6
commit cf43af0729
8 changed files with 558 additions and 3 deletions

View File

@@ -153,6 +153,12 @@ def get_inference_context(cfg, device):
)
def get_eval_grad_context(solver=None):
if isinstance(solver, swm.solver.GradientSolver):
return torch.enable_grad()
return torch.inference_mode()
def make_profiler(cfg, results_path):
profile_cfg = get_profile_cfg(cfg)
if not profile_cfg["enabled"]:
@@ -345,6 +351,7 @@ def run_eval_subset(
)
else:
policy = swm.policy.RandomPolicy()
solver = None
inference_ctx = nullcontext()
inference_precision = "fp32"
compile_cfg = get_compile_cfg(local_cfg)
@@ -357,7 +364,7 @@ def run_eval_subset(
if str(device).startswith("cuda") and torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.time()
with torch.inference_mode():
with get_eval_grad_context(solver):
with profiler_ctx as profiler:
with inference_ctx:
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):