"""
Refined OCR -> MCQ extraction and validation pipeline
Preserves original functionality but organized into modular functions and clearer flow.

Behavior change requested by user:
- Use Mistral (mistralai SDK) exclusively for file upload and OCR.
- Use OpenAI (openai.OpenAI client) for both question extraction and validation LLM calls.

Notes:
- Expects MISTRAL_API_KEY and OPENAI_API_KEY to be set as environment variables or filled in constants below.
- Writes intermediate files into BASE_DIR.
- Uses Mistral SDK for file upload/ocr and OpenAI client for LLM calls.

"""

from __future__ import annotations

import base64
import json
import logging
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import demjson3
from jsonschema import validate as jsonschema_validate
from mistralai import Mistral
from openai import OpenAI
from celery_app import celery

# Configuration ---------------------------------------------------------------
LOG_LEVEL = logging.INFO

MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY", "")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")

PDF_PATH = os.getenv("PDF_PATH", "QuantitativeAptitudeVOL1.pdf")
BASE_DIR = Path(os.getenv("BASE_DIR", "./public/tmp/ocr/"))
BASE_DIR.mkdir(parents=True, exist_ok=True)

OCR_OUTPUT = BASE_DIR / "ocr_output.md"
CLEANED_OCR_OUTPUT = BASE_DIR / "cleaned_ocr_output.md"
VALIDATED_OCR_OUTPUT = BASE_DIR / "validated_ocr_output.md"
OUTPUT_JSON_PATH = BASE_DIR / "structured_mcq_output.json"

# LLM / SDK clients (initialized lazily) --------------------------------------
_mistral_client: Optional[Mistral] = None
_openai_client: Optional[OpenAI] = None

# Logging ---------------------------------------------------------------------
logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s — %(levelname)s — %(message)s")
logger = logging.getLogger(__name__)

# Prompts ---------------------------------------------------------------------
SYSTEM_PROMPT_EXTRACT = (
    "You are a precise data extraction system designed to parse OCR-derived markdown text "
    "containing multiple-choice questions (MCQs) and image references. "
    "Your objective is to produce a strictly structured JSON array where each element represents one question. "
    "You must not include any text, markdown, or commentary outside the JSON.\n\n"
    "Each extracted object must follow this exact schema:\n"
    "{\n"
    '  "Question": string or null,             # the question text\n'
    '  "Options": {                            # the four options if available\n'
    '    "A": string,\n'
    '    "B": string,\n'
    '    "C": string,\n'
    '    "D": string\n'
    '  } or null,                              # if options are missing, set to null\n'
    '  "FigureRefs": [                         # filenames of any referenced images\n'
    '    "img-5.jpeg", "img-6.jpeg"            # or [] if none\n'
    '  ]\n'
    "}\n\n"
    "Parsing rules:\n"
    "- Identify each question block accurately, even if spacing or numbering is inconsistent.\n"
    "- Extract the full question text excluding option lines and image markdown.\n"
    "- Detect image references from markdown patterns like ![...](filename.jpeg) "
    "and list only the filename(s) in 'FigureRefs'.\n"
    "- If a question contains an image but no options, set 'Question' and 'Options' to null, "
    "and include the image filename in 'FigureRefs'.\n"
    "- Never fabricate or guess content not present in the text.\n"
    "- Preserve all punctuation and numeric tokens exactly as seen.\n"
    "- The final output must be a valid JSON array with no comments, markdown, or explanatory text.\n\n"
    "Output only the JSON array — nothing else."
)

