113 lines
3.8 KiB
Python
113 lines
3.8 KiB
Python
import json
|
|
from typing import Optional
|
|
|
|
import requests
|
|
|
|
|
|
class OpenAICompatClient:
|
|
"""Lightweight OpenAI-compatible chat client with optional API key."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str,
|
|
api_base: str,
|
|
api_key: str = "",
|
|
temperature: float = 0.0,
|
|
max_tokens: int = 2048,
|
|
timeout: int = 300,
|
|
thinking: bool = False,
|
|
no_thinking: bool = False,
|
|
extra_json: Optional[str] = None,
|
|
) -> None:
|
|
self.model = model
|
|
self.api_base = api_base
|
|
self.api_key = api_key or ""
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self.timeout = timeout
|
|
self.extra_body = self._build_extra_body(thinking, no_thinking, extra_json)
|
|
|
|
def _build_extra_body(
|
|
self,
|
|
thinking: bool,
|
|
no_thinking: bool,
|
|
extra_json: Optional[str],
|
|
) -> Optional[dict]:
|
|
extra_body = {}
|
|
if thinking:
|
|
extra_body["chat_template_kwargs"] = {"thinking": True}
|
|
elif no_thinking:
|
|
extra_body["chat_template_kwargs"] = {"enable_thinking": False}
|
|
extra_body["top_p"] = 0.95
|
|
extra_body["top_k"] = 20
|
|
extra_body["presence_penalty"] = 1.5
|
|
if self.temperature == 0.0:
|
|
self.temperature = 0.7
|
|
if extra_json:
|
|
extra_body.update(json.loads(extra_json))
|
|
if self.max_tokens != 2048:
|
|
extra_body["max_tokens"] = self.max_tokens
|
|
return extra_body or None
|
|
|
|
def _endpoint(self) -> str:
|
|
url = self.api_base.rstrip("/")
|
|
if not url.endswith("/chat/completions"):
|
|
url = url + "/chat/completions"
|
|
return url
|
|
|
|
def chat(self, system: str, user: str, max_tokens: Optional[int] = None) -> str:
|
|
headers = {"Content-Type": "application/json"}
|
|
if self.api_key and self.api_key != "not-needed":
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": [],
|
|
"temperature": self.temperature,
|
|
"max_tokens": max_tokens or self.max_tokens,
|
|
}
|
|
if system:
|
|
payload["messages"].append({"role": "system", "content": system})
|
|
payload["messages"].append({"role": "user", "content": user})
|
|
if self.extra_body:
|
|
payload.update(self.extra_body)
|
|
|
|
resp = requests.post(
|
|
self._endpoint(),
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=self.timeout,
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
if "choices" in data and data["choices"]:
|
|
message = data["choices"][0].get("message", {})
|
|
content = message.get("content")
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
parts = []
|
|
for item in content:
|
|
if isinstance(item, dict):
|
|
text = item.get("text") or item.get("content")
|
|
if text:
|
|
parts.append(str(text))
|
|
elif item:
|
|
parts.append(str(item))
|
|
if parts:
|
|
return "\n".join(parts)
|
|
reasoning = (
|
|
message.get("reasoning_content")
|
|
or data["choices"][0].get("reasoning_content")
|
|
or data["choices"][0].get("text")
|
|
)
|
|
if reasoning:
|
|
return str(reasoning)
|
|
raise ValueError(
|
|
f"Empty assistant content; response preview: {json.dumps(data)[:500]}"
|
|
)
|
|
raise ValueError(f"Unexpected response format: {json.dumps(data)[:500]}")
|
|
|
|
def test_connection(self) -> str:
|
|
return self.chat("You are a test assistant.", "Say OK.", max_tokens=64)
|