第一次完整测例跑完
This commit is contained in:
198
unitree_deploy/scripts/robot_client.py
Normal file
198
unitree_deploy/scripts/robot_client.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from typing import Any, Deque, MutableMapping, OrderedDict
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
from unitree_deploy.real_unitree_env import make_real_env
|
||||
from unitree_deploy.utils.eval_utils import (
|
||||
ACTTemporalEnsembler,
|
||||
LongConnectionClient,
|
||||
populate_queues,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Network & environment defaults
|
||||
# -----------------------------------------------------------------------------
|
||||
os.environ["http_proxy"] = ""
|
||||
os.environ["https_proxy"] = ""
|
||||
HOST = "127.0.0.1"
|
||||
PORT = 8000
|
||||
BASE_URL = f"http://{HOST}:{PORT}"
|
||||
|
||||
# fmt: off
|
||||
INIT_POSE = {
|
||||
'g1_dex1': np.array([0.10559805, 0.02726714, -0.01210221, -0.33341318, -0.22513399, -0.02627627, -0.15437093, 0.1273793 , -0.1674708 , -0.11544029, -0.40095493, 0.44332668, 0.11566751, 0.3936641, 5.4, 5.4], dtype=np.float32),
|
||||
'z1_dual_dex1_realsense': np.array([-1.0262332, 1.4281361, -1.2149128, 0.6473399, -0.12425245, 0.44945636, 0.89584476, 1.2593982, -1.0737865, 0.6672816, 0.39730102, -0.47400007, 0.9894176, 0.9817477 ], dtype=np.float32),
|
||||
'z1_realsense': np.array([-0.06940782, 1.4751548, -0.7554075, 1.0501366, 0.02931615, -0.02810347, -0.99238837], dtype=np.float32),
|
||||
}
|
||||
ZERO_ACTION = {
|
||||
'g1_dex1': torch.zeros(16, dtype=torch.float32),
|
||||
'z1_dual_dex1_realsense': torch.zeros(14, dtype=torch.float32),
|
||||
'z1_realsense': torch.zeros(7, dtype=torch.float32),
|
||||
}
|
||||
CAM_KEY = {
|
||||
'g1_dex1': 'cam_right_high',
|
||||
'z1_dual_dex1_realsense': 'cam_high',
|
||||
'z1_realsense': 'cam_high',
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
||||
def prepare_observation(args: argparse.Namespace, obs: Any) -> OrderedDict:
|
||||
"""
|
||||
Convert a raw env observation into the model's expected input dict.
|
||||
"""
|
||||
rgb_image = cv2.cvtColor(
|
||||
obs.observation["images"][CAM_KEY[args.robot_type]], cv2.COLOR_BGR2RGB)
|
||||
observation = {
|
||||
"observation.images.top":
|
||||
torch.from_numpy(rgb_image).permute(2, 0, 1),
|
||||
"observation.state":
|
||||
torch.from_numpy(obs.observation["qpos"]),
|
||||
"action": ZERO_ACTION[args.robot_type],
|
||||
}
|
||||
return OrderedDict(observation)
|
||||
|
||||
|
||||
def run_policy(
|
||||
args: argparse.Namespace,
|
||||
env: Any,
|
||||
client: LongConnectionClient,
|
||||
temporal_ensembler: ACTTemporalEnsembler,
|
||||
cond_obs_queues: MutableMapping[str, Deque[torch.Tensor]],
|
||||
output_dir: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Single rollout loop:
|
||||
1) warm start the robot,
|
||||
2) stream observations,
|
||||
3) fetch actions from the policy server,
|
||||
4) execute with temporal ensembling for smoother control.
|
||||
"""
|
||||
|
||||
_ = env.step(INIT_POSE[args.robot_type])
|
||||
time.sleep(2.0)
|
||||
t = 0
|
||||
|
||||
while True:
|
||||
# Gapture observation
|
||||
obs = env.get_observation(t)
|
||||
# Format observation
|
||||
obs = prepare_observation(args, obs)
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, obs)
|
||||
# Call server to get actions
|
||||
pred_actions = client.predict_action(args.language_instruction,
|
||||
cond_obs_queues).unsqueeze(0)
|
||||
# Keep only the next horizon of actions and apply temporal ensemble smoothing
|
||||
actions = temporal_ensembler.update(
|
||||
pred_actions[:, :args.action_horizon])[0]
|
||||
|
||||
# Execute the actions
|
||||
for n in range(args.exe_steps):
|
||||
action = actions[n].cpu().numpy()
|
||||
print(f">>> Exec => step {n} action: {action}", flush=True)
|
||||
print("---------------------------------------------")
|
||||
|
||||
# Maintain real-time loop at `control_freq` Hz
|
||||
t1 = time.time()
|
||||
obs = env.step(action)
|
||||
time.sleep(max(0, 1 / args.control_freq - time.time() + t1))
|
||||
t += 1
|
||||
|
||||
# Prime the queue for the next action step (except after the last one in this chunk)
|
||||
if n < args.exe_steps - 1:
|
||||
obs = prepare_observation(args, obs)
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, obs)
|
||||
|
||||
|
||||
def run_eval(args: argparse.Namespace) -> None:
|
||||
client = LongConnectionClient(BASE_URL)
|
||||
|
||||
# Initialize ACT temporal moving-averge smoother
|
||||
temporal_ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff=0.01,
|
||||
chunk_size=args.action_horizon,
|
||||
exe_steps=args.exe_steps)
|
||||
temporal_ensembler.reset()
|
||||
|
||||
# Initialize observation and action horizon queue
|
||||
cond_obs_queues = {
|
||||
"observation.images.top": deque(maxlen=args.observation_horizon),
|
||||
"observation.state": deque(maxlen=args.observation_horizon),
|
||||
"action": deque(
|
||||
maxlen=16), # NOTE: HAND CODE AS THE MODEL PREDCIT FUTURE 16 STEPS
|
||||
}
|
||||
|
||||
env = make_real_env(
|
||||
robot_type=args.robot_type,
|
||||
dt=1 / args.control_freq,
|
||||
)
|
||||
env.connect()
|
||||
|
||||
try:
|
||||
for episode_idx in tqdm.tqdm(range(0, args.num_rollouts_planned)):
|
||||
output_dir = Path(args.output_dir) / f"episode_{episode_idx:03d}"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
run_policy(args, env, client, temporal_ensembler, cond_obs_queues,
|
||||
output_dir)
|
||||
finally:
|
||||
env.close()
|
||||
env.close()
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--robot_type",
|
||||
type=str,
|
||||
default="g1_dex1",
|
||||
help="The type of the robot embodiment.")
|
||||
parser.add_argument(
|
||||
"--action_horizon",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of future actions, predicted by the policy, to keep",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exe_steps",
|
||||
type=int,
|
||||
default=16,
|
||||
help=
|
||||
"Number of future actions to execute, which must be less than the above action horizon.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_horizon",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of most recent frames/states to consider.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language_instruction",
|
||||
type=str,
|
||||
default="Pack black camera into box",
|
||||
help="The language instruction provided to the policy server.",
|
||||
)
|
||||
parser.add_argument("--num_rollouts_planned",
|
||||
type=int,
|
||||
default=10,
|
||||
help="The number of rollouts to run.")
|
||||
parser.add_argument("--output_dir",
|
||||
type=str,
|
||||
default="./results",
|
||||
help="The directory for saving results.")
|
||||
parser.add_argument("--control_freq",
|
||||
type=float,
|
||||
default=30,
|
||||
help="The Low-level control frequency in Hz.")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
run_eval(args)
|
||||
Reference in New Issue
Block a user