@VisualizationRegistry.register("output_kde")
class OutputTokenKDE(VisualizationProvider):
"""Generate KDE plot of completion token distribution."""
visualization_id = "output_kde"
def render(self, context: VisualizationContext) -> VisualizationResult:
"""Render completion token KDE plot."""
dataset = _load_dataset(context.results_dir)
# Get model name from options or infer
model_name = context.options.get("model")
if not model_name:
model_name = _infer_model_name(dataset)
if not model_name:
return VisualizationResult(
visualization="output_kde",
artifacts={},
warnings=(
"No model found in dataset. Specify --model in options.",
),
)
# Extract completion tokens
tokens = _extract_completion_tokens(dataset, model_name)
if not tokens or len(tokens) < 2:
return VisualizationResult(
visualization="output_kde",
artifacts={},
warnings=(
f"Insufficient completion token data for model '{model_name}' "
f"(found {len(tokens)} samples).",
),
)
# Infer hardware label
hardware_label = _infer_hardware_label(dataset, model_name)
# Create plot
output_path = context.output_dir / "completion_tokens_kde.png"
success = _create_kde_plot(tokens, output_path, model_name, hardware_label)
if not success:
return VisualizationResult(
visualization="output_kde",
artifacts={},
warnings=("Failed to generate KDE plot (insufficient variation).",),
)
return VisualizationResult(
visualization="output_kde",
artifacts={"kde_plot": output_path},
metadata={
"model": model_name,
"hardware": hardware_label,
"sample_count": len(tokens),
"mean_tokens": float(np.mean(tokens)),
"median_tokens": float(np.median(tokens)),
},
)