Files
llmiotsafe/EGPv2_1/api_client.py
2026-05-12 17:01:39 +08:00

92 lines
3.5 KiB
Python

import json
from typing import Optional
import requests
class OpenAICompatClient:
def __init__(
self,
model: str,
api_base: str,
api_key: str = "",
temperature: float = 0.0,
max_tokens: int = 2048,
timeout: int = 300,
thinking: bool = False,
no_thinking: bool = False,
extra_json: Optional[str] = None,
) -> None:
self.model = model
self.api_base = api_base
self.api_key = api_key or ""
self.temperature = 0.0
self.max_tokens = max_tokens
self.timeout = timeout
self.extra_body = self._build_extra_body(thinking, no_thinking, extra_json)
def _build_extra_body(self, thinking: bool, no_thinking: bool, extra_json: Optional[str]) -> Optional[dict]:
extra_body = {}
if thinking:
extra_body["chat_template_kwargs"] = {"thinking": True}
elif no_thinking:
extra_body["chat_template_kwargs"] = {"enable_thinking": False}
extra_body["top_p"] = 0.95
extra_body["top_k"] = 20
extra_body["presence_penalty"] = 1.5
if extra_json:
extra_body.update(json.loads(extra_json))
if self.max_tokens != 2048:
extra_body["max_tokens"] = self.max_tokens
return extra_body or None
def _endpoint(self) -> str:
url = self.api_base.rstrip("/")
if not url.endswith("/chat/completions"):
url = url + "/chat/completions"
return url
def chat(self, system: str, user: str, max_tokens: Optional[int] = None) -> str:
headers = {"Content-Type": "application/json"}
if self.api_key and self.api_key != "not-needed":
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
"model": self.model,
"messages": [],
"temperature": self.temperature,
"max_tokens": max_tokens or self.max_tokens,
}
if system:
payload["messages"].append({"role": "system", "content": system})
payload["messages"].append({"role": "user", "content": user})
if self.extra_body:
payload.update(self.extra_body)
resp = requests.post(self._endpoint(), headers=headers, json=payload, timeout=self.timeout)
resp.raise_for_status()
data = resp.json()
if "choices" in data and data["choices"]:
message = data["choices"][0].get("message", {})
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
text = item.get("text") or item.get("content")
if text:
parts.append(str(text))
elif item:
parts.append(str(item))
if parts:
return "\n".join(parts)
reasoning = message.get("reasoning_content") or data["choices"][0].get("reasoning_content")
if reasoning:
return str(reasoning)
raise ValueError(f"Empty assistant content; response preview: {json.dumps(data)[:500]}")
raise ValueError(f"Unexpected response format: {json.dumps(data)[:500]}")
def test_connection(self) -> str:
return self.chat("You are a test assistant.", "Say OK.", max_tokens=64)