add profile frame and bf15/fp16 switch

This commit is contained in:
qihuanye
2026-03-31 11:09:02 +00:00
parent ca231f9f9d
commit 8b84251eb9
4 changed files with 249 additions and 88 deletions

View File

@@ -84,6 +84,33 @@ python eval.py --config-name=pusht.yaml policy=pusht/lewm
python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt
```
## Profiling
`eval.py` now supports optional inference profiling with PyTorch's native profiler.
Example:
```bash
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
inference_precision=bf16 \
+profile.enabled=true \
+profile.with_stack=true \
+profile.record_shapes=true \
+profile.profile_memory=true
```
Supported inference precision modes:
- `inference_precision=fp32`
- `inference_precision=bf16`
- `inference_precision=fp16`
Outputs are written next to the evaluation results:
- `torch_profile/key_averages.txt` for the aggregated operator table
- `torch_profile/trace.json` for Chrome tracing
- TensorBoard trace files under `torch_profile/`
The trace includes custom regions such as `eval.world_evaluate_from_dataset`, `lewm.get_cost`, `lewm.rollout`, and `lewm.predict` to make the planning path easier to inspect.
## Pretrained Checkpoints
Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`.