import os
import re
import json
import time
import base64
import logging
import concurrent.futures
from pathlib import Path
from typing import List, Optional, Dict, Any, Set
from pydantic import BaseModel, ValidationError
from PIL import Image
import shutil
import glob
from dotenv import load_dotenv

# Install: pip install openai python-dotenv pillow pydantic
import openai

# ---------- Basic config & logging ----------
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
PER_PAGE_TIMEOUT = 120   # seconds per page (adjust)
MAX_RETRIES = 2
RETRY_DELAY = 1.5

load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") or "YOUR_KEY_HERE"
openai.api_key = OPENAI_API_KEY

# ---------- Pipeline config ----------
INPUT_PAGES_JSON = Path("ocr_output.json")   # your OCR+YOLO per-page JSON input
OUTPUT_MCQ_JSON = Path("all_mcqs.json")
DEBUG_DIR = Path("debug_outputs")
DEBUG_DIR.mkdir(exist_ok=True)
MODEL_NAME = os.getenv("OPENAI_VIS_MODEL", "gpt-4o")  # adjust if needed
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", "2000"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.0"))
# System temp dir (cross-platform)
TMP_DIR = Path("public/tmp")

# Add these constants for directory paths
CROPS_OUT_DIR = "public/images/crops_out"
DEBUG_OUTPUTS_DIR = "public/debug_outputs"


# ---------- Extended Pydantic model ----------
class MCQ(BaseModel):
    Question: str
    Options: Dict[str, str]               # {"A":"...", "B":"..."}
    Answer: Optional[str]
    Explanation: Optional[str] = None
    FigureRefs: Optional[List[str]] = None
    # Optional math-aware fields produced by first pass
    QuestionLatex: Optional[str] = None
    OptionsLatex: Optional[Dict[str, str]] = None
    AmbiguousMath: Optional[bool] = False

# ---------- Helpers ----------

def extract_figure_references_from_results(results: List[dict]) -> Set[str]:
    """
    Extract all unique figure references from the final results.
    Handles different formats of FigureRefs in the results.
    """
    referenced_images = set()
    
    if not results:
        return referenced_images
    
    for result in results:
        if not isinstance(result, dict):
            continue
            
        # Handle different formats of FigureRefs
        figure_refs = result.get("FigureRefs") or result.get("figureRefs") or result.get("figure_refs")
        
        if figure_refs:
            if isinstance(figure_refs, list):
                for ref in figure_refs:
                    if isinstance(ref, str):
                        # Extract filename from path if it's a full path
                        filename = os.path.basename(ref)
                        # Remove any query parameters or fragments
                        filename = filename.split('?')[0].split('#')[0]
                        referenced_images.add(filename)
            elif isinstance(figure_refs, str):
                # Handle comma-separated string
                for ref in figure_refs.split(','):
                    ref = ref.strip()
                    if ref:
                        filename = os.path.basename(ref)
                        filename = filename.split('?')[0].split('#')[0]
                        referenced_images.add(filename)
    
    return referenced_images

def extract_json_from_text(text: str) -> Optional[Any]:
    """Extract JSON from text, handling various formats including code blocks."""
    # First, try to parse the entire text as JSON
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass
    
    # Look for JSON in code blocks
    json_patterns = [
        r"```(?:json)?\s*([\s\S]*?)\s*```",  # ```json ... ```
        r"```\s*([\s\S]*?)\s*```",           # ``` ... ```
        r"`([^`]*)`",                         # `...`
    ]
    
    for pattern in json_patterns:
        matches = re.findall(pattern, text, re.IGNORECASE)
        for match in matches:
            try:
                return json.loads(match)
            except json.JSONDecodeError:
                continue
    
    # Try to find any JSON-like structure in the text
    try:
        # Look for the first { and last } and try to parse
        start = text.find('{')
        end = text.rfind('}') + 1
        if start >= 0 and end > start:
            return json.loads(text[start:end])
    except json.JSONDecodeError:
        pass
    
    return None


def normalize_options_field(mcq_data: dict) -> dict:
    """Normalize the Options field to be a dictionary with letter keys."""
    options = mcq_data.get("Options", {})
    
    # If options is a list, convert to dict with A, B, C, D keys
    if isinstance(options, list):
        normalized = {}
        letters = ["A", "B", "C", "D", "E", "F", "G", "H"]
        for i, option in enumerate(options):
            if i < len(letters):
                normalized[letters[i]] = str(option)
        mcq_data["Options"] = normalized
    
    # If options is a dict with numeric keys, convert to letter keys
    elif isinstance(options, dict):
        normalized = {}
        letters = ["A", "B", "C", "D", "E", "F", "G", "H"]
        for i, (key, value) in enumerate(options.items()):
            if i < len(letters):
                normalized[letters[i]] = str(value)
        mcq_data["Options"] = normalized
    
    return mcq_data

