# utils/mistral_pipeline.py
from pathlib import Path
import os
import json
import logging
import base64
from typing import Any, List, Dict, Optional, Tuple
import demjson3

from mistralai import Mistral
from openai import OpenAI

logger = logging.getLogger(__name__)

# default base dir used by tasks — will be overridden by caller if necessary


MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY", "")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")

# lazy clients
_mistral_client: Optional[Mistral] = None
_openai_client: Optional[OpenAI] = None

SYSTEM_PROMPT_EXTRACT = (
    "You are a precise data extraction system designed to parse OCR-derived markdown text "
    "containing multiple-choice questions (MCQs) and image references. "
    "Your objective is to produce a strictly structured JSON array where each element represents one question. "
    "You must not include any text, markdown, or commentary outside the JSON.\n\n"
    "Each extracted object must follow this exact schema:\n"
    "{\n"
    '  "QuestionNumber": string,                # e.g. "1", "Q1", "12.", etc., or "" if not found\n'
    '  "Question": string or null,              # the question text\n'
    '  "Options": {                             # the four options if available\n'
    '    "A": string,\n'
    '    "B": string,\n'
    '    "C": string,\n'
    '    "D": string\n'
    '  } or null,                               # if options are missing, set to null\n'
    '  "FigureRefs": [                          # filenames of any referenced images\n'
    '    "img-5.jpeg", "img-6.jpeg"             # or [] if none\n'
    '  ]\n'
    "}\n\n"
    "Parsing rules:\n"
    "- Identify each question block accurately, even if spacing or numbering is inconsistent.\n"
    "- If a question number is visible (e.g., '1.', 'Q5', '23)'), extract it into 'QuestionNumber'.\n"
    "- If no clear number is visible, set 'QuestionNumber' to an empty string \"\".\n"
    "- Extract the full question text excluding option lines and image markdown.\n"
    "- Detect image references from markdown patterns like ![...](filename.jpeg) "
    "and list only the filename(s) in 'FigureRefs'.\n"
    "- If a question contains an image but no options, set 'Question' and 'Options' to null, "
    "and include the image filename in 'FigureRefs'.\n"
    "- Never fabricate or guess content not present in the text.\n"
    "- Preserve all punctuation, numbering, and numeric tokens exactly as seen.\n"
    "- The final output must be a valid JSON array with no comments, markdown, or explanatory text.\n\n"
    "Output only the JSON array — nothing else."
)

VALIDATION_PROMPT = (
    "You are an expert linguistic normalizer and validator for OCR-derived multiple-choice questions (MCQs).\n"
    "Your task is to analyze the given MCQ — consisting of a \"Question\" and its \"Options\" — and determine whether any textual correction is required.\n\n"
    "Evaluation and Correction Rules:\n\n"
    "1. If both the question and all options are already clear, readable, and semantically correct, return exactly null (without quotes, spaces, or JSON formatting).\n\n"
    "2. Otherwise, apply precise corrections only where necessary, following these principles:\n"
    "   - Correct OCR artefacts, spacing errors, or punctuation issues.\n"
    "   - Fix clear grammatical mistakes without altering meaning.\n"
    "   - Preserve all numbers, symbols, and domain-specific notation.\n"
    "   - Keep the same option labels (\"A\", \"B\", \"C\", \"D\") and their logical order.\n"
    "   - Do not invent, rephrase, or guess missing content.\n"
    "   - If some text is incomplete or ambiguous, leave it unchanged.\n\n"
    "Output Requirements:\n\n"
    "- If no correction is needed → return: null\n"
    "- If corrections are made → return a valid JSON object strictly in this form:\n"
    "{\n"
    '  "Question": "corrected question text",\n'
    '  "Options": {\n'
    '    "A": "corrected text",\n'
    '    "B": "corrected text",\n'
    '    "C": "corrected text",\n'
    '    "D": "corrected text"\n'
    '  }\n'
    "}\n\n"
    "Important:\n"
    "- Do not wrap the JSON in code fences or explanations.\n"
    "- Output only the JSON object or null.\n"
    "- The output must be valid JSON when applicable.\n"
)


def get_mistral_client() -> Mistral:
    global _mistral_client
    if _mistral_client is None:
        _mistral_client = Mistral(api_key=MISTRAL_API_KEY)
    return _mistral_client

def get_openai_client() -> OpenAI:
    global _openai_client
    if _openai_client is None:
        _openai_client = OpenAI(api_key=OPENAI_API_KEY)
    return _openai_client

def write_image_from_base64(target_dir: Path, image_id: str, base64_string: str) -> Optional[Path]:
    try:
        header, encoded = base64_string.split(",", 1)
    except ValueError:
        logger.error("Unexpected base64 image format for %s", image_id)
        return None
    try:
        b = base64.b64decode(encoded)
        target_dir.mkdir(parents=True, exist_ok=True)
        p = target_dir / image_id
        p.write_bytes(b)
        return p
    except Exception as e:
        logger.exception("Failed to write image %s: %s", image_id, e)
        return None

