Skip to content

openai_server

ipw.clients.openai_server

OpenAI-compatible profiling client for remote inference servers (vLLM, llama.cpp, etc.).

OpenAIServerClient

Bases: InferenceClient

Profiling client for any OpenAI-compatible inference server.

Unlike the eval-only OpenAIClient, this client supports streaming chat completions with TTFT measurement and token usage tracking, making it suitable for ipw profile workloads.

Source code in intelligence-per-watt/src/ipw/clients/openai_server.py
@ClientRegistry.register("openai-server")
class OpenAIServerClient(InferenceClient):
    """Profiling client for any OpenAI-compatible inference server.

    Unlike the eval-only ``OpenAIClient``, this client supports streaming
    chat completions with TTFT measurement and token usage tracking,
    making it suitable for ``ipw profile`` workloads.
    """

    client_id, client_name = "openai-server", "OpenAI Server"

    def __init__(self, base_url: str | None = None, **config: Any) -> None:
        host = (base_url or config.get("base_url") or "http://localhost:8000/v1").rstrip("/")
        # Ensure URL ends with /v1
        if not host.endswith("/v1"):
            host = f"{host}/v1"
        super().__init__(host, **config)
        self.api_key = config.get("api_key") or "EMPTY"
        self.timeout_seconds = float(config.get("timeout_seconds", 120.0))

    def stream_chat_completion(
        self, model: str, prompt: str, **params: Any
    ) -> Response:
        url = f"{self.base_url}/chat/completions"
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}",
        }

        messages = params.pop("messages", None) or [{"role": "user", "content": prompt}]
        payload: dict[str, Any] = {
            "model": model,
            "messages": messages,
            "stream": True,
        }
        max_tokens = params.pop("max_tokens", None) or params.pop("num_predict", None)
        if max_tokens:
            payload["max_tokens"] = int(max_tokens)
        else:
            payload["max_tokens"] = 4096

        for k, v in params.items():
            if k not in ("model", "messages", "prompt", "stream"):
                payload[k] = v

        start = time.perf_counter()
        try:
            resp = requests.post(
                url,
                headers=headers,
                json=payload,
                stream=True,
                timeout=self.timeout_seconds,
            )
            resp.raise_for_status()
        except req_exc.HTTPError as exc:
            detail = ""
            if exc.response is not None:
                try:
                    detail = f" | body: {exc.response.json()}"
                except Exception:
                    detail = f" | body: {exc.response.text[:500]}"
            raise RuntimeError(f"OpenAI server request failed: {exc}{detail}") from exc
        except Exception as exc:
            raise RuntimeError(f"OpenAI server request failed: {exc}") from exc

        content_parts: list[str] = []
        prompt_tokens = 0
        completion_tokens = 0
        ttft_ms: float | None = None

        for line in resp.iter_lines(decode_unicode=True):
            if not line:
                continue
            if not line.startswith("data: "):
                continue
            data_str = line[6:]
            if data_str.strip() == "[DONE]":
                break
            try:
                chunk = json.loads(data_str)
            except json.JSONDecodeError:
                continue

            choices = chunk.get("choices") or []
            if choices:
                delta = choices[0].get("delta") or {}
                text = delta.get("content")
                if text:
                    if ttft_ms is None:
                        ttft_ms = (time.perf_counter() - start) * 1000
                    content_parts.append(text)

            usage = chunk.get("usage")
            if usage:
                prompt_tokens = usage.get("prompt_tokens", prompt_tokens)
                completion_tokens = usage.get("completion_tokens", completion_tokens)

        return Response(
            content="".join(content_parts),
            usage=ChatUsage(
                prompt_tokens=prompt_tokens,
                completion_tokens=completion_tokens,
                total_tokens=prompt_tokens + completion_tokens,
            ),
            time_to_first_token_ms=ttft_ms or 0.0,
        )

    def list_models(self) -> Sequence[str]:
        url = f"{self.base_url}/models"
        headers = {"Authorization": f"Bearer {self.api_key}"}
        try:
            resp = requests.get(url, headers=headers, timeout=10)
            resp.raise_for_status()
            data = resp.json()
            return [m["id"] for m in data.get("data", [])]
        except Exception as exc:
            raise RuntimeError(f"Failed to list models: {exc}") from exc

    def health(self) -> bool:
        try:
            self.list_models()
            return True
        except Exception:
            return False

    def chat(
        self,
        system_prompt: str,
        user_prompt: str,
        *,
        temperature: float | None = None,
        max_output_tokens: int | None = None,
    ) -> str:
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]
        params: dict[str, Any] = {"messages": messages}
        if max_output_tokens:
            params["max_tokens"] = max_output_tokens
        resp = self.stream_chat_completion(
            model=self._config.get("model", ""),
            prompt="",
            **params,
        )
        return resp.content