VALIDATION_PROMPT = (
    "You are an expert linguistic normalizer and validator for OCR-derived multiple-choice questions (MCQs).\n"
    "Your task is to analyze the given MCQ — consisting of a \"Question\" and its \"Options\" — and determine whether any textual correction is required.\n\n"
    "Evaluation and Correction Rules:\n\n"
    "1. If both the question and all options are already clear, readable, and semantically correct, return exactly null (without quotes, spaces, or JSON formatting).\n\n"
    "2. Otherwise, apply precise corrections only where necessary, following these principles:\n"
    "   - Correct OCR artefacts, spacing errors, or punctuation issues.\n"
    "   - Fix clear grammatical mistakes without altering meaning.\n"
    "   - Preserve all numbers, symbols, and domain-specific notation.\n"
    "   - Keep the same option labels (\"A\", \"B\", \"C\", \"D\") and their logical order.\n"
    "   - Do not invent, rephrase, or guess missing content.\n"
    "   - If some text is incomplete or ambiguous, leave it unchanged.\n\n"
    "Output Requirements:\n\n"
    "- If no correction is needed → return: null\n"
    "- If corrections are made → return a valid JSON object strictly in this form:\n"
    "{\n"
    '  "Question": "corrected question text",\n'
    '  "Options": {\n'
    '    "A": "corrected text",\n'
    '    "B": "corrected text",\n'
    '    "C": "corrected text",\n'
    '    "D": "corrected text"\n'
    '  }\n'
    "}\n\n"
    "Important:\n"
    "- Do not wrap the JSON in code fences or explanations.\n"
    "- Output only the JSON object or null.\n"
    "- The output must be valid JSON when applicable.\n"
)

# Utility functions -----------------------------------------------------------

def get_mistral_client() -> Mistral:
    global _mistral_client
    if _mistral_client is None:
        _mistral_client = Mistral(api_key=MISTRAL_API_KEY)
    return _mistral_client


def get_openai_client() -> OpenAI:
    global _openai_client
    if _openai_client is None:
        _openai_client = OpenAI(api_key=OPENAI_API_KEY)
    return _openai_client


def save_text(path: Path, content: str) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(content, encoding="utf-8")
    logger.debug("Saved text to %s", path)


def load_text(path: Path) -> str:
    try:
        return path.read_text(encoding="utf-8")
    except UnicodeDecodeError:
        logger.warning("UTF-8 decode error when reading %s, falling back to latin-1", path)
        return path.read_text(encoding="latin-1")


def detect_image_extension_from_header(base64_string: str) -> str:
    if base64_string.startswith("data:image/jpeg"):
        return ".jpeg"
    if base64_string.startswith("data:image/png"):
        return ".png"
    return ".bin"


def write_image_from_base64(target_dir: Path, image_id: str, base64_string: str) -> Optional[Path]:
    """Decode base64 data and write binary image. Returns written Path or None on failure."""
    if not base64_string:
        return None
    try:
        header, encoded = base64_string.split(",", 1)
    except ValueError:
        logger.error("Unexpected base64 format for image %s", image_id)
        return None

    try:
        image_bytes = base64.b64decode(encoded)
        target_path = target_dir / image_id
        with open(target_path, "wb") as fh:
            fh.write(image_bytes)
        logger.debug("Wrote image %s", target_path)
        return target_path
    except Exception as exc:
        logger.exception("Failed to write image %s: %s", image_id, exc)
        return None


# Core pipeline functions -----------------------------------------------------

def upload_pdf_for_ocr(pdf_path: Path, client: Mistral) -> Any:
    """Uploads a PDF to Mistral and returns the upload object."""
    logger.info("Uploading PDF: %s", pdf_path)
    with open(pdf_path, "rb") as fh:
        upload = client.files.upload(
            file={"file_name": pdf_path.name, "content": fh},
            purpose="ocr",
        )
    logger.info("Uploaded file id=%s", getattr(upload, "id", None))
    return upload


def run_ocr_on_uploaded_file(uploaded_file: Any, client: Mistral) -> Any:
    """Request OCR processing and return the OCR response object."""
    file_id = uploaded_file.id
    logger.info("Retrieving signed URL for file id=%s", file_id)
    signed_url = client.files.get_signed_url(file_id=file_id)

    logger.info("Calling OCR model")
    ocr_response = client.ocr.process(
        model="mistral-ocr-latest",
        document={"type": "document_url", "document_url": signed_url.url},
        include_image_base64=True,
    )
    logger.info("OCR completed — pages: %s", len(getattr(ocr_response, "pages", [])))
    return ocr_response


