# mcq_creator.py — Unified MCQ Content Generator (Final Production Version)
# (modified: single-slide-per-question + JSON-serializable slides output)

import os
import re
import json
import random
import logging
import base64
from pathlib import Path
from typing import List, Optional, Dict, Any
from pydantic import BaseModel
from dotenv import load_dotenv
from mistralai import Mistral
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from celery_app import celery

# --------------------------------
# Setup
# --------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

ROOT_DIR = Path(__file__).parent
load_dotenv(ROOT_DIR / ".env")

MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
MISTRAL_OCR_MODEL = os.getenv("MISTRAL_OCR_MODEL", "mistral-ocr-latest")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")

OUTPUT_DIR = ROOT_DIR / "outputs"
IMAGES_DIR = OUTPUT_DIR / "images"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
IMAGES_DIR.mkdir(parents=True, exist_ok=True)

# --------------------------------
# Models
# --------------------------------
class MCQModel(BaseModel):
    q_no: Optional[int]
    question: str
    options: List[str]
    correct_answer: str
    explanation: Optional[str] = None
    image_reference: Optional[str] = None
    refined_explanation: Optional[str] = None
    narration: Optional[str] = None


class MCQSetModel(BaseModel):
    title: Optional[str]
    questions: List[MCQModel]


# --------------------------------
# Helpers (unchanged)
# --------------------------------
def make_math_tts_friendly(text: str) -> str:
    replacements = {
        "√": "square root of ",
        "²": " squared ",
        "³": " cubed ",
        "+": " plus ",
        "-": " minus ",
        "×": " times ",
        "÷": " divided by ",
        "=": " equals ",
        "°": " degrees ",
        "%": " percent ",
    }
    for k, v in replacements.items():
        text = text.replace(k, v)

    unit_patterns = {
        r"(?<=\d)\s*m\b": " meters",
        r"(?<=\d)\s*cm\b": " centimeters",
        r"(?<=\d)\s*mm\b": " millimeters",
        r"(?<=\d)\s*km\b": " kilometers",
        r"(?<=\d)\s*g\b": " grams",
        r"(?<=\d)\s*kg\b": " kilograms",
        r"(?<=\d)\s*l\b": " liters",
        r"(?<=\d)\s*s\b": " seconds",
        r"(?<=\d)\s*hr\b": " hours",
        r"(?<=\d)\s*min\b": " minutes",
    }
    for pattern, repl in unit_patterns.items():
        text = re.sub(pattern, repl, text)

    return re.sub(r"\s+", " ", text).strip()


MATH_INLINE = re.compile(r"\\\((.*?)\\\)")
MATH_DISPLAY = re.compile(r"\\\[(.*?)\\\]")

def sanitize_math(text: str) -> str:
    text = MATH_INLINE.sub(lambda m: f"<MATH:{m.group(1).strip()}>", text)
    text = MATH_DISPLAY.sub(lambda m: f"<MATH:{m.group(1).strip()}>", text)
    return text



