增加脚本
This commit is contained in:
131
scripts/convert_hf_checkpoint.py
Normal file
131
scripts/convert_hf_checkpoint.py
Normal file
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python
|
||||
"""Convert LeWM HuggingFace weights into eval-compatible object checkpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import stable_pretraining as spt
|
||||
import torch
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from jepa import JEPA
|
||||
from module import ARPredictor, Embedder, MLP
|
||||
|
||||
|
||||
def _load_json(path: Path) -> dict:
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _strip_target(config: dict) -> dict:
|
||||
return {key: value for key, value in config.items() if key != "_target_"}
|
||||
|
||||
|
||||
def infer_config_from_state_dict(state_dict: dict) -> dict:
|
||||
action_dim = state_dict["action_encoder.patch_embed.weight"].shape[1]
|
||||
return {
|
||||
"encoder": {
|
||||
"size": "tiny",
|
||||
"patch_size": 14,
|
||||
"image_size": 224,
|
||||
"pretrained": False,
|
||||
"use_mask_token": False,
|
||||
},
|
||||
"predictor": {
|
||||
"num_frames": 3,
|
||||
"input_dim": 192,
|
||||
"hidden_dim": 192,
|
||||
"output_dim": 192,
|
||||
"depth": 6,
|
||||
"heads": 16,
|
||||
"mlp_dim": 2048,
|
||||
"dim_head": 64,
|
||||
"dropout": 0.1,
|
||||
"emb_dropout": 0.0,
|
||||
},
|
||||
"action_encoder": {
|
||||
"input_dim": action_dim,
|
||||
"emb_dim": 192,
|
||||
},
|
||||
"projector": {
|
||||
"input_dim": 192,
|
||||
"output_dim": 192,
|
||||
"hidden_dim": 2048,
|
||||
},
|
||||
"pred_proj": {
|
||||
"input_dim": 192,
|
||||
"output_dim": 192,
|
||||
"hidden_dim": 2048,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_model(config: dict) -> JEPA:
|
||||
encoder = spt.backbone.utils.vit_hf(**_strip_target(config["encoder"]))
|
||||
predictor = ARPredictor(**_strip_target(config["predictor"]))
|
||||
action_encoder = Embedder(**_strip_target(config["action_encoder"]))
|
||||
|
||||
projector_cfg = _strip_target(config["projector"])
|
||||
projector_cfg["norm_fn"] = torch.nn.BatchNorm1d
|
||||
projector = MLP(**projector_cfg)
|
||||
|
||||
pred_proj_cfg = _strip_target(config["pred_proj"])
|
||||
pred_proj_cfg["norm_fn"] = torch.nn.BatchNorm1d
|
||||
pred_proj = MLP(**pred_proj_cfg)
|
||||
|
||||
return JEPA(
|
||||
encoder=encoder,
|
||||
predictor=predictor,
|
||||
action_encoder=action_encoder,
|
||||
projector=projector,
|
||||
pred_proj=pred_proj,
|
||||
)
|
||||
|
||||
|
||||
def convert_checkpoint(input_dir: Path, output_name: str) -> tuple[Path, Path]:
|
||||
config_path = input_dir / "config.json"
|
||||
weights_path = input_dir / "weights.pt"
|
||||
if not weights_path.exists():
|
||||
raise FileNotFoundError(f"Missing weights file: {weights_path}")
|
||||
|
||||
state_dict = torch.load(weights_path, map_location="cpu")
|
||||
config = _load_json(config_path) if config_path.exists() else infer_config_from_state_dict(state_dict)
|
||||
model = build_model(config)
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=True)
|
||||
if missing or unexpected:
|
||||
raise RuntimeError(
|
||||
f"State dict mismatch: missing={missing}, unexpected={unexpected}"
|
||||
)
|
||||
model.eval()
|
||||
|
||||
object_path = input_dir / f"{output_name}_object.ckpt"
|
||||
weight_path = input_dir / f"{output_name}_weight.ckpt"
|
||||
torch.save(model, object_path)
|
||||
torch.save(model.state_dict(), weight_path)
|
||||
return object_path, weight_path
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"input_dir",
|
||||
type=Path,
|
||||
help="Directory containing weights.pt and optionally config.json.",
|
||||
)
|
||||
parser.add_argument("--output-name", default="lewm")
|
||||
args = parser.parse_args()
|
||||
|
||||
object_path, weight_path = convert_checkpoint(args.input_dir, args.output_name)
|
||||
print(f"wrote {object_path}")
|
||||
print(f"wrote {weight_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user