Skip to content

gaia

ipw.evaluation.gaia

GAIAHandler

Bases: EvaluationHandler

GAIA evaluation: exact match with normalization + LLM fallback.

Source code in intelligence-per-watt/src/ipw/evaluation/gaia.py
@EvaluationRegistry.register("gaia")
class GAIAHandler(EvaluationHandler):
    """GAIA evaluation: exact match with normalization + LLM fallback."""

    evaluation_method = "gaia"

    def evaluate(
        self,
        *,
        problem: str,
        reference: str,
        model_answer: str,
        metadata: Dict[str, object],
    ) -> Tuple[Optional[bool], Dict[str, object]]:
        if not model_answer or not model_answer.strip():
            return False, {"reason": "empty_response"}

        if not reference or not reference.strip():
            return None, {"reason": "no_ground_truth"}

        # Try exact match first (fast, no API call)
        if _exact_match(model_answer, reference):
            return True, {"match_type": "exact"}

        # LLM fallback for semantic comparison
        if not hasattr(self._client, "chat"):
            return False, {
                "match_type": "exact_failed",
                "reason": "no_llm_client_for_fallback",
            }

        try:
            prompt = _LLM_FALLBACK_PROMPT.format(
                question=problem or "(No question provided)",
                response=model_answer,
                ground_truth=reference,
            )
            raw = self._client.chat(
                system_prompt="",
                user_prompt=prompt,
                temperature=0.0,
                max_output_tokens=1024,
            )

            structured_match = re.search(
                r"^correct:\s*(yes|no)", raw, re.MULTILINE | re.IGNORECASE
            )
            if structured_match:
                is_correct = structured_match.group(1).lower() == "yes"
            else:
                is_correct = (
                    "CORRECT" in raw.upper() and "INCORRECT" not in raw.upper()
                )

            meta: Dict[str, object] = {
                "match_type": "llm_fallback",
                "raw_judge_output": raw,
            }
            extracted_match = re.search(
                r"^extracted_final_answer:\s*(.+)", raw, re.MULTILINE
            )
            if extracted_match:
                meta["extracted_answer"] = extracted_match.group(1).strip()

            return is_correct, meta

        except Exception as exc:
            LOGGER.error("GAIA LLM fallback failed: %s", exc)
            return False, {
                "match_type": "llm_fallback_error",
                "error": str(exc),
            }