def save_ocr_markdown(ocr_response: Any, out_path: Path) -> None:
    page_markdowns = [page.markdown for page in ocr_response.pages]
    ocr_text = "\n\n---\n\n".join(page_markdowns)
    save_text(out_path, ocr_text)
    logger.info("Saved OCR markdown to %s", out_path)


def extract_images_from_ocr(ocr_response: Any, target_dir: Path) -> List[Path]:
    images_written: List[Path] = []
    for page_index, page in enumerate(ocr_response.pages):
        for image_obj in getattr(page, "images", []) or []:
            image_id = image_obj.id
            base64_data = getattr(image_obj, "image_base64", None)
            if base64_data:
                written = write_image_from_base64(target_dir, image_id, base64_data)
                if written:
                    images_written.append(written)
                else:
                    logger.warning("Image %s on page %s failed to save.", image_id, page_index)
    logger.info("Extracted %d images", len(images_written))
    return images_written


def split_pages_from_markdown(md_text: str) -> List[str]:
    pages = [p.strip() for p in md_text.split("\n\n---\n\n") if p.strip()]
    logger.info("Detected %d pages in OCR markdown", len(pages))
    return pages


def call_extraction_llm(openai_client: OpenAI, page_text: str, system_prompt: str) -> Any:
    """Use OpenAI client to extract structured questions from a page's text."""
    prompt = f"Extract the MCQs from the following page text:\n\n{page_text}"
    try:
        response = openai_client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
            response_format={"type": "json_object"},
            temperature=0,
        )
        result_text = response.choices[0].message.content.strip()
        try:
            return json.loads(result_text)
        except json.JSONDecodeError:
            logger.warning("Extraction LLM returned non-JSON for page; saving raw text")
            return {"raw_text": result_text}
    except Exception as exc:
        logger.exception("Error during OpenAI extraction: %s", exc)
        return {"error": str(exc)}


def extract_questions_from_pages(pages: List[str], openai_client: OpenAI, system_prompt: str) -> List[Dict[str, Any]]:
    results: List[Dict[str, Any]] = []
    total = len(pages)
    for idx, page_text in enumerate(pages, start=1):
        logger.info("Processing page %s/%s with OpenAI extraction", idx, total)
        page_result = call_extraction_llm(openai_client, page_text, system_prompt)
        results.append({"page": idx, "content": page_result})
    save_text(OUTPUT_JSON_PATH, json.dumps(results, ensure_ascii=False, indent=2))
    logger.info("Saved structured extraction to %s", OUTPUT_JSON_PATH)
    return results


def is_valid_json_file(path: Path, schema: Optional[dict] = None) -> bool:
    try:
        data = json.loads(load_text(path))
        if schema:
            jsonschema_validate(instance=data, schema=schema)
        return True
    except Exception as exc:
        logger.warning("JSON validation failed for %s: %s", path, exc)
        return False


def ensure_clean_json(source_path: Path, dest_path: Path) -> None:
    """Ensure the source file is valid JSON, otherwise attempt demjson3 decode and write a cleaned file."""
    try:
        data = json.loads(load_text(source_path))
        save_text(dest_path, json.dumps(data, ensure_ascii=False, indent=4))
        logger.info("Source was valid JSON — cleaned output saved to %s", dest_path)
    except json.JSONDecodeError:
        raw = load_text(source_path)
        try:
            decoded = demjson3.decode(raw)
            save_text(dest_path, json.dumps(decoded, ensure_ascii=False, indent=4))
            logger.info("Successfully converted corrupted JSON using demjson3 to %s", dest_path)
        except Exception as exc:
            logger.exception("Failed to clean JSON: %s", exc)
            raise


# Validation of individual questions via OpenAI LLM ---------------------------

