Skip to content

Index

ipw.compute

Compute estimation utilities for LLM inference.

estimate_flops(model, input_tokens, output_tokens, use_calflops=False)

Estimate FLOPs for a model inference.

Strategy: 1. If use_calflops=True, try calflops library first 2. Fall back to 2PT formula using known parameter counts 3. Return (0, 0) if model is unknown

Parameters:

Name Type Description Default
model str

Model name or path

required
input_tokens int

Number of input tokens

required
output_tokens int

Number of output tokens

required
use_calflops bool

Whether to try calflops library

False

Returns:

Type Description
tuple[float, float]

Tuple of (total_flops, flops_per_token)

Source code in intelligence-per-watt/src/ipw/compute/flops.py
def estimate_flops(
    model: str,
    input_tokens: int,
    output_tokens: int,
    use_calflops: bool = False,
) -> tuple[float, float]:
    """Estimate FLOPs for a model inference.

    Strategy:
    1. If use_calflops=True, try calflops library first
    2. Fall back to 2*P*T formula using known parameter counts
    3. Return (0, 0) if model is unknown

    Args:
        model: Model name or path
        input_tokens: Number of input tokens
        output_tokens: Number of output tokens
        use_calflops: Whether to try calflops library

    Returns:
        Tuple of (total_flops, flops_per_token)
    """
    # Try calflops first if requested
    if use_calflops:
        result = estimate_flops_calflops(model, input_tokens, output_tokens)
        if result is not None:
            return result

    # Fall back to 2*P*T formula
    params = lookup_params(model)
    if params is not None:
        return estimate_flops_fallback(params, input_tokens, output_tokens)

    logger.debug(f"Unknown model '{model}', cannot estimate FLOPs")
    return 0.0, 0.0

estimate_flops_calflops(model_name_or_path, input_tokens, output_tokens)

Estimate FLOPs using the calflops library (optional dependency).

Returns None if calflops is not installed or estimation fails.

Source code in intelligence-per-watt/src/ipw/compute/flops.py
def estimate_flops_calflops(
    model_name_or_path: str,
    input_tokens: int,
    output_tokens: int,
) -> tuple[float, float] | None:
    """Estimate FLOPs using the calflops library (optional dependency).

    Returns None if calflops is not installed or estimation fails.
    """
    try:
        from calflops import calculate_flops  # type: ignore[import-untyped]
    except ImportError:
        logger.debug("calflops not installed, skipping detailed FLOPs estimation")
        return None

    try:
        # calflops can estimate FLOPs for HuggingFace models
        flops, macs, params = calculate_flops(
            model_name=model_name_or_path,
            input_shape=(1, input_tokens),
            output_as_string=False,
        )
        total_flops = float(flops) if flops else 0.0
        total_tokens = input_tokens + output_tokens
        flops_per_token = total_flops / total_tokens if total_tokens > 0 else 0.0
        return total_flops, flops_per_token
    except Exception as e:
        logger.warning(f"calflops estimation failed: {e}")
        return None

estimate_flops_fallback(params_billions, input_tokens, output_tokens)

Estimate FLOPs using the 2PT approximation.

For transformer inference: - Prefill: ~2 * P * T_input (matrix multiplications) - Decode: ~2 * P * T_output (autoregressive generation) - Total: ~2 * P * (T_input + T_output)

Parameters:

Name Type Description Default
params_billions float

Model parameter count in billions

required
input_tokens int

Number of input tokens

required
output_tokens int

Number of output tokens

required

Returns:

Type Description
tuple[float, float]

Tuple of (total_flops, flops_per_token)

Source code in intelligence-per-watt/src/ipw/compute/flops.py
def estimate_flops_fallback(
    params_billions: float,
    input_tokens: int,
    output_tokens: int,
) -> tuple[float, float]:
    """Estimate FLOPs using the 2*P*T approximation.

    For transformer inference:
    - Prefill: ~2 * P * T_input (matrix multiplications)
    - Decode: ~2 * P * T_output (autoregressive generation)
    - Total: ~2 * P * (T_input + T_output)

    Args:
        params_billions: Model parameter count in billions
        input_tokens: Number of input tokens
        output_tokens: Number of output tokens

    Returns:
        Tuple of (total_flops, flops_per_token)
    """
    params = params_billions * 1e9
    total_tokens = input_tokens + output_tokens
    total_flops = 2.0 * params * total_tokens
    flops_per_token = 2.0 * params if total_tokens > 0 else 0.0
    return total_flops, flops_per_token

lookup_params(model)

Look up parameter count (in billions) for a model.

Returns None if the model is not in the known list.

Source code in intelligence-per-watt/src/ipw/compute/flops.py
def lookup_params(model: str) -> float | None:
    """Look up parameter count (in billions) for a model.

    Returns None if the model is not in the known list.
    """
    normalized = normalize_model_name(model)
    # Try exact match first
    if normalized in MODEL_PARAMS:
        return MODEL_PARAMS[normalized]
    # Try partial match
    for key, params in MODEL_PARAMS.items():
        if key in normalized or normalized in key:
            return params
    return None

normalize_model_name(model)

Normalize model name for parameter lookup.

Handles common naming patterns like: - 'meta-llama/Llama-3.1-8B-Instruct' -> 'llama-3.1-8b' - 'llama3.2:1b' (ollama format) -> 'llama-3.2-1b' - 'qwen2.5-7b-instruct' -> 'qwen-2.5-7b'

Source code in intelligence-per-watt/src/ipw/compute/flops.py
def normalize_model_name(model: str) -> str:
    """Normalize model name for parameter lookup.

    Handles common naming patterns like:
    - 'meta-llama/Llama-3.1-8B-Instruct' -> 'llama-3.1-8b'
    - 'llama3.2:1b' (ollama format) -> 'llama-3.2-1b'
    - 'qwen2.5-7b-instruct' -> 'qwen-2.5-7b'
    """
    name = model.lower()
    # Remove common suffixes
    for suffix in ["-instruct", "-chat", "-it", "-base", ":latest"]:
        name = name.replace(suffix, "")
    # Remove org prefix
    if "/" in name:
        name = name.split("/")[-1]
    # Normalize separators
    name = name.replace(":", "-").replace("_", "-")
    # Insert hyphen between letters and digits (e.g., llama3 -> llama-3, qwen2 -> qwen-2)
    # but not after 'x' to preserve patterns like '8x7b'
    name = re.sub(r"([a-wyz])(\d)", r"\1-\2", name)
    # Remove extra qualifiers
    for q in ["-fp16", "-fp32", "-bf16", "-awq", "-gptq", "-gguf", "-q4", "-q8"]:
        name = name.replace(q, "")
    return name.strip("-")