"""
Modified pipeline implementing the user's plan:
1) First pass: extract MCQs and *try to reconstruct mathematical expressions / figure refs* (single prompt function)
2) Second pass: validate/repair each MCQ for logical validity and optionally rephrase while keeping key points
3) Skip any questions that appear to continue on another page (heuristics applied)

Outputs structured, Pydantic-validated JSON using an extended MCQ model.

Notes:
- Keep your OPENAI_API_KEY in .env or env vars as before.
- Tweak MODEL_NAME, timeouts and retry values to taste.

"""

import os
import re
import json
import time
import base64
import logging
import concurrent.futures
from pathlib import Path
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, ValidationError
from PIL import Image
from utils.extractor_functions import *
from dotenv import load_dotenv

# Install: pip install openai python-dotenv pillow pydantic
import openai

# ---------- Basic config & logging ----------
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
PER_PAGE_TIMEOUT = 120   # seconds per page (adjust)
MAX_RETRIES = 2
RETRY_DELAY = 1.5

load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or "YOUR_KEY_HERE"
openai.api_key = OPENAI_API_KEY

# ---------- Pipeline config ----------
DEBUG_DIR = Path("public/debug_outputs")
DEBUG_DIR.mkdir(exist_ok=True)
MODEL_NAME = os.getenv("OPENAI_VIS_MODEL", "gpt-4o-mini")  # adjust if needed
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", "2000"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.0"))

# ---------- Pydantic Models (Add these if missing) ----------

class MCQ(BaseModel):
    Question: str
    Options: Dict[str, str]  # Changed from List to Dict for better structure
    Answer: Optional[str] = None
    Explanation: Optional[str] = None
    FigureRefs: Optional[List[str]] = None
    QuestionLatex: Optional[str] = None
    OptionsLatex: Optional[Dict[str, str]] = None
    AmbiguousMath: Optional[bool] = None



# ---------- Model call helpers (two-stage) ----------

def extract_and_reconstruct(page_num: int, page_text: str, figure_paths: List[str]) -> Optional[List[dict]]:
    """
    First pass: ask the model to extract MCQs and attempt to reconstruct math/figure refs.
    Skips obviously incomplete/continued questions (prompt tells model to skip).
    Properly sends images in multimodal format, falling back to listing paths if unavailable.
    """
    system_prompt = build_system_prompt_reconstruct()

    # Start with text content
    if figure_paths:
        img_list_str = ", ".join(figure_paths)
    else:
        img_list_str = "None"

    user_parts = [
        {"type": "text", "text": f"PAGE {page_num} TEXT:\n\n{page_text}\n\nIMAGE PATHS: {img_list_str}"}
    ]

    # Try to attach images if they exist & compress successfully
    for fig_path_str in figure_paths:
        fig_path = Path(fig_path_str)
        if fig_path.exists():
            try:
                data_uri = compress_image_to_data_uri(fig_path)
                if data_uri:
                    user_parts.append({"type": "image_url", "image_url": {"url": data_uri}})
                else:
                    logging.warning(f"[Page {page_num}] Could not compress image under limit: {fig_path}")
            except Exception as e:
                logging.warning(f"[Page {page_num}] Image compression failed for {fig_path}: {e}")
        else:
            logging.warning(f"[Page {page_num}] Missing image file: {fig_path}")

    attempt = 0
    while attempt < MAX_RETRIES:
        try:
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_parts},
            ]
            resp = openai.chat.completions.create(
                model=MODEL_NAME,
                messages=messages,
                max_tokens=MAX_OUTPUT_TOKENS,
                temperature=TEMPERATURE,
            )
        except Exception as e:
            logging.warning(f"[Page {page_num}] OpenAI API error on attempt {attempt}: {e}")
            attempt += 1
            time.sleep(RETRY_DELAY * attempt)
            continue

        raw_text = getattr(resp.choices[0].message, "content", None)
        if raw_text:
            print(f"Raw response for page {page_num} on attempt {attempt}: {raw_text[:200]}...")
            try:
                (DEBUG_DIR / f"page{page_num}_extract_attempt{attempt}_raw.txt").write_text(raw_text, encoding="utf-8")
            except Exception:
                pass
        else:
            logging.info(f"[Page {page_num}] Empty model output on attempt {attempt}. Retrying...")
            attempt += 1
            time.sleep(RETRY_DELAY * attempt)
            continue

        parsed_json = extract_json_from_text(raw_text)
        if parsed_json is not None:
            # Ensure we return a list
            if isinstance(parsed_json, dict):
                return [parsed_json]
            elif isinstance(parsed_json, list):
                return parsed_json
            else:
                logging.warning(f"[Page {page_num}] Unexpected JSON type: {type(parsed_json)}")
                return []

        # Repair pass
        logging.info(f"[Page {page_num}] Failed to parse JSON; running repair pass.")
        try:
            repair_prompt = (
                "The previous reply was not valid JSON. Re-format it into a valid JSON array of objects "
                "with keys: Question, Options, optional Answer, optional Explanation, optional FigureRefs, "
                "optional QuestionLatex, optional OptionsLatex, optional AmbiguousMath. "
                "Return ONLY the JSON array.\n\n"
                f"PREVIOUS_OUTPUT:\n\n{raw_text}\n\n"
            )
            repair_messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": repair_prompt},
            ]
            repair_resp = openai.chat.completions.create(
                model=MODEL_NAME,
                messages=repair_messages,
                max_tokens=MAX_OUTPUT_TOKENS,
                temperature=0.0,
            )
            repair_text = getattr(repair_resp.choices[0].message, "content", None)
        except Exception as e:
            logging.warning(f"[Page {page_num}] Repair API error: {e}")
            attempt += 1
            time.sleep(RETRY_DELAY * attempt)
            continue

        if repair_text:
            repaired = extract_json_from_text(repair_text)
            if repaired is not None:
                if isinstance(repaired, dict):
                    return [repaired]
                elif isinstance(repaired, list):
                    return repaired
            else:
                logging.info(f"[Page {page_num}] Repair attempt failed to produce parseable JSON.")
        else:
            logging.info(f"[Page {page_num}] Repair attempt returned empty.")

        attempt += 1
        time.sleep(RETRY_DELAY * attempt)

    logging.error(f"[Page {page_num}] All attempts exhausted; skipping page.")
    return []


