Update chat script to use openai lib

This commit is contained in:
2026-03-02 04:21:54 +00:00
parent 3e645a3525
commit cf47ad444a

View File

@@ -3,7 +3,7 @@
# requires-python = ">=3.11" # requires-python = ">=3.11"
# dependencies = [ # dependencies = [
# "rich>=13.7.0", # "rich>=13.7.0",
# "httpx>=0.27.0", # "openai>=1.0.0",
# ] # ]
# /// # ///
@@ -18,11 +18,10 @@ Usage:
""" """
import argparse import argparse
import json
import sys import sys
from typing import Optional from typing import Optional
import httpx from openai import OpenAI, APIStatusError
from rich.console import Console from rich.console import Console
from rich.live import Live from rich.live import Live
from rich.markdown import Markdown from rich.markdown import Markdown
@@ -34,16 +33,13 @@ from rich.table import Table
class ChatClient: class ChatClient:
def __init__(self, base_url: str, token: Optional[str] = None): def __init__(self, base_url: str, token: Optional[str] = None):
self.base_url = base_url.rstrip("/") self.base_url = base_url.rstrip("/")
self.token = token self.client = OpenAI(
base_url=f"{self.base_url}/v1",
api_key=token or "no-key",
)
self.messages = [] self.messages = []
self.console = Console() self.console = Console()
def _headers(self) -> dict:
headers = {"Content-Type": "application/json"}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
return headers
def chat(self, user_message: str, model: str, stream: bool = True): def chat(self, user_message: str, model: str, stream: bool = True):
"""Send a chat message and get response.""" """Send a chat message and get response."""
# Add user message to history # Add user message to history
@@ -52,35 +48,20 @@ class ChatClient:
"content": [{"type": "input_text", "text": user_message}] "content": [{"type": "input_text", "text": user_message}]
}) })
payload = {
"model": model,
"input": self.messages,
"stream": stream
}
if stream: if stream:
return self._stream_response(payload, model) return self._stream_response(model)
else: else:
return self._sync_response(payload, model) return self._sync_response(model)
def _sync_response(self, payload: dict, model: str) -> str: def _sync_response(self, model: str) -> str:
"""Non-streaming response.""" """Non-streaming response."""
with self.console.status(f"[bold blue]Thinking ({model})..."): with self.console.status(f"[bold blue]Thinking ({model})..."):
resp = httpx.post( response = self.client.responses.create(
f"{self.base_url}/v1/responses", model=model,
json=payload, input=self.messages,
headers=self._headers(),
timeout=60.0
) )
resp.raise_for_status()
data = resp.json() assistant_text = response.output_text
assistant_text = ""
for msg in data.get("output", []):
for block in msg.get("content", []):
if block.get("type") == "output_text":
assistant_text += block.get("text", "")
# Add to history # Add to history
self.messages.append({ self.messages.append({
@@ -90,40 +71,19 @@ class ChatClient:
return assistant_text return assistant_text
def _stream_response(self, payload: dict, model: str) -> str: def _stream_response(self, model: str) -> str:
"""Streaming response with live rendering.""" """Streaming response with live rendering."""
assistant_text = "" assistant_text = ""
with httpx.stream( with Live(console=self.console, refresh_per_second=10) as live:
"POST", stream = self.client.responses.create(
f"{self.base_url}/v1/responses", model=model,
json=payload, input=self.messages,
headers=self._headers(), stream=True,
timeout=60.0 )
) as resp: for event in stream:
resp.raise_for_status() if event.type == "response.output_text.delta":
assistant_text += event.delta
with Live(console=self.console, refresh_per_second=10) as live:
for line in resp.iter_lines():
if not line.startswith("data: "):
continue
data_str = line[6:] # Remove "data: " prefix
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
continue
if chunk.get("done"):
break
delta = chunk.get("delta", {})
for block in delta.get("content", []):
if block.get("type") == "output_text":
assistant_text += block.get("text", "")
# Render markdown in real-time
live.update(Markdown(assistant_text)) live.update(Markdown(assistant_text))
# Add to history # Add to history
@@ -139,23 +99,21 @@ class ChatClient:
self.messages = [] self.messages = []
def print_models_table(base_url: str, headers: dict): def print_models_table(client: OpenAI):
"""Fetch and print available models from the gateway.""" """Fetch and print available models from the gateway."""
console = Console() console = Console()
try: try:
resp = httpx.get(f"{base_url}/v1/models", headers=headers, timeout=10) models = client.models.list()
resp.raise_for_status()
data = resp.json().get("data", [])
except Exception as e: except Exception as e:
console.print(f"[red]Failed to fetch models: {e}[/red]") console.print(f"[red]Failed to fetch models: {e}[/red]")
return return
table = Table(title="Available Models", show_header=True, header_style="bold magenta") table = Table(title="Available Models", show_header=True, header_style="bold magenta")
table.add_column("Provider", style="cyan") table.add_column("Owner", style="cyan")
table.add_column("Model ID", style="green") table.add_column("Model ID", style="green")
for model in data: for model in models:
table.add_row(model.get("provider", ""), model.get("id", "")) table.add_row(model.owned_by, model.id)
console.print(table) console.print(table)
@@ -163,14 +121,29 @@ def print_models_table(base_url: str, headers: dict):
def main(): def main():
parser = argparse.ArgumentParser(description="Chat with go-llm-gateway") parser = argparse.ArgumentParser(description="Chat with go-llm-gateway")
parser.add_argument("--url", default="http://localhost:8080", help="Gateway URL") parser.add_argument("--url", default="http://localhost:8080", help="Gateway URL")
parser.add_argument("--model", default="gemini-2.0-flash-exp", help="Model to use") parser.add_argument("--model", default=None, help="Model to use (defaults to first available)")
parser.add_argument("--token", help="Auth token (Bearer)") parser.add_argument("--token", help="Auth token (Bearer)")
parser.add_argument("--no-stream", action="store_true", help="Disable streaming") parser.add_argument("--no-stream", action="store_true", help="Disable streaming")
args = parser.parse_args() args = parser.parse_args()
console = Console() console = Console()
client = ChatClient(args.url, args.token) client = ChatClient(args.url, args.token)
current_model = args.model
# Fetch available models and select default
try:
available_models = list(client.client.models.list())
except Exception as e:
console.print(f"[bold red]Failed to connect to gateway:[/bold red] {e}")
sys.exit(1)
if not available_models:
console.print("[bold red]Error:[/bold red] No models are configured on the gateway.")
sys.exit(1)
if args.model:
current_model = args.model
else:
current_model = available_models[0].id
stream_enabled = not args.no_stream stream_enabled = not args.no_stream
# Welcome banner # Welcome banner
@@ -230,7 +203,7 @@ def main():
)) ))
elif cmd == "/models": elif cmd == "/models":
print_models_table(args.url, client._headers()) print_models_table(client.client)
elif cmd == "/model": elif cmd == "/model":
if len(cmd_parts) < 2: if len(cmd_parts) < 2:
@@ -265,8 +238,8 @@ def main():
# For non-streaming, render markdown # For non-streaming, render markdown
console.print(Markdown(response)) console.print(Markdown(response))
except httpx.HTTPStatusError as e: except APIStatusError as e:
console.print(f"[bold red]Error {e.response.status_code}:[/bold red] {e.response.text}") console.print(f"[bold red]Error {e.status_code}:[/bold red] {e.message}")
except Exception as e: except Exception as e:
console.print(f"[bold red]Error:[/bold red] {e}") console.print(f"[bold red]Error:[/bold red] {e}")