增加脚本
This commit is contained in:
@@ -24,6 +24,7 @@ dataset:
|
|||||||
|
|
||||||
seed: 42
|
seed: 42
|
||||||
policy: random # ckpt name or random
|
policy: random # ckpt name or random
|
||||||
|
inference_precision: fp16
|
||||||
|
|
||||||
plan_config:
|
plan_config:
|
||||||
horizon: 5
|
horizon: 5
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ dataset:
|
|||||||
|
|
||||||
seed: 42
|
seed: 42
|
||||||
policy: random # ckpt name or random
|
policy: random # ckpt name or random
|
||||||
|
inference_precision: fp16
|
||||||
|
|
||||||
plan_config:
|
plan_config:
|
||||||
horizon: 5
|
horizon: 5
|
||||||
|
|||||||
131
scripts/convert_hf_checkpoint.py
Normal file
131
scripts/convert_hf_checkpoint.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""Convert LeWM HuggingFace weights into eval-compatible object checkpoints."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import stable_pretraining as spt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
from jepa import JEPA
|
||||||
|
from module import ARPredictor, Embedder, MLP
|
||||||
|
|
||||||
|
|
||||||
|
def _load_json(path: Path) -> dict:
|
||||||
|
with path.open("r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_target(config: dict) -> dict:
|
||||||
|
return {key: value for key, value in config.items() if key != "_target_"}
|
||||||
|
|
||||||
|
|
||||||
|
def infer_config_from_state_dict(state_dict: dict) -> dict:
|
||||||
|
action_dim = state_dict["action_encoder.patch_embed.weight"].shape[1]
|
||||||
|
return {
|
||||||
|
"encoder": {
|
||||||
|
"size": "tiny",
|
||||||
|
"patch_size": 14,
|
||||||
|
"image_size": 224,
|
||||||
|
"pretrained": False,
|
||||||
|
"use_mask_token": False,
|
||||||
|
},
|
||||||
|
"predictor": {
|
||||||
|
"num_frames": 3,
|
||||||
|
"input_dim": 192,
|
||||||
|
"hidden_dim": 192,
|
||||||
|
"output_dim": 192,
|
||||||
|
"depth": 6,
|
||||||
|
"heads": 16,
|
||||||
|
"mlp_dim": 2048,
|
||||||
|
"dim_head": 64,
|
||||||
|
"dropout": 0.1,
|
||||||
|
"emb_dropout": 0.0,
|
||||||
|
},
|
||||||
|
"action_encoder": {
|
||||||
|
"input_dim": action_dim,
|
||||||
|
"emb_dim": 192,
|
||||||
|
},
|
||||||
|
"projector": {
|
||||||
|
"input_dim": 192,
|
||||||
|
"output_dim": 192,
|
||||||
|
"hidden_dim": 2048,
|
||||||
|
},
|
||||||
|
"pred_proj": {
|
||||||
|
"input_dim": 192,
|
||||||
|
"output_dim": 192,
|
||||||
|
"hidden_dim": 2048,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(config: dict) -> JEPA:
|
||||||
|
encoder = spt.backbone.utils.vit_hf(**_strip_target(config["encoder"]))
|
||||||
|
predictor = ARPredictor(**_strip_target(config["predictor"]))
|
||||||
|
action_encoder = Embedder(**_strip_target(config["action_encoder"]))
|
||||||
|
|
||||||
|
projector_cfg = _strip_target(config["projector"])
|
||||||
|
projector_cfg["norm_fn"] = torch.nn.BatchNorm1d
|
||||||
|
projector = MLP(**projector_cfg)
|
||||||
|
|
||||||
|
pred_proj_cfg = _strip_target(config["pred_proj"])
|
||||||
|
pred_proj_cfg["norm_fn"] = torch.nn.BatchNorm1d
|
||||||
|
pred_proj = MLP(**pred_proj_cfg)
|
||||||
|
|
||||||
|
return JEPA(
|
||||||
|
encoder=encoder,
|
||||||
|
predictor=predictor,
|
||||||
|
action_encoder=action_encoder,
|
||||||
|
projector=projector,
|
||||||
|
pred_proj=pred_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_checkpoint(input_dir: Path, output_name: str) -> tuple[Path, Path]:
|
||||||
|
config_path = input_dir / "config.json"
|
||||||
|
weights_path = input_dir / "weights.pt"
|
||||||
|
if not weights_path.exists():
|
||||||
|
raise FileNotFoundError(f"Missing weights file: {weights_path}")
|
||||||
|
|
||||||
|
state_dict = torch.load(weights_path, map_location="cpu")
|
||||||
|
config = _load_json(config_path) if config_path.exists() else infer_config_from_state_dict(state_dict)
|
||||||
|
model = build_model(config)
|
||||||
|
missing, unexpected = model.load_state_dict(state_dict, strict=True)
|
||||||
|
if missing or unexpected:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"State dict mismatch: missing={missing}, unexpected={unexpected}"
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
object_path = input_dir / f"{output_name}_object.ckpt"
|
||||||
|
weight_path = input_dir / f"{output_name}_weight.ckpt"
|
||||||
|
torch.save(model, object_path)
|
||||||
|
torch.save(model.state_dict(), weight_path)
|
||||||
|
return object_path, weight_path
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"input_dir",
|
||||||
|
type=Path,
|
||||||
|
help="Directory containing weights.pt and optionally config.json.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--output-name", default="lewm")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
object_path, weight_path = convert_checkpoint(args.input_dir, args.output_name)
|
||||||
|
print(f"wrote {object_path}")
|
||||||
|
print(f"wrote {weight_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
97
scripts/warmup_eval.sh
Executable file
97
scripts/warmup_eval.sh
Executable file
@@ -0,0 +1,97 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Warm up LeWM evaluation before a formal run.
|
||||||
|
#
|
||||||
|
# This script intentionally does a small eval for each task so ROCm/PyTorch can
|
||||||
|
# initialize GPU contexts, compile predictor graphs, populate kernel caches, and
|
||||||
|
# touch dataset/checkpoint paths before the timed run.
|
||||||
|
#
|
||||||
|
# Site-specific things to check before using this at the competition:
|
||||||
|
# 1. STABLEWM_HOME points to the directory containing datasets/checkpoints.
|
||||||
|
# 2. The policy names below match the checkpoint folders at STABLEWM_HOME.
|
||||||
|
# 3. The dataset names in config/eval/*.yaml match the onsite dataset files.
|
||||||
|
# 4. The GPU visibility variables match the GPUs allocated to this job.
|
||||||
|
# 5. WARMUP_NUM_EVAL is close enough to the formal shape to trigger useful
|
||||||
|
# compilation, but small enough not to waste much time.
|
||||||
|
|
||||||
|
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||||
|
cd "${REPO_ROOT}"
|
||||||
|
|
||||||
|
PYTHON_BIN="${PYTHON_BIN:-${REPO_ROOT}/.venv/bin/python}"
|
||||||
|
STABLEWM_HOME="${STABLEWM_HOME:-/mnt/ASC1637/stablewm}"
|
||||||
|
export STABLEWM_HOME
|
||||||
|
|
||||||
|
# If Slurm allocates multiple GPUs, set these to the allocated physical GPU ids.
|
||||||
|
# Example for physical GPU 2 and 3:
|
||||||
|
# ROCR_VISIBLE_DEVICES=2,3 HIP_VISIBLE_DEVICES=0,1 CUDA_VISIBLE_DEVICES=0,1
|
||||||
|
#
|
||||||
|
# Important ROCm detail:
|
||||||
|
# ROCR_VISIBLE_DEVICES uses physical ids.
|
||||||
|
# HIP_VISIBLE_DEVICES/CUDA_VISIBLE_DEVICES use ids after ROCR remapping.
|
||||||
|
export ROCR_VISIBLE_DEVICES="${ROCR_VISIBLE_DEVICES:-0}"
|
||||||
|
export HIP_VISIBLE_DEVICES="${HIP_VISIBLE_DEVICES:-0}"
|
||||||
|
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}"
|
||||||
|
|
||||||
|
WARMUP_NUM_EVAL="${WARMUP_NUM_EVAL:-10}"
|
||||||
|
INFERENCE_PRECISION="${INFERENCE_PRECISION:-fp16}"
|
||||||
|
OUTPUT_DIR="${OUTPUT_DIR:-/tmp/lewm_warmup}"
|
||||||
|
mkdir -p "${OUTPUT_DIR}"
|
||||||
|
|
||||||
|
# Enable multi-GPU warmup by setting MULTI_GPU=1.
|
||||||
|
# MULTI_GPU_DEVICES are process-local ids, not physical ids after ROCR remapping.
|
||||||
|
# Example:
|
||||||
|
# ROCR_VISIBLE_DEVICES=2,3 HIP_VISIBLE_DEVICES=0,1 MULTI_GPU=1 MULTI_GPU_DEVICES='[0,1]'
|
||||||
|
MULTI_GPU="${MULTI_GPU:-0}"
|
||||||
|
MULTI_GPU_DEVICES="${MULTI_GPU_DEVICES:-[0,1]}"
|
||||||
|
|
||||||
|
COMMON_ARGS=(
|
||||||
|
"eval.num_eval=${WARMUP_NUM_EVAL}"
|
||||||
|
"inference_precision=${INFERENCE_PRECISION}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if [[ "${MULTI_GPU}" == "1" ]]; then
|
||||||
|
COMMON_ARGS+=(
|
||||||
|
"+multi_gpu.enabled=true"
|
||||||
|
"+multi_gpu.devices=${MULTI_GPU_DEVICES}"
|
||||||
|
)
|
||||||
|
fi
|
||||||
|
|
||||||
|
run_warmup() {
|
||||||
|
local config_name="$1"
|
||||||
|
local policy="$2"
|
||||||
|
local output_name="$3"
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "== Warmup ${config_name} policy=${policy} =="
|
||||||
|
"${PYTHON_BIN}" eval.py \
|
||||||
|
"--config-name=${config_name}" \
|
||||||
|
"policy=${policy}" \
|
||||||
|
"output.filename=${OUTPUT_DIR}/${output_name}" \
|
||||||
|
"${COMMON_ARGS[@]}"
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "LeWM warmup"
|
||||||
|
echo " repo: ${REPO_ROOT}"
|
||||||
|
echo " python: ${PYTHON_BIN}"
|
||||||
|
echo " STABLEWM_HOME: ${STABLEWM_HOME}"
|
||||||
|
echo " ROCR_VISIBLE_DEVICES: ${ROCR_VISIBLE_DEVICES}"
|
||||||
|
echo " HIP_VISIBLE_DEVICES: ${HIP_VISIBLE_DEVICES}"
|
||||||
|
echo " CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES}"
|
||||||
|
echo " WARMUP_NUM_EVAL: ${WARMUP_NUM_EVAL}"
|
||||||
|
echo " INFERENCE_PRECISION: ${INFERENCE_PRECISION}"
|
||||||
|
echo " MULTI_GPU: ${MULTI_GPU}"
|
||||||
|
if [[ "${MULTI_GPU}" == "1" ]]; then
|
||||||
|
echo " MULTI_GPU_DEVICES: ${MULTI_GPU_DEVICES}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Defaults match the checkpoint names used in this repo. If onsite checkpoint
|
||||||
|
# folders differ, override by editing these calls or passing the equivalent
|
||||||
|
# eval.py command manually.
|
||||||
|
run_warmup "pusht.yaml" "pusht/lewm" "warmup_pusht.txt"
|
||||||
|
run_warmup "reacher.yaml" "reacher/lewm" "warmup_reacher.txt"
|
||||||
|
run_warmup "cube.yaml" "cube/lewm" "warmup_cube.txt"
|
||||||
|
run_warmup "tworoom.yaml" "tworoom/lewm" "warmup_tworoom.txt"
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "Warmup complete. Logs were appended under ${OUTPUT_DIR}."
|
||||||
Reference in New Issue
Block a user