def _get_message_content(resp):
    """Robustly extract assistant text content across SDK variants."""
    try:
        return resp.choices[0].message.content  # modern openai>=1.x
    except Exception:
        try:
            return resp["choices"][0]["message"]["content"]  # legacy style
        except Exception:
            return None

def validate_single(mcq_item: dict) -> Optional[dict]:
    """
    Second pass: check logical validity and (if valid) return a cleaned object.
    If invalid or incomplete, return None.
    """
    system_prompt = build_system_prompt_validate()
    user_content = json.dumps(mcq_item, ensure_ascii=False)

    attempt = 0
    while attempt < MAX_RETRIES:
        try:
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_content},
            ]
            resp = openai.chat.completions.create(
                model=MODEL_NAME,
                messages=messages,
                max_tokens=1200,
                temperature=0.0,
            )
        except Exception as e:
            logging.warning(f"Validation API error on attempt {attempt}: {e}")
            attempt += 1
            time.sleep(RETRY_DELAY * attempt)
            continue

        raw_text = _get_message_content(resp)
        if not raw_text:
            logging.info("Validation returned empty. Retrying...")
            attempt += 1
            time.sleep(RETRY_DELAY * attempt)
            continue

        # Extract JSON from response
        parsed = extract_json_from_text(raw_text)
        
        if parsed is None:
            logging.info("Validation JSON parse failed. Retrying...")
            attempt += 1
            time.sleep(RETRY_DELAY * attempt)
            continue

        # If the model (incorrectly) returned a single-item array, unwrap it.
        if isinstance(parsed, list) and len(parsed) == 1 and isinstance(parsed[0], dict):
            parsed = parsed[0]

        # Skip signal
        if isinstance(parsed, dict) and parsed.get("skip"):
            logging.info(f"Item skipped by validator: {parsed.get('reason', 'no reason')}")
            return None

        # Return the cleaned object
        return parsed

    logging.info("Validation attempts exhausted; skipping item.")
    return None

# ---------- Heuristics to decide if MCQ is incomplete / continuation ----------

