add profile frame and bf15/fp16 switch
This commit is contained in:
19
.gitignore
vendored
Normal file
19
.gitignore
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
.venv/
|
||||||
|
outputs/
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
|
||||||
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
torch_profile/
|
||||||
|
trace.json
|
||||||
|
key_averages.txt
|
||||||
|
eval_tmp_*.npy
|
||||||
|
*.mp4
|
||||||
|
*.gif
|
||||||
|
|
||||||
|
.DS_Store
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
27
README.md
27
README.md
@@ -84,6 +84,33 @@ python eval.py --config-name=pusht.yaml policy=pusht/lewm
|
|||||||
python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt
|
python eval.py --config-name=pusht.yaml policy=pusht/lewm_object.ckpt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Profiling
|
||||||
|
|
||||||
|
`eval.py` now supports optional inference profiling with PyTorch's native profiler.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
|
||||||
|
inference_precision=bf16 \
|
||||||
|
+profile.enabled=true \
|
||||||
|
+profile.with_stack=true \
|
||||||
|
+profile.record_shapes=true \
|
||||||
|
+profile.profile_memory=true
|
||||||
|
```
|
||||||
|
|
||||||
|
Supported inference precision modes:
|
||||||
|
- `inference_precision=fp32`
|
||||||
|
- `inference_precision=bf16`
|
||||||
|
- `inference_precision=fp16`
|
||||||
|
|
||||||
|
Outputs are written next to the evaluation results:
|
||||||
|
- `torch_profile/key_averages.txt` for the aggregated operator table
|
||||||
|
- `torch_profile/trace.json` for Chrome tracing
|
||||||
|
- TensorBoard trace files under `torch_profile/`
|
||||||
|
|
||||||
|
The trace includes custom regions such as `eval.world_evaluate_from_dataset`, `lewm.get_cost`, `lewm.rollout`, and `lewm.predict` to make the planning path easier to inspect.
|
||||||
|
|
||||||
## Pretrained Checkpoints
|
## Pretrained Checkpoints
|
||||||
|
|
||||||
Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`.
|
Pre-trained checkpoints are available on [Google Drive](https://drive.google.com/drive/folders/1r31os0d4-rR0mdHc7OlY_e5nh3XT4r4e). Download the checkpoint archive and place the extracted files under `$STABLEWM_HOME/`.
|
||||||
|
|||||||
135
eval.py
135
eval.py
@@ -3,6 +3,7 @@ import os
|
|||||||
os.environ["MUJOCO_GL"] = "egl"
|
os.environ["MUJOCO_GL"] = "egl"
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
@@ -46,6 +47,99 @@ def get_dataset(cfg, dataset_name):
|
|||||||
)
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_profile_cfg(cfg):
|
||||||
|
profile_cfg = {
|
||||||
|
"enabled": False,
|
||||||
|
"trace_dirname": "torch_profile",
|
||||||
|
"record_shapes": True,
|
||||||
|
"profile_memory": True,
|
||||||
|
"with_stack": False,
|
||||||
|
"with_flops": False,
|
||||||
|
"row_limit": 40,
|
||||||
|
"worker_name": "eval",
|
||||||
|
"export_chrome_trace": True,
|
||||||
|
"export_tensorboard": True,
|
||||||
|
}
|
||||||
|
cfg_profile = cfg.get("profile")
|
||||||
|
if cfg_profile is not None:
|
||||||
|
profile_cfg.update(OmegaConf.to_container(cfg_profile, resolve=True))
|
||||||
|
return profile_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_context(cfg, device):
|
||||||
|
precision = str(cfg.get("inference_precision", "fp32")).lower()
|
||||||
|
device_type = "cuda" if device.startswith("cuda") else "cpu"
|
||||||
|
|
||||||
|
if precision == "fp32":
|
||||||
|
return nullcontext(), "fp32"
|
||||||
|
|
||||||
|
if precision in {"bf16", "bfloat16"}:
|
||||||
|
return (
|
||||||
|
torch.autocast(device_type=device_type, dtype=torch.bfloat16),
|
||||||
|
"bf16",
|
||||||
|
)
|
||||||
|
|
||||||
|
if precision in {"fp16", "float16"}:
|
||||||
|
if device_type != "cuda":
|
||||||
|
print("fp16 inference is only supported on CUDA, falling back to fp32.")
|
||||||
|
return nullcontext(), "fp32"
|
||||||
|
return (
|
||||||
|
torch.autocast(device_type=device_type, dtype=torch.float16),
|
||||||
|
"fp16",
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported inference_precision={precision}. Expected one of: fp32, bf16, fp16."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_profiler(cfg, results_path):
|
||||||
|
profile_cfg = get_profile_cfg(cfg)
|
||||||
|
if not profile_cfg["enabled"]:
|
||||||
|
return nullcontext(), None, profile_cfg
|
||||||
|
|
||||||
|
activities = [torch.profiler.ProfilerActivity.CPU]
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||||
|
|
||||||
|
profile_dir = results_path / profile_cfg["trace_dirname"]
|
||||||
|
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
profiler = torch.profiler.profile(
|
||||||
|
activities=activities,
|
||||||
|
record_shapes=profile_cfg["record_shapes"],
|
||||||
|
profile_memory=profile_cfg["profile_memory"],
|
||||||
|
with_stack=profile_cfg["with_stack"],
|
||||||
|
with_flops=profile_cfg["with_flops"],
|
||||||
|
)
|
||||||
|
return profiler, profile_dir, profile_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def dump_profiler_results(profiler, profile_dir, profile_cfg):
|
||||||
|
if profiler is None or profile_dir is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
has_cuda = torch.cuda.is_available()
|
||||||
|
table = profiler.key_averages().table(
|
||||||
|
sort_by="self_cuda_time_total" if has_cuda else "self_cpu_time_total",
|
||||||
|
row_limit=profile_cfg["row_limit"],
|
||||||
|
)
|
||||||
|
|
||||||
|
summary_path = profile_dir / "key_averages.txt"
|
||||||
|
summary_path.write_text(table)
|
||||||
|
|
||||||
|
if profile_cfg["export_chrome_trace"]:
|
||||||
|
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
|
||||||
|
|
||||||
|
if profile_cfg["export_tensorboard"]:
|
||||||
|
trace_handler = torch.profiler.tensorboard_trace_handler(
|
||||||
|
str(profile_dir), worker_name=profile_cfg["worker_name"]
|
||||||
|
)
|
||||||
|
trace_handler(profiler)
|
||||||
|
|
||||||
|
return summary_path
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
|
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
|
||||||
def run(cfg: DictConfig):
|
def run(cfg: DictConfig):
|
||||||
"""Run evaluation of dinowm vs random policy."""
|
"""Run evaluation of dinowm vs random policy."""
|
||||||
@@ -83,12 +177,15 @@ def run(cfg: DictConfig):
|
|||||||
|
|
||||||
# -- run evaluation
|
# -- run evaluation
|
||||||
policy = cfg.get("policy", "random")
|
policy = cfg.get("policy", "random")
|
||||||
|
|
||||||
if policy != "random":
|
if policy != "random":
|
||||||
model = swm.policy.AutoCostModel(cfg.policy)
|
model = swm.policy.AutoCostModel(cfg.policy)
|
||||||
model = model.to("cuda")
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = model.to(device)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
|
print(f"model parameter dtype: {next(model.parameters()).dtype}")
|
||||||
|
inference_ctx, inference_precision = get_inference_context(cfg, device)
|
||||||
|
print(f"inference execution precision: {inference_precision}")
|
||||||
model.interpolate_pos_encoding = True
|
model.interpolate_pos_encoding = True
|
||||||
config = swm.PlanConfig(**cfg.plan_config)
|
config = swm.PlanConfig(**cfg.plan_config)
|
||||||
solver = hydra.utils.instantiate(cfg.solver, model=model)
|
solver = hydra.utils.instantiate(cfg.solver, model=model)
|
||||||
@@ -98,12 +195,15 @@ def run(cfg: DictConfig):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
policy = swm.policy.RandomPolicy()
|
policy = swm.policy.RandomPolicy()
|
||||||
|
inference_ctx = nullcontext()
|
||||||
|
inference_precision = "fp32"
|
||||||
|
|
||||||
results_path = (
|
results_path = (
|
||||||
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
|
Path(swm.data.utils.get_cache_dir(), cfg.policy).parent
|
||||||
if cfg.policy != "random"
|
if cfg.policy != "random"
|
||||||
else Path(__file__).parent
|
else Path(__file__).parent
|
||||||
)
|
)
|
||||||
|
profiler_ctx, profile_dir, profile_cfg = make_profiler(cfg, results_path)
|
||||||
|
|
||||||
# sample the episodes and the starting indices
|
# sample the episodes and the starting indices
|
||||||
episode_len = get_episodes_length(dataset, ep_indices)
|
episode_len = get_episodes_length(dataset, ep_indices)
|
||||||
@@ -138,17 +238,25 @@ def run(cfg: DictConfig):
|
|||||||
|
|
||||||
world.set_policy(policy)
|
world.set_policy(policy)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
metrics = world.evaluate_from_dataset(
|
with profiler_ctx as profiler:
|
||||||
dataset,
|
with inference_ctx:
|
||||||
start_steps=eval_start_idx.tolist(),
|
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
|
||||||
goal_offset_steps=cfg.eval.goal_offset_steps,
|
metrics = world.evaluate_from_dataset(
|
||||||
eval_budget=cfg.eval.eval_budget,
|
dataset,
|
||||||
episodes_idx=eval_episodes.tolist(),
|
start_steps=eval_start_idx.tolist(),
|
||||||
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
|
goal_offset_steps=cfg.eval.goal_offset_steps,
|
||||||
video_path=results_path,
|
eval_budget=cfg.eval.eval_budget,
|
||||||
)
|
episodes_idx=eval_episodes.tolist(),
|
||||||
|
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
|
||||||
|
video_path=results_path,
|
||||||
|
)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
profile_summary_path = dump_profiler_results(profiler, profile_dir, profile_cfg)
|
||||||
|
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
||||||
@@ -165,6 +273,11 @@ def run(cfg: DictConfig):
|
|||||||
f.write("==== RESULTS ====\n")
|
f.write("==== RESULTS ====\n")
|
||||||
f.write(f"metrics: {metrics}\n")
|
f.write(f"metrics: {metrics}\n")
|
||||||
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
|
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
|
||||||
|
f.write(f"inference_precision: {inference_precision}\n")
|
||||||
|
if profile_cfg["enabled"]:
|
||||||
|
f.write(f"profile_dir: {profile_dir}\n")
|
||||||
|
if profile_summary_path is not None:
|
||||||
|
f.write(f"profile_summary: {profile_summary_path}\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
156
jepa.py
156
jepa.py
@@ -30,29 +30,30 @@ class JEPA(nn.Module):
|
|||||||
"""Encode observations and actions into embeddings.
|
"""Encode observations and actions into embeddings.
|
||||||
info: dict with pixels and action keys
|
info: dict with pixels and action keys
|
||||||
"""
|
"""
|
||||||
|
with torch.profiler.record_function("lewm.encode"):
|
||||||
|
pixels = info['pixels'].float()
|
||||||
|
b = pixels.size(0)
|
||||||
|
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
|
||||||
|
output = self.encoder(pixels, interpolate_pos_encoding=True)
|
||||||
|
pixels_emb = output.last_hidden_state[:, 0] # cls token
|
||||||
|
emb = self.projector(pixels_emb)
|
||||||
|
info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b)
|
||||||
|
|
||||||
pixels = info['pixels'].float()
|
if "action" in info:
|
||||||
b = pixels.size(0)
|
info["act_emb"] = self.action_encoder(info["action"])
|
||||||
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
|
|
||||||
output = self.encoder(pixels, interpolate_pos_encoding=True)
|
|
||||||
pixels_emb = output.last_hidden_state[:, 0] # cls token
|
|
||||||
emb = self.projector(pixels_emb)
|
|
||||||
info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b)
|
|
||||||
|
|
||||||
if "action" in info:
|
return info
|
||||||
info["act_emb"] = self.action_encoder(info["action"])
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
def predict(self, emb, act_emb):
|
def predict(self, emb, act_emb):
|
||||||
"""Predict next state embedding
|
"""Predict next state embedding
|
||||||
emb: (B, T, D)
|
emb: (B, T, D)
|
||||||
act_emb: (B, T, A_emb)
|
act_emb: (B, T, A_emb)
|
||||||
"""
|
"""
|
||||||
preds = self.predictor(emb, act_emb)
|
with torch.profiler.record_function("lewm.predict"):
|
||||||
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
|
preds = self.predictor(emb, act_emb)
|
||||||
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
|
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
|
||||||
return preds
|
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
|
||||||
|
return preds
|
||||||
|
|
||||||
####################
|
####################
|
||||||
## Inference only ##
|
## Inference only ##
|
||||||
@@ -65,89 +66,90 @@ class JEPA(nn.Module):
|
|||||||
- S is the number of action plan samples
|
- S is the number of action plan samples
|
||||||
- T is the time horizon
|
- T is the time horizon
|
||||||
"""
|
"""
|
||||||
|
with torch.profiler.record_function("lewm.rollout"):
|
||||||
|
assert "pixels" in info, "pixels not in info_dict"
|
||||||
|
H = info["pixels"].size(2)
|
||||||
|
B, S, T = action_sequence.shape[:3]
|
||||||
|
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
|
||||||
|
info["action"] = act_0
|
||||||
|
n_steps = T - H
|
||||||
|
|
||||||
assert "pixels" in info, "pixels not in info_dict"
|
# copy and encode initial info dict
|
||||||
H = info["pixels"].size(2)
|
_init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)}
|
||||||
B, S, T = action_sequence.shape[:3]
|
_init = self.encode(_init)
|
||||||
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
|
emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1)
|
||||||
info["action"] = act_0
|
_init = {k: detach_clone(v) for k, v in _init.items()}
|
||||||
n_steps = T - H
|
|
||||||
|
|
||||||
# copy and encode initial info dict
|
# flatten batch and sample dimensions for rollout
|
||||||
_init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)}
|
emb = rearrange(emb, "b s ... -> (b s) ...").clone()
|
||||||
_init = self.encode(_init)
|
act = rearrange(act_0, "b s ... -> (b s) ...")
|
||||||
emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1)
|
act_future = rearrange(act_future, "b s ... -> (b s) ...")
|
||||||
_init = {k: detach_clone(v) for k, v in _init.items()}
|
|
||||||
|
|
||||||
# flatten batch and sample dimensions for rollout
|
# rollout predictor autoregressively for n_steps
|
||||||
emb = rearrange(emb, "b s ... -> (b s) ...").clone()
|
HS = history_size
|
||||||
act = rearrange(act_0, "b s ... -> (b s) ...")
|
for t in range(n_steps):
|
||||||
act_future = rearrange(act_future, "b s ... -> (b s) ...")
|
act_emb = self.action_encoder(act)
|
||||||
|
emb_trunc = emb[:, -HS:] # (BS, HS, D)
|
||||||
|
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
|
||||||
|
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
|
||||||
|
emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D)
|
||||||
|
|
||||||
# rollout predictor autoregressively for n_steps
|
next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim)
|
||||||
HS = history_size
|
act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim)
|
||||||
for t in range(n_steps):
|
|
||||||
act_emb = self.action_encoder(act)
|
# predict the last state
|
||||||
|
act_emb = self.action_encoder(act) # (BS, T, A_emb)
|
||||||
emb_trunc = emb[:, -HS:] # (BS, HS, D)
|
emb_trunc = emb[:, -HS:] # (BS, HS, D)
|
||||||
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
|
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
|
||||||
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
|
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
|
||||||
emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D)
|
emb = torch.cat([emb, pred_emb], dim=1)
|
||||||
|
|
||||||
next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim)
|
# unflatten batch and sample dimensions
|
||||||
act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim)
|
pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S)
|
||||||
|
info["predicted_emb"] = pred_rollout
|
||||||
|
|
||||||
# predict the last state
|
return info
|
||||||
act_emb = self.action_encoder(act) # (BS, T, A_emb)
|
|
||||||
emb_trunc = emb[:, -HS:] # (BS, HS, D)
|
|
||||||
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
|
|
||||||
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
|
|
||||||
emb = torch.cat([emb, pred_emb], dim=1)
|
|
||||||
|
|
||||||
# unflatten batch and sample dimensions
|
|
||||||
pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S)
|
|
||||||
info["predicted_emb"] = pred_rollout
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
def criterion(self, info_dict: dict):
|
def criterion(self, info_dict: dict):
|
||||||
"""Compute the cost between predicted embeddings and goal embeddings."""
|
"""Compute the cost between predicted embeddings and goal embeddings."""
|
||||||
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
|
with torch.profiler.record_function("lewm.criterion"):
|
||||||
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
|
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
|
||||||
|
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
|
||||||
|
|
||||||
goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb)
|
goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb)
|
||||||
|
|
||||||
# return last-step cost per action candidate
|
# return last-step cost per action candidate
|
||||||
cost = F.mse_loss(
|
cost = F.mse_loss(
|
||||||
pred_emb[..., -1:, :],
|
pred_emb[..., -1:, :],
|
||||||
goal_emb[..., -1:, :].detach(),
|
goal_emb[..., -1:, :].detach(),
|
||||||
reduction="none",
|
reduction="none",
|
||||||
).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S)
|
).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S)
|
||||||
|
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
|
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
|
||||||
""" Compute the cost of action candidates given an info dict with goal and initial state."""
|
""" Compute the cost of action candidates given an info dict with goal and initial state."""
|
||||||
|
with torch.profiler.record_function("lewm.get_cost"):
|
||||||
|
assert "goal" in info_dict, "goal not in info_dict"
|
||||||
|
|
||||||
assert "goal" in info_dict, "goal not in info_dict"
|
device = next(self.parameters()).device
|
||||||
|
for k in list(info_dict.keys()):
|
||||||
|
if torch.is_tensor(info_dict[k]):
|
||||||
|
info_dict[k] = info_dict[k].to(device)
|
||||||
|
|
||||||
device = next(self.parameters()).device
|
goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)}
|
||||||
for k in list(info_dict.keys()):
|
goal["pixels"] = goal["goal"]
|
||||||
if torch.is_tensor(info_dict[k]):
|
|
||||||
info_dict[k] = info_dict[k].to(device)
|
|
||||||
|
|
||||||
goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)}
|
for k in info_dict:
|
||||||
goal["pixels"] = goal["goal"]
|
if k.startswith("goal_"):
|
||||||
|
goal[k[len("goal_") :]] = goal.pop(k)
|
||||||
|
|
||||||
for k in info_dict:
|
goal.pop("action")
|
||||||
if k.startswith("goal_"):
|
goal = self.encode(goal)
|
||||||
goal[k[len("goal_") :]] = goal.pop(k)
|
|
||||||
|
|
||||||
goal.pop("action")
|
info_dict["goal_emb"] = goal["emb"]
|
||||||
goal = self.encode(goal)
|
info_dict = self.rollout(info_dict, action_candidates)
|
||||||
|
|
||||||
info_dict["goal_emb"] = goal["emb"]
|
cost = self.criterion(info_dict)
|
||||||
info_dict = self.rollout(info_dict, action_candidates)
|
|
||||||
|
return cost
|
||||||
cost = self.criterion(info_dict)
|
|
||||||
|
|
||||||
return cost
|
|
||||||
|
|||||||
Reference in New Issue
Block a user