Skip to content

vllm

ipw.clients.vllm

Offline vLLM client backed by AsyncLLM.

VLLMClient

Bases: InferenceClient

Offline AsyncLLM client.

Source code in intelligence-per-watt/src/ipw/clients/vllm.py
@ClientRegistry.register("vllm")
class VLLMClient(InferenceClient):
    """Offline AsyncLLM client."""

    client_id = "vllm"
    client_name = "vLLM Offline"
    DEFAULT_BASE_URL = "offline://vllm"

    def __init__(self, base_url: str | None = None, **config: Any) -> None:
        super().__init__(base_url or self.DEFAULT_BASE_URL, **config)
        self._engine_kwargs: dict[str, Any] = {}
        self._sampling_defaults: dict[str, Any] = {
            "max_tokens": 4096,
            "temperature": 0.6,
            "top_p": 0.95,
            "top_k": 20,
            "min_p": 0.0,
        }
        self._engine = None
        self._engine_args = None
        self._model_name = None
        self._loop_runner: _AsyncLoopRunner | None = _AsyncLoopRunner()
        self._closed = False
        atexit.register(self.close)

    def prepare(self, model: str) -> None:
        if self._closed:
            raise RuntimeError("vLLM client has been closed")
        self._ensure_engine(model)

    def stream_chat_completion(
        self, model: str, prompt: str, **params: Any
    ) -> Response:
        if self._closed:
            raise RuntimeError("vLLM client has been closed")
        self._ensure_engine(model)

        sampling_params = self._build_sampling_params(params)
        request_id = str(params.get("request_id", uuid.uuid4()))
        runner = self._loop_runner
        if runner is None:
            raise RuntimeError("vLLM client is shut down")
        return runner.run(
            self._stream_response(
                prompt=prompt, request_id=request_id, sampling_params=sampling_params
            )
        )

    def list_models(self) -> Sequence[str]:
        return [self._model_name] if self._model_name else []

    def health(self) -> bool:
        return not self._closed

    def close(self) -> None:
        if self._closed:
            return
        self._closed = True
        try:
            if self._engine is not None:
                self._engine.shutdown()
        except Exception:  # pragma: no cover - shutdown best-effort
            pass
        finally:
            self._engine = None
            if self._loop_runner is not None:
                self._loop_runner.shutdown()
                self._loop_runner = None

    def _ensure_engine(self, model: str) -> None:
        if not model:
            raise ValueError("model name is required")
        if self._engine is not None:
            if model != self._model_name:
                raise RuntimeError(
                    f"vLLM client already loaded model '{self._model_name}', cannot switch to '{model}'"
                )
            return

        kwargs = dict(self._engine_kwargs)
        kwargs["model"] = model
        try:
            self._engine_args = AsyncEngineArgs(**kwargs)
            self._engine = AsyncLLM.from_engine_args(self._engine_args)
        except Exception as exc:  # pragma: no cover - forwarded to caller
            raise RuntimeError(f"Failed to initialize vLLM engine: {exc}") from exc
        self._model_name = model

    def _build_sampling_params(self, params: Mapping[str, Any]):
        recognized = {
            "temperature",
            "top_p",
            "top_k",
            "min_p",
            "max_tokens",
            "stop",
            "seed",
            "best_of",
            "presence_penalty",
            "frequency_penalty",
            "repetition_penalty",
            "length_penalty",
        }

        def _coerce(value: Any) -> Any:
            if isinstance(value, str):
                text = value.strip()
                if not text:
                    return text
                try:
                    return json.loads(text)
                except json.JSONDecodeError:
                    return text
            return value

        overrides: dict[str, Any] = {}
        for key, value in params.items():
            if key.startswith("sampling_"):
                overrides[key.split("_", 1)[1]] = _coerce(value)
            elif key in recognized:
                overrides[key] = _coerce(value)

        sampling = {**self._sampling_defaults, **overrides}
        if "stop" in sampling:
            stop_value = sampling["stop"]
            if isinstance(stop_value, str):
                sampling["stop"] = [stop_value]
            elif isinstance(stop_value, (list, tuple)):
                sampling["stop"] = list(stop_value)
        sampling["output_kind"] = RequestOutputKind.DELTA
        return SamplingParams(**sampling)

    async def _stream_response(
        self, *, prompt: str, request_id: str, sampling_params: Any
    ) -> Response:
        if self._engine is None:
            raise RuntimeError("vLLM engine is not initialized")

        start_time = time.perf_counter()
        prompt_tokens: int | None = None
        completion_tokens = 0
        ttft_ms: float | None = None
        content_parts: list[str] = []

        try:
            async for chunk in self._engine.generate(
                request_id=request_id,
                prompt=prompt,
                sampling_params=sampling_params,
            ):
                outputs = getattr(chunk, "outputs", []) or []
                if prompt_tokens is None:
                    prompt_ids = getattr(chunk, "prompt_token_ids", None) or []
                    prompt_tokens = len(prompt_ids)

                stop_requested = False

                for completion in outputs:
                    delta_text = getattr(completion, "text", "") or ""
                    if delta_text:
                        content_parts.append(delta_text)
                        if ttft_ms is None:
                            ttft_ms = (time.perf_counter() - start_time) * 1000.0

                    delta_token_ids = getattr(completion, "delta_token_ids", None)
                    if delta_token_ids is None:
                        delta_token_ids = getattr(completion, "token_ids_delta", None)
                    if delta_token_ids is not None:
                        completion_tokens += len(delta_token_ids)
                    else:
                        token_ids = getattr(completion, "token_ids", None)
                        if token_ids:
                            completion_tokens += len(token_ids)
                            if ttft_ms is None:
                                ttft_ms = (time.perf_counter() - start_time) * 1000.0

                    finished_reason = getattr(completion, "finished_reason", None)
                    if finished_reason is not None:
                        if str(finished_reason).lower() in {
                            "stop",
                            "stopped",
                            "eos",
                            "eos_token",
                        }:
                            stop_requested = True

                if stop_requested:
                    break

                if getattr(chunk, "finished", False):
                    break
        except (
            Exception
        ) as exc:  # pragma: no cover - actual streaming exercised in integration
            raise RuntimeError(f"vLLM offline generation failed: {exc}") from exc

        usage = ChatUsage(
            prompt_tokens=prompt_tokens or 0,
            completion_tokens=completion_tokens,
            total_tokens=(prompt_tokens or 0) + completion_tokens,
        )
        content = "".join(content_parts)
        return Response(
            content=content, usage=usage, time_to_first_token_ms=ttft_ms or 0.0
        )