#!/usr/bin/env python """Convert LeWM HuggingFace weights into eval-compatible object checkpoints.""" from __future__ import annotations import argparse import json import sys from pathlib import Path import stable_pretraining as spt import torch REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from jepa import JEPA from module import ARPredictor, Embedder, MLP def _load_json(path: Path) -> dict: with path.open("r", encoding="utf-8") as f: return json.load(f) def _strip_target(config: dict) -> dict: return {key: value for key, value in config.items() if key != "_target_"} def infer_config_from_state_dict(state_dict: dict) -> dict: action_dim = state_dict["action_encoder.patch_embed.weight"].shape[1] return { "encoder": { "size": "tiny", "patch_size": 14, "image_size": 224, "pretrained": False, "use_mask_token": False, }, "predictor": { "num_frames": 3, "input_dim": 192, "hidden_dim": 192, "output_dim": 192, "depth": 6, "heads": 16, "mlp_dim": 2048, "dim_head": 64, "dropout": 0.1, "emb_dropout": 0.0, }, "action_encoder": { "input_dim": action_dim, "emb_dim": 192, }, "projector": { "input_dim": 192, "output_dim": 192, "hidden_dim": 2048, }, "pred_proj": { "input_dim": 192, "output_dim": 192, "hidden_dim": 2048, }, } def build_model(config: dict) -> JEPA: encoder = spt.backbone.utils.vit_hf(**_strip_target(config["encoder"])) predictor = ARPredictor(**_strip_target(config["predictor"])) action_encoder = Embedder(**_strip_target(config["action_encoder"])) projector_cfg = _strip_target(config["projector"]) projector_cfg["norm_fn"] = torch.nn.BatchNorm1d projector = MLP(**projector_cfg) pred_proj_cfg = _strip_target(config["pred_proj"]) pred_proj_cfg["norm_fn"] = torch.nn.BatchNorm1d pred_proj = MLP(**pred_proj_cfg) return JEPA( encoder=encoder, predictor=predictor, action_encoder=action_encoder, projector=projector, pred_proj=pred_proj, ) def convert_checkpoint(input_dir: Path, output_name: str) -> tuple[Path, Path]: config_path = input_dir / "config.json" weights_path = input_dir / "weights.pt" if not weights_path.exists(): raise FileNotFoundError(f"Missing weights file: {weights_path}") state_dict = torch.load(weights_path, map_location="cpu") config = _load_json(config_path) if config_path.exists() else infer_config_from_state_dict(state_dict) model = build_model(config) missing, unexpected = model.load_state_dict(state_dict, strict=True) if missing or unexpected: raise RuntimeError( f"State dict mismatch: missing={missing}, unexpected={unexpected}" ) model.eval() object_path = input_dir / f"{output_name}_object.ckpt" weight_path = input_dir / f"{output_name}_weight.ckpt" torch.save(model, object_path) torch.save(model.state_dict(), weight_path) return object_path, weight_path def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( "input_dir", type=Path, help="Directory containing weights.pt and optionally config.json.", ) parser.add_argument("--output-name", default="lewm") args = parser.parse_args() object_path, weight_path = convert_checkpoint(args.input_dir, args.output_name) print(f"wrote {object_path}") print(f"wrote {weight_path}") if __name__ == "__main__": main()