diff --git a/scripts/chat.py b/scripts/chat.py index a1d5245..7efe1b1 100755 --- a/scripts/chat.py +++ b/scripts/chat.py @@ -3,7 +3,7 @@ # requires-python = ">=3.11" # dependencies = [ # "rich>=13.7.0", -# "httpx>=0.27.0", +# "openai>=1.0.0", # ] # /// @@ -18,11 +18,10 @@ Usage: """ import argparse -import json import sys from typing import Optional -import httpx +from openai import OpenAI, APIStatusError from rich.console import Console from rich.live import Live from rich.markdown import Markdown @@ -34,15 +33,12 @@ from rich.table import Table class ChatClient: def __init__(self, base_url: str, token: Optional[str] = None): self.base_url = base_url.rstrip("/") - self.token = token + self.client = OpenAI( + base_url=f"{self.base_url}/v1", + api_key=token or "no-key", + ) self.messages = [] self.console = Console() - - def _headers(self) -> dict: - headers = {"Content-Type": "application/json"} - if self.token: - headers["Authorization"] = f"Bearer {self.token}" - return headers def chat(self, user_message: str, model: str, stream: bool = True): """Send a chat message and get response.""" @@ -52,35 +48,20 @@ class ChatClient: "content": [{"type": "input_text", "text": user_message}] }) - payload = { - "model": model, - "input": self.messages, - "stream": stream - } - if stream: - return self._stream_response(payload, model) + return self._stream_response(model) else: - return self._sync_response(payload, model) + return self._sync_response(model) - def _sync_response(self, payload: dict, model: str) -> str: + def _sync_response(self, model: str) -> str: """Non-streaming response.""" with self.console.status(f"[bold blue]Thinking ({model})..."): - resp = httpx.post( - f"{self.base_url}/v1/responses", - json=payload, - headers=self._headers(), - timeout=60.0 + response = self.client.responses.create( + model=model, + input=self.messages, ) - resp.raise_for_status() - data = resp.json() - assistant_text = "" - - for msg in data.get("output", []): - for block in msg.get("content", []): - if block.get("type") == "output_text": - assistant_text += block.get("text", "") + assistant_text = response.output_text # Add to history self.messages.append({ @@ -90,40 +71,19 @@ class ChatClient: return assistant_text - def _stream_response(self, payload: dict, model: str) -> str: + def _stream_response(self, model: str) -> str: """Streaming response with live rendering.""" assistant_text = "" - with httpx.stream( - "POST", - f"{self.base_url}/v1/responses", - json=payload, - headers=self._headers(), - timeout=60.0 - ) as resp: - resp.raise_for_status() - - with Live(console=self.console, refresh_per_second=10) as live: - for line in resp.iter_lines(): - if not line.startswith("data: "): - continue - - data_str = line[6:] # Remove "data: " prefix - - try: - chunk = json.loads(data_str) - except json.JSONDecodeError: - continue - - if chunk.get("done"): - break - - delta = chunk.get("delta", {}) - for block in delta.get("content", []): - if block.get("type") == "output_text": - assistant_text += block.get("text", "") - - # Render markdown in real-time + with Live(console=self.console, refresh_per_second=10) as live: + stream = self.client.responses.create( + model=model, + input=self.messages, + stream=True, + ) + for event in stream: + if event.type == "response.output_text.delta": + assistant_text += event.delta live.update(Markdown(assistant_text)) # Add to history @@ -139,23 +99,21 @@ class ChatClient: self.messages = [] -def print_models_table(base_url: str, headers: dict): +def print_models_table(client: OpenAI): """Fetch and print available models from the gateway.""" console = Console() try: - resp = httpx.get(f"{base_url}/v1/models", headers=headers, timeout=10) - resp.raise_for_status() - data = resp.json().get("data", []) + models = client.models.list() except Exception as e: console.print(f"[red]Failed to fetch models: {e}[/red]") return table = Table(title="Available Models", show_header=True, header_style="bold magenta") - table.add_column("Provider", style="cyan") + table.add_column("Owner", style="cyan") table.add_column("Model ID", style="green") - for model in data: - table.add_row(model.get("provider", ""), model.get("id", "")) + for model in models: + table.add_row(model.owned_by, model.id) console.print(table) @@ -163,14 +121,29 @@ def print_models_table(base_url: str, headers: dict): def main(): parser = argparse.ArgumentParser(description="Chat with go-llm-gateway") parser.add_argument("--url", default="http://localhost:8080", help="Gateway URL") - parser.add_argument("--model", default="gemini-2.0-flash-exp", help="Model to use") + parser.add_argument("--model", default=None, help="Model to use (defaults to first available)") parser.add_argument("--token", help="Auth token (Bearer)") parser.add_argument("--no-stream", action="store_true", help="Disable streaming") args = parser.parse_args() console = Console() client = ChatClient(args.url, args.token) - current_model = args.model + + # Fetch available models and select default + try: + available_models = list(client.client.models.list()) + except Exception as e: + console.print(f"[bold red]Failed to connect to gateway:[/bold red] {e}") + sys.exit(1) + + if not available_models: + console.print("[bold red]Error:[/bold red] No models are configured on the gateway.") + sys.exit(1) + + if args.model: + current_model = args.model + else: + current_model = available_models[0].id stream_enabled = not args.no_stream # Welcome banner @@ -230,7 +203,7 @@ def main(): )) elif cmd == "/models": - print_models_table(args.url, client._headers()) + print_models_table(client.client) elif cmd == "/model": if len(cmd_parts) < 2: @@ -265,8 +238,8 @@ def main(): # For non-streaming, render markdown console.print(Markdown(response)) - except httpx.HTTPStatusError as e: - console.print(f"[bold red]Error {e.response.status_code}:[/bold red] {e.response.text}") + except APIStatusError as e: + console.print(f"[bold red]Error {e.status_code}:[/bold red] {e.message}") except Exception as e: console.print(f"[bold red]Error:[/bold red] {e}")