Update chat script to use openai lib
This commit is contained in:
123
scripts/chat.py
123
scripts/chat.py
@@ -3,7 +3,7 @@
|
|||||||
# requires-python = ">=3.11"
|
# requires-python = ">=3.11"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "rich>=13.7.0",
|
# "rich>=13.7.0",
|
||||||
# "httpx>=0.27.0",
|
# "openai>=1.0.0",
|
||||||
# ]
|
# ]
|
||||||
# ///
|
# ///
|
||||||
|
|
||||||
@@ -18,11 +18,10 @@ Usage:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import httpx
|
from openai import OpenAI, APIStatusError
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.live import Live
|
from rich.live import Live
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
@@ -34,15 +33,12 @@ from rich.table import Table
|
|||||||
class ChatClient:
|
class ChatClient:
|
||||||
def __init__(self, base_url: str, token: Optional[str] = None):
|
def __init__(self, base_url: str, token: Optional[str] = None):
|
||||||
self.base_url = base_url.rstrip("/")
|
self.base_url = base_url.rstrip("/")
|
||||||
self.token = token
|
self.client = OpenAI(
|
||||||
|
base_url=f"{self.base_url}/v1",
|
||||||
|
api_key=token or "no-key",
|
||||||
|
)
|
||||||
self.messages = []
|
self.messages = []
|
||||||
self.console = Console()
|
self.console = Console()
|
||||||
|
|
||||||
def _headers(self) -> dict:
|
|
||||||
headers = {"Content-Type": "application/json"}
|
|
||||||
if self.token:
|
|
||||||
headers["Authorization"] = f"Bearer {self.token}"
|
|
||||||
return headers
|
|
||||||
|
|
||||||
def chat(self, user_message: str, model: str, stream: bool = True):
|
def chat(self, user_message: str, model: str, stream: bool = True):
|
||||||
"""Send a chat message and get response."""
|
"""Send a chat message and get response."""
|
||||||
@@ -52,35 +48,20 @@ class ChatClient:
|
|||||||
"content": [{"type": "input_text", "text": user_message}]
|
"content": [{"type": "input_text", "text": user_message}]
|
||||||
})
|
})
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"input": self.messages,
|
|
||||||
"stream": stream
|
|
||||||
}
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_response(payload, model)
|
return self._stream_response(model)
|
||||||
else:
|
else:
|
||||||
return self._sync_response(payload, model)
|
return self._sync_response(model)
|
||||||
|
|
||||||
def _sync_response(self, payload: dict, model: str) -> str:
|
def _sync_response(self, model: str) -> str:
|
||||||
"""Non-streaming response."""
|
"""Non-streaming response."""
|
||||||
with self.console.status(f"[bold blue]Thinking ({model})..."):
|
with self.console.status(f"[bold blue]Thinking ({model})..."):
|
||||||
resp = httpx.post(
|
response = self.client.responses.create(
|
||||||
f"{self.base_url}/v1/responses",
|
model=model,
|
||||||
json=payload,
|
input=self.messages,
|
||||||
headers=self._headers(),
|
|
||||||
timeout=60.0
|
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
|
||||||
|
|
||||||
data = resp.json()
|
assistant_text = response.output_text
|
||||||
assistant_text = ""
|
|
||||||
|
|
||||||
for msg in data.get("output", []):
|
|
||||||
for block in msg.get("content", []):
|
|
||||||
if block.get("type") == "output_text":
|
|
||||||
assistant_text += block.get("text", "")
|
|
||||||
|
|
||||||
# Add to history
|
# Add to history
|
||||||
self.messages.append({
|
self.messages.append({
|
||||||
@@ -90,40 +71,19 @@ class ChatClient:
|
|||||||
|
|
||||||
return assistant_text
|
return assistant_text
|
||||||
|
|
||||||
def _stream_response(self, payload: dict, model: str) -> str:
|
def _stream_response(self, model: str) -> str:
|
||||||
"""Streaming response with live rendering."""
|
"""Streaming response with live rendering."""
|
||||||
assistant_text = ""
|
assistant_text = ""
|
||||||
|
|
||||||
with httpx.stream(
|
with Live(console=self.console, refresh_per_second=10) as live:
|
||||||
"POST",
|
stream = self.client.responses.create(
|
||||||
f"{self.base_url}/v1/responses",
|
model=model,
|
||||||
json=payload,
|
input=self.messages,
|
||||||
headers=self._headers(),
|
stream=True,
|
||||||
timeout=60.0
|
)
|
||||||
) as resp:
|
for event in stream:
|
||||||
resp.raise_for_status()
|
if event.type == "response.output_text.delta":
|
||||||
|
assistant_text += event.delta
|
||||||
with Live(console=self.console, refresh_per_second=10) as live:
|
|
||||||
for line in resp.iter_lines():
|
|
||||||
if not line.startswith("data: "):
|
|
||||||
continue
|
|
||||||
|
|
||||||
data_str = line[6:] # Remove "data: " prefix
|
|
||||||
|
|
||||||
try:
|
|
||||||
chunk = json.loads(data_str)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if chunk.get("done"):
|
|
||||||
break
|
|
||||||
|
|
||||||
delta = chunk.get("delta", {})
|
|
||||||
for block in delta.get("content", []):
|
|
||||||
if block.get("type") == "output_text":
|
|
||||||
assistant_text += block.get("text", "")
|
|
||||||
|
|
||||||
# Render markdown in real-time
|
|
||||||
live.update(Markdown(assistant_text))
|
live.update(Markdown(assistant_text))
|
||||||
|
|
||||||
# Add to history
|
# Add to history
|
||||||
@@ -139,23 +99,21 @@ class ChatClient:
|
|||||||
self.messages = []
|
self.messages = []
|
||||||
|
|
||||||
|
|
||||||
def print_models_table(base_url: str, headers: dict):
|
def print_models_table(client: OpenAI):
|
||||||
"""Fetch and print available models from the gateway."""
|
"""Fetch and print available models from the gateway."""
|
||||||
console = Console()
|
console = Console()
|
||||||
try:
|
try:
|
||||||
resp = httpx.get(f"{base_url}/v1/models", headers=headers, timeout=10)
|
models = client.models.list()
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json().get("data", [])
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"[red]Failed to fetch models: {e}[/red]")
|
console.print(f"[red]Failed to fetch models: {e}[/red]")
|
||||||
return
|
return
|
||||||
|
|
||||||
table = Table(title="Available Models", show_header=True, header_style="bold magenta")
|
table = Table(title="Available Models", show_header=True, header_style="bold magenta")
|
||||||
table.add_column("Provider", style="cyan")
|
table.add_column("Owner", style="cyan")
|
||||||
table.add_column("Model ID", style="green")
|
table.add_column("Model ID", style="green")
|
||||||
|
|
||||||
for model in data:
|
for model in models:
|
||||||
table.add_row(model.get("provider", ""), model.get("id", ""))
|
table.add_row(model.owned_by, model.id)
|
||||||
|
|
||||||
console.print(table)
|
console.print(table)
|
||||||
|
|
||||||
@@ -163,14 +121,29 @@ def print_models_table(base_url: str, headers: dict):
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Chat with go-llm-gateway")
|
parser = argparse.ArgumentParser(description="Chat with go-llm-gateway")
|
||||||
parser.add_argument("--url", default="http://localhost:8080", help="Gateway URL")
|
parser.add_argument("--url", default="http://localhost:8080", help="Gateway URL")
|
||||||
parser.add_argument("--model", default="gemini-2.0-flash-exp", help="Model to use")
|
parser.add_argument("--model", default=None, help="Model to use (defaults to first available)")
|
||||||
parser.add_argument("--token", help="Auth token (Bearer)")
|
parser.add_argument("--token", help="Auth token (Bearer)")
|
||||||
parser.add_argument("--no-stream", action="store_true", help="Disable streaming")
|
parser.add_argument("--no-stream", action="store_true", help="Disable streaming")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
client = ChatClient(args.url, args.token)
|
client = ChatClient(args.url, args.token)
|
||||||
current_model = args.model
|
|
||||||
|
# Fetch available models and select default
|
||||||
|
try:
|
||||||
|
available_models = list(client.client.models.list())
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[bold red]Failed to connect to gateway:[/bold red] {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not available_models:
|
||||||
|
console.print("[bold red]Error:[/bold red] No models are configured on the gateway.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.model:
|
||||||
|
current_model = args.model
|
||||||
|
else:
|
||||||
|
current_model = available_models[0].id
|
||||||
stream_enabled = not args.no_stream
|
stream_enabled = not args.no_stream
|
||||||
|
|
||||||
# Welcome banner
|
# Welcome banner
|
||||||
@@ -230,7 +203,7 @@ def main():
|
|||||||
))
|
))
|
||||||
|
|
||||||
elif cmd == "/models":
|
elif cmd == "/models":
|
||||||
print_models_table(args.url, client._headers())
|
print_models_table(client.client)
|
||||||
|
|
||||||
elif cmd == "/model":
|
elif cmd == "/model":
|
||||||
if len(cmd_parts) < 2:
|
if len(cmd_parts) < 2:
|
||||||
@@ -265,8 +238,8 @@ def main():
|
|||||||
# For non-streaming, render markdown
|
# For non-streaming, render markdown
|
||||||
console.print(Markdown(response))
|
console.print(Markdown(response))
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except APIStatusError as e:
|
||||||
console.print(f"[bold red]Error {e.response.status_code}:[/bold red] {e.response.text}")
|
console.print(f"[bold red]Error {e.status_code}:[/bold red] {e.message}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"[bold red]Error:[/bold red] {e}")
|
console.print(f"[bold red]Error:[/bold red] {e}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user