def cleanup_directories(final_results: List[dict] = None):
    """
    Clean up directories while preserving images referenced in final results.
    - Always clean debug_outputs completely
    - Only clean crops_out images that are not referenced in final results
    """
    # Always clean debug_outputs completely
    try:
        if os.path.exists(DEBUG_OUTPUTS_DIR):
            shutil.rmtree(DEBUG_OUTPUTS_DIR)
            os.makedirs(DEBUG_OUTPUTS_DIR, exist_ok=True)
            logging.info(f"Cleaned up debug directory: {DEBUG_OUTPUTS_DIR}")
    except Exception as e:
        logging.warning(f"Failed to clean up {DEBUG_OUTPUTS_DIR}: {e}")
    
    # Clean crops_out directory selectively
    try:
        if os.path.exists(CROPS_OUT_DIR):
            # Get all referenced images from final results using enhanced extraction
            referenced_images = extract_figure_references_from_results(final_results)
            
            logging.info(f"Found {len(referenced_images)} referenced images: {list(referenced_images)}")
            
            # Get all images in crops_out directory
            all_images = set()
            image_patterns = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff', '*.webp']
            
            for pattern in image_patterns:
                for img_path in glob.glob(os.path.join(CROPS_OUT_DIR, pattern)):
                    if os.path.isfile(img_path):
                        filename = os.path.basename(img_path)
                        all_images.add(filename)
            
            # Find images to delete (those not referenced)
            images_to_delete = all_images - referenced_images
            
            # Delete unreferenced images
            deleted_count = 0
            for img_name in images_to_delete:
                img_path = os.path.join(CROPS_OUT_DIR, img_name)
                try:
                    if os.path.exists(img_path):
                        os.remove(img_path)
                        logging.info(f"Removed unreferenced image: {img_name}")
                        deleted_count += 1
                except Exception as e:
                    logging.warning(f"Failed to remove image {img_name}: {e}")
            
            # Clean empty subdirectories but preserve the main crops_out directory
            for item in os.listdir(CROPS_OUT_DIR):
                item_path = os.path.join(CROPS_OUT_DIR, item)
                if os.path.isdir(item_path):
                    try:
                        # Only remove if empty
                        if not os.listdir(item_path):
                            os.rmdir(item_path)
                            logging.info(f"Removed empty subdirectory: {item}")
                        else:
                            logging.info(f"Keeping non-empty subdirectory: {item}")
                    except Exception as e:
                        logging.warning(f"Failed to remove subdirectory {item}: {e}")
            
            logging.info(f"Cleaned crops_out directory. Kept {len(referenced_images)} referenced images, removed {deleted_count} unreferenced images.")
            
    except Exception as e:
        logging.warning(f"Failed to clean up {CROPS_OUT_DIR}: {e}")

def cleanup_temp_files(task_id: str = None):
    """
    Clean up temporary files created during processing
    """
    try:
        # Clean up temporary JSON files
        temp_files = glob.glob(os.path.join(TMP_DIR, f"*{task_id}*")) if task_id else []
        temp_files.extend(glob.glob(os.path.join(TMP_DIR, "*_results.json")))
        temp_files.extend(glob.glob(os.path.join(TMP_DIR, "*_mcqs.json")))
        
        for temp_file in temp_files:
            try:
                if os.path.exists(temp_file):
                    os.remove(temp_file)
                    logging.info(f"Removed temp file: {temp_file}")
            except Exception as e:
                logging.warning(f"Failed to remove temp file {temp_file}: {e}")
                
    except Exception as e:
        logging.warning(f"Failed to clean temp files: {e}")

def merge_pages(json_data):
    """Flatten the page-wise results into a single list of MCQs."""
    merged = []
    for page in json_data:
        if isinstance(page, dict) and 'mcqs' in page:
            # extend with page's MCQs (may be empty list)
            merged.extend(page['mcqs'])
    return merged


from PIL import Image
import io

import os, io, base64, logging
from pathlib import Path
from typing import Optional
from PIL import Image, ImageOps

MAX_IMAGE_BYTES = int(os.getenv("MAX_IMAGE_BYTES", "30000"))  # 30 KB

