106 lines
3.3 KiB
Python
106 lines
3.3 KiB
Python
import collections
|
|
import time
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
|
|
from unitree_deploy.utils.rerun_visualizer import RerunLogger, visualization_data
|
|
|
|
|
|
def extract_observation(step: dict):
|
|
observation = {}
|
|
|
|
for key, value in step.items():
|
|
if key.startswith("observation.images."):
|
|
if isinstance(value, np.ndarray) and value.ndim == 3 and value.shape[-1] in [1, 3]:
|
|
value = np.transpose(value, (2, 0, 1))
|
|
observation[key] = value
|
|
|
|
elif key == "observation.state":
|
|
observation[key] = value
|
|
|
|
return observation
|
|
|
|
|
|
class DatasetEvalEnv:
|
|
def __init__(self, repo_id: str, episode_index: int = 0, visualization: bool = True):
|
|
self.dataset = LeRobotDataset(repo_id=repo_id)
|
|
|
|
self.visualization = visualization
|
|
if self.visualization:
|
|
self.rerun_logger = RerunLogger()
|
|
|
|
self.from_idx = self.dataset.episode_data_index["from"][episode_index].item()
|
|
self.to_idx = self.dataset.episode_data_index["to"][episode_index].item()
|
|
self.step_idx = self.from_idx
|
|
|
|
self.ground_truth_actions = []
|
|
self.predicted_actions = []
|
|
|
|
def get_observation(self):
|
|
step = self.dataset[self.step_idx]
|
|
observation = extract_observation(step)
|
|
|
|
state = step["observation.state"].numpy()
|
|
self.ground_truth_actions.append(step["action"].numpy())
|
|
|
|
if self.visualization:
|
|
visualization_data(
|
|
self.step_idx,
|
|
observation,
|
|
observation["observation.state"],
|
|
step["action"].numpy(),
|
|
self.rerun_logger,
|
|
)
|
|
|
|
images_observation = {
|
|
key: value.numpy() for key, value in observation.items() if key.startswith("observation.images.")
|
|
}
|
|
|
|
obs = collections.OrderedDict()
|
|
obs["qpos"] = state
|
|
obs["images"] = images_observation
|
|
|
|
self.step_idx += 1
|
|
return obs
|
|
|
|
def step(self, action):
|
|
self.predicted_actions.append(action)
|
|
|
|
if self.step_idx == self.to_idx:
|
|
self._plot_results()
|
|
exit()
|
|
|
|
def _plot_results(self):
|
|
ground_truth_actions = np.array(self.ground_truth_actions)
|
|
predicted_actions = np.array(self.predicted_actions)
|
|
|
|
n_timesteps, n_dims = ground_truth_actions.shape
|
|
|
|
fig, axes = plt.subplots(n_dims, 1, figsize=(12, 4 * n_dims), sharex=True)
|
|
fig.suptitle("Ground Truth vs Predicted Actions")
|
|
|
|
for i in range(n_dims):
|
|
ax = axes[i] if n_dims > 1 else axes
|
|
ax.plot(ground_truth_actions[:, i], label="Ground Truth", color="blue")
|
|
ax.plot(predicted_actions[:, i], label="Predicted", color="red", linestyle="--")
|
|
ax.set_ylabel(f"Dim {i + 1}")
|
|
ax.legend()
|
|
|
|
axes[-1].set_xlabel("Timestep")
|
|
plt.tight_layout()
|
|
plt.savefig("figure.png")
|
|
time.sleep(1)
|
|
|
|
|
|
def make_dataset_eval_env() -> DatasetEvalEnv:
|
|
return DatasetEvalEnv()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
eval_dataset = DatasetEvalEnv(repo_id="unitreerobotics/G1_Brainco_PickApple_Dataset")
|
|
while True:
|
|
observation = eval_dataset.get_observation()
|
|
eval_dataset.step(observation["qpos"])
|