199 lines
6.8 KiB
Python
199 lines
6.8 KiB
Python
import argparse
|
|
import os
|
|
import time
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import tqdm
|
|
|
|
from typing import Any, Deque, MutableMapping, OrderedDict
|
|
from collections import deque
|
|
from pathlib import Path
|
|
|
|
from unitree_deploy.real_unitree_env import make_real_env
|
|
from unitree_deploy.utils.eval_utils import (
|
|
ACTTemporalEnsembler,
|
|
LongConnectionClient,
|
|
populate_queues,
|
|
)
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Network & environment defaults
|
|
# -----------------------------------------------------------------------------
|
|
os.environ["http_proxy"] = ""
|
|
os.environ["https_proxy"] = ""
|
|
HOST = "127.0.0.1"
|
|
PORT = 8000
|
|
BASE_URL = f"http://{HOST}:{PORT}"
|
|
|
|
# fmt: off
|
|
INIT_POSE = {
|
|
'g1_dex1': np.array([0.10559805, 0.02726714, -0.01210221, -0.33341318, -0.22513399, -0.02627627, -0.15437093, 0.1273793 , -0.1674708 , -0.11544029, -0.40095493, 0.44332668, 0.11566751, 0.3936641, 5.4, 5.4], dtype=np.float32),
|
|
'z1_dual_dex1_realsense': np.array([-1.0262332, 1.4281361, -1.2149128, 0.6473399, -0.12425245, 0.44945636, 0.89584476, 1.2593982, -1.0737865, 0.6672816, 0.39730102, -0.47400007, 0.9894176, 0.9817477 ], dtype=np.float32),
|
|
'z1_realsense': np.array([-0.06940782, 1.4751548, -0.7554075, 1.0501366, 0.02931615, -0.02810347, -0.99238837], dtype=np.float32),
|
|
}
|
|
ZERO_ACTION = {
|
|
'g1_dex1': torch.zeros(16, dtype=torch.float32),
|
|
'z1_dual_dex1_realsense': torch.zeros(14, dtype=torch.float32),
|
|
'z1_realsense': torch.zeros(7, dtype=torch.float32),
|
|
}
|
|
CAM_KEY = {
|
|
'g1_dex1': 'cam_right_high',
|
|
'z1_dual_dex1_realsense': 'cam_high',
|
|
'z1_realsense': 'cam_high',
|
|
}
|
|
# fmt: on
|
|
|
|
|
|
def prepare_observation(args: argparse.Namespace, obs: Any) -> OrderedDict:
|
|
"""
|
|
Convert a raw env observation into the model's expected input dict.
|
|
"""
|
|
rgb_image = cv2.cvtColor(
|
|
obs.observation["images"][CAM_KEY[args.robot_type]], cv2.COLOR_BGR2RGB)
|
|
observation = {
|
|
"observation.images.top":
|
|
torch.from_numpy(rgb_image).permute(2, 0, 1),
|
|
"observation.state":
|
|
torch.from_numpy(obs.observation["qpos"]),
|
|
"action": ZERO_ACTION[args.robot_type],
|
|
}
|
|
return OrderedDict(observation)
|
|
|
|
|
|
def run_policy(
|
|
args: argparse.Namespace,
|
|
env: Any,
|
|
client: LongConnectionClient,
|
|
temporal_ensembler: ACTTemporalEnsembler,
|
|
cond_obs_queues: MutableMapping[str, Deque[torch.Tensor]],
|
|
output_dir: Path,
|
|
) -> None:
|
|
"""
|
|
Single rollout loop:
|
|
1) warm start the robot,
|
|
2) stream observations,
|
|
3) fetch actions from the policy server,
|
|
4) execute with temporal ensembling for smoother control.
|
|
"""
|
|
|
|
_ = env.step(INIT_POSE[args.robot_type])
|
|
time.sleep(2.0)
|
|
t = 0
|
|
|
|
while True:
|
|
# Gapture observation
|
|
obs = env.get_observation(t)
|
|
# Format observation
|
|
obs = prepare_observation(args, obs)
|
|
cond_obs_queues = populate_queues(cond_obs_queues, obs)
|
|
# Call server to get actions
|
|
pred_actions = client.predict_action(args.language_instruction,
|
|
cond_obs_queues).unsqueeze(0)
|
|
# Keep only the next horizon of actions and apply temporal ensemble smoothing
|
|
actions = temporal_ensembler.update(
|
|
pred_actions[:, :args.action_horizon])[0]
|
|
|
|
# Execute the actions
|
|
for n in range(args.exe_steps):
|
|
action = actions[n].cpu().numpy()
|
|
print(f">>> Exec => step {n} action: {action}", flush=True)
|
|
print("---------------------------------------------")
|
|
|
|
# Maintain real-time loop at `control_freq` Hz
|
|
t1 = time.time()
|
|
obs = env.step(action)
|
|
time.sleep(max(0, 1 / args.control_freq - time.time() + t1))
|
|
t += 1
|
|
|
|
# Prime the queue for the next action step (except after the last one in this chunk)
|
|
if n < args.exe_steps - 1:
|
|
obs = prepare_observation(args, obs)
|
|
cond_obs_queues = populate_queues(cond_obs_queues, obs)
|
|
|
|
|
|
def run_eval(args: argparse.Namespace) -> None:
|
|
client = LongConnectionClient(BASE_URL)
|
|
|
|
# Initialize ACT temporal moving-averge smoother
|
|
temporal_ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff=0.01,
|
|
chunk_size=args.action_horizon,
|
|
exe_steps=args.exe_steps)
|
|
temporal_ensembler.reset()
|
|
|
|
# Initialize observation and action horizon queue
|
|
cond_obs_queues = {
|
|
"observation.images.top": deque(maxlen=args.observation_horizon),
|
|
"observation.state": deque(maxlen=args.observation_horizon),
|
|
"action": deque(
|
|
maxlen=16), # NOTE: HAND CODE AS THE MODEL PREDCIT FUTURE 16 STEPS
|
|
}
|
|
|
|
env = make_real_env(
|
|
robot_type=args.robot_type,
|
|
dt=1 / args.control_freq,
|
|
)
|
|
env.connect()
|
|
|
|
try:
|
|
for episode_idx in tqdm.tqdm(range(0, args.num_rollouts_planned)):
|
|
output_dir = Path(args.output_dir) / f"episode_{episode_idx:03d}"
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
run_policy(args, env, client, temporal_ensembler, cond_obs_queues,
|
|
output_dir)
|
|
finally:
|
|
env.close()
|
|
env.close()
|
|
|
|
|
|
def get_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--robot_type",
|
|
type=str,
|
|
default="g1_dex1",
|
|
help="The type of the robot embodiment.")
|
|
parser.add_argument(
|
|
"--action_horizon",
|
|
type=int,
|
|
default=16,
|
|
help="Number of future actions, predicted by the policy, to keep",
|
|
)
|
|
parser.add_argument(
|
|
"--exe_steps",
|
|
type=int,
|
|
default=16,
|
|
help=
|
|
"Number of future actions to execute, which must be less than the above action horizon.",
|
|
)
|
|
parser.add_argument(
|
|
"--observation_horizon",
|
|
type=int,
|
|
default=2,
|
|
help="Number of most recent frames/states to consider.",
|
|
)
|
|
parser.add_argument(
|
|
"--language_instruction",
|
|
type=str,
|
|
default="Pack black camera into box",
|
|
help="The language instruction provided to the policy server.",
|
|
)
|
|
parser.add_argument("--num_rollouts_planned",
|
|
type=int,
|
|
default=10,
|
|
help="The number of rollouts to run.")
|
|
parser.add_argument("--output_dir",
|
|
type=str,
|
|
default="./results",
|
|
help="The directory for saving results.")
|
|
parser.add_argument("--control_freq",
|
|
type=float,
|
|
default=30,
|
|
help="The Low-level control frequency in Hz.")
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
run_eval(args)
|