def upload_pdf_for_ocr(pdf_path: Path, mistral: Mistral) -> Any:
    logger.info("Uploading pdf to Mistral: %s", pdf_path)
    with open(pdf_path, "rb") as fh:
        upload = mistral.files.upload(file={"file_name": pdf_path.name, "content": fh}, purpose="ocr")
    return upload

def run_ocr_on_uploaded_file(uploaded_file: Any, client: Mistral) -> Any:
    """Request OCR processing and return the OCR response object."""
    file_id = uploaded_file.id
    logger.info("Retrieving signed URL for file id=%s", file_id)
    signed_url = client.files.get_signed_url(file_id=file_id)

    logger.info("Calling OCR model")
    ocr_response = client.ocr.process(
        model="mistral-ocr-latest",
        document={"type": "document_url", "document_url": signed_url.url},
        include_image_base64=True,
    )
    logger.info("OCR completed — pages: %s", len(getattr(ocr_response, "pages", [])))
    return ocr_response

def save_ocr_markdown(ocr_response: Any, out_path: Path) -> None:
    page_markdowns = [getattr(p, "markdown", "") for p in ocr_response.pages]
    ocr_text = "\n\n---\n\n".join(page_markdowns)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text(ocr_text, encoding="utf-8")

def extract_images_from_ocr(ocr_response: Any, target_dir: Path) -> List[Path]:
    written = []
    target_dir.mkdir(parents=True, exist_ok=True)
    for page in ocr_response.pages:
        for img in getattr(page, "images", []) or []:
            img_id = getattr(img, "id", None) or f"img-{len(written)+1}.bin"
            b64 = getattr(img, "image_base64", None)
            if b64:
                p = write_image_from_base64(target_dir, img_id, b64)
                if p: written.append(p)
    return written

def split_pages_from_markdown(md_text: str) -> List[str]:
    return [p.strip() for p in md_text.split("\n\n---\n\n") if p.strip()]

def call_extraction_llm(openai_client: OpenAI, page_text: str, system_prompt: str) -> Any:
    prompt = f"Extract the MCQs from the following page text:\n\n{page_text}"
    try:
        response = openai_client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
            response_format={"type": "json_object"},
            temperature=0
        )
        result_text = response.choices[0].message.content.strip()
        # be resilient to extra text
        try:
            return json.loads(result_text)
        except Exception:
            # attempt demjson fallback
            try:
                return demjson3.decode(result_text)
            except Exception:
                return {"raw_text": result_text}
    except Exception as e:
        logger.exception("Extraction LLM error: %s", e)
        return {"error": str(e)}

# def extract_questions_from_pages(pages: List[str], openai_client: OpenAI, system_prompt: str) -> List[Dict[str, Any]]:
#     results = []
#     for idx, p in enumerate(pages, start=1):
#         logger.info("LLM extract page %s", idx)
#         page_result = call_extraction_llm(openai_client, p, system_prompt)
#         results.append({"page": idx, "content": page_result})
#     return results


def validate_question_via_openai(question_entry: Dict[str, Any], openai_client: OpenAI, validation_prompt: str) -> Tuple[Dict[str, Any], bool]:
    """Return a possibly-corrected entry and a boolean indicating whether it was changed."""
    # If missing question or options, do not attempt to validate.
    if not question_entry.get("Question") or not question_entry.get("Options"):
        return question_entry, False

    prompt_payload = {"Question": question_entry["Question"], "Options": question_entry["Options"]}
    try:
        response = openai_client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "system", "content": validation_prompt}, {"role": "user", "content": json.dumps(prompt_payload, ensure_ascii=False)}],
            response_format={"type": "json_object"},
            temperature=0,
        )
    except Exception as exc:
        logger.exception("OpenAI API call failed during validation: %s", exc)
        return question_entry, False

    json_string = response.choices[0].message.content
    result: Optional[Dict[str, Any]] = None
    try:
        if json_string is not None and json_string.strip().lower() != "null":
            result = json.loads(json_string)
    except json.JSONDecodeError as exc:
        logger.warning("OpenAI returned invalid JSON during validation: %s", exc)
        logger.debug("Returned string start: %s", (json_string or "")[:200])

    if not result:
        return question_entry, False

    corrected = question_entry.copy()
    corrected["Question"] = result.get("Question", question_entry["Question"])
    if isinstance(result.get("Options"), dict):
        corrected["Options"] = result["Options"]
    corrected["QuestionNumber"] = question_entry.get("QuestionNumber", "")
    corrected["FigureRefs"] = question_entry.get("FigureRefs", [])
    return corrected, True

def process_questions_concurrently(questions: List[Dict[str, Any]], openai_client: OpenAI, validation_prompt: str, max_workers: int = 8):
    # simple thread pool implementation similar to your original
    from concurrent.futures import ThreadPoolExecutor, as_completed
    refined = [None] * len(questions)
    def _w(idx, q):
        return idx, validate_question_via_openai(q, openai_client, validation_prompt)
    changed = 0
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        futures = [ex.submit(_w, i, q) for i, q in enumerate(questions)]
        for fut in as_completed(futures):
            idx, (entry, changed_flag) = fut.result()
            refined[idx] = entry
            if changed_flag: changed += 1
    logger.info("Validation done: %d changed", changed)
    return [r for r in refined if r is not None]
