remove profile
This commit is contained in:
@@ -9,12 +9,9 @@ import logging
|
|||||||
import einops
|
import einops
|
||||||
import warnings
|
import warnings
|
||||||
import imageio
|
import imageio
|
||||||
import time
|
|
||||||
import json
|
|
||||||
import atexit
|
import atexit
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import nullcontext
|
||||||
from dataclasses import dataclass, field, asdict
|
|
||||||
from typing import Optional, Dict, List, Any, Mapping
|
from typing import Optional, Dict, List, Any, Mapping
|
||||||
|
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
@@ -26,375 +23,12 @@ from torch import nn
|
|||||||
from eval_utils import populate_queues
|
from eval_utils import populate_queues
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
# ========== Profiling Infrastructure ==========
|
|
||||||
@dataclass
|
|
||||||
class TimingRecord:
|
|
||||||
"""Record for a single timing measurement."""
|
|
||||||
name: str
|
|
||||||
start_time: float = 0.0
|
|
||||||
end_time: float = 0.0
|
|
||||||
cuda_time_ms: float = 0.0
|
|
||||||
count: int = 0
|
|
||||||
children: List['TimingRecord'] = field(default_factory=list)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cpu_time_ms(self) -> float:
|
|
||||||
return (self.end_time - self.start_time) * 1000
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
'name': self.name,
|
|
||||||
'cpu_time_ms': self.cpu_time_ms,
|
|
||||||
'cuda_time_ms': self.cuda_time_ms,
|
|
||||||
'count': self.count,
|
|
||||||
'children': [c.to_dict() for c in self.children]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ProfilerManager:
|
|
||||||
"""Manages macro and micro-level profiling."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
enabled: bool = False,
|
|
||||||
output_dir: str = "./profile_output",
|
|
||||||
profile_detail: str = "light",
|
|
||||||
):
|
|
||||||
self.enabled = enabled
|
|
||||||
self.output_dir = output_dir
|
|
||||||
self.profile_detail = profile_detail
|
|
||||||
self.macro_timings: Dict[str, List[float]] = {}
|
|
||||||
self.cuda_events: Dict[str, List[tuple]] = {}
|
|
||||||
self.memory_snapshots: List[Dict] = []
|
|
||||||
self.pytorch_profiler = None
|
|
||||||
self.current_iteration = 0
|
|
||||||
self.operator_stats: Dict[str, Dict] = {}
|
|
||||||
self.profiler_config = self._build_profiler_config(profile_detail)
|
|
||||||
|
|
||||||
if enabled:
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
def _build_profiler_config(self, profile_detail: str) -> Dict[str, Any]:
|
|
||||||
"""Return profiler settings based on the requested detail level."""
|
|
||||||
if profile_detail not in ("light", "full"):
|
|
||||||
raise ValueError(f"Unsupported profile_detail: {profile_detail}")
|
|
||||||
if profile_detail == "full":
|
|
||||||
return {
|
|
||||||
"record_shapes": True,
|
|
||||||
"profile_memory": True,
|
|
||||||
"with_stack": True,
|
|
||||||
"with_flops": True,
|
|
||||||
"with_modules": True,
|
|
||||||
"group_by_input_shape": True,
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
"record_shapes": False,
|
|
||||||
"profile_memory": False,
|
|
||||||
"with_stack": False,
|
|
||||||
"with_flops": False,
|
|
||||||
"with_modules": False,
|
|
||||||
"group_by_input_shape": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def profile_section(self, name: str, sync_cuda: bool = True):
|
|
||||||
"""Context manager for profiling a code section."""
|
|
||||||
if not self.enabled:
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
if sync_cuda and torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
start_event = None
|
|
||||||
end_event = None
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
start_event.record()
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
if sync_cuda and torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
cpu_time_ms = (end_time - start_time) * 1000
|
|
||||||
|
|
||||||
cuda_time_ms = 0.0
|
|
||||||
if start_event is not None and end_event is not None:
|
|
||||||
end_event.record()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
cuda_time_ms = start_event.elapsed_time(end_event)
|
|
||||||
|
|
||||||
if name not in self.macro_timings:
|
|
||||||
self.macro_timings[name] = []
|
|
||||||
self.macro_timings[name].append(cpu_time_ms)
|
|
||||||
|
|
||||||
if name not in self.cuda_events:
|
|
||||||
self.cuda_events[name] = []
|
|
||||||
self.cuda_events[name].append((cpu_time_ms, cuda_time_ms))
|
|
||||||
|
|
||||||
def record_memory(self, tag: str = ""):
|
|
||||||
"""Record current GPU memory state."""
|
|
||||||
if not self.enabled or not torch.cuda.is_available():
|
|
||||||
return
|
|
||||||
|
|
||||||
snapshot = {
|
|
||||||
'tag': tag,
|
|
||||||
'iteration': self.current_iteration,
|
|
||||||
'allocated_mb': torch.cuda.memory_allocated() / 1024**2,
|
|
||||||
'reserved_mb': torch.cuda.memory_reserved() / 1024**2,
|
|
||||||
'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2,
|
|
||||||
}
|
|
||||||
self.memory_snapshots.append(snapshot)
|
|
||||||
|
|
||||||
def start_pytorch_profiler(self, wait: int = 1, warmup: int = 1, active: int = 3):
|
|
||||||
"""Start PyTorch profiler for operator-level analysis."""
|
|
||||||
if not self.enabled:
|
|
||||||
return nullcontext()
|
|
||||||
|
|
||||||
self.pytorch_profiler = torch.profiler.profile(
|
|
||||||
activities=[
|
|
||||||
torch.profiler.ProfilerActivity.CPU,
|
|
||||||
torch.profiler.ProfilerActivity.CUDA,
|
|
||||||
],
|
|
||||||
schedule=torch.profiler.schedule(
|
|
||||||
wait=wait, warmup=warmup, active=active, repeat=1
|
|
||||||
),
|
|
||||||
on_trace_ready=self._trace_handler,
|
|
||||||
record_shapes=self.profiler_config["record_shapes"],
|
|
||||||
profile_memory=self.profiler_config["profile_memory"],
|
|
||||||
with_stack=self.profiler_config["with_stack"],
|
|
||||||
with_flops=self.profiler_config["with_flops"],
|
|
||||||
with_modules=self.profiler_config["with_modules"],
|
|
||||||
)
|
|
||||||
return self.pytorch_profiler
|
|
||||||
|
|
||||||
def _trace_handler(self, prof):
|
|
||||||
"""Handle profiler trace output."""
|
|
||||||
trace_path = os.path.join(
|
|
||||||
self.output_dir,
|
|
||||||
f"trace_iter_{self.current_iteration}.json"
|
|
||||||
)
|
|
||||||
prof.export_chrome_trace(trace_path)
|
|
||||||
|
|
||||||
# Extract operator statistics
|
|
||||||
key_averages = prof.key_averages(
|
|
||||||
group_by_input_shape=self.profiler_config["group_by_input_shape"]
|
|
||||||
)
|
|
||||||
for evt in key_averages:
|
|
||||||
op_name = evt.key
|
|
||||||
if op_name not in self.operator_stats:
|
|
||||||
self.operator_stats[op_name] = {
|
|
||||||
'count': 0,
|
|
||||||
'cpu_time_total_us': 0,
|
|
||||||
'cuda_time_total_us': 0,
|
|
||||||
'self_cpu_time_total_us': 0,
|
|
||||||
'self_cuda_time_total_us': 0,
|
|
||||||
'cpu_memory_usage': 0,
|
|
||||||
'cuda_memory_usage': 0,
|
|
||||||
'flops': 0,
|
|
||||||
}
|
|
||||||
stats = self.operator_stats[op_name]
|
|
||||||
stats['count'] += evt.count
|
|
||||||
stats['cpu_time_total_us'] += evt.cpu_time_total
|
|
||||||
stats['cuda_time_total_us'] += evt.cuda_time_total
|
|
||||||
stats['self_cpu_time_total_us'] += evt.self_cpu_time_total
|
|
||||||
stats['self_cuda_time_total_us'] += evt.self_cuda_time_total
|
|
||||||
if hasattr(evt, 'cpu_memory_usage'):
|
|
||||||
stats['cpu_memory_usage'] += evt.cpu_memory_usage
|
|
||||||
if hasattr(evt, 'cuda_memory_usage'):
|
|
||||||
stats['cuda_memory_usage'] += evt.cuda_memory_usage
|
|
||||||
if hasattr(evt, 'flops') and evt.flops:
|
|
||||||
stats['flops'] += evt.flops
|
|
||||||
|
|
||||||
def step_profiler(self):
|
|
||||||
"""Step the PyTorch profiler."""
|
|
||||||
if self.pytorch_profiler is not None:
|
|
||||||
self.pytorch_profiler.step()
|
|
||||||
|
|
||||||
def generate_report(self) -> str:
|
|
||||||
"""Generate comprehensive profiling report."""
|
|
||||||
if not self.enabled:
|
|
||||||
return "Profiling disabled."
|
|
||||||
|
|
||||||
report_lines = []
|
|
||||||
report_lines.append("=" * 80)
|
|
||||||
report_lines.append("PERFORMANCE PROFILING REPORT")
|
|
||||||
report_lines.append("=" * 80)
|
|
||||||
report_lines.append("")
|
|
||||||
|
|
||||||
# Macro-level timing summary
|
|
||||||
report_lines.append("-" * 40)
|
|
||||||
report_lines.append("MACRO-LEVEL TIMING SUMMARY")
|
|
||||||
report_lines.append("-" * 40)
|
|
||||||
report_lines.append(f"{'Section':<40} {'Count':>8} {'Total(ms)':>12} {'Avg(ms)':>12} {'CUDA Avg(ms)':>14}")
|
|
||||||
report_lines.append("-" * 86)
|
|
||||||
|
|
||||||
total_time = 0
|
|
||||||
timing_data = []
|
|
||||||
for name, times in sorted(self.macro_timings.items()):
|
|
||||||
cuda_times = [ct for _, ct in self.cuda_events.get(name, [])]
|
|
||||||
avg_time = np.mean(times)
|
|
||||||
avg_cuda = np.mean(cuda_times) if cuda_times else 0
|
|
||||||
total = sum(times)
|
|
||||||
total_time += total
|
|
||||||
timing_data.append({
|
|
||||||
'name': name,
|
|
||||||
'count': len(times),
|
|
||||||
'total_ms': total,
|
|
||||||
'avg_ms': avg_time,
|
|
||||||
'cuda_avg_ms': avg_cuda,
|
|
||||||
'times': times,
|
|
||||||
'cuda_times': cuda_times,
|
|
||||||
})
|
|
||||||
report_lines.append(f"{name:<40} {len(times):>8} {total:>12.2f} {avg_time:>12.2f} {avg_cuda:>14.2f}")
|
|
||||||
|
|
||||||
report_lines.append("-" * 86)
|
|
||||||
report_lines.append(f"{'TOTAL':<40} {'':<8} {total_time:>12.2f}")
|
|
||||||
report_lines.append("")
|
|
||||||
|
|
||||||
# Memory summary
|
|
||||||
if self.memory_snapshots:
|
|
||||||
report_lines.append("-" * 40)
|
|
||||||
report_lines.append("GPU MEMORY SUMMARY")
|
|
||||||
report_lines.append("-" * 40)
|
|
||||||
max_alloc = max(s['max_allocated_mb'] for s in self.memory_snapshots)
|
|
||||||
avg_alloc = np.mean([s['allocated_mb'] for s in self.memory_snapshots])
|
|
||||||
report_lines.append(f"Peak allocated: {max_alloc:>10.2f} MB")
|
|
||||||
report_lines.append(f"Average allocated: {avg_alloc:>10.2f} MB")
|
|
||||||
report_lines.append("")
|
|
||||||
|
|
||||||
# Top operators by CUDA time
|
|
||||||
if self.operator_stats:
|
|
||||||
report_lines.append("-" * 40)
|
|
||||||
report_lines.append("TOP 30 OPERATORS BY CUDA TIME")
|
|
||||||
report_lines.append("-" * 40)
|
|
||||||
sorted_ops = sorted(
|
|
||||||
self.operator_stats.items(),
|
|
||||||
key=lambda x: x[1]['cuda_time_total_us'],
|
|
||||||
reverse=True
|
|
||||||
)[:30]
|
|
||||||
|
|
||||||
report_lines.append(f"{'Operator':<50} {'Count':>8} {'CUDA(ms)':>12} {'CPU(ms)':>12} {'Self CUDA(ms)':>14}")
|
|
||||||
report_lines.append("-" * 96)
|
|
||||||
|
|
||||||
for op_name, stats in sorted_ops:
|
|
||||||
# Truncate long operator names
|
|
||||||
display_name = op_name[:47] + "..." if len(op_name) > 50 else op_name
|
|
||||||
report_lines.append(
|
|
||||||
f"{display_name:<50} {stats['count']:>8} "
|
|
||||||
f"{stats['cuda_time_total_us']/1000:>12.2f} "
|
|
||||||
f"{stats['cpu_time_total_us']/1000:>12.2f} "
|
|
||||||
f"{stats['self_cuda_time_total_us']/1000:>14.2f}"
|
|
||||||
)
|
|
||||||
report_lines.append("")
|
|
||||||
|
|
||||||
# Compute category breakdown
|
|
||||||
report_lines.append("-" * 40)
|
|
||||||
report_lines.append("OPERATOR CATEGORY BREAKDOWN")
|
|
||||||
report_lines.append("-" * 40)
|
|
||||||
|
|
||||||
categories = {
|
|
||||||
'Attention': ['attention', 'softmax', 'bmm', 'baddbmm'],
|
|
||||||
'Convolution': ['conv', 'cudnn'],
|
|
||||||
'Normalization': ['norm', 'layer_norm', 'batch_norm', 'group_norm'],
|
|
||||||
'Activation': ['relu', 'gelu', 'silu', 'sigmoid', 'tanh'],
|
|
||||||
'Linear/GEMM': ['linear', 'addmm', 'mm', 'gemm'],
|
|
||||||
'Memory': ['copy', 'contiguous', 'view', 'reshape', 'permute', 'transpose'],
|
|
||||||
'Elementwise': ['add', 'mul', 'div', 'sub', 'pow', 'exp', 'sqrt'],
|
|
||||||
}
|
|
||||||
|
|
||||||
category_times = {cat: 0.0 for cat in categories}
|
|
||||||
category_times['Other'] = 0.0
|
|
||||||
|
|
||||||
for op_name, stats in self.operator_stats.items():
|
|
||||||
op_lower = op_name.lower()
|
|
||||||
categorized = False
|
|
||||||
for cat, keywords in categories.items():
|
|
||||||
if any(kw in op_lower for kw in keywords):
|
|
||||||
category_times[cat] += stats['cuda_time_total_us']
|
|
||||||
categorized = True
|
|
||||||
break
|
|
||||||
if not categorized:
|
|
||||||
category_times['Other'] += stats['cuda_time_total_us']
|
|
||||||
|
|
||||||
total_op_time = sum(category_times.values())
|
|
||||||
report_lines.append(f"{'Category':<30} {'CUDA Time(ms)':>15} {'Percentage':>12}")
|
|
||||||
report_lines.append("-" * 57)
|
|
||||||
for cat, time_us in sorted(category_times.items(), key=lambda x: -x[1]):
|
|
||||||
pct = (time_us / total_op_time * 100) if total_op_time > 0 else 0
|
|
||||||
report_lines.append(f"{cat:<30} {time_us/1000:>15.2f} {pct:>11.1f}%")
|
|
||||||
report_lines.append("")
|
|
||||||
|
|
||||||
report = "\n".join(report_lines)
|
|
||||||
return report
|
|
||||||
|
|
||||||
def save_results(self):
|
|
||||||
"""Save all profiling results to files."""
|
|
||||||
if not self.enabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Save report
|
|
||||||
report = self.generate_report()
|
|
||||||
report_path = os.path.join(self.output_dir, "profiling_report.txt")
|
|
||||||
with open(report_path, 'w') as f:
|
|
||||||
f.write(report)
|
|
||||||
print(f">>> Profiling report saved to: {report_path}")
|
|
||||||
|
|
||||||
# Save detailed JSON data
|
|
||||||
data = {
|
|
||||||
'macro_timings': {
|
|
||||||
name: {
|
|
||||||
'times': times,
|
|
||||||
'cuda_times': [ct for _, ct in self.cuda_events.get(name, [])]
|
|
||||||
}
|
|
||||||
for name, times in self.macro_timings.items()
|
|
||||||
},
|
|
||||||
'memory_snapshots': self.memory_snapshots,
|
|
||||||
'operator_stats': self.operator_stats,
|
|
||||||
}
|
|
||||||
json_path = os.path.join(self.output_dir, "profiling_data.json")
|
|
||||||
with open(json_path, 'w') as f:
|
|
||||||
json.dump(data, f, indent=2)
|
|
||||||
print(f">>> Detailed profiling data saved to: {json_path}")
|
|
||||||
|
|
||||||
# Print summary to console
|
|
||||||
print("\n" + report)
|
|
||||||
|
|
||||||
|
|
||||||
# Global profiler instance
|
|
||||||
_profiler: Optional[ProfilerManager] = None
|
|
||||||
|
|
||||||
def get_profiler() -> ProfilerManager:
|
|
||||||
"""Get the global profiler instance."""
|
|
||||||
global _profiler
|
|
||||||
if _profiler is None:
|
|
||||||
_profiler = ProfilerManager(enabled=False)
|
|
||||||
return _profiler
|
|
||||||
|
|
||||||
def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> ProfilerManager:
|
|
||||||
"""Initialize the global profiler."""
|
|
||||||
global _profiler
|
|
||||||
_profiler = ProfilerManager(
|
|
||||||
enabled=enabled,
|
|
||||||
output_dir=output_dir,
|
|
||||||
profile_detail=profile_detail,
|
|
||||||
)
|
|
||||||
return _profiler
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Async I/O ==========
|
# ========== Async I/O ==========
|
||||||
_io_executor: Optional[ThreadPoolExecutor] = None
|
_io_executor: Optional[ThreadPoolExecutor] = None
|
||||||
_io_futures: List[Any] = []
|
_io_futures: List[Any] = []
|
||||||
@@ -447,28 +81,6 @@ def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
|||||||
_io_futures.append(fut)
|
_io_futures.append(fut)
|
||||||
|
|
||||||
|
|
||||||
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
|
|
||||||
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
|
|
||||||
if video_cpu.dim() == 5:
|
|
||||||
n = video_cpu.shape[0]
|
|
||||||
video = video_cpu.permute(2, 0, 1, 3, 4)
|
|
||||||
frame_grids = [
|
|
||||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
|
||||||
for framesheet in video
|
|
||||||
]
|
|
||||||
grid = torch.stack(frame_grids, dim=0)
|
|
||||||
grid = (grid + 1.0) / 2.0
|
|
||||||
grid = grid.unsqueeze(dim=0)
|
|
||||||
writer.add_video(tag, grid, fps=fps)
|
|
||||||
|
|
||||||
|
|
||||||
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
|
|
||||||
"""Submit TensorBoard logging to background thread pool."""
|
|
||||||
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
|
||||||
data_cpu = data.detach().cpu()
|
|
||||||
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
|
|
||||||
_io_futures.append(fut)
|
|
||||||
|
|
||||||
|
|
||||||
# ========== Original Functions ==========
|
# ========== Original Functions ==========
|
||||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||||
@@ -590,9 +202,7 @@ def load_model_checkpoint(model: nn.Module,
|
|||||||
|
|
||||||
def maybe_cast_module(module: nn.Module | None,
|
def maybe_cast_module(module: nn.Module | None,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
label: str,
|
label: str) -> None:
|
||||||
profiler: Optional[ProfilerManager] = None,
|
|
||||||
profile_name: Optional[str] = None) -> None:
|
|
||||||
if module is None:
|
if module is None:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
@@ -603,10 +213,6 @@ def maybe_cast_module(module: nn.Module | None,
|
|||||||
if param.dtype == dtype:
|
if param.dtype == dtype:
|
||||||
print(f">>> {label} already {dtype}; skip cast")
|
print(f">>> {label} already {dtype}; skip cast")
|
||||||
return
|
return
|
||||||
ctx = nullcontext()
|
|
||||||
if profiler is not None and profile_name:
|
|
||||||
ctx = profiler.profile_section(profile_name)
|
|
||||||
with ctx:
|
|
||||||
module.to(dtype=dtype)
|
module.to(dtype=dtype)
|
||||||
print(f">>> {label} cast to {dtype}")
|
print(f">>> {label} cast to {dtype}")
|
||||||
|
|
||||||
@@ -825,8 +431,6 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
|
|||||||
Returns:
|
Returns:
|
||||||
Tensor: Latent video tensor of shape [B, C, T, H, W].
|
Tensor: Latent video tensor of shape [B, C, T, H, W].
|
||||||
"""
|
"""
|
||||||
profiler = get_profiler()
|
|
||||||
with profiler.profile_section("get_latent_z/encode"):
|
|
||||||
b, c, t, h, w = videos.shape
|
b, c, t, h, w = videos.shape
|
||||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||||
vae_ctx = nullcontext()
|
vae_ctx = nullcontext()
|
||||||
@@ -941,8 +545,6 @@ def image_guided_synthesis_sim_mode(
|
|||||||
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
||||||
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
||||||
"""
|
"""
|
||||||
profiler = get_profiler()
|
|
||||||
|
|
||||||
b, _, t, _, _ = noise_shape
|
b, _, t, _, _ = noise_shape
|
||||||
ddim_sampler = getattr(model, "_ddim_sampler", None)
|
ddim_sampler = getattr(model, "_ddim_sampler", None)
|
||||||
if ddim_sampler is None:
|
if ddim_sampler is None:
|
||||||
@@ -952,7 +554,6 @@ def image_guided_synthesis_sim_mode(
|
|||||||
|
|
||||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||||
|
|
||||||
with profiler.profile_section("synthesis/conditioning_prep"):
|
|
||||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||||
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
|
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
|
||||||
@@ -1038,7 +639,6 @@ def image_guided_synthesis_sim_mode(
|
|||||||
cond_z0 = None
|
cond_z0 = None
|
||||||
|
|
||||||
if ddim_sampler is not None:
|
if ddim_sampler is not None:
|
||||||
with profiler.profile_section("synthesis/ddim_sampling"):
|
|
||||||
autocast_ctx = nullcontext()
|
autocast_ctx = nullcontext()
|
||||||
if diffusion_autocast_dtype is not None and model.device.type == "cuda":
|
if diffusion_autocast_dtype is not None and model.device.type == "cuda":
|
||||||
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
|
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
|
||||||
@@ -1061,7 +661,6 @@ def image_guided_synthesis_sim_mode(
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
# Reconstruct from latent to pixel space
|
# Reconstruct from latent to pixel space
|
||||||
with profiler.profile_section("synthesis/decode_first_stage"):
|
|
||||||
if getattr(model, "vae_bf16", False):
|
if getattr(model, "vae_bf16", False):
|
||||||
if samples.dtype != torch.bfloat16:
|
if samples.dtype != torch.bfloat16:
|
||||||
samples = samples.to(dtype=torch.bfloat16)
|
samples = samples.to(dtype=torch.bfloat16)
|
||||||
@@ -1091,13 +690,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
profiler = get_profiler()
|
# Create inference dir
|
||||||
|
|
||||||
# Create inference and tensorboard dirs
|
|
||||||
os.makedirs(args.savedir + '/inference', exist_ok=True)
|
os.makedirs(args.savedir + '/inference', exist_ok=True)
|
||||||
log_dir = args.savedir + f"/tensorboard"
|
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
|
||||||
writer = SummaryWriter(log_dir=log_dir)
|
|
||||||
|
|
||||||
# Load prompt
|
# Load prompt
|
||||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||||
@@ -1110,7 +704,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
if os.path.exists(prepared_path):
|
if os.path.exists(prepared_path):
|
||||||
# ---- Fast path: load the fully-prepared model ----
|
# ---- Fast path: load the fully-prepared model ----
|
||||||
print(f">>> Loading prepared model from {prepared_path} ...")
|
print(f">>> Loading prepared model from {prepared_path} ...")
|
||||||
with profiler.profile_section("model_loading/prepared"):
|
|
||||||
model = torch.load(prepared_path,
|
model = torch.load(prepared_path,
|
||||||
map_location=f"cuda:{gpu_no}",
|
map_location=f"cuda:{gpu_no}",
|
||||||
weights_only=False,
|
weights_only=False,
|
||||||
@@ -1122,7 +715,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
print(f">>> Prepared model loaded.")
|
print(f">>> Prepared model loaded.")
|
||||||
else:
|
else:
|
||||||
# ---- Normal path: construct + checkpoint + casting ----
|
# ---- Normal path: construct + checkpoint + casting ----
|
||||||
with profiler.profile_section("model_loading/config"):
|
|
||||||
config['model']['params']['wma_config']['params'][
|
config['model']['params']['wma_config']['params'][
|
||||||
'use_checkpoint'] = False
|
'use_checkpoint'] = False
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
@@ -1130,7 +722,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||||
|
|
||||||
with profiler.profile_section("model_loading/checkpoint"):
|
|
||||||
model = load_model_checkpoint(model, args.ckpt_path,
|
model = load_model_checkpoint(model, args.ckpt_path,
|
||||||
device=f"cuda:{gpu_no}")
|
device=f"cuda:{gpu_no}")
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -1143,8 +734,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
model.model,
|
model.model,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
"diffusion backbone",
|
"diffusion backbone",
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/diffusion_bf16",
|
|
||||||
)
|
)
|
||||||
diffusion_autocast_dtype = torch.bfloat16
|
diffusion_autocast_dtype = torch.bfloat16
|
||||||
print(">>> diffusion backbone set to bfloat16")
|
print(">>> diffusion backbone set to bfloat16")
|
||||||
@@ -1155,8 +744,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
model.first_stage_model,
|
model.first_stage_model,
|
||||||
vae_weight_dtype,
|
vae_weight_dtype,
|
||||||
"VAE",
|
"VAE",
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/vae_cast",
|
|
||||||
)
|
)
|
||||||
model.vae_bf16 = args.vae_dtype == "bf16"
|
model.vae_bf16 = args.vae_dtype == "bf16"
|
||||||
print(f">>> VAE dtype set to {args.vae_dtype}")
|
print(f">>> VAE dtype set to {args.vae_dtype}")
|
||||||
@@ -1195,16 +782,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
model.cond_stage_model,
|
model.cond_stage_model,
|
||||||
encoder_weight_dtype,
|
encoder_weight_dtype,
|
||||||
"cond_stage_model",
|
"cond_stage_model",
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/encoder_cond_cast",
|
|
||||||
)
|
)
|
||||||
if hasattr(model, "embedder") and model.embedder is not None:
|
if hasattr(model, "embedder") and model.embedder is not None:
|
||||||
maybe_cast_module(
|
maybe_cast_module(
|
||||||
model.embedder,
|
model.embedder,
|
||||||
encoder_weight_dtype,
|
encoder_weight_dtype,
|
||||||
"embedder",
|
"embedder",
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/encoder_embedder_cast",
|
|
||||||
)
|
)
|
||||||
model.encoder_bf16 = encoder_bf16
|
model.encoder_bf16 = encoder_bf16
|
||||||
model.encoder_mode = encoder_mode
|
model.encoder_mode = encoder_mode
|
||||||
@@ -1220,24 +803,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
model.image_proj_model,
|
model.image_proj_model,
|
||||||
projector_weight_dtype,
|
projector_weight_dtype,
|
||||||
"image_proj_model",
|
"image_proj_model",
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/projector_image_cast",
|
|
||||||
)
|
)
|
||||||
if hasattr(model, "state_projector") and model.state_projector is not None:
|
if hasattr(model, "state_projector") and model.state_projector is not None:
|
||||||
maybe_cast_module(
|
maybe_cast_module(
|
||||||
model.state_projector,
|
model.state_projector,
|
||||||
projector_weight_dtype,
|
projector_weight_dtype,
|
||||||
"state_projector",
|
"state_projector",
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/projector_state_cast",
|
|
||||||
)
|
)
|
||||||
if hasattr(model, "action_projector") and model.action_projector is not None:
|
if hasattr(model, "action_projector") and model.action_projector is not None:
|
||||||
maybe_cast_module(
|
maybe_cast_module(
|
||||||
model.action_projector,
|
model.action_projector,
|
||||||
projector_weight_dtype,
|
projector_weight_dtype,
|
||||||
"action_projector",
|
"action_projector",
|
||||||
profiler=profiler,
|
|
||||||
profile_name="model_loading/projector_action_cast",
|
|
||||||
)
|
)
|
||||||
if hasattr(model, "projector_bf16"):
|
if hasattr(model, "projector_bf16"):
|
||||||
model.projector_bf16 = projector_bf16
|
model.projector_bf16 = projector_bf16
|
||||||
@@ -1269,14 +846,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
# Build normalizer (always needed, independent of model loading path)
|
# Build normalizer (always needed, independent of model loading path)
|
||||||
logging.info("***** Configing Data *****")
|
logging.info("***** Configing Data *****")
|
||||||
with profiler.profile_section("data_loading"):
|
|
||||||
data = instantiate_from_config(config.data)
|
data = instantiate_from_config(config.data)
|
||||||
data.setup()
|
data.setup()
|
||||||
print(">>> Dataset is successfully loaded ...")
|
print(">>> Dataset is successfully loaded ...")
|
||||||
device = get_device_from_parameters(model)
|
device = get_device_from_parameters(model)
|
||||||
|
|
||||||
profiler.record_memory("after_model_load")
|
|
||||||
|
|
||||||
# Run over data
|
# Run over data
|
||||||
assert (args.height % 16 == 0) and (
|
assert (args.height % 16 == 0) and (
|
||||||
args.width % 16
|
args.width % 16
|
||||||
@@ -1290,10 +864,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
print(f'>>> Generate {n_frames} frames under each generation ...')
|
print(f'>>> Generate {n_frames} frames under each generation ...')
|
||||||
noise_shape = [args.bs, channels, n_frames, h, w]
|
noise_shape = [args.bs, channels, n_frames, h, w]
|
||||||
|
|
||||||
# Determine profiler iterations
|
|
||||||
profile_active_iters = getattr(args, 'profile_iterations', 3)
|
|
||||||
use_pytorch_profiler = profiler.enabled and profile_active_iters > 0
|
|
||||||
|
|
||||||
# Start inference
|
# Start inference
|
||||||
for idx in range(0, len(df)):
|
for idx in range(0, len(df)):
|
||||||
sample = df.iloc[idx]
|
sample = df.iloc[idx]
|
||||||
@@ -1309,7 +879,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
# Load transitions to get the initial state later
|
# Load transitions to get the initial state later
|
||||||
transition_path = get_transition_path(args.prompt_dir, sample)
|
transition_path = get_transition_path(args.prompt_dir, sample)
|
||||||
with profiler.profile_section("load_transitions"):
|
|
||||||
with h5py.File(transition_path, 'r') as h5f:
|
with h5py.File(transition_path, 'r') as h5f:
|
||||||
transition_dict = {}
|
transition_dict = {}
|
||||||
for key in h5f.keys():
|
for key in h5f.keys():
|
||||||
@@ -1337,7 +906,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Obtain initial frame and state
|
# Obtain initial frame and state
|
||||||
with profiler.profile_section("prepare_init_input"):
|
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
model_input_fs = ori_fps // fs
|
model_input_fs = ori_fps // fs
|
||||||
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
||||||
@@ -1360,24 +928,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
# Update observation queues
|
# Update observation queues
|
||||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||||
|
|
||||||
# Setup PyTorch profiler context if enabled
|
|
||||||
pytorch_prof_ctx = nullcontext()
|
|
||||||
if use_pytorch_profiler:
|
|
||||||
pytorch_prof_ctx = profiler.start_pytorch_profiler(
|
|
||||||
wait=1, warmup=1, active=profile_active_iters
|
|
||||||
)
|
|
||||||
|
|
||||||
# Multi-round interaction with the world-model
|
# Multi-round interaction with the world-model
|
||||||
with pytorch_prof_ctx:
|
|
||||||
for itr in tqdm(range(args.n_iter)):
|
for itr in tqdm(range(args.n_iter)):
|
||||||
log_every = max(1, args.step_log_every)
|
log_every = max(1, args.step_log_every)
|
||||||
log_step = (itr % log_every == 0)
|
log_step = (itr % log_every == 0)
|
||||||
profiler.current_iteration = itr
|
|
||||||
profiler.record_memory(f"iter_{itr}_start")
|
|
||||||
|
|
||||||
with profiler.profile_section("iteration_total"):
|
|
||||||
# Get observation
|
# Get observation
|
||||||
with profiler.profile_section("prepare_observation"):
|
|
||||||
observation = {
|
observation = {
|
||||||
'observation.images.top':
|
'observation.images.top':
|
||||||
torch.stack(list(
|
torch.stack(list(
|
||||||
@@ -1394,7 +950,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
# Use world-model in policy to generate action
|
# Use world-model in policy to generate action
|
||||||
if log_step:
|
if log_step:
|
||||||
print(f'>>> Step {itr}: generating actions ...')
|
print(f'>>> Step {itr}: generating actions ...')
|
||||||
with profiler.profile_section("action_generation"):
|
|
||||||
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
||||||
model,
|
model,
|
||||||
sample['instruction'],
|
sample['instruction'],
|
||||||
@@ -1412,7 +967,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
||||||
|
|
||||||
# Update future actions in the observation queues
|
# Update future actions in the observation queues
|
||||||
with profiler.profile_section("update_action_queues"):
|
|
||||||
for act_idx in range(len(pred_actions[0])):
|
for act_idx in range(len(pred_actions[0])):
|
||||||
obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]}
|
obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]}
|
||||||
obs_update['action'][:, ori_action_dim:] = 0.0
|
obs_update['action'][:, ori_action_dim:] = 0.0
|
||||||
@@ -1420,7 +974,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
obs_update)
|
obs_update)
|
||||||
|
|
||||||
# Collect data for interacting the world-model using the predicted actions
|
# Collect data for interacting the world-model using the predicted actions
|
||||||
with profiler.profile_section("prepare_wm_observation"):
|
|
||||||
observation = {
|
observation = {
|
||||||
'observation.images.top':
|
'observation.images.top':
|
||||||
torch.stack(list(
|
torch.stack(list(
|
||||||
@@ -1437,7 +990,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
# Interaction with the world-model
|
# Interaction with the world-model
|
||||||
if log_step:
|
if log_step:
|
||||||
print(f'>>> Step {itr}: interacting with world model ...')
|
print(f'>>> Step {itr}: interacting with world model ...')
|
||||||
with profiler.profile_section("world_model_interaction"):
|
|
||||||
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
||||||
model,
|
model,
|
||||||
"",
|
"",
|
||||||
@@ -1454,7 +1006,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
guidance_rescale=args.guidance_rescale,
|
guidance_rescale=args.guidance_rescale,
|
||||||
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
diffusion_autocast_dtype=diffusion_autocast_dtype)
|
||||||
|
|
||||||
with profiler.profile_section("update_state_queues"):
|
|
||||||
for step_idx in range(args.exe_steps):
|
for step_idx in range(args.exe_steps):
|
||||||
obs_update = {
|
obs_update = {
|
||||||
'observation.images.top':
|
'observation.images.top':
|
||||||
@@ -1470,20 +1021,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
obs_update)
|
obs_update)
|
||||||
|
|
||||||
# Save the imagen videos for decision-making (async)
|
# Save the imagen videos for decision-making (async)
|
||||||
with profiler.profile_section("save_results"):
|
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
|
||||||
log_to_tensorboard_async(writer,
|
|
||||||
pred_videos_0,
|
|
||||||
sample_tag,
|
|
||||||
fps=args.save_fps)
|
|
||||||
# Save videos environment changes via world-model interaction
|
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
|
||||||
log_to_tensorboard_async(writer,
|
|
||||||
pred_videos_1,
|
|
||||||
sample_tag,
|
|
||||||
fps=args.save_fps)
|
|
||||||
|
|
||||||
# Save the imagen videos for decision-making
|
|
||||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||||
save_results_async(pred_videos_0,
|
save_results_async(pred_videos_0,
|
||||||
sample_video_file,
|
sample_video_file,
|
||||||
@@ -1498,24 +1035,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
# Collect the result of world-model interactions
|
# Collect the result of world-model interactions
|
||||||
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
|
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
|
||||||
|
|
||||||
profiler.record_memory(f"iter_{itr}_end")
|
|
||||||
profiler.step_profiler()
|
|
||||||
|
|
||||||
full_video = torch.cat(wm_video, dim=2)
|
full_video = torch.cat(wm_video, dim=2)
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
|
||||||
log_to_tensorboard_async(writer,
|
|
||||||
full_video,
|
|
||||||
sample_tag,
|
|
||||||
fps=args.save_fps)
|
|
||||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||||
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
||||||
|
|
||||||
# Wait for all async I/O to complete before profiling report
|
# Wait for all async I/O to complete
|
||||||
_flush_io()
|
_flush_io()
|
||||||
|
|
||||||
# Save profiling results
|
|
||||||
profiler.save_results()
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@@ -1704,32 +1230,6 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="fps for the saving video")
|
help="fps for the saving video")
|
||||||
# Profiling arguments
|
|
||||||
parser.add_argument(
|
|
||||||
"--profile",
|
|
||||||
action='store_true',
|
|
||||||
default=False,
|
|
||||||
help="Enable performance profiling (macro and operator-level analysis)."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--profile_output_dir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Directory to save profiling results. Defaults to {savedir}/profile_output."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--profile_iterations",
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--profile_detail",
|
|
||||||
type=str,
|
|
||||||
choices=["light", "full"],
|
|
||||||
default="light",
|
|
||||||
help="Profiling detail level. Use 'full' for shapes/stacks/memory/flops."
|
|
||||||
)
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@@ -1741,15 +1241,5 @@ if __name__ == '__main__':
|
|||||||
seed = random.randint(0, 2**31)
|
seed = random.randint(0, 2**31)
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
|
||||||
# Initialize profiler
|
|
||||||
profile_output_dir = args.profile_output_dir
|
|
||||||
if profile_output_dir is None:
|
|
||||||
profile_output_dir = os.path.join(args.savedir, "profile_output")
|
|
||||||
init_profiler(
|
|
||||||
enabled=args.profile,
|
|
||||||
output_dir=profile_output_dir,
|
|
||||||
profile_detail=args.profile_detail,
|
|
||||||
)
|
|
||||||
|
|
||||||
rank, gpu_num = 0, 1
|
rank, gpu_num = 0, 1
|
||||||
run_inference(args, gpu_num, rank)
|
run_inference(args, gpu_num, rank)
|
||||||
|
|||||||
Reference in New Issue
Block a user