from PIL import Image, ImageOps, ImageChops  # <-- Added ImageChops

def compress_image_to_data_uri(path: Path, max_bytes: int = MAX_IMAGE_BYTES) -> Optional[str]:
    """
    Compress image to a JPEG data URI under max_bytes.
    - Auto-crops whitespace
    - Converts to grayscale if it saves significant space
    - Aggressive resizing & quality tuning
    Returns data URI string or None if compression can't produce small enough image.
    """
    try:
        img = Image.open(path).convert("RGB")
    except Exception as e:
        logging.warning(f"Failed to open image {path}: {e}")
        return None

    # 1. Auto-crop white/near-white borders (helps for scanned diagrams/tables)
    bg = Image.new(img.mode, img.size, (255, 255, 255))
    diff = ImageChops.difference(img, bg)  # <-- Fixed here
    bbox = diff.getbbox()
    if bbox:
        img = img.crop(bbox)

    w, h = img.size

    # 2. Start with aggressive scaling for large images
    max_dim_target = 512  # target max dimension (tune for readability)
    if max(w, h) > max_dim_target:
        scale = max_dim_target / max(w, h)
        img = img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)

    # 3. Try grayscale for diagrams/tables
    img_gray = img.convert("L")
    buf_gray = io.BytesIO()
    img_gray.save(buf_gray, format="JPEG", quality=40, subsampling=2)
    gray_size = len(buf_gray.getvalue())
    if gray_size < max_bytes * 0.9:  # keep grayscale if space saving is significant
        img = img_gray.convert("RGB")

    # 4. Iteratively adjust quality & size
    quality = 50  # start lower than 75
    scale_factor = 0.85  # scale down aggressively if too big

    for _ in range(8):
        buf = io.BytesIO()
        img.save(buf, format="JPEG", quality=quality, subsampling=2, optimize=True)
        raw = buf.getvalue()

        if len(raw) <= max_bytes:
            b64 = base64.b64encode(raw).decode("ascii")
            return f"data:image/jpeg;base64,{b64}"

        # Adjust aggressively: reduce quality if still high, else resize
        if quality > 25:
            quality = max(20, int(quality * 0.8))
        else:
            w, h = img.size
            img = img.resize((int(w * scale_factor), int(h * scale_factor)), Image.LANCZOS)

    return None



# ---------- Prompt builders ----------

# ---------- Prompt Builders (Add these if missing) ----------

def build_system_prompt_reconstruct() -> str:
    """Build the system prompt for the first pass (extraction and reconstruction)."""
    return """You are an expert at extracting multiple choice questions (MCQs) from text. 
Your task is to:
1. Extract all complete MCQs from the provided text
2. Reconstruct mathematical expressions using LaTeX when possible
3. Associate figures with questions when appropriate
4. Skip questions that are clearly incomplete or continue on another page

Return a JSON array of objects with the following structure for each MCQ:
{
  "Question": "The question text",
  "Options": {"A": "Option A", "B": "Option B", ...},
  "Answer": "Correct answer letter (A, B, C, etc.) if available",
  "Explanation": "Explanation if available",
  "FigureRefs": ["List of figure names if applicable"],
  "QuestionLatex": "LaTeX version of the question if it contains math",
  "OptionsLatex": {"A": "LaTeX for option A", ...} if options contain math,
  "AmbiguousMath": true/false if mathematical notation is ambiguous
}

If no MCQs are found, return an empty array [].
"""

def build_system_prompt_validate() -> str:
    """Build the system prompt for the second pass (validation and repair)."""
    return """You are an expert at validating and repairing multiple choice questions (MCQs). 
Your task is to:
1. Validate the logical consistency of each MCQ
2. Repair any issues with formatting, clarity, or completeness
3. Return a valid JSON object for each MCQ (even if you make changes)
4. Only skip questions that are fundamentally flawed or incomplete

Return a JSON object with the same structure as the input, or if the question should be skipped:
{"skip": true, "reason": "Explanation for skipping"}

If the question is valid, return it in the proper format:
{
  "Question": "The question text",
  "Options": {"A": "Option A", "B": "Option B", ...},
  "Answer": "Correct answer letter (A, B, C, etc.) if available",
  "Explanation": "Explanation if available",
  "FigureRefs": ["List of figure names if applicable"],
  "QuestionLatex": "LaTeX version of the question if it contains math",
  "OptionsLatex": {"A": "LaTeX for option A", ...} if options contain math,
  "AmbiguousMath": true/false if mathematical notation is ambiguous
}
"""