Skip to content

regression

ipw.visualization.regression

Regression visualization provider - generates scatter plots from regression analysis.

RegressionVisualization

Bases: VisualizationProvider

Generate scatter plots with regression lines from regression analysis results.

Source code in intelligence-per-watt/src/ipw/visualization/regression.py
@VisualizationRegistry.register("regression")
class RegressionVisualization(VisualizationProvider):
    """Generate scatter plots with regression lines from regression analysis results."""

    visualization_id = "regression"

    def render(self, context: VisualizationContext) -> VisualizationResult:
        """Render regression scatter plots."""
        # Load regression analysis results
        regression_data = _load_regression_data(context.results_dir)

        data = regression_data.get("data", {})
        regressions = data.get("regressions", {})

        dataset = _load_dataset(context.results_dir)

        # Infer model name from options or dataset
        model_name = context.options.get("model")
        if not model_name:
            model_name = _infer_model_name(dataset)
            if not model_name:
                model_name = "unknown"

        hardware_label = _infer_hardware_label(dataset, model_name)

        artifacts = {}
        warnings = []

        # Define plots to generate
        plots = [
            {
                "key": "input_tokens_vs_ttft",
                "x_path": ["token_metrics", "input"],
                "y_path": ["latency_metrics", "time_to_first_token_seconds"],
                "title": "Prompt Tokens vs Time to First Token",
                "x_label": "Prompt tokens",
                "y_label": "TTFT (seconds)",
                "filename": "ttft.png",
            },
            {
                "key": "total_tokens_vs_latency",
                "x_path": ["token_metrics", "total"],
                "y_path": ["latency_metrics", "total_query_seconds"],
                "title": "Total Tokens vs Latency",
                "x_label": "Total tokens",
                "y_label": "Total latency (seconds)",
                "filename": "latency.png",
            },
            {
                "key": "total_tokens_vs_energy",
                "x_path": ["token_metrics", "total"],
                "y_path": ["energy_metrics", "per_query_joules"],
                "title": "Total Tokens vs Energy",
                "x_label": "Total tokens",
                "y_label": "Per-query energy (joules)",
                "filename": "energy.png",
            },
            {
                "key": "total_tokens_vs_power",
                "x_path": ["token_metrics", "total"],
                "y_path": ["power_metrics", "gpu", "per_query_watts", "avg"],
                "title": "Total Tokens vs Power",
                "x_label": "Total tokens",
                "y_label": "Per-query power (watts)",
                "filename": "power.png",
                "log_key": "total_tokens_vs_power_log",
            },
        ]

        for plot_spec in plots:
            key = plot_spec["key"]
            stats = regressions.get(key, {})

            if not stats or stats.get("count", 0) == 0:
                warnings.append(f"Skipping {plot_spec['filename']}: no regression data")
                continue

            # Extract samples from dataset
            xs, ys = _extract_regression_samples(
                dataset, model_name, plot_spec["x_path"], plot_spec["y_path"]
            )

            if not xs or not ys:
                warnings.append(f"Skipping {plot_spec['filename']}: no valid samples")
                continue

            output_path = context.output_dir / plot_spec["filename"]

            # Check for optional log fit
            log_stats = None
            if "log_key" in plot_spec:
                log_stats = regressions.get(plot_spec["log_key"])

            _create_scatter_plot(
                xs=xs,
                ys=ys,
                stats=stats,
                title=str(plot_spec["title"]),
                x_label=str(plot_spec["x_label"]),
                y_label=str(plot_spec["y_label"]),
                output_path=output_path,
                model=model_name,
                hardware=hardware_label,
                log_fit_stats=log_stats,
            )

            artifacts[plot_spec["key"]] = output_path

        return VisualizationResult(
            visualization="regression",
            artifacts=artifacts,
            metadata={
                "model": model_name,
                "hardware": hardware_label,
            },
            warnings=tuple(warnings),
        )

render(context)

Render regression scatter plots.

Source code in intelligence-per-watt/src/ipw/visualization/regression.py
def render(self, context: VisualizationContext) -> VisualizationResult:
    """Render regression scatter plots."""
    # Load regression analysis results
    regression_data = _load_regression_data(context.results_dir)

    data = regression_data.get("data", {})
    regressions = data.get("regressions", {})

    dataset = _load_dataset(context.results_dir)

    # Infer model name from options or dataset
    model_name = context.options.get("model")
    if not model_name:
        model_name = _infer_model_name(dataset)
        if not model_name:
            model_name = "unknown"

    hardware_label = _infer_hardware_label(dataset, model_name)

    artifacts = {}
    warnings = []

    # Define plots to generate
    plots = [
        {
            "key": "input_tokens_vs_ttft",
            "x_path": ["token_metrics", "input"],
            "y_path": ["latency_metrics", "time_to_first_token_seconds"],
            "title": "Prompt Tokens vs Time to First Token",
            "x_label": "Prompt tokens",
            "y_label": "TTFT (seconds)",
            "filename": "ttft.png",
        },
        {
            "key": "total_tokens_vs_latency",
            "x_path": ["token_metrics", "total"],
            "y_path": ["latency_metrics", "total_query_seconds"],
            "title": "Total Tokens vs Latency",
            "x_label": "Total tokens",
            "y_label": "Total latency (seconds)",
            "filename": "latency.png",
        },
        {
            "key": "total_tokens_vs_energy",
            "x_path": ["token_metrics", "total"],
            "y_path": ["energy_metrics", "per_query_joules"],
            "title": "Total Tokens vs Energy",
            "x_label": "Total tokens",
            "y_label": "Per-query energy (joules)",
            "filename": "energy.png",
        },
        {
            "key": "total_tokens_vs_power",
            "x_path": ["token_metrics", "total"],
            "y_path": ["power_metrics", "gpu", "per_query_watts", "avg"],
            "title": "Total Tokens vs Power",
            "x_label": "Total tokens",
            "y_label": "Per-query power (watts)",
            "filename": "power.png",
            "log_key": "total_tokens_vs_power_log",
        },
    ]

    for plot_spec in plots:
        key = plot_spec["key"]
        stats = regressions.get(key, {})

        if not stats or stats.get("count", 0) == 0:
            warnings.append(f"Skipping {plot_spec['filename']}: no regression data")
            continue

        # Extract samples from dataset
        xs, ys = _extract_regression_samples(
            dataset, model_name, plot_spec["x_path"], plot_spec["y_path"]
        )

        if not xs or not ys:
            warnings.append(f"Skipping {plot_spec['filename']}: no valid samples")
            continue

        output_path = context.output_dir / plot_spec["filename"]

        # Check for optional log fit
        log_stats = None
        if "log_key" in plot_spec:
            log_stats = regressions.get(plot_spec["log_key"])

        _create_scatter_plot(
            xs=xs,
            ys=ys,
            stats=stats,
            title=str(plot_spec["title"]),
            x_label=str(plot_spec["x_label"]),
            y_label=str(plot_spec["y_label"]),
            output_path=output_path,
            model=model_name,
            hardware=hardware_label,
            log_fit_stats=log_stats,
        )

        artifacts[plot_spec["key"]] = output_path

    return VisualizationResult(
        visualization="regression",
        artifacts=artifacts,
        metadata={
            "model": model_name,
            "hardware": hardware_label,
        },
        warnings=tuple(warnings),
    )