old:get_nearby_obstacles_in_session; new:get_nearby_obstacles
This commit is contained in:
@@ -8,6 +8,100 @@ from uav_api_client import UAVAPIClient
|
||||
import math
|
||||
import heapq
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
class TargetInfo:
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
self.id: str = data.get("id")
|
||||
self.name: str = data.get("name")
|
||||
self.type: str = data.get("type")
|
||||
self.position: Dict[str, float] = data.get("position")
|
||||
self.description: str = data.get("description")
|
||||
self.velocity: Optional[Dict[str, float]] = data.get("velocity")
|
||||
self.radius: Optional[float] = data.get("radius")
|
||||
self.created_at: float = data.get("created_at")
|
||||
self.last_updated: float = data.get("last_updated")
|
||||
self.moving_path: Optional[List[Dict[str, float]]] = data.get("moving_path")
|
||||
self.moving_duration: Optional[float] = data.get("moving_duration")
|
||||
self.current_path_index: Optional[int] = data.get("current_path_index")
|
||||
self.path_direction: Optional[int] = data.get("path_direction")
|
||||
self.time_in_direction: Optional[float] = data.get("time_in_direction")
|
||||
self.calculated_speed: Optional[float] = data.get("calculated_speed")
|
||||
self.charge_amount: Optional[float] = data.get("charge_amount")
|
||||
self.vertices: Optional[List[Dict[str, float]]] = data.get("vertices")
|
||||
self.is_reached: bool = data.get("is_reached")
|
||||
self.reached_by: List[str] = data.get("reached_by")
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TargetInfo) and self.id == other.id
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"type": self.type,
|
||||
"position": self.position,
|
||||
"description": self.description,
|
||||
"velocity": self.velocity,
|
||||
"radius": self.radius,
|
||||
"created_at": self.created_at,
|
||||
"last_updated": self.last_updated,
|
||||
"moving_path": self.moving_path,
|
||||
"moving_duration": self.moving_duration,
|
||||
"current_path_index": self.current_path_index,
|
||||
"path_direction": self.path_direction,
|
||||
"time_in_direction": self.time_in_direction,
|
||||
"calculated_speed": self.calculated_speed,
|
||||
"charge_amount": self.charge_amount,
|
||||
"vertices": self.vertices,
|
||||
"is_reached": self.is_reached,
|
||||
"reached_by": self.reached_by,
|
||||
}
|
||||
|
||||
class ObstacleInfo:
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
self.id: str = data.get("id")
|
||||
self.name: str = data.get("name")
|
||||
self.type: str = data.get("type")
|
||||
self.position: Dict[str, float] = data.get("position")
|
||||
self.description: str = data.get("description")
|
||||
self.radius: Optional[float] = data.get("radius")
|
||||
self.vertices: Optional[List[Dict[str, float]]] = data.get("vertices")
|
||||
self.width: Optional[float] = data.get("width")
|
||||
self.length: Optional[float] = data.get("length")
|
||||
self.height: Optional[float] = data.get("height")
|
||||
self.area: Optional[float] = data.get("area")
|
||||
self.created_at: float = data.get("created_at")
|
||||
self.last_updated: float = data.get("last_updated")
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, ObstacleInfo) and self.id == other.id
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"type": self.type,
|
||||
"position": self.position,
|
||||
"description": self.description,
|
||||
"radius": self.radius,
|
||||
"vertices": self.vertices,
|
||||
"width": self.width,
|
||||
"length": self.length,
|
||||
"height": self.height,
|
||||
"area": self.area,
|
||||
"created_at": self.created_at,
|
||||
"last_updated": self.last_updated,
|
||||
}
|
||||
|
||||
targets_expolred : set[TargetInfo] = set()
|
||||
obstacles_detected : set[ObstacleInfo] = set()
|
||||
|
||||
# --- 内部几何算法类 ---
|
||||
class GeometryUtils:
|
||||
@@ -216,7 +310,7 @@ def create_uav_tools(client: UAVAPIClient) -> list:
|
||||
return f"Error getting weather: {str(e)}"
|
||||
|
||||
@tool
|
||||
def get_targets() -> str:
|
||||
def get_targets_in_session() -> str:
|
||||
"""Get all targets in the session including fixed, moving, waypoint, circle and polygon to search or patrol.
|
||||
Use this to see what targets you need to visit.
|
||||
|
||||
@@ -226,7 +320,21 @@ def create_uav_tools(client: UAVAPIClient) -> list:
|
||||
return json.dumps(targets, indent=2)
|
||||
except Exception as e:
|
||||
return f"Error getting targets: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def get_targets() -> str:
|
||||
"""Get all explored targets that have been detected so far.
|
||||
This returns targets from the agent's memory.
|
||||
|
||||
No input required."""
|
||||
try:
|
||||
global targets_expolred
|
||||
targets_list = [target.to_dict() for target in targets_expolred]
|
||||
print(len(targets_list))
|
||||
return json.dumps(targets_list, indent=2)
|
||||
except Exception as e:
|
||||
return f"Error getting explored targets: {str(e)}"
|
||||
|
||||
@tool
|
||||
def get_all_waypoints() -> str:
|
||||
"""Get all waypoints in the session including coordinates and altitude.
|
||||
@@ -240,7 +348,7 @@ def create_uav_tools(client: UAVAPIClient) -> list:
|
||||
return f"Error getting waypoints: {str(e)}"
|
||||
|
||||
@tool
|
||||
def get_obstacles() -> str:
|
||||
def get_obstacles_in_session() -> str:
|
||||
"""Get all obstacles in the session that drones must avoid.
|
||||
Use this to understand what obstacles exist in the environment.
|
||||
|
||||
@@ -249,8 +357,21 @@ def create_uav_tools(client: UAVAPIClient) -> list:
|
||||
obstacles = client.get_obstacles()
|
||||
return json.dumps(obstacles, indent=2)
|
||||
except Exception as e:
|
||||
return f"Error getting obstacles: {str(e)}"
|
||||
return f"Error getting detected obstacles: {str(e)}"
|
||||
|
||||
@tool
|
||||
def get_obstacles() -> str:
|
||||
"""Get all obstacles that have been detected so far.
|
||||
This returns obstacles from the agent's memory.
|
||||
|
||||
No input required."""
|
||||
try:
|
||||
global obstacles_detected
|
||||
obstacles_list = [obstacle.to_dict() for obstacle in obstacles_detected]
|
||||
print(len(obstacles_list))
|
||||
return json.dumps(obstacles_list, indent=2)
|
||||
except Exception as e:
|
||||
return f"Error getting detected obstacles: {str(e)}"
|
||||
|
||||
@tool
|
||||
def get_drone_status(input_json: str) -> str:
|
||||
@@ -278,13 +399,9 @@ def create_uav_tools(client: UAVAPIClient) -> list:
|
||||
@tool
|
||||
def get_nearby_entities(input_json: str) -> str:
|
||||
"""Get drones, targets, and obstacles near a specific drone (within its perception radius).
|
||||
|
||||
Input should be a JSON string with:
|
||||
- drone_id: The ID of the drone (required, get from Action list_drones)
|
||||
|
||||
Example: {{"drone_id": "drone-001"}}
|
||||
"""
|
||||
This also updates the internal sets of explored targets and detected obstacles."""
|
||||
try:
|
||||
global targets_expolred, obstacles_detected
|
||||
params = json.loads(input_json) if isinstance(input_json, str) else input_json
|
||||
drone_id = params.get('drone_id')
|
||||
|
||||
@@ -292,6 +409,17 @@ def create_uav_tools(client: UAVAPIClient) -> list:
|
||||
return "Error: drone_id is required"
|
||||
|
||||
nearby = client.get_nearby_entities(drone_id)
|
||||
|
||||
# Update explored targets
|
||||
if 'targets' in nearby:
|
||||
for target_data in nearby['targets']:
|
||||
targets_expolred.add(TargetInfo(target_data))
|
||||
|
||||
# Update detected obstacles
|
||||
if 'obstacles' in nearby:
|
||||
for obstacle_data in nearby['obstacles']:
|
||||
obstacles_detected.add(ObstacleInfo(obstacle_data))
|
||||
|
||||
return json.dumps(nearby, indent=2)
|
||||
except json.JSONDecodeError as e:
|
||||
return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\"}}"
|
||||
|
||||
Reference in New Issue
Block a user