性能剖析

This commit is contained in:
2026-01-18 00:31:39 +08:00
parent 25c6fc04db
commit c86c2be5ff
26 changed files with 272 additions and 54 deletions

View File

@@ -56,21 +56,50 @@ class TimingRecord:
}
class ProfilerManager:
"""Manages macro and micro-level profiling."""
def __init__(self, enabled: bool = False, output_dir: str = "./profile_output"):
self.enabled = enabled
self.output_dir = output_dir
self.macro_timings: Dict[str, List[float]] = {}
self.cuda_events: Dict[str, List[tuple]] = {}
self.memory_snapshots: List[Dict] = []
self.pytorch_profiler = None
self.current_iteration = 0
self.operator_stats: Dict[str, Dict] = {}
if enabled:
os.makedirs(output_dir, exist_ok=True)
class ProfilerManager:
"""Manages macro and micro-level profiling."""
def __init__(
self,
enabled: bool = False,
output_dir: str = "./profile_output",
profile_detail: str = "light",
):
self.enabled = enabled
self.output_dir = output_dir
self.profile_detail = profile_detail
self.macro_timings: Dict[str, List[float]] = {}
self.cuda_events: Dict[str, List[tuple]] = {}
self.memory_snapshots: List[Dict] = []
self.pytorch_profiler = None
self.current_iteration = 0
self.operator_stats: Dict[str, Dict] = {}
self.profiler_config = self._build_profiler_config(profile_detail)
if enabled:
os.makedirs(output_dir, exist_ok=True)
def _build_profiler_config(self, profile_detail: str) -> Dict[str, Any]:
"""Return profiler settings based on the requested detail level."""
if profile_detail not in ("light", "full"):
raise ValueError(f"Unsupported profile_detail: {profile_detail}")
if profile_detail == "full":
return {
"record_shapes": True,
"profile_memory": True,
"with_stack": True,
"with_flops": True,
"with_modules": True,
"group_by_input_shape": True,
}
return {
"record_shapes": False,
"profile_memory": False,
"with_stack": False,
"with_flops": False,
"with_modules": False,
"group_by_input_shape": False,
}
@contextmanager
def profile_section(self, name: str, sync_cuda: bool = True):
@@ -133,22 +162,22 @@ class ProfilerManager:
if not self.enabled:
return nullcontext()
self.pytorch_profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=wait, warmup=warmup, active=active, repeat=1
),
on_trace_ready=self._trace_handler,
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
with_modules=True,
)
return self.pytorch_profiler
self.pytorch_profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=wait, warmup=warmup, active=active, repeat=1
),
on_trace_ready=self._trace_handler,
record_shapes=self.profiler_config["record_shapes"],
profile_memory=self.profiler_config["profile_memory"],
with_stack=self.profiler_config["with_stack"],
with_flops=self.profiler_config["with_flops"],
with_modules=self.profiler_config["with_modules"],
)
return self.pytorch_profiler
def _trace_handler(self, prof):
"""Handle profiler trace output."""
@@ -158,8 +187,10 @@ class ProfilerManager:
)
prof.export_chrome_trace(trace_path)
# Extract operator statistics
key_averages = prof.key_averages(group_by_input_shape=True)
# Extract operator statistics
key_averages = prof.key_averages(
group_by_input_shape=self.profiler_config["group_by_input_shape"]
)
for evt in key_averages:
op_name = evt.key
if op_name not in self.operator_stats:
@@ -344,18 +375,22 @@ class ProfilerManager:
# Global profiler instance
_profiler: Optional[ProfilerManager] = None
def get_profiler() -> ProfilerManager:
"""Get the global profiler instance."""
global _profiler
if _profiler is None:
_profiler = ProfilerManager(enabled=False)
return _profiler
def init_profiler(enabled: bool, output_dir: str) -> ProfilerManager:
"""Initialize the global profiler."""
global _profiler
_profiler = ProfilerManager(enabled=enabled, output_dir=output_dir)
return _profiler
def get_profiler() -> ProfilerManager:
"""Get the global profiler instance."""
global _profiler
if _profiler is None:
_profiler = ProfilerManager(enabled=False)
return _profiler
def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> ProfilerManager:
"""Initialize the global profiler."""
global _profiler
_profiler = ProfilerManager(
enabled=enabled,
output_dir=output_dir,
profile_detail=profile_detail,
)
return _profiler
# ========== Original Functions ==========
@@ -1193,13 +1228,20 @@ def get_parser():
default=None,
help="Directory to save profiling results. Defaults to {savedir}/profile_output."
)
parser.add_argument(
"--profile_iterations",
type=int,
default=3,
help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis."
)
return parser
parser.add_argument(
"--profile_iterations",
type=int,
default=3,
help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis."
)
parser.add_argument(
"--profile_detail",
type=str,
choices=["light", "full"],
default="light",
help="Profiling detail level. Use 'full' for shapes/stacks/memory/flops."
)
return parser
if __name__ == '__main__':
@@ -1214,7 +1256,11 @@ if __name__ == '__main__':
profile_output_dir = args.profile_output_dir
if profile_output_dir is None:
profile_output_dir = os.path.join(args.savedir, "profile_output")
init_profiler(enabled=args.profile, output_dir=profile_output_dir)
init_profiler(
enabled=args.profile,
output_dir=profile_output_dir,
profile_detail=args.profile_detail,
)
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)