更改求解器 step=150时 成功率更高 step=125时 速度更快 成功率持平
This commit is contained in:
9
eval.py
9
eval.py
@@ -153,6 +153,12 @@ def get_inference_context(cfg, device):
|
||||
)
|
||||
|
||||
|
||||
def get_eval_grad_context(solver=None):
|
||||
if isinstance(solver, swm.solver.GradientSolver):
|
||||
return torch.enable_grad()
|
||||
return torch.inference_mode()
|
||||
|
||||
|
||||
def make_profiler(cfg, results_path):
|
||||
profile_cfg = get_profile_cfg(cfg)
|
||||
if not profile_cfg["enabled"]:
|
||||
@@ -345,6 +351,7 @@ def run_eval_subset(
|
||||
)
|
||||
else:
|
||||
policy = swm.policy.RandomPolicy()
|
||||
solver = None
|
||||
inference_ctx = nullcontext()
|
||||
inference_precision = "fp32"
|
||||
compile_cfg = get_compile_cfg(local_cfg)
|
||||
@@ -357,7 +364,7 @@ def run_eval_subset(
|
||||
if str(device).startswith("cuda") and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
with torch.inference_mode():
|
||||
with get_eval_grad_context(solver):
|
||||
with profiler_ctx as profiler:
|
||||
with inference_ctx:
|
||||
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
|
||||
|
||||
Reference in New Issue
Block a user