@ClientRegistry.register("vllm")
class VLLMClient(InferenceClient):
"""Offline AsyncLLM client."""
client_id = "vllm"
client_name = "vLLM Offline"
DEFAULT_BASE_URL = "offline://vllm"
def __init__(self, base_url: str | None = None, **config: Any) -> None:
super().__init__(base_url or self.DEFAULT_BASE_URL, **config)
self._engine_kwargs: dict[str, Any] = {}
self._sampling_defaults: dict[str, Any] = {
"max_tokens": 4096,
"temperature": 0.6,
"top_p": 0.95,
"top_k": 20,
"min_p": 0.0,
}
self._engine = None
self._engine_args = None
self._model_name = None
self._loop_runner: _AsyncLoopRunner | None = _AsyncLoopRunner()
self._closed = False
atexit.register(self.close)
def prepare(self, model: str) -> None:
if self._closed:
raise RuntimeError("vLLM client has been closed")
self._ensure_engine(model)
def stream_chat_completion(
self, model: str, prompt: str, **params: Any
) -> Response:
if self._closed:
raise RuntimeError("vLLM client has been closed")
self._ensure_engine(model)
sampling_params = self._build_sampling_params(params)
request_id = str(params.get("request_id", uuid.uuid4()))
runner = self._loop_runner
if runner is None:
raise RuntimeError("vLLM client is shut down")
return runner.run(
self._stream_response(
prompt=prompt, request_id=request_id, sampling_params=sampling_params
)
)
def list_models(self) -> Sequence[str]:
return [self._model_name] if self._model_name else []
def health(self) -> bool:
return not self._closed
def close(self) -> None:
if self._closed:
return
self._closed = True
try:
if self._engine is not None:
self._engine.shutdown()
except Exception: # pragma: no cover - shutdown best-effort
pass
finally:
self._engine = None
if self._loop_runner is not None:
self._loop_runner.shutdown()
self._loop_runner = None
def _ensure_engine(self, model: str) -> None:
if not model:
raise ValueError("model name is required")
if self._engine is not None:
if model != self._model_name:
raise RuntimeError(
f"vLLM client already loaded model '{self._model_name}', cannot switch to '{model}'"
)
return
kwargs = dict(self._engine_kwargs)
kwargs["model"] = model
try:
self._engine_args = AsyncEngineArgs(**kwargs)
self._engine = AsyncLLM.from_engine_args(self._engine_args)
except Exception as exc: # pragma: no cover - forwarded to caller
raise RuntimeError(f"Failed to initialize vLLM engine: {exc}") from exc
self._model_name = model
def _build_sampling_params(self, params: Mapping[str, Any]):
recognized = {
"temperature",
"top_p",
"top_k",
"min_p",
"max_tokens",
"stop",
"seed",
"best_of",
"presence_penalty",
"frequency_penalty",
"repetition_penalty",
"length_penalty",
}
def _coerce(value: Any) -> Any:
if isinstance(value, str):
text = value.strip()
if not text:
return text
try:
return json.loads(text)
except json.JSONDecodeError:
return text
return value
overrides: dict[str, Any] = {}
for key, value in params.items():
if key.startswith("sampling_"):
overrides[key.split("_", 1)[1]] = _coerce(value)
elif key in recognized:
overrides[key] = _coerce(value)
sampling = {**self._sampling_defaults, **overrides}
if "stop" in sampling:
stop_value = sampling["stop"]
if isinstance(stop_value, str):
sampling["stop"] = [stop_value]
elif isinstance(stop_value, (list, tuple)):
sampling["stop"] = list(stop_value)
sampling["output_kind"] = RequestOutputKind.DELTA
return SamplingParams(**sampling)
async def _stream_response(
self, *, prompt: str, request_id: str, sampling_params: Any
) -> Response:
if self._engine is None:
raise RuntimeError("vLLM engine is not initialized")
start_time = time.perf_counter()
prompt_tokens: int | None = None
completion_tokens = 0
ttft_ms: float | None = None
content_parts: list[str] = []
try:
async for chunk in self._engine.generate(
request_id=request_id,
prompt=prompt,
sampling_params=sampling_params,
):
outputs = getattr(chunk, "outputs", []) or []
if prompt_tokens is None:
prompt_ids = getattr(chunk, "prompt_token_ids", None) or []
prompt_tokens = len(prompt_ids)
stop_requested = False
for completion in outputs:
delta_text = getattr(completion, "text", "") or ""
if delta_text:
content_parts.append(delta_text)
if ttft_ms is None:
ttft_ms = (time.perf_counter() - start_time) * 1000.0
delta_token_ids = getattr(completion, "delta_token_ids", None)
if delta_token_ids is None:
delta_token_ids = getattr(completion, "token_ids_delta", None)
if delta_token_ids is not None:
completion_tokens += len(delta_token_ids)
else:
token_ids = getattr(completion, "token_ids", None)
if token_ids:
completion_tokens += len(token_ids)
if ttft_ms is None:
ttft_ms = (time.perf_counter() - start_time) * 1000.0
finished_reason = getattr(completion, "finished_reason", None)
if finished_reason is not None:
if str(finished_reason).lower() in {
"stop",
"stopped",
"eos",
"eos_token",
}:
stop_requested = True
if stop_requested:
break
if getattr(chunk, "finished", False):
break
except (
Exception
) as exc: # pragma: no cover - actual streaming exercised in integration
raise RuntimeError(f"vLLM offline generation failed: {exc}") from exc
usage = ChatUsage(
prompt_tokens=prompt_tokens or 0,
completion_tokens=completion_tokens,
total_tokens=(prompt_tokens or 0) + completion_tokens,
)
content = "".join(content_parts)
return Response(
content=content, usage=usage, time_to_first_token_ms=ttft_ms or 0.0
)