def validate_question_via_openai(question_entry: Dict[str, Any], openai_client: OpenAI, validation_prompt: str) -> Tuple[Dict[str, Any], bool]:
    """Return a possibly-corrected entry and a boolean indicating whether it was changed."""
    # If missing question or options, do not attempt to validate.
    if not question_entry.get("Question") or not question_entry.get("Options"):
        return question_entry, False

    prompt_payload = {"Question": question_entry["Question"], "Options": question_entry["Options"]}
    try:
        response = openai_client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "system", "content": validation_prompt}, {"role": "user", "content": json.dumps(prompt_payload, ensure_ascii=False)}],
            response_format={"type": "json_object"},
            temperature=0,
        )
    except Exception as exc:
        logger.exception("OpenAI API call failed during validation: %s", exc)
        return question_entry, False

    json_string = response.choices[0].message.content
    result: Optional[Dict[str, Any]] = None
    try:
        if json_string is not None and json_string.strip().lower() != "null":
            result = json.loads(json_string)
    except json.JSONDecodeError as exc:
        logger.warning("OpenAI returned invalid JSON during validation: %s", exc)
        logger.debug("Returned string start: %s", (json_string or "")[:200])

    if not result:
        return question_entry, False

    corrected = question_entry.copy()
    corrected["Question"] = result.get("Question", question_entry["Question"])
    if isinstance(result.get("Options"), dict):
        corrected["Options"] = result["Options"]
    return corrected, True


def process_questions_concurrently(questions: List[Dict[str, Any]], openai_client: OpenAI, validation_prompt: str, max_workers: int = 16) -> List[Dict[str, Any]]:
    refined: List[Optional[Dict[str, Any]]] = [None] * len(questions)
    changed_count = 0

    def _worker(idx: int, entry: Dict[str, Any]) -> Tuple[int, Tuple[Dict[str, Any], bool]]:
        return idx, validate_question_via_openai(entry, openai_client, validation_prompt)

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(_worker, idx, q) for idx, q in enumerate(questions)]
        for future in as_completed(futures):
            idx, (refined_entry, changed) = future.result()
            refined[idx] = refined_entry
            if changed:
                changed_count += 1

    logger.info("Processing complete: %d questions refined, %d unchanged.", changed_count, len(questions) - changed_count)
    return [r for r in refined if r is not None]


# Orchestration / main -------------------------------------------------------


def main() -> None:
    logger.info("Starting OCR -> MCQ pipeline (Mistral for OCR, OpenAI for extraction/validation)")

    mistral = get_mistral_client()
    openai_client = get_openai_client()

    # Upload PDF
    uploaded = upload_pdf_for_ocr(Path(PDF_PATH), mistral)

    # Run OCR
    ocr_resp = run_ocr_on_uploaded_file(uploaded, mistral)

    # Save markdown text and images
    save_ocr_markdown(ocr_resp, OCR_OUTPUT)
    extract_images_from_ocr(ocr_resp, BASE_DIR)

    # Parse pages locally
    md_text = load_text(OCR_OUTPUT)
    pages = split_pages_from_markdown(md_text)

    # Extract structured questions using OpenAI
    structured = extract_questions_from_pages(pages, openai_client, SYSTEM_PROMPT_EXTRACT)

    # Save and clean just for persistence safety
    save_text(OUTPUT_JSON_PATH, json.dumps(structured, ensure_ascii=False, indent=2))
    ensure_clean_json(OUTPUT_JSON_PATH, CLEANED_OCR_OUTPUT)

    # Proceed directly with structured results instead of reloading
    pages_data = structured

    # Validate questions using OpenAI
    for page in pages_data:
        page_num = page.get("page")
        content = page.get("content", {})
        questions = content.get("Questions") or []
        if questions:
            logger.info("Validating %d questions on page %s using OpenAI", len(questions), page_num)
            page["content"]["Questions"] = process_questions_concurrently(
                questions, openai_client, VALIDATION_PROMPT
            )
        else:
            logger.debug("Page %s: no 'Questions' key or empty — skipping validation.", page_num)

    save_text(VALIDATED_OCR_OUTPUT, json.dumps(pages_data, ensure_ascii=False, indent=2))
    logger.info("Validation complete. Output saved to: %s", VALIDATED_OCR_OUTPUT)


if __name__ == "__main__":
    try:
        main()
    except Exception as exc:
        logger.exception("Pipeline failed: %s", exc)
        raise