Files
unifolm-world-model-action/unitree_deploy/unitree_deploy/eval_dataset_env.py
2025-09-23 15:13:22 +08:00

106 lines
3.3 KiB
Python

import collections
import time
import matplotlib.pyplot as plt
import numpy as np
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from unitree_deploy.utils.rerun_visualizer import RerunLogger, visualization_data
def extract_observation(step: dict):
observation = {}
for key, value in step.items():
if key.startswith("observation.images."):
if isinstance(value, np.ndarray) and value.ndim == 3 and value.shape[-1] in [1, 3]:
value = np.transpose(value, (2, 0, 1))
observation[key] = value
elif key == "observation.state":
observation[key] = value
return observation
class DatasetEvalEnv:
def __init__(self, repo_id: str, episode_index: int = 0, visualization: bool = True):
self.dataset = LeRobotDataset(repo_id=repo_id)
self.visualization = visualization
if self.visualization:
self.rerun_logger = RerunLogger()
self.from_idx = self.dataset.episode_data_index["from"][episode_index].item()
self.to_idx = self.dataset.episode_data_index["to"][episode_index].item()
self.step_idx = self.from_idx
self.ground_truth_actions = []
self.predicted_actions = []
def get_observation(self):
step = self.dataset[self.step_idx]
observation = extract_observation(step)
state = step["observation.state"].numpy()
self.ground_truth_actions.append(step["action"].numpy())
if self.visualization:
visualization_data(
self.step_idx,
observation,
observation["observation.state"],
step["action"].numpy(),
self.rerun_logger,
)
images_observation = {
key: value.numpy() for key, value in observation.items() if key.startswith("observation.images.")
}
obs = collections.OrderedDict()
obs["qpos"] = state
obs["images"] = images_observation
self.step_idx += 1
return obs
def step(self, action):
self.predicted_actions.append(action)
if self.step_idx == self.to_idx:
self._plot_results()
exit()
def _plot_results(self):
ground_truth_actions = np.array(self.ground_truth_actions)
predicted_actions = np.array(self.predicted_actions)
n_timesteps, n_dims = ground_truth_actions.shape
fig, axes = plt.subplots(n_dims, 1, figsize=(12, 4 * n_dims), sharex=True)
fig.suptitle("Ground Truth vs Predicted Actions")
for i in range(n_dims):
ax = axes[i] if n_dims > 1 else axes
ax.plot(ground_truth_actions[:, i], label="Ground Truth", color="blue")
ax.plot(predicted_actions[:, i], label="Predicted", color="red", linestyle="--")
ax.set_ylabel(f"Dim {i + 1}")
ax.legend()
axes[-1].set_xlabel("Timestep")
plt.tight_layout()
plt.savefig("figure.png")
time.sleep(1)
def make_dataset_eval_env() -> DatasetEvalEnv:
return DatasetEvalEnv()
if __name__ == "__main__":
eval_dataset = DatasetEvalEnv(repo_id="unitreerobotics/G1_Brainco_PickApple_Dataset")
while True:
observation = eval_dataset.get_observation()
eval_dataset.step(observation["qpos"])