From 687e20c96aa84aba8576a0c56242afc4b5e22f04 Mon Sep 17 00:00:00 2001 From: abnerhexu <20591243+abnerhexu@users.noreply.github.com> Date: Mon, 12 Jan 2026 09:56:56 +0800 Subject: [PATCH] first commit --- .gitignore | 3 + llm_settings.json | 54 ++ main.py | 1153 +++++++++++++++++++++++++++++++++++++ template/__init__.py | 10 + template/agent_prompt.py | 92 +++ template/parsing_error.py | 18 + tp | 140 +++++ uav_agent.py | 671 +++++++++++++++++++++ uav_api_client.py | 365 ++++++++++++ uav_langchain_tools.py | 649 +++++++++++++++++++++ 10 files changed, 3155 insertions(+) create mode 100644 .gitignore create mode 100644 llm_settings.json create mode 100644 main.py create mode 100644 template/__init__.py create mode 100644 template/agent_prompt.py create mode 100644 template/parsing_error.py create mode 100644 tp create mode 100644 uav_agent.py create mode 100644 uav_api_client.py create mode 100644 uav_langchain_tools.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e0c2e03 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/* +.DS_Store +*/__pycache__/* diff --git a/llm_settings.json b/llm_settings.json new file mode 100644 index 0000000..dbe98d6 --- /dev/null +++ b/llm_settings.json @@ -0,0 +1,54 @@ +{ + "selected_provider": "Kimi", + "provider_configs": { + "Ollama": { + "type": "ollama", + "base_url": "http://localhost:11434", + "models_endpoint": "/api/tags", + "chat_endpoint": "/api/chat", + "requires_api_key": false, + "api_key": "", + "encoding": "utf-8", + "default_model": "gpt-oss", + "default_models": [], + "allow_endpoint_edit": false, + "allow_api_toggle": false, + "system_prompt": "" + }, + "OpenAI": { + "type": "openai-compatible", + "base_url": "https://api.openai.com/v1", + "models_endpoint": "/models", + "chat_endpoint": "/chat/completions", + "requires_api_key": true, + "api_key": "", + "encoding": "utf-8", + "default_model": "gpt-4o-mini", + "default_models": [ + "gpt-4o-mini", + "gpt-4o", + "gpt-4.1-mini", + "gpt-3.5-turbo" + ], + "allow_endpoint_edit": true, + "allow_api_toggle": true, + "system_prompt": "" + }, + "Kimi": { + "type": "openai-compatible", + "base_url": "https://api.moonshot.cn/v1", + "models_endpoint": "/v1/models", + "chat_endpoint": "/v1/chat/completions", + "requires_api_key": true, + "api_key": "sk-2gCgINOEErD1ctdxIB7ALIPnHboZPrQRj1hvVJtEydT1JbXv", + "encoding": "utf-8", + "default_model": "kimi-k2-0711-preview", + "default_models": [ + "kimi-k2-0711-preview" + ], + "allow_endpoint_edit": true, + "allow_api_toggle": true, + "system_prompt": "You are a very friendly drone control agent. No matter what language I use to give you instructions, please call the tools to perform the task and then reply in English." + } + } +} \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..fc4efcd --- /dev/null +++ b/main.py @@ -0,0 +1,1153 @@ +import json +import os +import threading +import time +import tkinter as tk +from tkinter import ttk, scrolledtext, messagebox +from typing import Any, Dict, List, Optional + +from uav_agent import UAVControlAgent, load_llm_settings +from pathlib import Path + +# Try to import speech recognition with fallback +try: + import speech_recognition as sr + import pyaudio + import torch + from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline + SPEECH_AVAILABLE = True + AUDIO_AVAILABLE = True + WHISPER_AVAILABLE = True +except ImportError: + SPEECH_AVAILABLE = False + AUDIO_AVAILABLE = False + WHISPER_AVAILABLE = False + sr = None + +CONFIG_FILE = "llm_settings.json" +CHAT_ICONS = { + "You": "🧑‍✈️", + "UAV Agent": "🤖", + "System": "ℹ️", + "Session Summary": "📋", +} +DEFAULT_CHAT_ICON = "💬" + + +# ------------------------------------------------------------------ # +# Configuration utilities (shared between GUI and CLI) +# ------------------------------------------------------------------ # +def save_llm_settings(settings: Dict[str, Any], settings_path: str = CONFIG_FILE) -> None: + """Save LLM settings to JSON file""" + try: + path = Path(settings_path) + with open(path, 'w', encoding='utf-8') as f: + json.dump(settings, f, indent=2) + except Exception as e: + print(f"Warning: Could not save LLM settings to {settings_path}: {e}") + + +class UAVAgentGUI: + """ + Tkinter-based control panel for the UAV agent. + + This class focuses on GUI presentation and user interaction. + Core business logic (LLM setup, agent execution, UAV API calls) is delegated + to UAVControlAgent class from uav_agent.py. + + Responsibilities: + - GUI layout and widget management + - User input handling (commands, configuration) + - Displaying results and status updates + - Voice input UI (if available) + - Threading for non-blocking operations + + NOT responsible for: + - LLM initialization (handled by UAVControlAgent) + - UAV API communication (handled by UAVControlAgent) + - Command execution logic (handled by UAVControlAgent) + """ + + def __init__(self, root: tk.Tk): + self.root = root + self.root.title("UAV Control Interface") + self.root.geometry("700x800") + self.root.configure(bg="#f0f0f0") + + icon_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "img", + "bot.png", + ) + if os.path.exists(icon_path): + try: + self.root.iconphoto(False, tk.PhotoImage(file=icon_path)) + except tk.TclError: + pass + + self.provider_var = tk.StringVar(value="Ollama") + self.model_var = tk.StringVar() + self.uav_base_url_var = tk.StringVar(value="http://localhost:8000") + self.uav_api_key_var = tk.StringVar(value="agent_secret_key_change_in_production") # UAV API key for authentication + self.temperature_var = tk.DoubleVar(value=0.1) + self.verbose_var = tk.BooleanVar(value=True) + self.debug_var = tk.BooleanVar(value=True) + self.status_var = tk.StringVar(value="🛠️ Configure connection and initialize the agent.") + + self.config_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + CONFIG_FILE, + ) + self.provider_configs: Dict[str, Dict[str, Any]] = { + "Ollama": { + "type": "ollama", + "base_url": "http://localhost:11434", + "default_model": "llama2", + "default_models": [], + "requires_api_key": False, + "api_key": "", + }, + "OpenAI": { + "type": "openai-compatible", + "base_url": "https://api.openai.com/v1", + "default_model": "gpt-4o-mini", + "default_models": [ + "gpt-4o-mini", + "gpt-4o", + "gpt-4.1-mini", + "gpt-3.5-turbo", + ], + "requires_api_key": True, + "api_key": "", + }, + } + + self.agent: Optional[UAVControlAgent] = None + self.agent_lock = threading.Lock() + + # Speech recognition setup + self.is_listening = False + self.voice_dialog = None + self.model_dtype = None + self.whisper_model_var = tk.StringVar(value="large") + self.recognizer = None + self.whisper_model = None + self.whisper_processor = None + self.whisper_pipeline = None + self.voice_enabled = False + self.loading_whisper = False + self.pending_voice_start = False + self.current_whisper_model = None + self.pending_model_reload = None + self.voice_stop_event = None + self.voice_recording_thread = None + self.voice_transcribe_requested = False + self.voice_cancel_btn = None + self.voice_done_btn = None + self.voice_status_label = None + + self.load_app_config() + self.setup_ui() + self.update_provider_dropdown() + self.on_provider_change() + self.root.after(400, lambda: self.initialize_agent(show_warnings=False)) + if SPEECH_AVAILABLE and AUDIO_AVAILABLE and WHISPER_AVAILABLE: + self.root.after(200, self.load_whisper_pipeline) + + # ------------------------------------------------------------------ # + # Configuration handling + # ------------------------------------------------------------------ # + def ensure_config_defaults( + self, + name: str, + config: Dict[str, Any], + ) -> Dict[str, Any]: + """Fill in missing fields for a provider configuration.""" + merged = dict(config) + defaults = self.provider_configs.get(name, {}) + for key, value in defaults.items(): + merged.setdefault(key, value) + merged.setdefault("default_models", []) + if isinstance(merged.get("default_models"), str): + merged["default_models"] = [merged["default_models"]] + merged["api_key"] = str(merged.get("api_key") or "") + merged["default_model"] = merged.get("default_model") or "" + merged["base_url"] = merged.get("base_url") or defaults.get("base_url", "") + merged["requires_api_key"] = bool(merged.get("requires_api_key", False)) + return merged + + def load_app_config(self) -> None: + """Load shared LLM provider settings from disk using shared function.""" + settings = load_llm_settings(self.config_path) + if settings and "provider_configs" in settings: + for name, cfg in settings["provider_configs"].items(): + self.provider_configs[name] = self.ensure_config_defaults(name, cfg) + selected = settings.get("selected_provider") + if selected and selected in self.provider_configs: + self.provider_var.set(selected) + else: + # Seed OpenAI key from environment if config missing + env_key = os.getenv("OPENAI_API_KEY", "").strip() + if env_key and "OpenAI" in self.provider_configs: + self.provider_configs["OpenAI"]["api_key"] = env_key + + current_provider = self.provider_configs.get(self.provider_var.get()) + if current_provider and current_provider.get("default_model"): + self.model_var.set(current_provider["default_model"]) + else: + self.model_var.set("") + + def save_app_config(self) -> None: + """Persist provider configuration back to disk using shared function.""" + data = { + "selected_provider": self.provider_var.get(), + "provider_configs": self.provider_configs, + } + save_llm_settings(data, self.config_path) + + def get_current_provider_config(self) -> Optional[Dict[str, Any]]: + """Return the config object for the active provider.""" + return self.provider_configs.get(self.provider_var.get()) + + # ------------------------------------------------------------------ # + # UI setup + # ------------------------------------------------------------------ # + def setup_ui(self) -> None: + main_frame = ttk.Frame(self.root, padding=12) + main_frame.grid(row=0, column=0, sticky="nsew") + self.root.rowconfigure(0, weight=1) + self.root.columnconfigure(0, weight=1) + main_frame.columnconfigure(0, weight=1) + + # title = ttk.Label( + # main_frame, + # text="UAV Control Interface", + # font=("Arial", 18, "bold"), + # ) + # title.grid(row=0, column=0, sticky="w", pady=(0, 10)) + + config_frame = ttk.LabelFrame(main_frame, text="LLM Provider", padding=10) + config_frame.grid(row=1, column=0, sticky="ew", pady=(0, 10)) + for col_idx in range(4): + config_frame.columnconfigure(col_idx, weight=1 if col_idx == 1 else 0) + + ttk.Label(config_frame, text="Provider:").grid(row=0, column=0, sticky="w") + self.provider_dropdown = ttk.Combobox( + config_frame, + textvariable=self.provider_var, + state="readonly", + width=15, + ) + self.provider_dropdown.grid(row=0, column=1, sticky="ew", pady=2, padx=(6, 0)) + self.provider_dropdown.bind("<>", lambda _: self.on_provider_change()) + + ttk.Button( + config_frame, + text="Configure", + command=self.open_provider_dialog, + width=10, + ).grid(row=0, column=2) + + ttk.Label(config_frame, text="Model:").grid(row=1, column=0, sticky="w") + self.model_dropdown = ttk.Combobox( + config_frame, + textvariable=self.model_var, + width=15, + ) + self.model_dropdown.grid(row=1, column=1, sticky="ew", pady=2, padx=(6, 0)) + + # Temperature label and spinbox combined in one frame, aligned with Configure button + temp_frame = ttk.Frame(config_frame) + temp_frame.grid(row=1, column=2, padx=(10, 0), sticky="e") + ttk.Label(temp_frame, text="Temperature:").pack(side=tk.LEFT, padx=(0, 5)) + temp_spin = ttk.Spinbox( + temp_frame, + textvariable=self.temperature_var, + from_=0.0, + to=1.0, + increment=0.05, + format="%.2f", + width=6, + ) + temp_spin.pack(side=tk.LEFT) + + check_frame = ttk.Frame(config_frame) + check_frame.grid(row=2, column=0, columnspan=4, sticky="w", pady=(4, 0)) + ttk.Checkbutton(check_frame, text="Verbose", variable=self.verbose_var).pack(side=tk.LEFT, padx=(0, 12)) + ttk.Checkbutton(check_frame, text="Debug", variable=self.debug_var).pack(side=tk.LEFT) + + uav_frame = ttk.LabelFrame(main_frame, text="UAV Connection", padding=10) + uav_frame.grid(row=2, column=0, sticky="ew", pady=(0, 10)) + uav_frame.columnconfigure(1, weight=1) + + ttk.Label(uav_frame, text="UAV API Base URL:").grid(row=0, column=0, sticky="w") + ttk.Entry(uav_frame, textvariable=self.uav_base_url_var).grid(row=0, column=1, sticky="ew", padx=(6, 0)) + ttk.Button(uav_frame, text="Reload Agent", command=self.initialize_agent).grid(row=0, column=2, padx=(10, 0)) + ttk.Button(uav_frame, text="Session Summary", command=lambda: self.refresh_session_summary()).grid(row=0, column=3, padx=(10, 0)) + + ttk.Label(uav_frame, text="API Key (Optional):").grid(row=1, column=0, sticky="w", pady=(6, 0)) + api_key_entry = ttk.Entry(uav_frame, textvariable=self.uav_api_key_var) + api_key_entry.grid(row=1, column=1, columnspan=3, sticky="ew", padx=(6, 0), pady=(6, 0)) + + # Add tooltip/hint label + hint_label = ttk.Label(uav_frame, text="Leave empty for AGENT role, or enter USER/SYSTEM/ADMIN key", font=("Arial", 9), foreground="gray") + hint_label.grid(row=2, column=1, columnspan=3, sticky="w", padx=(6, 0), pady=(2, 0)) + + notebook = ttk.Notebook(main_frame) + notebook.grid(row=3, column=0, sticky="nsew", pady=(0, 10)) + main_frame.rowconfigure(3, weight=4) + + chat_frame = ttk.Frame(notebook) + chat_frame.columnconfigure(0, weight=1) + chat_frame.rowconfigure(0, weight=1) + self.chat_output = scrolledtext.ScrolledText(chat_frame, wrap=tk.WORD, state=tk.DISABLED) + self.chat_output.grid(row=0, column=0, sticky="nsew") + self.chat_output.configure(height=22, font=("Arial", 11)) + notebook.add(chat_frame, text="Conversation") + + steps_frame = ttk.Frame(notebook) + steps_frame.columnconfigure(0, weight=1) + steps_frame.rowconfigure(0, weight=1) + self.steps_output = scrolledtext.ScrolledText(steps_frame, wrap=tk.WORD, state=tk.DISABLED, height=8) + self.steps_output.configure(font=("Courier New", 10)) + self.steps_output.grid(row=0, column=0, sticky="nsew") + notebook.add(steps_frame, text="Intermediate Steps") + + input_frame = ttk.LabelFrame(main_frame, text="Command", padding=3) + input_frame.grid(row=4, column=0, sticky="ew") + input_frame.columnconfigure(0, weight=1) + + self.command_input = tk.Text(input_frame, height=5, wrap=tk.WORD) + self.command_input.grid(row=0, column=0, sticky="nsew", pady=(0, 3)) + input_frame.rowconfigure(0, weight=1) + self.command_input.bind("", self.handle_command_return) + self.command_input.bind("", self.handle_command_return) + + button_bar = ttk.Frame(input_frame) + button_bar.grid(row=1, column=0, sticky="e") + + self.send_button = ttk.Button(button_bar, text="Send Command", command=self.send_command) + self.send_button.pack(side=tk.RIGHT, padx=(6, 0)) + + # Voice button + if SPEECH_AVAILABLE and AUDIO_AVAILABLE and WHISPER_AVAILABLE: + voice_text = "🎤 Loading.." + else: + voice_text = "🎤 Unavailable" + + self.voice_btn = ttk.Button(button_bar, text=voice_text, command=self.toggle_voice_input, state=tk.DISABLED) + self.voice_btn.pack(side=tk.RIGHT, padx=(6, 0)) + + ttk.Button(button_bar, text="Clear", command=lambda: self.command_input.delete("1.0", tk.END)).pack(side=tk.RIGHT, padx=(6, 0)) + + + + status_bar = ttk.Frame(main_frame) + status_bar.grid(row=5, column=0, sticky="ew", pady=(10, 0)) + status_bar.columnconfigure(0, weight=1) + ttk.Label(status_bar, textvariable=self.status_var).grid(row=0, column=0, sticky="w") + + # ------------------------------------------------------------------ # + # UI helpers + # ------------------------------------------------------------------ # + def set_status(self, message: str) -> None: + self.status_var.set(message) + + def append_chat(self, speaker: str, message: str) -> None: + text = self.stringify(message) + icon = CHAT_ICONS.get(speaker, DEFAULT_CHAT_ICON) + self.chat_output.config(state=tk.NORMAL) + self.chat_output.insert(tk.END, f"{icon} {speaker}: {text.strip()}\n\n") + self.chat_output.see(tk.END) + self.chat_output.config(state=tk.DISABLED) + + def append_steps(self, text: str) -> None: + self.steps_output.config(state=tk.NORMAL) + self.steps_output.delete("1.0", tk.END) + self.steps_output.insert(tk.END, text.strip() + "\n") + self.steps_output.see(tk.END) + self.steps_output.config(state=tk.DISABLED) + + def clear_steps(self) -> None: + self.steps_output.config(state=tk.NORMAL) + self.steps_output.delete("1.0", tk.END) + self.steps_output.config(state=tk.DISABLED) + + def update_provider_dropdown(self) -> None: + provider_names = sorted(self.provider_configs.keys()) + self.provider_dropdown["values"] = provider_names + if self.provider_var.get() not in provider_names and provider_names: + self.provider_var.set(provider_names[0]) + + def on_provider_change(self) -> None: + config = self.get_current_provider_config() + if not config: + return + + models = self.collect_model_choices(config) + self.model_dropdown["values"] = models + if models: + if self.model_var.get() not in models: + self.model_var.set(models[0]) + else: + self.model_var.set(config.get("default_model", "")) + + self.save_app_config() + + def collect_model_choices(self, config: Dict[str, Any]) -> List[str]: + models: List[str] = [] + stored = config.get("default_models", []) + if isinstance(stored, list): + models.extend([str(item) for item in stored if item]) + elif isinstance(stored, str) and stored: + models.append(stored) + default_model = config.get("default_model") + if default_model and default_model not in models: + models.insert(0, default_model) + return models + + # ------------------------------------------------------------------ # + # Agent lifecycle + # ------------------------------------------------------------------ # + def initialize_agent(self, show_warnings: bool = True) -> None: + thread = threading.Thread( + target=self._initialize_agent_worker, + args=(show_warnings,), + daemon=True, + ) + thread.start() + + def _initialize_agent_worker(self, show_warnings: bool) -> None: + """Worker thread to initialize the agent - delegates to UAVControlAgent.""" + with self.agent_lock: + config = self.get_current_provider_config() + if not config: + if show_warnings: + self.root.after(0, lambda: messagebox.showerror("Provider", "No provider configuration found.")) + else: + self.root.after(0, lambda: self.set_status("⚙️ Configure a provider to initialize the agent.")) + return + + # Extract configuration parameters + llm_params = self._extract_llm_params(config) + if llm_params is None: + # Error already handled in _extract_llm_params + return + + # Get UAV connection parameters + uav_base_url = self.uav_base_url_var.get().strip() or "http://localhost:8000" + uav_api_key = self.uav_api_key_var.get().strip() or None + temperature = float(self.temperature_var.get()) + verbose = bool(self.verbose_var.get()) + debug = bool(self.debug_var.get()) + + self.root.after(0, lambda: self.set_status("🛠️ Initializing UAV agent...")) + + # Delegate to UAVControlAgent - it handles all LLM initialization logic + try: + agent = UAVControlAgent( + base_url=uav_base_url, + uav_api_key=uav_api_key, + llm_provider=llm_params['llm_provider'], + llm_model=llm_params['llm_model'], + llm_api_key=llm_params['llm_api_key'], + llm_base_url=llm_params['llm_base_url'], + temperature=temperature, + verbose=verbose, + debug=debug, + ) + except Exception as exc: + if show_warnings: + self.root.after( + 0, + lambda: messagebox.showerror("Agent Initialization", f"Failed to initialize agent:\n{exc}"), + ) + else: + self.root.after(0, lambda: self.append_chat("System", f"⚠️ Agent initialization failed: {exc}")) + self.root.after(0, lambda: self.set_status("❌ Agent initialization failed.")) + return + + self.agent = agent + model_name = llm_params['llm_model'] + self.root.after(0, lambda: self.set_status("✅ Agent ready.")) + self.root.after(0, lambda: self.append_chat("System", f"🚀 Agent initialized with model '{model_name or 'default'}'.")) + self.root.after(0, lambda: self.refresh_session_summary(silent=True)) + + def _extract_llm_params(self, config: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Extract and validate LLM parameters from provider config.""" + provider_type = config.get("type", "ollama") + base_url = config.get("base_url", "").strip() + model = self.model_var.get().strip() or config.get("default_model", "") + api_key = str(config.get("api_key", "") or "").strip() + + # Determine provider type + if provider_type == "ollama": + llm_provider = "ollama" + llm_base_url = None + else: + if "api.openai.com" in base_url: + llm_provider = "openai" + else: + llm_provider = "openai-compatible" + llm_base_url = base_url or None + + # Check API key requirement + if config.get("requires_api_key") and not api_key: + self.root.after( + 0, + lambda: messagebox.showwarning("API Key", "The selected provider requires an API key."), + ) + self.root.after(0, lambda: self.set_status("🔑 Add an API key to initialize this provider.")) + return None + + return { + 'llm_provider': llm_provider, + 'llm_model': model, + 'llm_api_key': api_key or None, + 'llm_base_url': llm_base_url + } + + # ------------------------------------------------------------------ # + # Session summary + # ------------------------------------------------------------------ # + def refresh_session_summary(self, silent: bool = False) -> None: + if not self.agent: + if silent: + self.set_status("ℹ️ Initialize the agent to view the session summary.") + else: + messagebox.showinfo("UAV Agent", "Initialize the agent first.") + return + thread = threading.Thread( + target=self._fetch_session_summary, + args=(silent,), + daemon=True, + ) + thread.start() + + def _fetch_session_summary(self, silent: bool) -> None: + """Fetch session summary - delegates to UAVControlAgent method.""" + with self.agent_lock: + if not self.agent: + return + self.root.after(0, lambda: self.set_status("📡 Fetching session summary...")) + try: + # Delegate to agent's get_session_summary method + summary = self.agent.get_session_summary() + except Exception as exc: + if silent: + self.root.after(0, lambda: self.append_chat("System", f"⚠️ Failed to fetch session summary: {exc}")) + else: + self.root.after( + 0, + lambda: messagebox.showerror("Session Summary", f"Failed to fetch session summary:\n{exc}"), + ) + self.root.after(0, lambda: self.set_status("⚠️ Failed to fetch session summary.")) + return + + self.root.after(0, lambda: self.append_chat("Session Summary", summary.strip())) + self.root.after(0, lambda: self.set_status("📋 Session summary updated.")) + + # ------------------------------------------------------------------ # + # Command execution + # ------------------------------------------------------------------ # + def handle_command_return(self, event: Any) -> Optional[str]: + if event is None: + return None + if event.state & 0x1: # Shift modifier adds newline + return None + self.send_command() + return "break" + + def send_command(self) -> None: + command = self.command_input.get("1.0", tk.END).strip() + if not command: + return + if not self.agent: + messagebox.showwarning("UAV Agent", "Initialize the agent before sending commands.") + return + + self.append_chat("You", command) + self.command_input.delete("1.0", tk.END) + self.clear_steps() + self.send_button.configure(state=tk.DISABLED) + self.set_status("🧠 Executing command...") + + thread = threading.Thread(target=self._execute_command, args=(command,), daemon=True) + thread.start() + + def _execute_command(self, command: str) -> None: + """Execute command - delegates to UAVControlAgent.execute() method.""" + with self.agent_lock: + if not self.agent: + self.root.after(0, lambda: self.set_status("ℹ️ Agent not initialized.")) + return + try: + # Delegate to agent's execute method - it handles all LLM interaction + result = self.agent.execute(command) + except Exception as exc: + self.root.after(0, lambda: self.append_chat("System", f"Error executing command: {exc}")) + self.root.after(0, lambda: self.set_status("⚠️ Command failed.")) + self.root.after(0, lambda: self.send_button.configure(state=tk.NORMAL)) + return + + success = result.get("success", False) + output = result.get("output", "") + steps_text = self._format_intermediate_steps(result.get("intermediate_steps", [])) + + self.root.after(0, lambda: self.append_chat("UAV Agent", output if output else "(no response)")) + self.root.after(0, lambda: self.append_steps(steps_text)) + self.root.after(0, lambda: self.set_status("✅ Command completed." if success else "⚠️ Command reported an error.")) + self.root.after(0, lambda: self.send_button.configure(state=tk.NORMAL)) + + def _format_intermediate_steps(self, steps: List[Any]) -> str: + """Format intermediate steps for display in GUI - pure presentation logic.""" + if not steps: + return "🧠 No intermediate steps captured." + + lines: List[str] = [] + for idx, step in enumerate(steps, start=1): + if isinstance(step, (list, tuple)) and len(step) == 2: + action, observation = step + else: + action, observation = step, "" + + lines.append(f"🧠 Step {idx}") + + log_text = self.extract_action_log(action) + if log_text: + lines.append(f" 💭 {log_text.strip()}") + + tool_name = getattr(action, "tool", None) + if tool_name: + lines.append(f" 🔧 Action: {tool_name}") + + tool_input = getattr(action, "tool_input", None) + if tool_input: + lines.append(f" 📦 Input: {self.stringify(tool_input)}") + + if observation: + lines.append(f" 👀 Observation: {self.stringify(observation)}") + + lines.append("") + + return "\n".join(lines).strip() + + def stringify(self, value: Any) -> str: + if isinstance(value, str): + return value + try: + return json.dumps(value, indent=2, sort_keys=True) + except (TypeError, ValueError): + return str(value) + + def extract_action_log(self, action: Any) -> str: + if action is None: + return "" + + log_text = getattr(action, "log", None) + if isinstance(log_text, str) and log_text.strip(): + return log_text + + message_log = getattr(action, "message_log", None) + if message_log: + parts: List[str] = [] + for message in message_log: + content = getattr(message, "content", "") + if isinstance(content, str) and content.strip(): + parts.append(content.strip()) + elif content: + parts.append(str(content)) + if parts: + return "\n".join(parts) + + if isinstance(action, str): + return action + + tool_name = getattr(action, "tool", None) + if tool_name: + return f"Preparing to call tool '{tool_name}'" + + return "" + + # ------------------------------------------------------------------ # + # Provider dialog + # ------------------------------------------------------------------ # + def open_provider_dialog(self) -> None: + name = self.provider_var.get() + config = self.provider_configs.get(name, {}) + + dialog = tk.Toplevel(self.root) + dialog.title(f"Configure Provider - {name}") + dialog.transient(self.root) + dialog.grab_set() + dialog.resizable(False, False) + for idx in range(0, 6): + dialog.columnconfigure(idx % 2, weight=1 if idx % 2 == 1 else 0) + + ttk.Label(dialog, text="Provider Name:").grid(row=0, column=0, sticky="w", padx=10, pady=(10, 4)) + ttk.Label(dialog, text=name).grid(row=0, column=1, sticky="w", padx=10, pady=(10, 4)) + + ttk.Label(dialog, text="Type:").grid(row=1, column=0, sticky="w", padx=10, pady=4) + type_var = tk.StringVar(value=config.get("type", "ollama")) + type_combo = ttk.Combobox( + dialog, + textvariable=type_var, + values=["ollama", "openai-compatible"], + state="readonly", + width=20, + ) + type_combo.grid(row=1, column=1, sticky="ew", padx=10, pady=4) + + ttk.Label(dialog, text="Base URL:").grid(row=2, column=0, sticky="w", padx=10, pady=4) + base_var = tk.StringVar(value=config.get("base_url", "")) + ttk.Entry(dialog, textvariable=base_var).grid(row=2, column=1, sticky="ew", padx=10, pady=4) + + ttk.Label(dialog, text="Default Model:").grid(row=3, column=0, sticky="w", padx=10, pady=4) + default_model_var = tk.StringVar(value=config.get("default_model", "")) + ttk.Entry(dialog, textvariable=default_model_var).grid(row=3, column=1, sticky="ew", padx=10, pady=4) + + ttk.Label(dialog, text="Default Models (comma separated):").grid(row=4, column=0, sticky="w", padx=10, pady=4) + defaults_var = tk.StringVar(value=", ".join(config.get("default_models", []))) + ttk.Entry(dialog, textvariable=defaults_var).grid(row=4, column=1, sticky="ew", padx=10, pady=4) + + requires_key_var = tk.BooleanVar(value=config.get("requires_api_key", False)) + ttk.Checkbutton(dialog, text="Requires API Key", variable=requires_key_var).grid( + row=5, column=0, columnspan=2, sticky="w", padx=10, pady=4 + ) + + ttk.Label(dialog, text="API Key:").grid(row=6, column=0, sticky="w", padx=10, pady=4) + api_key_var = tk.StringVar(value=config.get("api_key", "")) + api_entry = ttk.Entry(dialog, textvariable=api_key_var, show="*", width=30) + api_entry.grid(row=6, column=1, sticky="ew", padx=10, pady=4) + + def sync_api_state(*_): + state = tk.NORMAL if requires_key_var.get() else tk.DISABLED + api_entry.config(state=state) + + sync_api_state() + requires_key_var.trace_add("write", sync_api_state) + + button_frame = ttk.Frame(dialog) + button_frame.grid(row=7, column=0, columnspan=2, pady=10) + + def save(): + updated = { + "type": type_var.get(), + "base_url": base_var.get().strip(), + "default_model": default_model_var.get().strip(), + "default_models": [item.strip() for item in defaults_var.get().split(",") if item.strip()], + "requires_api_key": requires_key_var.get(), + "api_key": api_key_var.get().strip(), + } + self.provider_configs[name] = self.ensure_config_defaults(name, updated) + if name == self.provider_var.get(): + self.on_provider_change() + self.save_app_config() + dialog.destroy() + + ttk.Button(button_frame, text="Save", command=save).pack(side=tk.RIGHT, padx=5) + ttk.Button(button_frame, text="Cancel", command=dialog.destroy).pack(side=tk.RIGHT) + + # ------------------------------------------------------------------ # + # Voice input methods + # ------------------------------------------------------------------ # + def load_whisper_pipeline(self, selected_model=None, force_reload=False): + """Load Whisper resources in a background thread to avoid blocking the UI.""" + if not (SPEECH_AVAILABLE and AUDIO_AVAILABLE and WHISPER_AVAILABLE): + def on_fail_missing(): + self.loading_whisper = False + self.voice_enabled = False + self.pending_voice_start = False + if hasattr(self, "voice_btn"): + self.voice_btn.config(text="🎤 Unavailable", state=tk.DISABLED) + self.set_status("Voice recording unavailable (missing dependencies)") + + self.root.after(0, on_fail_missing) + return + if sr is None: + def on_fail_sr(): + self.loading_whisper = False + self.voice_enabled = False + self.pending_voice_start = False + if hasattr(self, "voice_btn"): + self.voice_btn.config(text="🎤 Unavailable", state=tk.DISABLED) + self.set_status("Voice recording unavailable (speech_recognition missing)") + + self.root.after(0, on_fail_sr) + return + if selected_model is None: + selected_model = self.whisper_model_var.get() + + if self.loading_whisper: + if force_reload: + self.pending_model_reload = selected_model + return + + if self.voice_enabled and not force_reload and selected_model == self.current_whisper_model: + return + + self.loading_whisper = True + self.voice_enabled = False + if hasattr(self, "voice_btn"): + self.voice_btn.config(text="🎤 Loading..", state=tk.DISABLED) + self.set_status(f"Loading Whisper {selected_model} model...") + + def loader(): + try: + recognizer = self.recognizer or sr.Recognizer() + + recognizer.dynamic_energy_threshold = True + recognizer.energy_threshold = 150 + recognizer.pause_threshold = 0.5 + recognizer.phrase_threshold = 0.1 + recognizer.non_speaking_duration = 0.2 + + if WHISPER_AVAILABLE: + device = "cuda:0" if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + + if selected_model == "large": + model_path = "./whisper-large-v3-turbo" + elif selected_model == "medium": + model_path = "./whisper-medium" + else: + model_path = "./whisper-small" + + if not os.path.exists(model_path): + raise FileNotFoundError(f"Whisper {selected_model} model not found at {model_path}") + + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_path, + dtype=dtype, + low_cpu_mem_usage=True, + use_safetensors=True + ) + model.to(device) + + processor = AutoProcessor.from_pretrained(model_path) + pipeline_obj = pipeline( + "automatic-speech-recognition", + model=model, + tokenizer=processor.tokenizer, + feature_extractor=processor.feature_extractor, + dtype=dtype, + device=device, + generate_kwargs={ + "task": "transcribe", + "language": None # Auto-detect language (English or Chinese) + } + ) + else: + model = None + processor = None + pipeline_obj = None + device = None + dtype = None + + def on_success(): + self.recognizer = recognizer + self.whisper_model = model + self.whisper_processor = processor + self.whisper_pipeline = pipeline_obj + self.device = device + self.model_dtype = dtype + self.current_whisper_model = selected_model + self.voice_enabled = pipeline_obj is not None + self.loading_whisper = False + self.pending_model_reload = None + if hasattr(self, "voice_btn"): + if self.voice_enabled: + self.voice_btn.config(text="🎤 Record", state=tk.NORMAL) + else: + self.voice_btn.config(text="🎤 Disabled", state=tk.DISABLED) + if self.voice_enabled: + self.set_status(f"Whisper {selected_model} model ready") + else: + self.set_status("Whisper model unavailable") + + if self.pending_model_reload and self.pending_model_reload != selected_model: + next_model = self.pending_model_reload + self.pending_model_reload = None + self.load_whisper_pipeline(selected_model=next_model, force_reload=True) + return + + if getattr(self, "pending_voice_start", False): + if self.voice_enabled: + self.pending_voice_start = False + self.start_voice_input() + else: + self.pending_voice_start = False + messagebox.showwarning("Voice Recording", "Voice model is not available.") + + self.root.after(0, on_success) + + except Exception as e: + def on_fail(): + self.loading_whisper = False + self.voice_enabled = False + self.pending_model_reload = None + if hasattr(self, "voice_btn"): + label = "🎤 Disabled" + self.voice_btn.config(text=label, state=tk.DISABLED) + self.set_status(f"Model loading failed") + self.pending_voice_start = False + # messagebox.showerror("Model Loading Error", f"Failed to prepare voice model") + + self.root.after(0, on_fail) + + threading.Thread(target=loader, daemon=True).start() + + def toggle_voice_input(self): + """Toggle voice recording on/off""" + if not self.voice_enabled: + messagebox.showwarning("Voice Recording", "Voice recording model is not ready yet. Please wait a moment and try again.") + return + + if not self.is_listening: + self.start_voice_input() + else: + self.finish_voice_input() + + def start_voice_input(self): + """Start recording voice input""" + if not self.voice_enabled: + self.pending_voice_start = True + if not self.loading_whisper: + self.load_whisper_pipeline() + self.create_voice_dialog(status_text="Loading voice model...", done_enabled=False) + self.set_status("Preparing voice model...") + return + self.pending_voice_start = False + if sr is None: + messagebox.showwarning("Voice Recording", "speech_recognition library not available") + return + + self.is_listening = True + self.voice_btn.config(text="🎤 Recording...", state=tk.DISABLED) + self.voice_transcribe_requested = False + self.voice_stop_event = threading.Event() + self.create_voice_dialog(status_text="🎤 Initializing microphone...", done_enabled=False) + self.set_status("🎤 Recording active") + + thread = threading.Thread(target=self.begin_voice_capture, daemon=True) + thread.start() + self.voice_recording_thread = thread + + def finish_voice_input(self, event=None): + if not self.is_listening: + self.cancel_voice_input() + return + self.voice_transcribe_requested = True + self.set_status("Processing recording...") + self.update_voice_dialog("Processing...", False) + self.disable_voice_dialog_buttons() + self.stop_voice_recording() + + def cancel_voice_input(self, event=None): + if not self.is_listening: + if self.voice_dialog: + self.voice_dialog.destroy() + self.voice_dialog = None + if self.voice_btn: + self.voice_btn.config(text="🎤 Record", state=tk.NORMAL) + self.set_status("Recording cancelled") + self.pending_voice_start = False + return + self.voice_transcribe_requested = False + self.set_status("Cancelling recording...") + self.update_voice_dialog("Cancelling...", False) + self.disable_voice_dialog_buttons() + self.stop_voice_recording() + + def stop_voice_recording(self): + if self.voice_stop_event: + self.voice_stop_event.set() + + def create_voice_dialog(self, status_text="🎤 Recording...", done_enabled=True): + """Create or refresh the voice input dialog.""" + self.voice_dialog = tk.Toplevel(self.root) + self.voice_dialog.title("Voice Input") + self.voice_dialog.geometry("320x120") + self.voice_dialog.resizable(False, False) + + self.voice_dialog.transient(self.root) + self.voice_dialog.grab_set() + + self.voice_status_label = ttk.Label(self.voice_dialog, text=status_text, font=('Arial', 14, 'bold')) + self.voice_status_label.pack(pady=(20, 10)) + + button_frame = ttk.Frame(self.voice_dialog) + button_frame.pack(pady=(0, 15)) + + self.voice_cancel_btn = ttk.Button(button_frame, text="Cancel", command=self.cancel_voice_input) + self.voice_cancel_btn.pack(side=tk.LEFT, padx=10) + + self.voice_done_btn = ttk.Button(button_frame, text="Done", command=self.finish_voice_input) + self.voice_done_btn.pack(side=tk.LEFT, padx=10) + if done_enabled: + self.voice_done_btn.focus_set() + else: + self.voice_done_btn.config(state=tk.DISABLED) + self.voice_cancel_btn.focus_set() + + self.voice_dialog.bind("", lambda e: self.finish_voice_input()) + self.voice_dialog.bind("", lambda e: self.finish_voice_input()) + + self.voice_dialog.protocol("WM_DELETE_WINDOW", self.cancel_voice_input) + + def update_voice_dialog(self, status_text=None, done_enabled=None): + if status_text and self.voice_status_label: + self.voice_status_label.config(text=status_text) + if done_enabled is not None and self.voice_done_btn: + self.voice_done_btn.config(state=tk.NORMAL if done_enabled else tk.DISABLED) + if done_enabled: + self.voice_done_btn.focus_set() + else: + if self.voice_cancel_btn: + self.voice_cancel_btn.focus_set() + + def disable_voice_dialog_buttons(self): + if self.voice_cancel_btn: + self.voice_cancel_btn.config(state=tk.DISABLED) + if self.voice_done_btn: + self.voice_done_btn.config(state=tk.DISABLED) + + def begin_voice_capture(self): + microphone = None + error_message = None + try: + microphone = sr.Microphone() + except Exception as mic_error: + error_message = f"Cannot access microphone: {mic_error}" + + if error_message or microphone is None: + self.root.after(0, lambda: self.on_voice_session_complete("", error_message or "Microphone error", False)) + return + + self.root.after(0, lambda: self.update_voice_dialog("🎤 Listening...", True)) + self.record_voice_segment(microphone) + + def record_voice_segment(self, microphone): + """Record a complete voice segment and then transcribe it""" + if not self.voice_enabled or self.whisper_pipeline is None or self.recognizer is None: + self.set_status("Voice model not ready") + self.is_listening = False + return + + sample_rate = getattr(microphone, "SAMPLE_RATE", 16000) + sample_width = getattr(microphone, "SAMPLE_WIDTH", 2) + chunk_size = getattr(microphone, "CHUNK", 1024) + frames = [] + max_duration = 120 # safety guard + start_time = time.time() + transcribe_requested = False + error_message = None + text_result = "" + + try: + with microphone as source: + self.recognizer.adjust_for_ambient_noise(source, duration=0.05) + stream = source.stream + + while not self.voice_stop_event.is_set() and (time.time() - start_time) < max_duration: + try: + data = stream.read(chunk_size) + except IOError as e: + # Handle buffer overflow errors gracefully + if e.errno == -9981: # Input overflowed + continue + error_message = f"Recording error: {e}" + break + except Exception as read_error: + error_message = f"Recording error: {read_error}" + break + + frames.append(data) + + except Exception as e: + error_message = f"Voice recording error: {e}" + + transcribe_requested = self.voice_transcribe_requested + + if not frames: + if not error_message: + error_message = "No audio recorded" + elif transcribe_requested and self.whisper_pipeline is None: + error_message = "Voice model not ready" + elif transcribe_requested and not error_message: + audio_bytes = b"".join(frames) + audio_data = sr.AudioData(audio_bytes, sample_rate, sample_width) + + import tempfile + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: + tmp_filename = tmp_file.name + + try: + with open(tmp_filename, "wb") as f: + f.write(audio_data.get_wav_data()) + + result = self.whisper_pipeline( + tmp_filename, + return_timestamps=False, + generate_kwargs={ + "task": "transcribe", + "language": None # Auto-detect between English and Chinese + } + ) + text_result = result.get("text", "").strip() + except Exception as transcribe_error: + error_message = f"Transcription error: {transcribe_error}" + finally: + try: + os.unlink(tmp_filename) + except OSError: + pass + + self.root.after(0, lambda: self.on_voice_session_complete(text_result, error_message, transcribe_requested)) + + def on_voice_session_complete(self, text, error_message, transcribed): + self.is_listening = False + self.voice_stop_event = None + self.voice_recording_thread = None + self.voice_transcribe_requested = False + + if self.voice_dialog: + try: + self.voice_dialog.destroy() + except tk.TclError: + pass + self.voice_dialog = None + self.voice_status_label = None + self.voice_cancel_btn = None + self.voice_done_btn = None + + if self.voice_btn: + self.voice_btn.config(text="🎤 Record", state=tk.NORMAL) + + if error_message: + self.set_status("Recording error") + self.append_chat("System", error_message) + messagebox.showerror("Voice Recording", error_message) + else: + if transcribed and text: + self.command_input.insert(tk.END, text + " ") + snippet = text[:50] + ("..." if len(text) > 50 else "") + self.set_status(f"Added: {snippet}") + elif transcribed: + self.set_status("No speech detected") + else: + self.set_status("Recording cancelled") + + self.root.after(1000, lambda: self.set_status("✅ Agent ready.") if self.agent else self.set_status("🛠️ Configure connection and initialize the agent.")) + + # ------------------------------------------------------------------ # + # Main loop entry + # ------------------------------------------------------------------ # +def main() -> None: + root = tk.Tk() + app = UAVAgentGUI(root) + root.mainloop() + + +if __name__ == "__main__": + main() diff --git a/template/__init__.py b/template/__init__.py new file mode 100644 index 0000000..f976ce7 --- /dev/null +++ b/template/__init__.py @@ -0,0 +1,10 @@ +""" +UAV Agent Templates + +This package contains prompt templates for the UAV control agent. +""" + +from .agent_prompt import AGENT_PROMPT +from .parsing_error import PARSING_ERROR_TEMPLATE + +__all__ = ["AGENT_PROMPT", "PARSING_ERROR_TEMPLATE"] diff --git a/template/agent_prompt.py b/template/agent_prompt.py new file mode 100644 index 0000000..de78f0d --- /dev/null +++ b/template/agent_prompt.py @@ -0,0 +1,92 @@ +""" +UAV Agent Prompt Template + +This template defines the system prompt for the UAV control agent. +It provides guidelines, safety rules, task types, and response format instructions. +""" + +AGENT_PROMPT = """You are an intelligent UAV (drone) control agent. Your job is to understand user intentions and control drones safely and efficiently. + +IMPORTANT GUIDELINES: +0. ALWAYS Respond [TASK DONE] as a signal of finish task at the end of response. +1. ALWAYS check the current session status first to understand the mission task +2. ALWAYS list available drones before attempting to control them +3. ALWAYS check nearby entities of a drone before you control it, there are lot of obstacles. +4. Check weather conditions regularly - the weather will influence the battery usage +5. Be proactive in gathering information of obstacles and targets, by using nearby entities functions +6. Remember the information of obstacles and targets, because they are not always available +7. When visiting targets, get close enough within task_radius +9. Land drones safely when tasks are complete or battery is low +10. Monitor battery levels - if below 10%, consider charging before continuing + +SAFETY RULES: +- If you can not directly move the drone to a position, find a mediam waypoint to get there first, and then cosider the destination, repeat the process, until you can move directly to the destination. +- Always verify drone status and nearby entities before commands + + +AVAILABLE TOOLS: +You have access to these tools to accomplish your tasks: {tool_names} + +{tools} + +RESPONSE FORMAT: +Use this exact format for your responses: + +Question: the input question or command you must respond to +Thought: analyze what you need to do and what information you need +Action: the specific tool to use from the list above +Action Input: the input parameters for the tool (use proper JSON format) +Observation: the result from running the tool +... (repeat Thought/Action/Action Input/Observation as needed) +Thought: I now have enough information to provide a final answer +Final Answer: a clear, concise answer to the original question + +ACTION INPUT FORMAT RULES: +1. For tools with NO parameters (like list_drones, get_session_info): + Action Input: {{}} + +2. For tools with ONE string parameter (like get_drone_status): + Action Input: {{"drone_id": "drone-abc123"}} + +3. For tools with MULTIPLE parameters (like move_to): + Action Input: {{"drone_id": "drone-abc123", "x": 100.0, "y": 50.0, "z": 20.0}} + +CRITICAL: +- ALWAYS use proper JSON format with double quotes for keys and string values +- ALWAYS use curly braces for Action Input +- For tools with no parameters, use empty braces +- Numbers should NOT have quotes +- Strings MUST have quotes + +EXAMPLES: +Question: What drones are available? +Thought: I need to list all drones to see what's available +Action: list_drones +Action Input: {{}} +Observation: [result will be returned here] + +Question: Take off drone-001 to 15 meters +Thought: I need to take off the drone to the specified altitude +Action: take_off +Action Input: {{"drone_id": "drone-001", "altitude": 15.0}} +Observation: Drone took off successfully + +Question: Move drone-001 to position x=100, y=50, z=20 +Thought: I need to move the drone to the specified coordinates +Action: move_to +Action Input: {{"drone_id": "drone-001", "x": 100.0, "y": 50.0, "z": 20.0}} +Observation: Drone moved successfully + +Tips for you to finish task in the most efficient way: + +1. For a certain task, use the task-related API, do not get_session_info too many times. +2. If you want to to get the position of a target, use GET /targets API. If you want to get the position of an obstacle, use GET /obstacles API. DO NOT USE get_nearby_entities! +3. Move directly as much as possible. For example, when the task is moving from A to B via C, first try to Collision Detection between A to C, if there's no collision, then move directly to C, otherwise, detour. +4. Getting entities nearby do not always effective. You have only limited sensor range. Using /targets API to get targets and /obstacles API to get obstacles is more effective. +5. If battery is below 30, find the nerest waypoint, go there and land, then charge to 100. +6. Reaching to a higher latitude can help you see targets, but do not exceed the drone's limit. + +Begin! + +Question: {input} +Thought:{agent_scratchpad}""" diff --git a/template/parsing_error.py b/template/parsing_error.py new file mode 100644 index 0000000..48643d0 --- /dev/null +++ b/template/parsing_error.py @@ -0,0 +1,18 @@ +""" +Parsing Error Template + +This template defines the error message shown to the LLM when it produces +invalid JSON in the Action Input field. +""" + +PARSING_ERROR_TEMPLATE = """Parsing error: {error} + +REMINDER - Action Input must be valid JSON: +- Use double quotes for keys and string values +- Use curly braces: {{}} +- For no parameters: {{}} +- For one parameter: {{"drone_id": "drone-001"}} +- For multiple parameters: {{"drone_id": "drone-001", "altitude": 15.0}} +- Numbers WITHOUT quotes, strings WITH quotes + +Please try again with proper JSON format.""" diff --git a/tp b/tp new file mode 100644 index 0000000..f1ee8dc --- /dev/null +++ b/tp @@ -0,0 +1,140 @@ +These are all the APIs you can use. + +Drone Management + +| Method | Endpoint | Description | +|--------|----------|-------------| +| GET | `/drones` | List all drones | +| POST | `/drones` | Register new drone | +| GET | `/drones/{{id}}` | Get drone details | +| PUT | `/drones/{{id}}` | Update drone properties (metadata, state, battery, position, home) | +| PUT | `/drones/{{id}}/position` | Update drone position only | +| DELETE | `/drones/{{id}}` | Delete drone | +| POST | `/drones/{{id}}/battery` | Update battery level | + +Command Management + +Generic Command Endpoint + +| Method | Endpoint | Description | +|--------|----------|-------------| +| POST | `/drones/{{id}}/command` | Send any command | +| GET | `/drones/{{id}}/commands` | Get command history | +| GET | `/commands/{{command_id}}` | Get command status | + +Direct Command Endpoints + +All commands use **POST** method with `/drones/{{id}}/command/{{command_name}}` + +Charge can only done at Waypoint target. + +| Command | Parameters | Description | +|---------|-----------|-------------| +| `take_off` | `?altitude=10.0` | Takeoff to altitude | +| `land` | - | Land at position | +| `move_to` | `?x=50&y=50&z=15` | Move to coordinates | +| `move_towards` | `?distance=20&heading=90` | Move distance in direction (uses current heading if not specified) | +| `move_along_path` | Body: `{{waypoints:[...]}}` | Follow waypoints | +| `change_altitude` | `?altitude=20.0` | Change altitude only | +| `hover` | `duration` (optional) | Hold position | +| `rotate` | `?heading=180.0` | Change heading/orientation | +| `return_home` | - | Return to launch | +| `set_home` | - | Set home position | +| `calibrate` | - | Calibrate sensors | +| `take_photo` | - | Capture image | +| `send_message` | `?target_drone_id=X&message=Y` | Send to drone | +| `broadcast` | `?message=text` | Send to all | +| `charge` | `?charge_amount=30.0` | Charge battery | + +Target Management + +Type of target we have fixed, moving, waypoint, circle, polygon. - Note: fixed type can also represent points of interest + +| Method | Endpoint | Description | +|--------|----------|-------------| +| GET | `/targets` | List all targets | +| GET | `/targets/{{id}}` | Get target details | +| GET | `/targets/type/{{type}}` | Get by type | + +Waypoint Endpoints + +| Method | Endpoint | Description | +|--------|----------|-------------| +| GET | `/targets/waypoints` | List charging stations | +| POST | `/targets/waypoints/{{id}}/check-drone` | Check if drone at waypoint | +| GET | `/targets/waypoints/nearest` | Find nearest waypoint | + +Obstacle Management + +Type of obstacle we hvave point, circle, ellipse, polygon. + +| Method | Endpoint | Description | +|--------|----------|-------------| +| GET | `/obstacles` | List all obstacles | +| GET | `/obstacles/{{id}}` | Get obstacle | +| GET | `/obstacles/type/{{type}}` | Get by type | + +Collision Detection + +**Endpoint:** `POST /obstacles/path_collision` + +**Authentication:** Requires SYSTEM role (ADMIN inherits) + +**Description:** Checks if a flight path from start to end collides with any obstacles. Returns the **first** obstacle that collides with the path. + +**Parameters:** + +| Name | Type | Required | Description | +|------|------|----------|-------------| +| start | object | Yes | Start point {{x, y, z}} | +| end | object | Yes | End point {{x, y, z}} | +| safety_margin | float | No | Additional clearance distance (in meters) around the flight path (default: 0.0). Creates a corridor with specified width on each side. Use 0.0 for direct line path, or > 0.0 for safety corridor (e.g., 5.0 creates a 10m-wide corridor). Note: Drone movement commands use 0.0 by default | + +**Height Logic:** +- `height = 0`: Impassable at any altitude +- `height > 0`: Collision only if max flight altitude <= obstacle.height + +**Response:** Collision response object or null if no collision + +**Example Request:** +```bash +curl -X POST http://localhost:8000/obstacles/path_collision \\ + -H "Content-Type: application/json" \\ + -H "X-API-Key: system_secret_key_change_in_production" \\ + -d '{{ + "start": {{"x": 0.0, "y": 0.0, "z": 10.0}}, + "end": {{"x": 200.0, "y": 300.0, "z": 10.0}}, + "safety_margin": 2.0 + }}' +``` + +**Example Response (Collision):** +```json +{{ + "obstacle_id": "550e8400-e29b-41d4-a716-446655440001", + "obstacle_name": "Water Tower", + "obstacle_type": "circle", + "collision_type": "path_intersection", + "distance": 5.0 +}} +``` + +**Example Response (No Collision):** +```json +null +``` + +| Method | Endpoint | Description | Auth | +|--------|----------|-------------|------| +| POST | `/obstacles/path_collision` | Check if flight path collides with obstacles | SYSTEM | +| POST | `/obstacles/point_collision` | Check if point is inside any obstacles (returns all matches) | SYSTEM | + +Proximity + +| Method | Endpoint | Description | +|--------|----------|-------------| +| GET | `/drones/{{id}}/nearby` | Aggregated nearby drones, targets, obstacles (uses drone's perceived_radius) | +| GET | `/drones/{{id}}/nearby/drones` | Nearby drones (uses drone's perceived_radius) | +| GET | `/drones/{{id}}/nearby/targets` | Nearby targets (uses drone's perceived_radius) | +| GET | `/drones/{{id}}/nearby/obstacles` | Nearby obstacles (uses drone's perceived_radius) | +All proximity endpoints use the drone's `perceived_radius` to determine the search area. \ No newline at end of file diff --git a/uav_agent.py b/uav_agent.py new file mode 100644 index 0000000..211baed --- /dev/null +++ b/uav_agent.py @@ -0,0 +1,671 @@ +""" +UAV Control Agent +An intelligent agent that understands natural language commands and controls drones using the UAV API +Uses LangChain 1.0+ with modern @tool decorator pattern +""" +from langchain_classic.agents import create_react_agent +from langchain_classic.agents import AgentExecutor +from langchain_classic.prompts import PromptTemplate +from langchain_ollama import ChatOllama +from langchain_openai import ChatOpenAI +from uav_api_client import UAVAPIClient +from uav_langchain_tools import create_uav_tools +from template.agent_prompt import AGENT_PROMPT +from template.parsing_error import PARSING_ERROR_TEMPLATE +from typing import Optional, Dict, Any +import json +import os +from pathlib import Path + + +def load_llm_settings(settings_path: str = "llm_settings.json") -> Optional[Dict[str, Any]]: + """Load LLM settings from JSON file""" + try: + path = Path(settings_path) + if path.exists(): + with open(path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + print(f"Warning: Could not load LLM settings from {settings_path}: {e}") + return None + + +def prompt_user_for_llm_config() -> Dict[str, Any]: + """Prompt user to select LLM provider and model""" + settings = load_llm_settings() + + if not settings or 'provider_configs' not in settings: + print("⚠️ No llm_settings.json found or invalid format. Using command line arguments.") + return {} + + provider_configs = settings['provider_configs'] + selected_provider = settings.get('selected_provider', '') + + print("\n" + "="*60) + print("🤖 LLM Provider Configuration") + print("="*60) + + # Show available providers + providers = list(provider_configs.keys()) + print("\nAvailable providers:") + for i, provider in enumerate(providers, 1): + config = provider_configs[provider] + default_marker = " (selected in settings)" if provider == selected_provider else "" + print(f" {i}. {provider}{default_marker}") + print(f" Type: {config.get('type', 'unknown')}") + print(f" Base URL: {config.get('base_url', 'N/A')}") + print(f" Requires API Key: {config.get('requires_api_key', False)}") + + # Prompt for provider selection + print(f"\nSelect a provider (1-{len(providers)}) [default: {selected_provider or providers[0]}]: ", end='') + provider_choice = input().strip() + + if not provider_choice: + # Use default + if selected_provider and selected_provider in providers: + chosen_provider = selected_provider + else: + chosen_provider = providers[0] + else: + try: + idx = int(provider_choice) - 1 + if 0 <= idx < len(providers): + chosen_provider = providers[idx] + else: + print(f"Invalid choice. Using default: {selected_provider or providers[0]}") + chosen_provider = selected_provider or providers[0] + except ValueError: + print(f"Invalid input. Using default: {selected_provider or providers[0]}") + chosen_provider = selected_provider or providers[0] + + config = provider_configs[chosen_provider] + print(f"\n✅ Selected provider: {chosen_provider}") + + # Show available models + default_models = config.get('default_models', []) + default_model = config.get('default_model', '') + + if default_models: + print("\nAvailable models:") + for i, model in enumerate(default_models, 1): + default_marker = " (default)" if model == default_model else "" + print(f" {i}. {model}{default_marker}") + print(f" {len(default_models) + 1}. Custom model (enter manually)") + + print(f"\nSelect a model (1-{len(default_models) + 1}) [default: {default_model}]: ", end='') + model_choice = input().strip() + + if not model_choice: + chosen_model = default_model + else: + try: + idx = int(model_choice) - 1 + if 0 <= idx < len(default_models): + chosen_model = default_models[idx] + elif idx == len(default_models): + # Custom model + print("Enter custom model name: ", end='') + chosen_model = input().strip() or default_model + else: + print(f"Invalid choice. Using default: {default_model}") + chosen_model = default_model + except ValueError: + print(f"Invalid input. Using default: {default_model}") + chosen_model = default_model + else: + # No predefined models, ask for custom input + print(f"\nEnter model name [default: {default_model}]: ", end='') + chosen_model = input().strip() or default_model + + print(f"✅ Selected model: {chosen_model}") + + # Determine provider type + provider_type = config.get('type', 'ollama') + if provider_type == 'openai-compatible': + if 'api.openai.com' in config.get('base_url', ''): + llm_provider = 'openai' + else: + llm_provider = 'openai-compatible' + else: + llm_provider = provider_type + + # Get API key if required + api_key = config.get('api_key', '').strip() + if config.get('requires_api_key', False) and not api_key: + print("\n⚠️ This provider requires an API key.") + print("Enter API key (or press Enter to use environment variable): ", end='') + api_key = input().strip() + + result = { + 'llm_provider': llm_provider, + 'llm_model': chosen_model, + 'llm_base_url': config.get('base_url'), + 'llm_api_key': api_key if api_key else None, + 'provider_name': chosen_provider + } + + print("\n" + "="*60) + print("✅ Configuration complete!") + print("="*60) + print(f"Provider: {chosen_provider}") + print(f"Type: {llm_provider}") + print(f"Model: {chosen_model}") + print(f"Base URL: {config.get('base_url')}") + if api_key: + print(f"API Key: {'*' * (len(api_key) - 4) + api_key[-4:] if len(api_key) > 4 else '****'}") + print("="*60 + "\n") + + return result + + +class UAVControlAgent: + """Intelligent agent for controlling UAVs using natural language""" + + def __init__( + self, + base_url: str = "http://localhost:8000", + uav_api_key: Optional[str] = None, + llm_provider: str = "ollama", + llm_model: str = "llama2", + llm_api_key: Optional[str] = None, + llm_base_url: Optional[str] = None, + temperature: float = 0.1, + verbose: bool = True, + debug: bool = False + ): + """ + Initialize the UAV Control Agent + + Args: + base_url: Base URL of the UAV API server + uav_api_key: API key for UAV server authentication (None = USER role, or provide SYSTEM/ADMIN key) + llm_provider: LLM provider ('ollama', 'openai', 'openai-compatible') + llm_model: Model name (e.g., 'llama2', 'gpt-4o-mini', 'deepseek-chat') + llm_api_key: API key for LLM provider (required for openai/openai-compatible) + llm_base_url: Custom base URL for LLM API (for openai-compatible providers) + temperature: LLM temperature (lower = more deterministic) + verbose: Enable verbose output for agent reasoning + debug: Enable debug output for connection and setup info + """ + self.client = UAVAPIClient(base_url, api_key=uav_api_key) + self.verbose = verbose + self.debug = debug + + if self.debug: + print("\n" + "="*60) + print("🔧 UAV Agent Initialization - Debug Mode") + print("="*60) + print(f"UAV API Server: {base_url}") + print(f"LLM Provider: {llm_provider}") + print(f"LLM Model: {llm_model}") + print(f"Temperature: {temperature}") + print(f"Verbose: {verbose}") + print() + + # Test UAV API connection + if self.debug: + print("🔌 Testing UAV API connection...") + try: + session = self.client.get_current_session() + if self.debug: + print(f"✅ Connected to UAV API") + print(f" Session: {session.get('name', 'Unknown')}") + print(f" Task: {session.get('task', 'Unknown')}") + print() + except Exception as e: + if self.debug: + print(f"⚠️ Warning: Could not connect to UAV API: {e}") + print(f" Make sure the UAV server is running at {base_url}") + print() + + # Initialize LLM based on provider + if self.debug: + print(f"🤖 Initializing LLM provider: {llm_provider}") + + if llm_provider == "ollama": + if self.debug: + print(f" Using Ollama with model: {llm_model}") + print(f" Ollama URL: http://localhost:11434 (default)") + + self.llm = ChatOllama( + model=llm_model, + temperature=temperature + ) + + if self.debug: + print(f"✅ Ollama LLM initialized") + print() + + elif llm_provider in ["openai", "openai-compatible"]: + if not llm_api_key: + raise ValueError(f"API key is required for {llm_provider} provider. Use --llm-api-key or set environment variable.") + + # Determine base URL + if llm_provider == "openai": + final_base_url = llm_base_url or "https://api.openai.com/v1" + provider_name = "OpenAI" + else: + if not llm_base_url: + raise ValueError("llm_base_url is required for openai-compatible provider") + final_base_url = llm_base_url + provider_name = "OpenAI-Compatible API" + + if self.debug: + print(f" Provider: {provider_name}") + print(f" Base URL: {final_base_url}") + print(f" Model: {llm_model}") + print(f" API Key: {'*' * (len(llm_api_key) - 4) + llm_api_key[-4:] if len(llm_api_key) > 4 else '****'}") + + # Create LLM instance + kwargs = { + "model": llm_model, + "temperature": temperature, + "api_key": llm_api_key, + "base_url": final_base_url + } + + self.llm = ChatOpenAI(**kwargs) + + if self.debug: + print(f"✅ {provider_name} LLM initialized") + print() + else: + raise ValueError( + f"Unknown LLM provider: {llm_provider}. " + f"Use 'ollama', 'openai', or 'openai-compatible'" + ) + + # Create tools using the new @tool decorator approach + if self.debug: + print("🔧 Creating UAV control tools...") + self.tools = create_uav_tools(self.client) + if self.debug: + print(f"✅ Created {len(self.tools)} tools") + print(f" Tools: {', '.join([tool.name for tool in self.tools[:5]])}...") + print() + + # Create prompt template + if self.debug: + print("📝 Creating agent prompt template...") + self.prompt = self._create_prompt() + if self.debug: + print("✅ Prompt template created") + print() + + # Create ReAct agent + if self.debug: + print("🤖 Creating ReAct agent...") + self.agent = create_react_agent( + llm=self.llm, + tools=self.tools, + prompt=self.prompt + ) + + if self.debug: + print("✅ ReAct agent created") + print() + + # Create agent executor with improved error handling + if self.debug: + print("⚙️ Creating agent executor...") + print(f" Max iterations: 20") + print(f" Verbose mode: {verbose}") + + # Custom error handler to help LLM fix formatting issues + def handle_parsing_error(error) -> str: + """Provide helpful feedback when Action Input parsing fails""" + return PARSING_ERROR_TEMPLATE.format(error=str(error)) + + self.agent_executor = AgentExecutor( + agent=self.agent, + tools=self.tools, + verbose=verbose, + handle_parsing_errors=handle_parsing_error, + max_iterations=50, # Increased for complex tasks + return_intermediate_steps=True, + early_stopping_method="generate" # Better handling of completion + ) + if self.debug: + print("✅ Agent executor created") + print() + + # Session context + if self.debug: + print("🔄 Refreshing session context...") + self.session_context = {} + self.refresh_session_context() + + if self.debug: + print("="*60) + print("✅ UAV Agent Initialization Complete!") + print("="*60) + print() + + def _create_prompt(self) -> PromptTemplate: + """Create the agent prompt template""" + prompt_template = PromptTemplate( + template=AGENT_PROMPT, + input_variables=["input", "agent_scratchpad"], + partial_variables={ + "tools": "\n".join([ + f"- {tool.name}: {tool.description}" + for tool in self.tools + ]), + "tool_names": ", ".join([tool.name for tool in self.tools]) + } + ) + return prompt_template + + def refresh_session_context(self): + """Refresh session context information""" + try: + session = self.client.get_current_session() + self.session_context = { + 'session_id': session.get('id'), + 'task_type': session.get('task'), + 'task_description': session.get('task_description'), + 'status': session.get('status') + } + except Exception as e: + if self.verbose: + print(f"Warning: Could not refresh session context: {e}") + + def get_session_summary(self) -> str: + """Get a summary of the current session""" + try: + session = self.client.get_current_session() + progress = self.client.get_task_progress() + drones = self.client.list_drones() + + summary = f""" +=== Current Session Summary === +Session: {session.get('name', 'Unknown')} +Task: {session.get('task', 'Unknown')} - {session.get('task_description', '')} +Status: {session.get('status', 'Unknown')} + +Progress: {progress.get('progress_percentage', 0)}% ({progress.get('status_message', 'Unknown')}) +Completed: {progress.get('is_completed', False)} + +Drones: {len(drones)} available +""" + for drone in drones: + summary += f" - {drone.get('name')} ({drone.get('id')}): {drone.get('status')}, Battery: {drone.get('battery_level', 0):.1f}%\n" + + return summary.strip() + except Exception as e: + return f"Error getting session summary: {e}" + + def execute(self, command: str) -> Dict[str, Any]: + """ + Execute a natural language command + + Args: + command: Natural language command from user + + Returns: + Dictionary with 'output', 'intermediate_steps', and 'success' keys + """ + if self.debug: + print(f"\n{'='*60}") + print(f"🎯 Executing Command") + print(f"{'='*60}") + print(f"Command: {command}") + print(f"{'='*60}\n") + + try: + if self.debug: + print("🔄 Invoking agent executor...") + + result = self.agent_executor.invoke({"input": command}) + + if self.debug: + print(f"\n{'='*60}") + print("✅ Command Execution Complete") + print(f"{'='*60}") + print(f"Success: True") + print(f"Intermediate steps: {len(result.get('intermediate_steps', []))}") + print(f"{'='*60}\n") + + return { + 'success': True, + 'output': result.get('output', ''), + 'intermediate_steps': result.get('intermediate_steps', []) + } + except Exception as e: + if self.debug: + print(f"\n{'='*60}") + print("❌ Command Execution Failed") + print(f"{'='*60}") + print(f"Error: {str(e)}") + print(f"{'='*60}\n") + + return { + 'success': False, + 'output': f"Error executing command: {str(e)}", + 'intermediate_steps': [] + } + + def run_interactive(self): + """Run the agent in interactive mode""" + print("\n" + "="*60) + print("🚁 UAV Control Agent - Interactive Mode") + print("="*60) + print("\nType 'quit', 'exit', or 'q' to stop") + print("Type 'status' to see session summary") + print("Type 'help' for example commands\n") + + # Show initial session summary + print(self.get_session_summary()) + print("\n" + "-"*60 + "\n") + + while True: + try: + user_input = input("\n🎮 Command: ").strip() + + if not user_input: + continue + + if user_input.lower() in ['quit', 'exit', 'q']: + print("\n👋 Goodbye!") + break + + if user_input.lower() == 'status': + print(self.get_session_summary()) + continue + + if user_input.lower() == 'help': + self._print_help() + continue + + # Execute command + print("\n🤖 Processing...\n") + result = self.execute(user_input) + + if result['success']: + print(f"\n✅ {result['output']}\n") + else: + print(f"\n❌ {result['output']}\n") + + except KeyboardInterrupt: + print("\n\n👋 Goodbye!") + break + except Exception as e: + print(f"\n❌ Error: {e}\n") + + def _print_help(self): + """Print example commands""" + help_text = """ +Example Commands: +================== + +Information: +- "What drones are available?" +- "Show me the current mission status" +- "What targets do I need to visit?" +- "Check the weather conditions" +- "What's the task progress?" + +Basic Control: +- "Take off drone-abc123 to 15 meters" +- "Move drone-abc123 to coordinates x=100, y=50, z=20" +- "Land drone-abc123" +- "Return all drones home" + +Mission Execution: +- "Visit all targets with the first drone" +- "Search the area with available drones" +- "Complete the mission task" +- "Patrol the assigned areas" + +Safety: +- "Check if there are obstacles between (0,0,10) and (100,100,10)" +- "What's nearby drone-abc123?" +- "Check battery levels" + +Smart Commands: +- "Take photos at all target locations" +- "Charge any drones with low battery" +- "Survey all targets and return home" +""" + print(help_text) + + +def main(): + """Main entry point""" + import argparse + import sys + + parser = argparse.ArgumentParser( + description="UAV Control Agent - Natural Language Drone Control" + ) + parser.add_argument( + '--base-url', + default='http://localhost:8000', + help='UAV API base URL' + ) + parser.add_argument( + '--uav-api-key', + default=None, + help='API key for UAV server (defaults to USER role if not provided, or set UAV_API_KEY env var)' + ) + parser.add_argument( + '--llm-provider', + default=None, + choices=['ollama', 'openai', 'openai-compatible'], + help='LLM provider (ollama, openai, or openai-compatible for DeepSeek, etc.)' + ) + parser.add_argument( + '--llm-model', + default=None, + help='LLM model name (e.g., llama2, gpt-4o-mini, deepseek-chat)' + ) + parser.add_argument( + '--llm-api-key', + default=None, + help='API key for LLM provider (or set via environment variable)' + ) + parser.add_argument( + '--llm-base-url', + default=None, + help='Custom base URL for LLM API (required for openai-compatible providers)' + ) + parser.add_argument( + '--temperature', + type=float, + default=0.1, + help='LLM temperature (0.0-1.0)' + ) + parser.add_argument( + '--command', '-c', + default=None, + help='Single command to execute (non-interactive)' + ) + parser.add_argument( + '--quiet', '-q', + action='store_true', + help='Reduce verbosity' + ) + parser.add_argument( + '--debug', '-d', + action='store_true', + help='Enable debug output for connection and setup info' + ) + parser.add_argument( + '--no-prompt', + action='store_true', + help='Skip interactive provider/model selection (use command line args or defaults)' + ) + + args = parser.parse_args() + + # Determine if we should prompt for config + should_prompt = ( + not args.no_prompt and + not args.command and # Only prompt in interactive mode + args.llm_provider is None and # No provider specified + args.llm_model is None # No model specified + ) + + # Get configuration from user prompt or command line + if should_prompt: + config = prompt_user_for_llm_config() + if config: + llm_provider = config.get('llm_provider', 'ollama') + llm_model = config.get('llm_model', 'llama2') + llm_base_url = config.get('llm_base_url') + llm_api_key = config.get('llm_api_key') + else: + # Fallback to defaults + llm_provider = 'ollama' + llm_model = 'llama2' + llm_base_url = None + llm_api_key = None + else: + # Use command line arguments or defaults + llm_provider = args.llm_provider or 'ollama' + llm_model = args.llm_model or 'llama2' + llm_base_url = args.llm_base_url + llm_api_key = args.llm_api_key + + # Get LLM API key from args or environment if not set + if not llm_api_key: + llm_api_key = os.getenv("OPENAI_API_KEY") or os.getenv("LLM_API_KEY") + + # Get UAV API key from args or environment + uav_api_key = args.uav_api_key or os.getenv("UAV_API_KEY") + + # Create agent + try: + agent = UAVControlAgent( + base_url=args.base_url, + uav_api_key=uav_api_key, + llm_provider=llm_provider, + llm_model=llm_model, + llm_api_key=llm_api_key, + llm_base_url=llm_base_url, + temperature=args.temperature, + verbose=not args.quiet, + debug=args.debug + ) + except Exception as e: + print(f"❌ Failed to create agent: {e}") + print("\nMake sure:") + print(" - Ollama is running (if using --llm-provider ollama)") + print(" - OPENAI_API_KEY is set (if using --llm-provider openai)") + print(" - UAV API server is accessible") + return 1 + + if args.command: + # Single command mode + result = agent.execute(args.command) + print(result['output']) + return 0 if result['success'] else 1 + else: + # Interactive mode + agent.run_interactive() + return 0 + + +if __name__ == "__main__": + import sys + sys.exit(main()) diff --git a/uav_api_client.py b/uav_api_client.py new file mode 100644 index 0000000..c437d15 --- /dev/null +++ b/uav_api_client.py @@ -0,0 +1,365 @@ +""" +UAV API Client +Wrapper for the UAV Control System API to simplify drone operations +""" +import requests +from typing import Dict, List, Any, Tuple, Optional + + +class UAVAPIClient: + """Client for interacting with the UAV Control System API""" + + def __init__(self, base_url: str = "http://localhost:8000", api_key: Optional[str] = None): + """ + Initialize UAV API Client + + Args: + base_url: Base URL of the UAV API server + api_key: Optional API key for authentication (defaults to USER role if not provided) + - None or empty: USER role (basic access) + - Valid key: SYSTEM or ADMIN role (based on key) + """ + self.base_url = base_url.rstrip('/') + self.api_key = api_key + self.headers = {} + if self.api_key: + self.headers['X-API-Key'] = self.api_key + + def _request(self, method: str, endpoint: str, **kwargs) -> Any: + """Make HTTP request to the API""" + url = f"{self.base_url}{endpoint}" + + # Merge authentication headers with any provided headers + headers = kwargs.pop('headers', {}) + headers.update(self.headers) + + try: + response = requests.request(method, url, headers=headers, **kwargs) + response.raise_for_status() + if response.status_code == 204: + return None + return response.json() + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + raise Exception(f"Authentication failed: Invalid API key") + elif e.response.status_code == 403: + error_detail = e.response.json().get('detail', 'Access denied') + raise Exception(f"Permission denied: {error_detail}") + raise Exception(f"API request failed: {e}") + except requests.exceptions.RequestException as e: + raise Exception(f"API request failed: {e}") + + # Drone Operations + def list_drones(self) -> List[Dict[str, Any]]: + """Get all drones in the current session""" + return self._request('GET', '/drones') + + def get_all_waypoints(self) -> List[Dict[str, Any]]: + """Get all waypoints in the current session""" + return self._request('GET', '/targets/type/waypoint') + + def get_drone_status(self, drone_id: str) -> Dict[str, Any]: + """Get detailed status of a specific drone""" + return self._request('GET', f'/drones/{drone_id}') + + def take_off(self, drone_id: str, altitude: float = 10.0) -> Dict[str, Any]: + """Command drone to take off to specified altitude""" + return self._request('POST', f'/drones/{drone_id}/command/take_off',params={'altitude': altitude}) + + def land(self, drone_id: str) -> Dict[str, Any]: + """Command drone to land at current position""" + return self._request('POST', f'/drones/{drone_id}/command/land') + + def move_to(self, drone_id: str, x: float, y: float, z: float) -> Dict[str, Any]: + """Move drone to specific coordinates""" + return self._request('POST', f'/drones/{drone_id}/command/move_to', + params={'x': x, 'y': y, 'z': z}) + + def optimal_way_to(self, drone_id: str, x: float, y: float, z: float, min_safe_height: float = 0.5) -> List[Dict[str, float]]: + """ + 计算到达目标点的最优路径(仅水平绕行)。 + """ + # 1. 目标高度检查 + if z < min_safe_height: + print(f"Error: Target altitude {z}m is too low.") + return [] + + status = self.get_drone_status(drone_id) + start_pos = status['position'] + start_coords = (start_pos['x'], start_pos['y'], start_pos['z']) + end_coords = (x, y, z) + + # 2. 起点高度检查 + if start_coords[2] < min_safe_height: + print(f"Warning: Drone is currently below safe height!") + + # 3. 执行递归搜索 + path_points = self._find_path_recursive( + start_coords, + end_coords, + avoidance_radius=2.0, # 初始绕行半径 + depth=0, + max_depth=4, # 最大递归深度 + min_safe_height=min_safe_height + ) + + if path_points is None: + print(f"Error: Unable to find a collision-free path.") + return [] + + # 4. 格式化输出 + formatted_path = [{"x": p[0], "y": p[1], "z": p[2]} for p in path_points] + for point in formatted_path: + self.move_to(drone_id, **point) + return formatted_path + + # --- 递归核心 --- + def _find_path_recursive(self, start: Tuple[float, float, float], end: Tuple[float, float, float], + avoidance_radius: float, depth: int, max_depth: int, + min_safe_height: float) -> Optional[List[Tuple[float, float, float]]]: + sx, sy, sz = start + ex, ey, ez = end + + # 1. 检查直连是否有碰撞 + collision = self.check_path_collision(sx, sy, sz, ex, ey, ez) + if not collision: + return [end] + + # 2. 达到最大深度则停止 + if depth >= max_depth: + return None + + # 3. 计算路径中点 + mid_point = ((sx + ex) / 2, (sy + ey) / 2, (sz + ez) / 2) + + # 随着深度增加,减小绕行半径,进行更精细的搜索 + current_radius = avoidance_radius / (1 + 0.5 * depth) + + # 4. 获取仅包含左右方向的候选点 + candidates = self._get_horizontal_avoidance_points(start, end, mid_point, current_radius) + + # 5. 遍历候选点 + for candidate in candidates: + # 过滤掉非法高度的点 (虽然水平偏移理论上不改变高度,但以防万一) + if candidate[2] < min_safe_height: + continue + + # 递归处理:起点 -> 候选点 + path_first = self._find_path_recursive(start, candidate, avoidance_radius, depth + 1, max_depth, min_safe_height) + + if path_first is not None: + # 递归处理:候选点 -> 终点 + path_second = self._find_path_recursive(candidate, end, avoidance_radius, depth + 1, max_depth, min_safe_height) + + if path_second is not None: + # 路径拼接 + return path_first + path_second + + # 所有左右尝试都失败 + return None + + # --- 向量计算 (核心修改部分) --- + def _get_horizontal_avoidance_points(self, start, end, mid, radius) -> List[Tuple[float, float, float]]: + """ + 生成候选点:强制仅在水平面上进行左右偏移。 + """ + # 1. 计算飞行方向向量 D = End - Start + dx = end[0] - start[0] + dy = end[1] - start[1] + # dz 我们不关心,因为我们要在水平面找垂线 + + # 计算水平投影的长度 + dist_horizontal = (dx*dx + dy*dy)**0.5 + + rx, ry, rz = 0.0, 0.0, 0.0 + + # 2. 计算右向量 (Right Vector) + if dist_horizontal == 0: + # 特殊情况:垂直升降 (Start和End的x,y相同) + # 此时"左右"没有绝对定义,我们任意选取 X 轴方向作为偏移方向 + rx, ry, rz = 1.0, 0.0, 0.0 + else: + # 标准情况:利用 2D 向量旋转 90 度原理 + # 向量 (x, y) 顺时针旋转 90 度变为 (y, -x) + # 归一化 + rx = dy / dist_horizontal + ry = -dx / dist_horizontal + rz = 0.0 # 强制 Z 轴分量为 0,保证水平 + + mx, my, mz = mid + + # 3. 生成候选点:只生成 右(Right) 和 左(Left) + # 注意:Right 是 (rx, ry),Left 是 (-rx, -ry) + candidates = [] + + # 右侧点 + c1 = (mx + rx * radius, my + ry * radius, mz) # Z高度保持中点高度不变 + candidates.append(c1) + + # 左侧点 + c2 = (mx - rx * radius, my - ry * radius, mz) + candidates.append(c2) + + return candidates + + + def move_along_path(self, drone_id: str, waypoints: List[Dict[str, float]]) -> Dict[str, Any]: + """Move drone along a path of waypoints""" + return self._request('POST', f'/drones/{drone_id}/command/move_along_path', + json={'waypoints': waypoints}) + + def change_altitude(self, drone_id: str, altitude: float) -> Dict[str, Any]: + """Change drone altitude while maintaining X/Y position""" + return self._request('POST', f'/drones/{drone_id}/command/change_altitude', + params={'altitude': altitude}) + + def hover(self, drone_id: str, duration: Optional[float] = None) -> Dict[str, Any]: + """ + Command drone to hover at current position. + + Args: + drone_id: ID of the drone + duration: Optional duration to hover in seconds + """ + params = {} + if duration is not None: + params['duration'] = duration + return self._request('POST', f'/drones/{drone_id}/command/hover', params=params) + + def rotate(self, drone_id: str, heading: float) -> Dict[str, Any]: + """Rotate drone to face specific direction (0-360 degrees)""" + return self._request('POST', f'/drones/{drone_id}/command/rotate', + params={'heading': heading}) + + def move_towards(self, drone_id: str, distance: float, heading: Optional[float] = None, + dz: Optional[float] = None) -> Dict[str, Any]: + """ + Move drone a specific distance in a direction. + + Args: + drone_id: ID of the drone + distance: Distance to move in meters + heading: Optional heading direction (0-360). If None, uses current heading. + dz: Optional vertical component (altitude change) + """ + params = {'distance': distance} + if heading is not None: + params['heading'] = heading + if dz is not None: + params['dz'] = dz + return self._request('POST', f'/drones/{drone_id}/command/move_towards', params=params) + + def return_home(self, drone_id: str) -> Dict[str, Any]: + """Command drone to return to home position""" + return self._request('POST', f'/drones/{drone_id}/command/return_home') + + def set_home(self, drone_id: str) -> Dict[str, Any]: + """Set current position as home position""" + return self._request('POST', f'/drones/{drone_id}/command/set_home') + + def calibrate(self, drone_id: str) -> Dict[str, Any]: + """Calibrate drone sensors""" + return self._request('POST', f'/drones/{drone_id}/command/calibrate') + + def charge(self, drone_id: str, charge_amount: float) -> Dict[str, Any]: + """Charge drone battery (when landed)""" + return self._request('POST', f'/drones/{drone_id}/command/charge', + params={'charge_amount': charge_amount}) + + def take_photo(self, drone_id: str) -> Dict[str, Any]: + """Take a photo with drone camera""" + return self._request('POST', f'/drones/{drone_id}/command/take_photo') + + def send_message(self, drone_id: str, target_drone_id: str, message: str) -> Dict[str, Any]: + """ + Send a message to another drone. + + Args: + drone_id: ID of the sender drone + target_drone_id: ID of the recipient drone + message: Content of the message + """ + return self._request('POST', f'/drones/{drone_id}/command/send_message', + params={'target_drone_id': target_drone_id, 'message': message}) + + def broadcast(self, drone_id: str, message: str) -> Dict[str, Any]: + """ + Broadcast a message to all other drones. + + Args: + drone_id: ID of the sender drone + message: Content of the message + """ + return self._request('POST', f'/drones/{drone_id}/command/broadcast', + params={'message': message}) + + # Session Operations + def get_current_session(self) -> Dict[str, Any]: + """Get information about current mission session""" + return self._request('GET', '/sessions/current') + + def get_session_data(self, session_id: str = 'current') -> Dict[str, Any]: + """Get all entities in a session (drones, targets, obstacles, environment)""" + return self._request('GET', f'/sessions/{session_id}/data') + + def get_task_progress(self, session_id: str = 'current') -> Dict[str, Any]: + """Get mission task completion progress""" + return self._request('GET', f'/sessions/{session_id}/task-progress') + + # Environment Operations + def get_weather(self) -> Dict[str, Any]: + """Get current weather conditions""" + return self._request('GET', '/environments/current') + + def get_targets(self) -> List[Dict[str, Any]]: + """Get all targets in the session""" + fixed = self._request('GET', '/targets/type/fixed') + moving = self._request('GET', '/targets/type/moving') + waypoint = self._request('GET', '/targets/type/waypoint') + circle = self._request('GET', '/targets/type/circle') + polygen = self._request('GET', '/targets/type/polygon') + return fixed + moving + waypoint + circle + polygen + + def get_waypoints(self) -> List[Dict[str, Any]]: + """Get all charging station waypoints""" + return self._request('GET', '/targets/type/waypoint') + + def get_nearest_waypoint(self, x: str, y: str, z: str) -> Dict[str, Any]: + """Get nearest charging station waypoint""" + return self._request('GET', '/targets/waypoints/nearest', + json={'x': x, 'y': y, 'z': z}) + + def get_obstacles(self) -> List[Dict[str, Any]]: + """Get all obstacles in the session""" + point = self._request('GET', '/obstacles/type/point') + circle = self._request('GET', '/obstacles/type/circle') + polygon = self._request('GET', '/obstacles/type/polygon') + ellipse = self._request('GET', '/obstacles/type/ellipse') + return point + circle + polygon + ellipse + + def get_nearby_entities(self, drone_id: str) -> Dict[str, Any]: + """Get entities near a drone (within perceived radius)""" + return self._request('GET', f'/drones/{drone_id}/nearby') + + # Safety Operations + def check_point_collision(self, x: float, y: float, z: float, + safety_margin: float = 0.0) -> Optional[Dict[str, Any]]: + """Check if a point collides with any obstacle""" + result = self._request('POST', '/obstacles/collision/check', + json={ + 'point': {'x': x, 'y': y, 'z': z}, + 'safety_margin': safety_margin + }) + return result + + def check_path_collision(self, start_x: float, start_y: float, start_z: float, + end_x: float, end_y: float, end_z: float, + safety_margin: float = 1.0) -> Optional[Dict[str, Any]]: + """Check if a path intersects any obstacle""" + result = self._request('POST', '/obstacles/collision/path', + json={ + 'start': {'x': start_x, 'y': start_y, 'z': start_z}, + 'end': {'x': end_x, 'y': end_y, 'z': end_z}, + 'safety_margin': safety_margin + }) + return result diff --git a/uav_langchain_tools.py b/uav_langchain_tools.py new file mode 100644 index 0000000..8235372 --- /dev/null +++ b/uav_langchain_tools.py @@ -0,0 +1,649 @@ +""" +LangChain Tools for UAV Control +Wraps the UAV API client as LangChain tools using @tool decorator +All tools accept JSON string input for consistent parameter handling +""" +from langchain.tools import tool +from uav_api_client import UAVAPIClient +import json + + +def create_uav_tools(client: UAVAPIClient) -> list: + """ + Create all UAV control tools for LangChain agent using @tool decorator + All tools that require parameters accept a JSON string input + """ + + # ========== Information Gathering Tools (No Parameters) ========== + + @tool + def list_drones() -> str: + """List all available drones in the current session with their status, battery level, and position. + Use this to see what drones are available before trying to control them. + + No input required.""" + try: + drones = client.list_drones() + return json.dumps(drones, indent=2) + except Exception as e: + return f"Error listing drones: {str(e)}" + + @tool + def get_session_info() -> str: + """Get current session information including task type, statistics, and status. + Use this to understand what mission you need to complete. + + No input required.""" + try: + session = client.get_current_session() + return json.dumps(session, indent=2) + except Exception as e: + return f"Error getting session info: {str(e)}" + + @tool + def get_session_data() -> str: + """Get all session data including drones, targets, and obstacles. + Use this to understand the environment and plan your mission. + + No input required.""" + try: + session_data = client.get_session_data() + return json.dumps(session_data, indent=2) + except Exception as e: + return f"Error getting session data: {str(e)}" + + @tool + def get_task_progress() -> str: + """Get mission task progress including completion percentage and status. + Use this to track mission completion and see how close you are to finishing. + + No input required.""" + try: + progress = client.get_task_progress() + return json.dumps(progress, indent=2) + except Exception as e: + return f"Error getting task progress: {str(e)}" + + @tool + def get_weather() -> str: + """Get current weather conditions including wind speed, visibility, and weather type. + Check this before takeoff to ensure safe flying conditions. + + No input required.""" + try: + weather = client.get_weather() + return json.dumps(weather, indent=2) + except Exception as e: + return f"Error getting weather: {str(e)}" + + @tool + def get_targets() -> str: + """Get all targets in the session including fixed, moving, waypoint, circle and polygon to search or patrol. + Use this to see what targets you need to visit. + + No input required.""" + try: + targets = client.get_targets() + return json.dumps(targets, indent=2) + except Exception as e: + return f"Error getting targets: {str(e)}" + + @tool + def get_all_waypoints() -> str: + """Get all waypoints in the session including coordinates and altitude. + Use this to understand the where to charge that drones will follow. + + No input required.""" + try: + waypoints = client.get_all_waypoints() + return json.dumps(waypoints, indent=2) + except Exception as e: + return f"Error getting waypoints: {str(e)}" + + @tool + def get_obstacles() -> str: + """Get all obstacles in the session that drones must avoid. + Use this to understand what obstacles exist in the environment. + + No input required.""" + try: + obstacles = client.get_obstacles() + return json.dumps(obstacles, indent=2) + except Exception as e: + return f"Error getting obstacles: {str(e)}" + + + @tool + def get_drone_status(input_json: str) -> str: + """Get detailed status of a specific drone including position, battery, heading, and visited targets. + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + + Example: {{"drone_id": "drone-001"}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + + if not drone_id: + return "Error: drone_id is required" + + status = client.get_drone_status(drone_id) + return json.dumps(status, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\"}}" + except Exception as e: + return f"Error getting drone status: {str(e)}" + + @tool + def get_nearby_entities(input_json: str) -> str: + """Get drones, targets, and obstacles near a specific drone (within its perception radius). + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + + Example: {{"drone_id": "drone-001"}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + + if not drone_id: + return "Error: drone_id is required" + + nearby = client.get_nearby_entities(drone_id) + return json.dumps(nearby, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\"}}" + except Exception as e: + return f"Error getting nearby entities: {str(e)}" + + @tool + def land(input_json: str) -> str: + """Command a drone to land at its current position. + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + + Example: {{"drone_id": "drone-001"}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + + if not drone_id: + return "Error: drone_id is required" + + result = client.land(drone_id) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\"}}" + except Exception as e: + return f"Error during landing: {str(e)}" + + @tool + def hover(input_json: str) -> str: + """Command a drone to hover at its current position. + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + - duration: Optional duration in seconds to hover (optional) + + Example: {{"drone_id": "drone-001", "duration": 5.0}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + duration = params.get('duration') + + if not drone_id: + return "Error: drone_id is required" + + result = client.hover(drone_id, duration) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\"}}" + except Exception as e: + return f"Error hovering: {str(e)}" + + @tool + def return_home(input_json: str) -> str: + """Command a drone to return to its home position. + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + + Example: {{"drone_id": "drone-001"}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + + if not drone_id: + return "Error: drone_id is required" + + result = client.return_home(drone_id) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\"}}" + except Exception as e: + return f"Error returning home: {str(e)}" + + @tool + def set_home(input_json: str) -> str: + """Set the drone's current position as its new home position. + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + + Example: {{"drone_id": "drone-001"}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + + if not drone_id: + return "Error: drone_id is required" + + result = client.set_home(drone_id) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\"}}" + except Exception as e: + return f"Error setting home: {str(e)}" + + @tool + def calibrate(input_json: str) -> str: + """Calibrate the drone's sensors. + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + + Example: {{"drone_id": "drone-001"}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + + if not drone_id: + return "Error: drone_id is required" + + result = client.calibrate(drone_id) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\"}}" + except Exception as e: + return f"Error calibrating: {str(e)}" + + @tool + def take_photo(input_json: str) -> str: + """Command a drone to take a photo. + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + + Example: {{"drone_id": "drone-001"}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + + if not drone_id: + return "Error: drone_id is required" + + result = client.take_photo(drone_id) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\"}}" + except Exception as e: + return f"Error taking photo: {str(e)}" + + # ========== Two Parameter Tools ========== + + @tool + def take_off(input_json: str) -> str: + """Command a drone to take off to a specified altitude. + Drone must be on ground (idle or ready status). + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + - altitude: Target altitude in meters (optional, default: 10.0) + + Example: {{"drone_id": "drone-001", "altitude": 15.0}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + altitude = params.get('altitude', 10.0) + + if not drone_id: + return "Error: drone_id is required" + + result = client.take_off(drone_id, altitude) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\", \"altitude\": 15.0}}" + except Exception as e: + return f"Error during takeoff: {str(e)}" + + @tool + def change_altitude(input_json: str) -> str: + """Change a drone's altitude while maintaining X/Y position. + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + - altitude: Target altitude in meters (required) + + Example: {{"drone_id": "drone-001", "altitude": 20.0}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + altitude = params.get('altitude') + + if not drone_id: + return "Error: drone_id is required" + if altitude is None: + return "Error: altitude is required" + + result = client.change_altitude(drone_id, altitude) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\", \"altitude\": 20.0}}" + except Exception as e: + return f"Error changing altitude: {str(e)}" + + @tool + def rotate(input_json: str) -> str: + """Rotate a drone to face a specific direction. + 0=North, 90=East, 180=South, 270=West. + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + - heading: Target heading in degrees 0-360 (required) + + Example: {{"drone_id": "drone-001", "heading": 90.0}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + heading = params.get('heading') + + if not drone_id: + return "Error: drone_id is required" + if heading is None: + return "Error: heading is required" + + result = client.rotate(drone_id, heading) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\", \"heading\": 90.0}}" + except Exception as e: + return f"Error rotating: {str(e)}" + + @tool + def send_message(input_json: str) -> str: + """Send a message from one drone to another. + + Input should be a JSON string with: + - drone_id: The ID of the sender drone (required) + - target_drone_id: The ID of the recipient drone (required) + - message: The message content (required) + + Example: {{"drone_id": "drone-001", "target_drone_id": "drone-002", "message": "Hello"}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + target_drone_id = params.get('target_drone_id') + message = params.get('message') + + if not drone_id: + return "Error: drone_id is required" + if not target_drone_id: + return "Error: target_drone_id is required" + if not message: + return "Error: message is required" + + result = client.send_message(drone_id, target_drone_id, message) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\", \"target_drone_id\": \"drone-002\", \"message\": \"...\"}}" + except Exception as e: + return f"Error sending message: {str(e)}" + + @tool + def broadcast(input_json: str) -> str: + """Broadcast a message from one drone to all other drones. + + Input should be a JSON string with: + - drone_id: The ID of the sender drone (required) + - message: The message content (required) + + Example: {{"drone_id": "drone-001", "message": "Alert"}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + message = params.get('message') + + if not drone_id: + return "Error: drone_id is required" + if not message: + return "Error: message is required" + + result = client.broadcast(drone_id, message) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\", \"message\": \"...\"}}" + except Exception as e: + return f"Error broadcasting: {str(e)}" + + @tool + def charge(input_json: str) -> str: + """Command a drone to charge its battery. + Drone must be landed at a charging station. + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + - charge_amount: Amount to charge in percent (required) + + Example: {{"drone_id": "drone-001", "charge_amount": 25.0}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + charge_amount = params.get('charge_amount') + + if not drone_id: + return "Error: drone_id is required" + if charge_amount is None: + return "Error: charge_amount is required" + + result = client.charge(drone_id, charge_amount) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\", \"charge_amount\": 25.0}}" + except Exception as e: + return f"Error charging: {str(e)}" + + @tool + def move_towards(input_json: str) -> str: + """Move a drone a specific distance in a direction. + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + - distance: Distance to move in meters (required) + - heading: Heading direction in degrees 0-360 (optional, default: current heading) + - dz: Vertical component in meters (optional) + + Example: {{"drone_id": "drone-001", "distance": 10.0, "heading": 90.0}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + distance = params.get('distance') + heading = params.get('heading') + dz = params.get('dz') + + if not drone_id: + return "Error: drone_id is required" + if distance is None: + return "Error: distance is required" + + result = client.move_towards(drone_id, distance, heading, dz) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\", \"distance\": 10.0}}" + except Exception as e: + return f"Error moving towards: {str(e)}" + + # @tool + # def move_along_path(input_json: str) -> str: + # """Move a drone along a path of waypoints. + + # Input should be a JSON string with: + # - drone_id: The ID of the drone (required) + # - waypoints: List of points with x, y, z coordinates (required) + + # Example: {{"drone_id": "drone-001", "waypoints": [{{"x": 10, "y": 10, "z": 10}}, {{"x": 20, "y": 20, "z": 10}}]}} + # """ + # try: + # params = json.loads(input_json) if isinstance(input_json, str) else input_json + # drone_id = params.get('drone_id') + # waypoints = params.get('waypoints') + + # if not drone_id: + # return "Error: drone_id is required" + # if not waypoints: + # return "Error: waypoints list is required" + + # result = client.move_along_path(drone_id, waypoints) + # return json.dumps(result, indent=2) + # except json.JSONDecodeError as e: + # return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\", \"waypoints\": [...]}}" + # except Exception as e: + # return f"Error moving along path: {str(e)}" + + # ========== Multi-Parameter Tools ========== + + @tool + def get_nearest_waypoint(input_json: str) -> str: + """Get the nearest waypoint to a specific drone. + Input should be a JSON string with: + - x: The x-coordinate of the drone (required) + - y: The y-coordinate of the drone (required) + - z: The z-coordinate of the drone (required) + + Example: {{"x": 0.0, "y": 0.0, "z": 0.0}}""" + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + x = params.get('x') + y = params.get('y') + z = params.get('z') + + if x is None or y is None or z is None: + return "Error: x, y, and z coordinates are required" + nearest = client.get_nearest_waypoint(x, y, z) + return json.dumps(nearest, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"x\": 0.0, \"y\": 0.0, \"z\": 0.0}}" + except Exception as e: + return f"Error getting nearest waypoint: {str(e)}" + + @tool + def move_to(input_json: str) -> str: + """Move a drone to specific 3D coordinates (x, y, z). + Always check for collisions first using check_path_collision. + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + - x: Target X coordinate in meters (required) + - y: Target Y coordinate in meters (required) + - z: Target Z coordinate (altitude) in meters (required) + + Example: {{"drone_id": "drone-001", "x": 100.0, "y": 50.0, "z": 20.0}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + x = params.get('x') + y = params.get('y') + z = params.get('z') + + if not drone_id: + return "Error: drone_id is required" + if x is None or y is None or z is None: + return "Error: x, y, and z coordinates are required" + + result = client.move_to(drone_id, x, y, z) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\", \"x\": 100.0, \"y\": 50.0, \"z\": 20.0}}" + except Exception as e: + return f"Error moving drone: {str(e)}" + + @tool + def optimal_way_to(input_json: str) -> str: + """Get the optimal path to a specific 3D coordinates (x, y, z). + Always check for collisions first using check_path_collision. + + Input should be a JSON string with: + - drone_id: The ID of the drone (required) + - x: Target X coordinate in meters (required) + - y: Target Y coordinate in meters (required) + - z: Target Z coordinate (altitude) in meters (required) + + Example: {{"drone_id": "drone-001", "x": 100.0, "y": 50.0, "z": 20.0}} + """ + try: + params = json.loads(input_json) if isinstance(input_json, str) else input_json + drone_id = params.get('drone_id') + x = params.get('x') + y = params.get('y') + z = params.get('z') + + if not drone_id: + return "Error: drone_id is required" + if x is None or y is None or z is None: + return "Error: x, y, and z coordinates are required" + + result = client.optimal_way_to(drone_id, x, y, z) + return json.dumps(result, indent=2) + except json.JSONDecodeError as e: + return f"Error parsing JSON input: {str(e)}. Expected format: {{\"drone_id\": \"drone-001\", \"x\": 100.0, \"y\": 50.0, \"z\": 20.0}}" + except Exception as e: + return f"Error moving drone: {str(e)}" + + + # Return all tools + return [ + list_drones, + get_drone_status, + get_session_info, + # get_session_data, + get_task_progress, + get_weather, + # get_targets, + get_obstacles, + get_nearby_entities, + take_off, + land, + move_to, + # optimal_way_to, + move_towards, + change_altitude, + hover, + rotate, + return_home, + set_home, + calibrate, + take_photo, + send_message, + broadcast, + charge, + get_nearest_waypoint, + get_all_waypoints, + get_targets + ]