# -------------------------------------------------
# Main Function: Extract MCQs from PDF via Mistral
# -------------------------------------------------
def mistral_ocr_extract_mcqs(pdf_path: str) -> MCQSetModel:
    if not os.path.exists(pdf_path):
        raise FileNotFoundError(f"PDF not found: {pdf_path}")

    client = Mistral(api_key=MISTRAL_API_KEY)

    # --- Read PDF ---
    with open(pdf_path, "rb") as f:
        encoded = base64.b64encode(f.read()).decode("utf-8")

    logger.info("🧠 Running Mistral OCR on PDF with image extraction...")

    # --- OCR call ---
    ocr_result = client.ocr.process(
        model=MISTRAL_OCR_MODEL,
        document={"type": "document_url", "document_url": f"data:application/pdf;base64,{encoded}"},
        include_image_base64=True,
    )

    page_image_map: Dict[int, List[str]] = {}
    page_positions = {}
    combined_text = ""

    # -----------------------------------------
    # Step 1: Parse pages, markdown, save images
    #------------------------------------------
    for page_idx, page in enumerate(ocr_result.pages, start=1):
        page_md = getattr(page, "markdown", "") or ""
        page_images = []
        image_positions = []

        for img_idx, img in enumerate(getattr(page, "images", []) or [], start=1):
            try:
                img_b64 = getattr(img, "image_base64", None)
                if not img_b64:
                    continue

                # Clean base64
                img_b64 = re.sub(r"^data:image\/[a-zA-Z0-9+]+;base64,", "", img_b64.strip())
                img_b64 = re.sub(r"\s+", "", img_b64)
                img_bytes = base64.b64decode(img_b64)

                # Detect image type
                if img_bytes.startswith(b"\x89PNG"):
                    ext = "png"
                elif img_bytes.startswith(b"\xff\xd8"):
                    ext = "jpeg"
                else:
                    ext = "bin"

                img_name = f"page{page_idx:02d}_img{img_idx:02d}.{ext}"
                img_path = IMAGES_DIR / img_name
                with open(img_path, "wb") as f:
                    f.write(img_bytes)

                size = img_path.stat().st_size
                if size > 0:
                    page_images.append(str(img_path))
                    ref_token = getattr(img, "id", None) or Path(img_name).stem
                    pos = page_md.find(ref_token) if ref_token else -1
                    image_positions.append((pos, str(img_path)))
                    logger.info(f"📸 Saved image {img_name} ({size} bytes)")
                else:
                    logger.warning(f"⚠️ Image {img_name} saved empty!")

            except Exception as e:
                logger.warning(f"⚠️ Failed to save image from page {page_idx}: {e}")

        page_image_map[page_idx] = page_images
        page_positions[page_idx] = sorted(image_positions)

        combined_text += f"\n\n--- PAGE {page_idx} ---\n\n" + page_md

    # -----------------------------------------------------
    # Step 2: Ask Mistral to extract MCQs in structured JSON
    # -----------------------------------------------------
    extraction_prompt = f"""
Extract all MCQs and return valid JSON in this schema:
{{
  "title": "<optional title>",
  "questions": [
    {{
      "q_no": <integer>,
      "question": "<text>",
      "options": ["<opt1>", "<opt2>", "<opt3>", "<opt4>"],
      "correct_answer": "<text or letter>",
      "explanation": "<brief reason>",
      "image_reference": "<filename or caption or null>"
    }}
  ]
}}

TEXT:
{combined_text[:20000]}
"""

    resp = client.chat.complete(
        model="mistral-large-latest",
        messages=[
            {"role": "system", "content": "Extract MCQs precisely as JSON only."},
            {"role": "user", "content": extraction_prompt},
        ],
    )

    # ---------------------------------------------------------
    # Step 3: JSON sanitation (backslashes + weird unicode fixes)
    # ---------------------------------------------------------
    raw = resp.choices[0].message.content.strip()
    match = re.search(r"\{[\s\S]*\}", raw)
    json_text = match.group(0) if match else raw

    # Basic cleanup
    json_text = re.sub(r"[\x00-\x1F\x7F]", "", json_text)
    json_text = json_text.replace("“", '"').replace("”", '"')
    json_text = json_text.replace("‘", "'").replace("’", "'")
    json_text = re.sub(r",\s*}", "}", json_text)
    json_text = re.sub(r",\s*]", "]", json_text)

    # Escape unescaped backslashes → prevents JSON errors
    json_text = json_text.replace("\\(", "\\\\(").replace("\\)", "\\\\)")
    json_text = json_text.replace("\\[", "\\\\[").replace("\\]", "\\\\]")

    # Try decoding JSON
    try:
        data = json.loads(json_text)
    except json.JSONDecodeError:
        logger.warning("⚠️ JSON malformed — attempting auto-key quoting fix.")
        fixed = re.sub(r"(\w+):", r'"\1":', json_text)
        data = json.loads(fixed)

    # --------------------------------------------------------
    # Step 4: Handle math so PPT doesn't show backslash garbage
    # --------------------------------------------------------
    for q in data.get("questions", []):
        q["question"] = sanitize_math(q.get("question", ""))
        if isinstance(q.get("options"), list):
            q["options"] = [sanitize_math(o) for o in q["options"]]
        q["explanation"] = sanitize_math(q.get("explanation", ""))

    # --------------------------------------------------------
    # Step 5: Map images to nearest question above
    # --------------------------------------------------------
    questions = []
    mcq_list = data.get("questions", [])
    logger.info(f"🔗 Mapping images to {len(mcq_list)} MCQs using question-number proximity...")

    for i, q in enumerate(mcq_list, start=1):
        opts = q.get("options", [])
        if isinstance(opts, str):
            opts = [o.strip() for o in opts.splitlines() if o.strip()]

        matched_path = None

        # crude page-guess
        total_pages = len(ocr_result.pages)
        page_guess = 1 if total_pages == 1 else min(total_pages, max(1, (i * total_pages) // len(mcq_list)))

        page_md = getattr(ocr_result.pages[page_guess - 1], "markdown", "")
        image_positions = page_positions.get(page_guess, [])

        if image_positions:
            question_positions = [(m.start(), m.group()) for m in re.finditer(r"\b\d+\.", page_md)]

            for pos, img_path in image_positions:
                if pos == -1:
                    continue
                nearest_q = None
                nearest_qpos = -1
                for qp, qtxt in question_positions:
                    if qp < pos and qp > nearest_qpos:
                        nearest_qpos = qp
                        nearest_q = qtxt

                if nearest_q and str(i).startswith(nearest_q.strip(".")):
                    matched_path = img_path
                    logger.info(f"🖼️ Mapped {Path(img_path).name} → Q{i}")
                    break

            if not matched_path and i == 1 and image_positions:
                matched_path = image_positions[0][1]
                logger.info(f"🖼️ Attached first image → Q1")

        questions.append(
            MCQModel(
                q_no=q.get("q_no", i),
                question=q.get("question", ""),
                options=opts,
                correct_answer=q.get("correct_answer", ""),
                explanation=q.get("explanation", ""),
                image_reference=matched_path,
            )
        )

    logger.info(f"✅ Extracted {len(questions)} MCQs from PDF.")
    return MCQSetModel(title=data.get("title", "MCQ Set"), questions=questions)
# --------------------------------
# Step 2: Refine explanations using OpenAI
# --------------------------------
def refine_explanations_with_openai(mcq_set: MCQSetModel) -> MCQSetModel:
    llm = ChatOpenAI(model_name=OPENAI_MODEL, temperature=0.25, openai_api_key=OPENAI_API_KEY)
    prompt = PromptTemplate(
        input_variables=["question", "correct_answer", "explanation"],
        template=(
            "You are a helpful teacher preparing video explanations for MCQs.\n"
            "Improve the explanation in 4–6 sentences, keeping a friendly teacher tone.\n"
            "Return JSON ONLY:\n"
            "{{'refined_explanation': '<concise reasoning>', 'narration': '<teacher-style narration>'}}\n\n"
            "Question: {question}\n"
            "Correct Answer: {correct_answer}\n"
            "Explanation: {explanation}"
        ),
    )
    chain = prompt | llm

    refined = []
    for q in mcq_set.questions:
        try:
            out = chain.invoke({
                "question": q.question,
                "correct_answer": q.correct_answer,
                "explanation": q.explanation
            })
            raw = getattr(out, "content", str(out)).strip()

            # try to extract { ... } portion
            match = re.search(r"\{[\s\S]*\}", raw)
            parsed = None

            if match:
                try:
                    parsed = json.loads(match.group(0).replace("'", '"'))
                except json.JSONDecodeError:
                    # Try to fix simple missing-quote or unquoted key issues
                    fixed = match.group(0)
                    fixed = re.sub(r"(\w+):", r'"\1":', fixed)   # quote keys
                    fixed = fixed.replace("“", '"').replace("”", '"')
                    try:
                        parsed = json.loads(fixed)
                    except Exception:
                        parsed = None

            # If still not parsed, detect plain-text responses
            if not parsed:
                if "Refined explanation" in raw or "refined explanation" in raw:
                    # extract manually from plain text
                    exp_match = re.search(r"[Rr]efined explanation[:\-]\s*(.+?)(?:\n|$)", raw)
                    nar_match = re.search(r"[Nn]arration[:\-]\s*(.+?)(?:\n|$)", raw)
                    parsed = {
                        "refined_explanation": exp_match.group(1).strip() if exp_match else q.explanation,
                        "narration": nar_match.group(1).strip() if nar_match else f"The correct answer is {q.correct_answer}. {q.explanation}"
                    }
                else:
                    parsed = {
                        "refined_explanation": q.explanation,
                        "narration": f"The correct answer is {q.correct_answer}. {q.explanation}"
                    }

            q.refined_explanation = parsed.get("refined_explanation", q.explanation)
            q.narration = parsed.get("narration", f"The correct answer is {q.correct_answer}. {q.refined_explanation}")

        except Exception as e:
            logger.warning(f"⚠️ Refinement failed for Q{q.q_no}: {e}")
            q.refined_explanation = q.explanation
            q.narration = f"The correct answer is {q.correct_answer}. {q.refined_explanation}"

        refined.append(q)

    return MCQSetModel(title=mcq_set.title, questions=refined)

# --------------------------------
# Step 3: Build Slides (1 per question) — returns JSON-serializable dict
# --------------------------------
def build_slides_from_mcqs(mcq_set: MCQSetModel) -> Dict[str, Any]:
    """
    Build slides as a JSON-serializable dict:
      {
        "title": "...",
        "slides": [
           { slide dict ... },
           ...
        ]
      }
    Each slide contains question, options, refined_explanation, image info, narration (speaker notes) etc.
    """

    openings = [
        "Alright students, let's look at this question.",
        "Okay everyone, listen carefully to this one.",
        "This question looks interesting — let's go through it.",
        "Let’s try to understand this together.",
        "Now, this might seem tricky, but we’ll make it simple."
    ]
    conclusions = [
        "So the correct answer is {answer}.",
        "That’s why the right answer is {answer}.",
        "Hence, {answer} is the correct option.",
        "Therefore, the answer is {answer}.",
        "And that gives us {answer} as the final answer."
    ]
    transitions = [
        "Let's move to the next question.",
        "Alright, ready for the next one?",
        "Okay, let’s continue.",
        "Now, let’s try another question.",
        "Great, moving on!"
    ]

    slides: List[Dict[str, Any]] = []
    slide_counter = 1

    for q in mcq_set.questions:
        # normalize options: if options is a string, break into lines
        opts = q.options or []
        if isinstance(opts, str):
            opts = [o.strip() for o in opts.splitlines() if o.strip()]

        # build narration / speaker notes
        opening = random.choice(openings)
        conclusion = random.choice(conclusions).format(answer=q.correct_answer)
        transition = random.choice(transitions)

        # ensure refined_explanation exists
        refined = q.refined_explanation or q.explanation or ""
        narration = q.narration or f"{refined} {conclusion} {transition}"
        narration = make_math_tts_friendly(narration)

        # Put everything for this question into a single slide dict
        slide_dict = {
            "slide_no": slide_counter,
            "q_no": q.q_no if q.q_no is not None else slide_counter,
            "title": f"Question {q.q_no}" if q.q_no is not None else f"Question {slide_counter}",
            "question": q.question or "",
            "options": opts,
            "correct_answer": q.correct_answer or "",
            "explanation": q.explanation or "",
            "refined_explanation": refined,
            "speaker_notes": narration,
            "image_path": q.image_reference or None,
            "has_image": bool(q.image_reference),
            # keep room for future optional fields (image_suggestion etc.)
            "image_suggestion": None,
        }

        slides.append(slide_dict)
        slide_counter += 1

    slideset_title = f"{mcq_set.title or 'MCQ Set'} — Slides 1–{len(slides)}"

    return {"title": slideset_title, "slides": slides}


# --------------------------------
# Step 4: Save Outputs
# --------------------------------
def save_json_outputs(mcq_set: MCQSetModel, slides_json: Dict):
    mcq_path = OUTPUT_DIR / "mcq_data.json"
    slides_path = OUTPUT_DIR / "slides.json"

    with open(mcq_path, "w", encoding="utf-8") as f:
        json.dump(mcq_set.model_dump(), f, indent=2, ensure_ascii=False)
    with open(slides_path, "w", encoding="utf-8") as f:
        json.dump(slides_json, f, indent=2, ensure_ascii=False)

    logger.info(f"✅ Saved {len(slides_json['slides'])} slides to {slides_path}")
    return str(mcq_path), str(slides_path)


# --------------------------------
# Orchestrator
# --------------------------------
@celery.task(name="video_main.generate_mcq_content")
def generate_mcq_content(pdf_path: str, task_id: Optional[str] = None) -> Dict:
    """Celery task: Extract MCQs and generate slides.json + mcq_data.json"""
    logger.info(f"📘 Processing: {pdf_path}")
    mcqs = mistral_ocr_extract_mcqs(pdf_path)
    refined = refine_explanations_with_openai(mcqs)
    slides = build_slides_from_mcqs(refined)
    save_json_outputs(refined, slides)
    return {
        "id": task_id or "mcq_task",
        "json": slides
    }



# --------------------------------
# CLI
# --------------------------------
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Generate MCQ slides and audio-ready JSON from PDF")
    parser.add_argument("pdf", help="Path to the input PDF")
    args = parser.parse_args()

    generate_mcq_content(args.pdf)
    print("✅ All outputs generated in ./outputs/")
