upload real-robot deployment code
This commit is contained in:
281
unitree_deploy/unitree_deploy/utils/eval_utils.py
Normal file
281
unitree_deploy/unitree_deploy/utils/eval_utils.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import requests
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets import load_from_disk
|
||||
from datasets.features.features import register_feature
|
||||
from safetensors.torch import load_file
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
|
||||
|
||||
class LongConnectionClient:
|
||||
def __init__(self, base_url):
|
||||
self.session = requests.Session()
|
||||
self.base_url = base_url
|
||||
|
||||
def send_post(self, endpoint, json_data):
|
||||
"""send POST request to endpoint"""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
response = None
|
||||
while True:
|
||||
try:
|
||||
response = self.session.post(url, json=json_data)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data["result"] == "ok":
|
||||
response = data
|
||||
break
|
||||
else:
|
||||
logging.info(data["desc"])
|
||||
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
return response
|
||||
|
||||
def close(self):
|
||||
""" "close session"""
|
||||
self.session.close()
|
||||
|
||||
def predict_action(self, language_instruction, batch) -> torch.Tensor:
|
||||
# collect data
|
||||
data = {
|
||||
"language_instruction": language_instruction,
|
||||
"observation.state": torch.stack(list(batch["observation.state"])).tolist(),
|
||||
"observation.images.top": torch.stack(list(batch["observation.images.top"])).tolist(),
|
||||
"action": torch.stack(list(batch["action"])).tolist(),
|
||||
}
|
||||
|
||||
# send data
|
||||
endpoint = "/predict_action"
|
||||
response = self.send_post(endpoint, data)
|
||||
# action = torch.tensor(response['action']).unsqueeze(0)
|
||||
action = torch.tensor(response["action"])
|
||||
return action
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int, exe_steps: int) -> None:
|
||||
"""Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.
|
||||
|
||||
The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
|
||||
They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
|
||||
coefficient works:
|
||||
- Setting it to 0 uniformly weighs all actions.
|
||||
- Setting it positive gives more weight to older actions.
|
||||
- Setting it negative gives more weight to newer actions.
|
||||
NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
|
||||
results in older actions being weighed more highly than newer actions (the experiments documented in
|
||||
https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be
|
||||
detrimental: doing so aggressively may diminish the benefits of action chunking).
|
||||
|
||||
Here we use an online method for computing the average rather than caching a history of actions in
|
||||
order to compute the average offline. For a simple 1D sequence it looks something like:
|
||||
|
||||
```
|
||||
import torch
|
||||
|
||||
seq = torch.linspace(8, 8.5, 100)
|
||||
print(seq)
|
||||
|
||||
m = 0.01
|
||||
exp_weights = torch.exp(-m * torch.arange(len(seq)))
|
||||
print(exp_weights)
|
||||
|
||||
# Calculate offline
|
||||
avg = (exp_weights * seq).sum() / exp_weights.sum()
|
||||
print("offline", avg)
|
||||
|
||||
# Calculate online
|
||||
for i, item in enumerate(seq):
|
||||
if i == 0:
|
||||
avg = item
|
||||
continue
|
||||
avg *= exp_weights[:i].sum()
|
||||
avg += item * exp_weights[i]
|
||||
avg /= exp_weights[:i+1].sum()
|
||||
print("online", avg)
|
||||
```
|
||||
"""
|
||||
self.chunk_size = chunk_size
|
||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||
self.exe_steps = exe_steps
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Resets the online computation variables."""
|
||||
self.ensembled_actions = None
|
||||
# (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
|
||||
self.ensembled_actions_count = None
|
||||
|
||||
def update(self, actions):
|
||||
"""
|
||||
Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
|
||||
time steps, and pop/return the next batch of actions in the sequence.
|
||||
"""
|
||||
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
|
||||
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
|
||||
if self.ensembled_actions is None:
|
||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||
# time step of the episode.
|
||||
self.ensembled_actions = actions.clone()
|
||||
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
|
||||
# operations later.
|
||||
self.ensembled_actions_count = torch.ones(
|
||||
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
|
||||
)
|
||||
else:
|
||||
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||
# the online update for those entries.
|
||||
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
|
||||
self.ensembled_actions += (
|
||||
actions[:, : -self.exe_steps] * self.ensemble_weights[self.ensembled_actions_count]
|
||||
)
|
||||
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
|
||||
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
|
||||
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
||||
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -self.exe_steps :]], dim=1)
|
||||
self.ensembled_actions_count = torch.cat(
|
||||
# [self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-self.exe_steps:])]
|
||||
[
|
||||
self.ensembled_actions_count,
|
||||
torch.ones((self.exe_steps, 1), dtype=torch.long, device=self.ensembled_actions_count.device),
|
||||
]
|
||||
)
|
||||
# "Consume" the first action.
|
||||
|
||||
actions, self.ensembled_actions, self.ensembled_actions_count = (
|
||||
self.ensembled_actions[:, : self.exe_steps],
|
||||
self.ensembled_actions[:, self.exe_steps :],
|
||||
self.ensembled_actions_count[self.exe_steps :],
|
||||
)
|
||||
return actions
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrame:
|
||||
"""
|
||||
Provides a type for a dataset containing video frames.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
|
||||
features = {"image": VideoFrame()}
|
||||
Dataset.from_dict(data_dict, features=Features(features))
|
||||
```
|
||||
"""
|
||||
|
||||
pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
|
||||
_type: str = field(default="VideoFrame", init=False, repr=False)
|
||||
|
||||
def __call__(self):
|
||||
return self.pa_type
|
||||
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"'register_feature' is experimental and might be subject to breaking changes in the future.",
|
||||
category=UserWarning,
|
||||
)
|
||||
# to make VideoFrame available in HuggingFace `datasets`
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def get_image(cam_list, target_shape=None, save_image=False):
|
||||
curr_images = []
|
||||
for cam in cam_list:
|
||||
color, _ = cam.get_frame()
|
||||
if save_image:
|
||||
cv2.imwrite("/home/world-model-x/output.png", color)
|
||||
color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB)
|
||||
if target_shape:
|
||||
color = cv2.resize(color, target_shape)
|
||||
curr_images.append(color)
|
||||
curr_images = np.stack(curr_images, axis=0)
|
||||
return curr_images
|
||||
|
||||
|
||||
def load_action_from_dataset(dataset_dir, episode_id):
|
||||
data = load_from_disk(dataset_dir + "/train")
|
||||
episode_data = load_file(dataset_dir + "/meta_data/episode_data_index.safetensors")
|
||||
start_id = episode_data["from"][episode_id]
|
||||
end_id = episode_data["to"][episode_id]
|
||||
actions = torch.FloatTensor(data["action"][start_id:end_id])
|
||||
return actions
|
||||
|
||||
|
||||
def load_stats_from_prompt_dir(dataset_dir, prompt_dir, subdir=""):
|
||||
dataset_dir += subdir + "/meta_data"
|
||||
stats = load_file(dataset_dir + "/stats.safetensors")
|
||||
return stats
|
||||
|
||||
|
||||
def populate_queues(queues, batch):
|
||||
for key in batch:
|
||||
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
|
||||
# queues have the keys they want).
|
||||
if key not in queues:
|
||||
continue
|
||||
if len(queues[key]) != queues[key].maxlen:
|
||||
# initialize by copying the first observation several times until the queue is full
|
||||
while len(queues[key]) != queues[key].maxlen:
|
||||
queues[key].append(batch[key])
|
||||
else:
|
||||
# add latest observation to the queue
|
||||
queues[key].append(batch[key])
|
||||
return queues
|
||||
|
||||
|
||||
def action_safe_checking(action, action_max, action_min, threshold=0.01):
|
||||
over_max = any(action - threshold > action_max.cpu().numpy())
|
||||
over_min = any(action + threshold < action_min.cpu().numpy())
|
||||
return not (over_max or over_min)
|
||||
|
||||
|
||||
def get_init_pose(dataset_dir, start_id=0):
|
||||
# load all par
|
||||
dataset_dir_path = Path(dataset_dir) / "data" / "chunk-000"
|
||||
parquet_files = list(dataset_dir_path.glob("*.parquet"))
|
||||
parquet_files = sorted([str(f) for f in parquet_files])
|
||||
first_rows = [pd.read_parquet(f, engine="pyarrow").iloc[[0]] for f in parquet_files]
|
||||
df = pd.concat(first_rows, ignore_index=True)
|
||||
action_array = np.stack(df["action"].values)
|
||||
init_pose = action_array[192:193, ...]
|
||||
return init_pose
|
||||
|
||||
|
||||
def save_image(obs, num_step=None, output_dir=None):
|
||||
rgb_image = cv2.cvtColor(obs.observation["images"]["cam_left_high"], cv2.COLOR_BGR2RGB)
|
||||
cv2.imwrite(f"{output_dir}/top_{num_step:06d}.png", rgb_image)
|
||||
|
||||
|
||||
def log_to_tensorboard(writer, data, tag, fps=10):
|
||||
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
||||
video = data
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video
|
||||
] # [3, n*h, 1*w]
|
||||
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = grid.unsqueeze(dim=0)
|
||||
writer.add_video(tag, grid, fps=fps)
|
||||
196
unitree_deploy/unitree_deploy/utils/joint_trajcetory_inter.py
Normal file
196
unitree_deploy/unitree_deploy/utils/joint_trajcetory_inter.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""The modification is derived from diffusion_policy/common/pose_trajectory_interpolator.py. Thank you for the outstanding contribution."""
|
||||
|
||||
import numbers
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import scipy.interpolate as si
|
||||
|
||||
|
||||
def joint_pose_distance(start_joint_angles, end_joint_angles):
|
||||
start_joint_angles = np.array(start_joint_angles)
|
||||
end_joint_angles = np.array(end_joint_angles)
|
||||
joint_angle_dist = np.linalg.norm(end_joint_angles - start_joint_angles)
|
||||
|
||||
return joint_angle_dist
|
||||
|
||||
|
||||
class JointTrajectoryInterpolator:
|
||||
def __init__(self, times: np.ndarray, joint_positions: np.ndarray):
|
||||
assert len(times) >= 1
|
||||
assert len(joint_positions) == len(times)
|
||||
self.num_joints = len(joint_positions[0])
|
||||
if not isinstance(times, np.ndarray):
|
||||
times = np.array(times)
|
||||
if not isinstance(joint_positions, np.ndarray):
|
||||
joint_positions = np.array(joint_positions)
|
||||
if len(times) == 1:
|
||||
self.single_step = True
|
||||
self._times = times
|
||||
self._joint_positions = joint_positions
|
||||
else:
|
||||
self.single_step = False
|
||||
assert np.all(times[1:] >= times[:-1])
|
||||
self.interpolators = si.interp1d(times, joint_positions, axis=0, assume_sorted=True)
|
||||
|
||||
@property
|
||||
def times(self) -> np.ndarray:
|
||||
if self.single_step:
|
||||
return self._times
|
||||
else:
|
||||
return self.interpolators.x
|
||||
|
||||
@property
|
||||
def joint_positions(self) -> np.ndarray:
|
||||
if self.single_step:
|
||||
return self._joint_positions
|
||||
else:
|
||||
n = len(self.times)
|
||||
joint_positions = np.zeros((n, self.num_joints))
|
||||
joint_positions = self.interpolators.y
|
||||
return joint_positions
|
||||
|
||||
def trim(self, start_t: float, end_t: float) -> "JointTrajectoryInterpolator":
|
||||
assert start_t <= end_t
|
||||
times = self.times
|
||||
should_keep = (start_t < times) & (times < end_t)
|
||||
keep_times = times[should_keep]
|
||||
all_times = np.concatenate([[start_t], keep_times, [end_t]])
|
||||
all_times = np.unique(all_times)
|
||||
all_joint_positions = self(all_times)
|
||||
return JointTrajectoryInterpolator(times=all_times, joint_positions=all_joint_positions)
|
||||
|
||||
def drive_to_waypoint(
|
||||
self,
|
||||
pose,
|
||||
time,
|
||||
curr_time,
|
||||
max_pos_speed=np.inf,
|
||||
) -> "JointTrajectoryInterpolator":
|
||||
assert max_pos_speed > 0
|
||||
time = max(time, curr_time)
|
||||
|
||||
curr_pose = self(curr_time)
|
||||
pos_dist = joint_pose_distance(curr_pose, pose)
|
||||
pos_min_duration = pos_dist / max_pos_speed
|
||||
duration = time - curr_time
|
||||
duration = max(duration, pos_min_duration)
|
||||
assert duration >= 0
|
||||
last_waypoint_time = curr_time + duration
|
||||
|
||||
# insert new pose
|
||||
trimmed_interp = self.trim(curr_time, curr_time)
|
||||
times = np.append(trimmed_interp.times, [last_waypoint_time], axis=0)
|
||||
poses = np.append(trimmed_interp.joint_positions, [pose], axis=0)
|
||||
|
||||
# create new interpolator
|
||||
final_interp = JointTrajectoryInterpolator(times, poses)
|
||||
return final_interp
|
||||
|
||||
def schedule_waypoint(
|
||||
self, pose, time, max_pos_speed=np.inf, curr_time=None, last_waypoint_time=None
|
||||
) -> "JointTrajectoryInterpolator":
|
||||
assert max_pos_speed > 0
|
||||
if last_waypoint_time is not None:
|
||||
assert curr_time is not None
|
||||
|
||||
# trim current interpolator to between curr_time and last_waypoint_time
|
||||
start_time = self.times[0]
|
||||
end_time = self.times[-1]
|
||||
assert start_time <= end_time
|
||||
|
||||
if curr_time is not None:
|
||||
if time <= curr_time:
|
||||
# if insert time is earlier than current time
|
||||
# no effect should be done to the interpolator
|
||||
return self
|
||||
# now, curr_time < time
|
||||
start_time = max(curr_time, start_time)
|
||||
|
||||
if last_waypoint_time is not None:
|
||||
# if last_waypoint_time is earlier than start_time
|
||||
# use start_time
|
||||
end_time = curr_time if time <= last_waypoint_time else max(last_waypoint_time, curr_time)
|
||||
else:
|
||||
end_time = curr_time
|
||||
|
||||
end_time = min(end_time, time)
|
||||
start_time = min(start_time, end_time)
|
||||
|
||||
# end time should be the latest of all times except time after this we can assume order (proven by zhenjia, due to the 2 min operations)
|
||||
# Constraints:
|
||||
# start_time <= end_time <= time (proven by zhenjia)
|
||||
# curr_time <= start_time (proven by zhenjia)
|
||||
# curr_time <= time (proven by zhenjia)
|
||||
|
||||
assert start_time <= end_time
|
||||
assert end_time <= time
|
||||
if last_waypoint_time is not None:
|
||||
if time <= last_waypoint_time:
|
||||
assert end_time == curr_time
|
||||
else:
|
||||
assert end_time == max(last_waypoint_time, curr_time)
|
||||
|
||||
if curr_time is not None:
|
||||
assert curr_time <= start_time
|
||||
assert curr_time <= time
|
||||
|
||||
trimmed_interp = self.trim(start_time, end_time)
|
||||
|
||||
# determine speed
|
||||
duration = time - end_time
|
||||
end_pose = trimmed_interp(end_time)
|
||||
pos_dist = joint_pose_distance(pose, end_pose)
|
||||
|
||||
joint_min_duration = pos_dist / max_pos_speed
|
||||
|
||||
duration = max(duration, joint_min_duration)
|
||||
assert duration >= 0
|
||||
last_waypoint_time = end_time + duration
|
||||
|
||||
# insert new pose
|
||||
times = np.append(trimmed_interp.times, [last_waypoint_time], axis=0)
|
||||
poses = np.append(trimmed_interp.joint_positions, [pose], axis=0)
|
||||
|
||||
# create new interpolator
|
||||
final_interp = JointTrajectoryInterpolator(times, poses)
|
||||
return final_interp
|
||||
|
||||
def __call__(self, t: Union[numbers.Number, np.ndarray]) -> np.ndarray:
|
||||
is_single = False
|
||||
if isinstance(t, numbers.Number):
|
||||
is_single = True
|
||||
t = np.array([t])
|
||||
|
||||
joint_positions = np.zeros((len(t), self.num_joints))
|
||||
|
||||
if self.single_step:
|
||||
joint_positions[:] = self._joint_positions[0]
|
||||
else:
|
||||
start_time = self.times[0]
|
||||
end_time = self.times[-1]
|
||||
t = np.clip(t, start_time, end_time)
|
||||
joint_positions[:, :] = self.interpolators(t)
|
||||
|
||||
if is_single:
|
||||
joint_positions = joint_positions[0]
|
||||
return joint_positions
|
||||
|
||||
|
||||
def generate_joint_positions(
|
||||
num_rows: int, num_cols: int, start: float = 0.0, step: float = 0.1, row_offset: float = 0.1
|
||||
) -> np.ndarray:
|
||||
base_row = np.arange(start, start + step * num_cols, step)
|
||||
array = np.vstack([base_row + i * row_offset for i in range(num_rows)])
|
||||
return array
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example joint trajectory data (time in seconds, joint positions as an array of NUM_JOINTS joint angles)
|
||||
times = np.array([0.0, 1.0, 2.0, 3.0, 4.0])
|
||||
joint_positions = generate_joint_positions(num_rows=5, num_cols=7, start=0.0, step=0.1, row_offset=0.1)
|
||||
interpolator = JointTrajectoryInterpolator(times, joint_positions)
|
||||
# Get joint positions at a specific time (e.g., t = 2.5 seconds)
|
||||
t = 0.1
|
||||
joint_pos_at_t = interpolator(t)
|
||||
print("Joint positions at time", t, ":", joint_pos_at_t)
|
||||
175
unitree_deploy/unitree_deploy/utils/rerun_visualizer.py
Normal file
175
unitree_deploy/unitree_deploy/utils/rerun_visualizer.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import rerun as rr
|
||||
import rerun.blueprint as rrb
|
||||
import torch
|
||||
|
||||
|
||||
class RerunLogger:
|
||||
"""
|
||||
A fully automatic Rerun logger designed to parse and visualize step
|
||||
dictionaries directly from a LeRobotDataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str = "",
|
||||
memory_limit: str = "200MB",
|
||||
idxrangeboundary: Optional[int] = 300,
|
||||
):
|
||||
"""Initializes the Rerun logger."""
|
||||
# Use a descriptive name for the Rerun recording
|
||||
rr.init(f"Dataset_Log_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
self.prefix = prefix
|
||||
self.blueprint_sent = False
|
||||
self.idxrangeboundary = idxrangeboundary
|
||||
|
||||
# --- Internal cache for discovered keys ---
|
||||
self._image_keys: Tuple[str, ...] = ()
|
||||
self._state_key: str = ""
|
||||
self._action_key: str = ""
|
||||
self._index_key: str = "index"
|
||||
self._task_key: str = "task"
|
||||
self._episode_index_key: str = "episode_index"
|
||||
|
||||
self.current_episode = -1
|
||||
|
||||
def _initialize_from_data(self, step_data: Dict[str, Any]):
|
||||
"""Inspects the first data dictionary to discover components and set up the blueprint."""
|
||||
print("RerunLogger: First data packet received. Auto-configuring...")
|
||||
|
||||
image_keys = []
|
||||
for key, value in step_data.items():
|
||||
if key.startswith("observation.images.") and isinstance(value, torch.Tensor) and value.ndim > 2:
|
||||
image_keys.append(key)
|
||||
elif key == "observation.state":
|
||||
self._state_key = key
|
||||
elif key == "action":
|
||||
self._action_key = key
|
||||
|
||||
self._image_keys = tuple(sorted(image_keys))
|
||||
|
||||
if "index" in step_data:
|
||||
self._index_key = "index"
|
||||
elif "frame_index" in step_data:
|
||||
self._index_key = "frame_index"
|
||||
|
||||
print(f" - Using '{self._index_key}' for time sequence.")
|
||||
print(f" - Detected State Key: '{self._state_key}'")
|
||||
print(f" - Detected Action Key: '{self._action_key}'")
|
||||
print(f" - Detected Image Keys: {self._image_keys}")
|
||||
if self.idxrangeboundary:
|
||||
self.setup_blueprint()
|
||||
|
||||
def setup_blueprint(self):
|
||||
"""Sets up and sends the Rerun blueprint based on detected components."""
|
||||
views = []
|
||||
|
||||
for key in self._image_keys:
|
||||
clean_name = key.replace("observation.images.", "")
|
||||
entity_path = f"{self.prefix}images/{clean_name}"
|
||||
views.append(rrb.Spatial2DView(origin=entity_path, name=clean_name))
|
||||
|
||||
if self._state_key:
|
||||
entity_path = f"{self.prefix}state"
|
||||
views.append(
|
||||
rrb.TimeSeriesView(
|
||||
origin=entity_path,
|
||||
name="Observation State",
|
||||
time_ranges=[
|
||||
rrb.VisibleTimeRange(
|
||||
"frame",
|
||||
start=rrb.TimeRangeBoundary.cursor_relative(seq=-self.idxrangeboundary),
|
||||
end=rrb.TimeRangeBoundary.cursor_relative(),
|
||||
)
|
||||
],
|
||||
plot_legend=rrb.PlotLegend(visible=True),
|
||||
)
|
||||
)
|
||||
|
||||
if self._action_key:
|
||||
entity_path = f"{self.prefix}action"
|
||||
views.append(
|
||||
rrb.TimeSeriesView(
|
||||
origin=entity_path,
|
||||
name="Action",
|
||||
time_ranges=[
|
||||
rrb.VisibleTimeRange(
|
||||
"frame",
|
||||
start=rrb.TimeRangeBoundary.cursor_relative(seq=-self.idxrangeboundary),
|
||||
end=rrb.TimeRangeBoundary.cursor_relative(),
|
||||
)
|
||||
],
|
||||
plot_legend=rrb.PlotLegend(visible=True),
|
||||
)
|
||||
)
|
||||
|
||||
if not views:
|
||||
print("Warning: No visualizable components detected in the data.")
|
||||
return
|
||||
|
||||
grid = rrb.Grid(contents=views)
|
||||
rr.send_blueprint(grid)
|
||||
self.blueprint_sent = True
|
||||
|
||||
def log_step(self, step_data: Dict[str, Any]):
|
||||
"""Logs a single step dictionary from your dataset."""
|
||||
if not self.blueprint_sent:
|
||||
self._initialize_from_data(step_data)
|
||||
|
||||
if self._index_key in step_data:
|
||||
current_index = step_data[self._index_key].item()
|
||||
rr.set_time_sequence("frame", current_index)
|
||||
|
||||
episode_idx = step_data.get(self._episode_index_key, torch.tensor(-1)).item()
|
||||
if episode_idx != self.current_episode:
|
||||
self.current_episode = episode_idx
|
||||
task_name = step_data.get(self._task_key, "Unknown Task")
|
||||
log_text = f"Starting Episode {self.current_episode}: {task_name}"
|
||||
rr.log(f"{self.prefix}info/task", rr.TextLog(log_text, level=rr.TextLogLevel.INFO))
|
||||
|
||||
for key in self._image_keys:
|
||||
if key in step_data:
|
||||
image_tensor = step_data[key]
|
||||
if image_tensor.ndim > 2:
|
||||
clean_name = key.replace("observation.images.", "")
|
||||
entity_path = f"{self.prefix}images/{clean_name}"
|
||||
if image_tensor.shape[0] in [1, 3, 4]:
|
||||
image_tensor = image_tensor.permute(1, 2, 0)
|
||||
rr.log(entity_path, rr.Image(image_tensor))
|
||||
|
||||
if self._state_key in step_data:
|
||||
state_tensor = step_data[self._state_key]
|
||||
entity_path = f"{self.prefix}state"
|
||||
for i, val in enumerate(state_tensor):
|
||||
rr.log(f"{entity_path}/joint_{i}", rr.Scalar(val.item()))
|
||||
|
||||
if self._action_key in step_data:
|
||||
action_tensor = step_data[self._action_key]
|
||||
entity_path = f"{self.prefix}action"
|
||||
for i, val in enumerate(action_tensor):
|
||||
rr.log(f"{entity_path}/joint_{i}", rr.Scalar(val.item()))
|
||||
|
||||
|
||||
def visualization_data(idx, observation, state, action, online_logger):
|
||||
item_data: Dict[str, Any] = {
|
||||
"index": torch.tensor(idx),
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
}
|
||||
for k, v in observation.items():
|
||||
if k not in ("index", "observation.state", "action"):
|
||||
item_data[k] = v
|
||||
# print(item_data)
|
||||
online_logger.log_step(item_data)
|
||||
|
||||
|
||||
def flatten_images(obs: dict) -> dict:
|
||||
flat = {}
|
||||
if "images" in obs:
|
||||
for k, v in obs["images"].items():
|
||||
flat[f"observation.images.{k}"] = torch.from_numpy(v)
|
||||
return flat
|
||||
180
unitree_deploy/unitree_deploy/utils/rich_logger.py
Normal file
180
unitree_deploy/unitree_deploy/utils/rich_logger.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import time
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
)
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
class RichLogger:
|
||||
def __init__(self, level: str = "INFO"):
|
||||
# Initialize the console for rich output
|
||||
self.console = Console()
|
||||
|
||||
# Define log levels with corresponding priority
|
||||
self.levels = {
|
||||
"DEBUG": 0, # Lowest level, all logs are displayed
|
||||
"INFO": 1, # Standard level, displays Info and higher
|
||||
"SUCCESS": 2, # Displays success and higher priority logs
|
||||
"WARNING": 3, # Displays warnings and errors
|
||||
"ERROR": 4, # Highest level, only errors are shown
|
||||
}
|
||||
|
||||
# Set default log level, use INFO if the level is invalid
|
||||
self.level = self.levels.get(level.upper(), 1)
|
||||
|
||||
def _log(self, level: str, message: str, style: str, emoji=None):
|
||||
# Check if the current log level allows this message to be printed
|
||||
if self.levels[level] < self.levels["INFO"]:
|
||||
return
|
||||
|
||||
# Format the timestamp
|
||||
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Create a styled message
|
||||
text = Text(f"[{timestamp}] [{level}] {message}", style=style)
|
||||
|
||||
# Print the message to the console
|
||||
self.console.print(text)
|
||||
|
||||
def _log(self, level: str, message: str, style: str, emoji: str = None):
|
||||
# Check if the current log level allows this message to be printed
|
||||
if self.levels[level] < self.levels["INFO"]:
|
||||
return
|
||||
|
||||
# Format the timestamp
|
||||
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# If emoji is provided, prepend it to the message
|
||||
if emoji:
|
||||
message = f"{emoji} {message}"
|
||||
|
||||
# Create a styled message
|
||||
text = Text(f"[{timestamp}] [{level}] {message}", style=style)
|
||||
|
||||
# Print the message to the console
|
||||
self.console.print(text)
|
||||
|
||||
# Basic log methods
|
||||
def info(self, message: str, emoji: str | None = None):
|
||||
# If the level is INFO or higher, print info log
|
||||
if self.levels["INFO"] >= self.level:
|
||||
self._log("INFO", message, "bold cyan", emoji)
|
||||
|
||||
def warning(self, message: str, emoji: str = "⚠️"):
|
||||
# If the level is WARNING or higher, print warning log
|
||||
if self.levels["WARNING"] >= self.level:
|
||||
self._log("WARNING", message, "bold yellow", emoji)
|
||||
|
||||
def error(self, message: str, emoji: str = "❌"):
|
||||
# If the level is ERROR or higher, print error log
|
||||
if self.levels["ERROR"] >= self.level:
|
||||
self._log("ERROR", message, "bold red", emoji)
|
||||
|
||||
def success(self, message: str, emoji: str = "🚀"):
|
||||
# If the level is SUCCESS or higher, print success log
|
||||
if self.levels["SUCCESS"] >= self.level:
|
||||
self._log("SUCCESS", message, "bold green", emoji)
|
||||
|
||||
def debug(self, message: str, emoji: str = "🔍"):
|
||||
# If the level is DEBUG or higher, print debug log
|
||||
if self.levels["DEBUG"] >= self.level:
|
||||
self._log("DEBUG", message, "dim", emoji)
|
||||
|
||||
# ========== Extended Features ==========
|
||||
# Display a message with an emoji
|
||||
def emoji(self, message: str, emoji: str = "🚀"):
|
||||
self.console.print(f"{emoji} {message}", style="bold magenta")
|
||||
|
||||
# Show a loading animation for a certain period
|
||||
def loading(self, message: str, seconds: float = 2.0):
|
||||
# Display a loading message with a spinner animation
|
||||
with self.console.status(f"[bold blue]{message}...", spinner="dots"):
|
||||
time.sleep(seconds)
|
||||
|
||||
# Show a progress bar for small tasks
|
||||
def progress(self, task_description: str, total: int = 100, speed: float = 0.02):
|
||||
# Create and display a progress bar with time elapsed
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
BarColumn(bar_width=None),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeElapsedColumn(),
|
||||
console=self.console,
|
||||
) as progress:
|
||||
# Add a task to the progress bar
|
||||
task = progress.add_task(f"[cyan]{task_description}", total=total)
|
||||
while not progress.finished:
|
||||
progress.update(task, advance=1)
|
||||
time.sleep(speed)
|
||||
|
||||
|
||||
# ========== Singleton Logger Instance ==========
|
||||
_logger = RichLogger()
|
||||
|
||||
|
||||
# ========== Function-style API ==========
|
||||
def log_info(message: str, emoji: str | None = None):
|
||||
_logger.info(message=message, emoji=emoji)
|
||||
|
||||
|
||||
def log_success(message: str, emoji: str = "🚀"):
|
||||
_logger.success(message=message, emoji=emoji)
|
||||
|
||||
|
||||
def log_warning(message: str, emoji: str = "⚠️"):
|
||||
_logger.warning(message=message, emoji=emoji)
|
||||
|
||||
|
||||
def log_error(message: str, emoji: str = "❌"):
|
||||
_logger.error(message=message, emoji=emoji)
|
||||
|
||||
|
||||
def log_debug(message: str, emoji: str = "🔍"):
|
||||
_logger.debug(message=message, emoji=emoji)
|
||||
|
||||
|
||||
def log_emoji(message: str, emoji: str = "🚀"):
|
||||
_logger.emoji(message, emoji)
|
||||
|
||||
|
||||
def log_loading(message: str, seconds: float = 2.0):
|
||||
_logger.loading(message, seconds)
|
||||
|
||||
|
||||
def log_progress(task_description: str, total: int = 100, speed: float = 0.02):
|
||||
_logger.progress(task_description, total, speed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage:
|
||||
# Initialize logger instance
|
||||
logger = RichLogger(level="INFO") # Set initial log level to INFO
|
||||
|
||||
# Log at different levels
|
||||
logger.info("System initialization complete.")
|
||||
logger.success("Robot started successfully!")
|
||||
logger.warning("Warning: Joint temperature high!")
|
||||
logger.error("Error: Failed to connect to robot")
|
||||
logger.debug("Debug: Initializing motor controllers")
|
||||
|
||||
# Display an emoji message
|
||||
logger.emoji("This is a fun message with an emoji!", emoji="🔥")
|
||||
|
||||
# Display loading animation for 3 seconds
|
||||
logger.loading("Loading motor control data...", seconds=3)
|
||||
|
||||
# Show progress bar for a task with 100 steps
|
||||
logger.progress("Processing task", total=100, speed=0.05)
|
||||
|
||||
# You can also use different log levels with a higher level than INFO, like ERROR:
|
||||
logger = RichLogger(level="ERROR")
|
||||
|
||||
# Only error and higher priority logs will be shown (INFO, SUCCESS, WARNING will be hidden)
|
||||
logger.info("This won't be displayed because the level is set to ERROR")
|
||||
logger.error("This error will be displayed!")
|
||||
171
unitree_deploy/unitree_deploy/utils/run_simulation.py
Normal file
171
unitree_deploy/unitree_deploy/utils/run_simulation.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Process, Queue
|
||||
from queue import Empty
|
||||
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
import numpy as np
|
||||
|
||||
from unitree_deploy.utils.rich_logger import log_info, log_success
|
||||
|
||||
|
||||
@dataclass
|
||||
class MujocoSimulationConfig:
|
||||
xml_path: str
|
||||
dof: int
|
||||
robot_type: str
|
||||
ctr_dof: int
|
||||
stop_dof: int
|
||||
|
||||
|
||||
def get_mujoco_sim_config(robot_type: str) -> MujocoSimulationConfig:
|
||||
if robot_type == "g1":
|
||||
return MujocoSimulationConfig(
|
||||
xml_path="unitree_deploy/robot_devices/assets/g1/g1_body29.xml",
|
||||
dof=30,
|
||||
robot_type="g1",
|
||||
ctr_dof=14,
|
||||
stop_dof=35,
|
||||
)
|
||||
elif robot_type == "z1":
|
||||
return MujocoSimulationConfig(
|
||||
xml_path="unitree_deploy/robot_devices/assets/z1/z1.xml",
|
||||
dof=6,
|
||||
robot_type="z1",
|
||||
ctr_dof=6,
|
||||
stop_dof=6,
|
||||
)
|
||||
elif robot_type == "h1_2":
|
||||
return MujocoSimulationConfig(
|
||||
xml_path="unitree_deploy/robot_devices/assets/z1/z1.urdf",
|
||||
dof=30,
|
||||
robot_type="g1",
|
||||
ctr_dof=14,
|
||||
stop_dof=35,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported robot_type: {robot_type}")
|
||||
|
||||
|
||||
class MujicoSimulation:
|
||||
def __init__(self, config: MujocoSimulationConfig):
|
||||
self.xml_path = config.xml_path
|
||||
|
||||
self.robot_type = config.robot_type
|
||||
|
||||
self.dof = config.dof
|
||||
self.ctr_dof = config.ctr_dof
|
||||
self.stop_dof = config.stop_dof
|
||||
|
||||
self.action_queue = Queue()
|
||||
self.state_queue = Queue()
|
||||
self.process = Process(target=self._run_simulation, args=(self.xml_path, self.action_queue, self.state_queue))
|
||||
self.process.daemon = True
|
||||
self.process.start()
|
||||
|
||||
def set_positions(self, joint_positions: np.ndarray):
|
||||
if joint_positions.shape[0] != self.ctr_dof:
|
||||
raise ValueError(f"joint_positions must contain {self.ctr_dof} values!")
|
||||
|
||||
if self.robot_type == "g1":
|
||||
joint_positions = np.concatenate([np.zeros(self.dof - self.ctr_dof, dtype=np.float32), joint_positions])
|
||||
elif self.robot_type == "z1":
|
||||
pass
|
||||
elif self.robot_type == "h1_2":
|
||||
joint_positions[: self.dof - self.ctr_dof] = 0.0
|
||||
else:
|
||||
raise ValueError(f"Unsupported robot_type: {self.robot_type}")
|
||||
|
||||
self.action_queue.put(joint_positions.tolist())
|
||||
|
||||
def get_current_positions(self, timeout=0.01):
|
||||
try:
|
||||
return self.state_queue.get(timeout=timeout)
|
||||
except Empty:
|
||||
return [0.0] * self.stop_dof
|
||||
|
||||
def stop(self):
|
||||
if hasattr(self, "process") and self.process is not None and self.process.is_alive():
|
||||
try:
|
||||
self.process.terminate()
|
||||
self.process.join()
|
||||
except Exception as e:
|
||||
print(f"[WARN] Failed to stop process: {e}")
|
||||
self.process = None
|
||||
|
||||
for qname in ["action_queue", "state_queue"]:
|
||||
queue = getattr(self, qname, None)
|
||||
if queue is not None:
|
||||
try:
|
||||
if hasattr(queue, "close") and callable(queue.close):
|
||||
queue.close()
|
||||
if hasattr(queue, "join_thread") and callable(queue.join_thread):
|
||||
queue.join_thread()
|
||||
except Exception as e:
|
||||
print(f"[WARN] Failed to cleanup {qname}: {e}")
|
||||
setattr(self, qname, None)
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
|
||||
@staticmethod
|
||||
def _run_simulation(xml_path: str, action_queue: Queue, state_queue: Queue):
|
||||
model = mujoco.MjModel.from_xml_path(xml_path)
|
||||
data = mujoco.MjData(model)
|
||||
|
||||
joint_names = [mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_JOINT, i) for i in range(model.njnt)]
|
||||
joints_indices = [
|
||||
model.jnt_qposadr[mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, name)] for name in joint_names
|
||||
]
|
||||
log_info(f"len joints indices: {len(joints_indices)}")
|
||||
|
||||
viewer = mujoco.viewer.launch_passive(model, data)
|
||||
|
||||
current_positions = np.zeros(len(joints_indices), dtype=np.float32)
|
||||
try:
|
||||
while viewer.is_running():
|
||||
try:
|
||||
new_pos = action_queue.get_nowait()
|
||||
if len(new_pos) == len(joints_indices):
|
||||
current_positions = new_pos
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
for idx, pos in zip(joints_indices, current_positions, strict=True):
|
||||
data.qpos[idx] = pos
|
||||
|
||||
data.qvel[:] = 0
|
||||
mujoco.mj_forward(model, data)
|
||||
|
||||
state_queue.put(data.qpos.copy())
|
||||
|
||||
viewer.sync()
|
||||
time.sleep(0.001)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
log_success("The simulation process was interrupted.")
|
||||
finally:
|
||||
viewer.close()
|
||||
|
||||
|
||||
def main():
|
||||
config = get_mujoco_sim_config(robot_type="g1")
|
||||
sim = MujicoSimulation(config)
|
||||
time.sleep(1) # Allow time for the simulation to start
|
||||
try:
|
||||
while True:
|
||||
positions = np.random.uniform(-1.0, 1.0, sim.ctr_dof)
|
||||
|
||||
sim.set_positions(positions)
|
||||
|
||||
# print(sim.get_current_positions())
|
||||
|
||||
time.sleep(1 / 50)
|
||||
except KeyboardInterrupt:
|
||||
print("Simulation stopped.")
|
||||
sim.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
36
unitree_deploy/unitree_deploy/utils/trajectory_generator.py
Normal file
36
unitree_deploy/unitree_deploy/utils/trajectory_generator.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pinocchio as pin
|
||||
|
||||
|
||||
def generate_rotation(step: int, rotation_speed: float, max_step: int = 240):
|
||||
"""Generate rotation (quaternions) and translation deltas for left and right arm motions."""
|
||||
angle = rotation_speed * step if step <= max_step // 2 else rotation_speed * (max_step - step)
|
||||
|
||||
# Create rotation quaternion for left arm (around Y-axis)
|
||||
l_quat = pin.Quaternion(np.cos(angle / 2), 0, np.sin(angle / 2), 0)
|
||||
|
||||
# Create rotation quaternion for right arm (around Z-axis)
|
||||
r_quat = pin.Quaternion(np.cos(angle / 2), 0, 0, np.sin(angle / 2))
|
||||
|
||||
# Define translation increments for left and right arm
|
||||
delta_l = np.array([0.001, 0.001, 0.001]) * 1.2
|
||||
delta_r = np.array([0.001, -0.001, 0.001]) * 1.2
|
||||
|
||||
# Reverse direction in second half of cycle
|
||||
if step > max_step // 2:
|
||||
delta_l *= -1
|
||||
delta_r *= -1
|
||||
|
||||
return l_quat, r_quat, delta_l, delta_r
|
||||
|
||||
|
||||
def sinusoidal_single_gripper_motion(period: float, amplitude: float, current_time: float) -> np.ndarray:
|
||||
value = amplitude * (math.sin(2 * math.pi * current_time / period) + 1) / 2
|
||||
return np.array([value*5])
|
||||
|
||||
|
||||
def sinusoidal_gripper_motion(period: float, amplitude: float, current_time: float) -> np.ndarray:
|
||||
value = amplitude * (math.sin(2 * math.pi * current_time / period) + 1) / 2
|
||||
return np.array([value]*5)
|
||||
@@ -0,0 +1,40 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class WeightedMovingFilter:
|
||||
def __init__(self, weights, data_size=14):
|
||||
self._window_size = len(weights)
|
||||
self._weights = np.array(weights)
|
||||
assert np.isclose(np.sum(self._weights), 1.0), (
|
||||
"[WeightedMovingFilter] the sum of weights list must be 1.0!"
|
||||
)
|
||||
self._data_size = data_size
|
||||
self._filtered_data = np.zeros(self._data_size)
|
||||
self._data_queue = []
|
||||
|
||||
def _apply_filter(self):
|
||||
if len(self._data_queue) < self._window_size:
|
||||
return self._data_queue[-1]
|
||||
|
||||
data_array = np.array(self._data_queue)
|
||||
temp_filtered_data = np.zeros(self._data_size)
|
||||
for i in range(self._data_size):
|
||||
temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1]
|
||||
|
||||
return temp_filtered_data
|
||||
|
||||
def add_data(self, new_data):
|
||||
assert len(new_data) == self._data_size
|
||||
|
||||
if len(self._data_queue) > 0 and np.array_equal(new_data, self._data_queue[-1]):
|
||||
return # skip duplicate data
|
||||
|
||||
if len(self._data_queue) >= self._window_size:
|
||||
self._data_queue.pop(0)
|
||||
|
||||
self._data_queue.append(new_data)
|
||||
self._filtered_data = self._apply_filter()
|
||||
|
||||
@property
|
||||
def filtered_data(self):
|
||||
return self._filtered_data
|
||||
Reference in New Issue
Block a user