def is_incomplete_mcq_candidate(item: dict, page_text: str) -> bool:
    # missing options or too few
    opts = item.get("Options")
    if not opts or not isinstance(opts, (dict, list)):
        return True
    if isinstance(opts, dict) and len(opts) < 2:
        return True
    # placeholder checks
    combined = (item.get("Question", "") + " " + " ".join((opts.values() if isinstance(opts, dict) else opts))).lower()
    if "continued on next page" in combined or "continued" in combined or "contd" in combined:
        return True
    if "..." in combined or "—" in combined and combined.strip().endswith("—"):
        return True
    # simple checks: options that are only single letters or empty
    if isinstance(opts, dict):
        for k, v in opts.items():
            if not v or len(v.strip()) < 2:
                return True
    return False


# ---------- Page processing / validation ----------

def process_page_obj(page_obj: dict) -> List[dict]:
    page_num = page_obj.get("page")
    page_text = page_obj.get("text", "") or ""
    figure_paths = page_obj.get("figures", []) or []

    raw_items = extract_and_reconstruct(page_num, page_text, figure_paths)
    if not raw_items:
        return []

    validated = []
    for i, item in enumerate(raw_items, start=1):
        if not isinstance(item, dict):
            logging.warning(f"[Page {page_num}] skipping non-object item #{i}")
            continue

        # If heuristics detect continuation/incomplete, skip immediately
        if is_incomplete_mcq_candidate(item, page_text):
            logging.info(f"[Page {page_num}] skipping item #{i} — detected as incomplete/continued on another page.")
            continue

        # Second pass: validate and possibly rephrase
        repaired = validate_single(item)
        if not repaired:
            logging.info(f"[Page {page_num}] item #{i} was skipped by validator or failed to repair.")
            continue

        # normalize Options and FigureRefs
        repaired = normalize_options_field(repaired)
        if "FigureRefs" in repaired and isinstance(repaired["FigureRefs"], list):
            repaired["FigureRefs"] = [str(x) for x in repaired["FigureRefs"]]

        try:
            mcq = MCQ(**repaired)
            try:
                validated.append(mcq.model_dump())
            except Exception:
                validated.append(mcq.model_dump())
        except ValidationError as ve:
            logging.warning(f"[Page {page_num}] Pydantic validation failed for item #{i}: {ve}")
            continue

    return validated


# # ---------- Main runner with per-page timeout & partial saves ----------

# def run_all_pages(input_json_path: Path, out_path: Path):
#     with input_json_path.open("r", encoding="utf-8") as f:
#         pages = json.load(f)

#     all_results = []
#     out_path.parent.mkdir(parents=True, exist_ok=True)
#     temp_save_every = 10

#     # single worker to keep ordering simple (can be changed later)
#     with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
#         for idx, page in enumerate(pages, start=1):
#             pnum = page.get("page", "?")
#             logging.info(f"Starting page {pnum} (index {idx}/{len(pages)})")
#             start_ts = time.time()

#             future = executor.submit(process_page_obj, page)
#             try:
#                 validated_mcqs = future.result(timeout=PER_PAGE_TIMEOUT)
#             except concurrent.futures.TimeoutError:
#                 logging.error(f"Page {pnum} timed out after {PER_PAGE_TIMEOUT}s — skipping.")
#                 future.cancel()
#                 validated_mcqs = []
#             except Exception as e:
#                 logging.exception(f"Page {pnum} raised exception: {e}")
#                 validated_mcqs = []

#             elapsed = time.time() - start_ts
#             logging.info(f"Finished page {pnum} in {elapsed:.1f}s — extracted {len(validated_mcqs)} MCQs")

#             all_results.append({"page": pnum, "mcqs": validated_mcqs})

#             if idx % temp_save_every == 0 or idx == len(pages):
#                 with out_path.open("w", encoding="utf-8") as f:
#                     json.dump(all_results, f, ensure_ascii=False, indent=2)
#                 logging.info(f"Saved partial results after {idx} pages to {out_path}")

#     logging.info(f"Done. Final saved to {out_path}")


# if __name__ == "__main__":
#     run_all_pages(INPUT_PAGES_JSON, OUTPUT_MCQ_JSON)