from flask import Blueprint, request, jsonify, stream_with_context, Response, send_file, render_template, session, redirect, url_for
from flask_cors import CORS, cross_origin
from utils.question_extractor import *
from utils.openai_pipeline import *
from utils.extractor_functions import cleanup_directories, merge_pages
from models.mcq import MCQ as final_mcq
import logging
import concurrent.futures
from pathlib import Path
import json
import os
import time
import uuid
from threading import Thread
import tempfile
import shutil
import glob
import gc
from datetime import datetime
from dotenv import load_dotenv
import redis
from celery_app import celery
import ssl

load_dotenv()

extractor_bp = Blueprint('extractor', __name__)
CORS(extractor_bp)

# Initialize Redis connection


# Redis client (same local instance)
redis_client = redis.Redis(
    host=os.getenv('REDIS_HOST', 'redis'),
    port=os.getenv('REDIS_PORT', 6379),
    db=os.getenv('REDIS_DB', 0)
)
# Celery configuration

# System temp dir (cross-platform)
TMP_DIR = Path("public/tmp")

# Add these constants for directory paths
CROPS_OUT_DIR = "public/images/crops_out"
DEBUG_OUTPUTS_DIR = "public/debug_outputs"
POPPLER_PATH =  os.getenv("POPPLER_PATH")

# Celery task for PDF processing

from queue import Queue, Empty
import threading
import gc

@celery.task(bind=True, name='extractor.extract_text_from_pdf')
def extract_text_from_pdf_task(self, pdf_path, task_id, batch_size: int = 20, dpi: int = 200,
                               worker_count: int | None = None, queue_maxsize: int = 60):
    """
    Producer-consumer pipeline:
      - producer: convert PDF pages in batches, put (page_number, PIL.Image) into queue
      - consumers: worker_count threads that pop from queue and run OCR+detector
    """
    try:
        # Setup
        TMP_DIR_PATH = Path(TMP_DIR) if not isinstance(TMP_DIR, Path) else TMP_DIR
        TMP_DIR_PATH.mkdir(parents=True, exist_ok=True)

        pdf_path = str(pdf_path)
        info = pdfinfo_from_path(pdf_path, userpw=None, poppler_path=POPPLER_PATH)
        total_pages = int(info.get("Pages", 0) or 0)
        logging.info(f"[{task_id}] PDF has {total_pages} pages")

        if total_pages == 0:
            raise ValueError("PDF page count returned 0")

        task_tmp = TMP_DIR_PATH / task_id
        task_tmp.mkdir(parents=True, exist_ok=True)
        Path(CROPS_OUT_DIR).mkdir(parents=True, exist_ok=True)

        # Detector (single instance; caution: must be thread-safe)
        detector = Prediction(
            model_path="public/models/layout_detector_model.pt",
            output_dir=CROPS_OUT_DIR,
            conf_thresh=0.9,
            allowed_labels=["Figure", "Table"]
        )

        # concurrency parameters
        cpu_count = os.cpu_count() or 2
        if worker_count is None:
            # conservative default: min(4, cpu_count)
            worker_count = max(1, min(4, cpu_count))
        queue_maxsize = int(queue_maxsize) if queue_maxsize else 100

        q: Queue = Queue(maxsize=queue_maxsize)
        results: list = []
        prev_state= get_progress(task_id) or {}
        print(f"Prev state: {prev_state}")
        results_lock = threading.Lock()


        # Helper: update progress to redis
        def set_progress(page, total, status="processing", extra=None):
            percent = round((page / total) * 100, 2) if total else 0
            #prev_state= get_progress(task_id) or {}
            progress_data = {"page": page, "total": total, "percent": percent, "status": status, "subject": prev_state.get("subject", ""), "filename": prev_state.get("filename", "")}
            if extra:
                progress_data.update(extra)
            try:
                print(f"Setting progress: {progress_data}")
                redis_client.setex(f"task:{task_id}", 24 * 3600, json.dumps(progress_data, ensure_ascii=False))
            except Exception as e:
                logging.debug(f"[{task_id}] Failed to write progress to redis: {e}")

        # Consumer worker: runs OCR + detector on a page image
        def consumer_worker(worker_idx: int):
            logging.info(f"[{task_id}] Consumer #{worker_idx} started")
            while True:
                try:
                    item = q.get()
                except Exception:
                    # queue get interrupted
                    break
                if item is None:
                    # sentinel received -> exit
                    q.task_done()
                    break

                page_number, page_img = item
                try:
                    text = ocr_page_image(page_img)

                    # detector.generate may return (processed_image, crops_metadata)
                    page_img_out, crops = detector.generate(page_img, img_name=f"{task_id}_page_{page_number:03d}")

                    entry = {"page": page_number, "text": text, "figures": crops}
                    with results_lock:
                        results.append(entry)

                    set_progress(page_number, total_pages)

                except Exception as page_err:
                    logging.exception(f"[{task_id}] Consumer #{worker_idx} error processing page {page_number}: {page_err}")
                    entry = {"page": page_number, "text": "", "figures": [], "error": str(page_err)}
                    with results_lock:
                        results.append(entry)
                    set_progress(page_number, total_pages, extra={"last_error": str(page_err)})

                finally:
                    # cleanup PIL images
                    try:
                        if 'page_img_out' in locals() and getattr(page_img_out, "close", None):
                            try:
                                page_img_out.close()
                            except Exception:
                                pass
                            try:
                                del page_img_out
                            except Exception:
                                pass
                        if getattr(page_img, "close", None):
                            try:
                                page_img.close()
                            except Exception:
                                pass
                        del page_img
                    except Exception:
                        pass
                    gc.collect()
                    q.task_done()

            logging.info(f"[{task_id}] Consumer #{worker_idx} exiting")

        # Producer: convert in batches and enqueue images
        def producer():
            try:
                for start in range(1, total_pages + 1, batch_size):
                    end = min(start + batch_size - 1, total_pages)
                    logging.info(f"[{task_id}] Converting pages {start}..{end} (dpi={dpi})")

                    images = convert_from_path(
                        pdf_path,
                        dpi=dpi,
                        first_page=start,
                        last_page=end,
                        output_folder=str(task_tmp),
                        fmt="png",
                        thread_count=1,
                        poppler_path=POPPLER_PATH
                    )

                    # enqueue each image (blocks if queue is full)
                    for offset, page_img in enumerate(images):
                        page_number = start + offset
                        # block until there is room (backpressure)
                        q.put((page_number, page_img))

                    # remove any files pdf2image wrote to disk for this batch
                    try:
                        for fpath in task_tmp.glob("*.png"):
                            try:
                                fpath.unlink()
                            except Exception:
                                pass
                    except Exception:
                        logging.debug(f"[{task_id}] No temp images to delete in {task_tmp}")

                # all pages produced -> send sentinel to consumers
                for _ in range(worker_count):
                    q.put(None)

            except Exception as e:
                logging.exception(f"[{task_id}] Producer error: {e}")
                # ensure consumers exit
                for _ in range(worker_count):
                    try:
                        q.put(None)
                    except Exception:
                        pass
                raise

        # Start consumer threads
        consumer_threads = []
        for i in range(worker_count):
            t = threading.Thread(target=consumer_worker, args=(i+1,), daemon=True)
            t.start()
            consumer_threads.append(t)

        # Start producer thread (blocks only inside itself)
        prod_thread = threading.Thread(target=producer, daemon=True)
        prod_thread.start()

        # Wait until producer finishes producing and queue is fully processed
        prod_thread.join()
        q.join()  # waits until all items processed (task_done called)
        # Consumers will exit after seeing sentinel; join them
        for t in consumer_threads:
            t.join(timeout=30)

        # Save final results as a single JSON file
        output_path = TMP_DIR_PATH / f"{task_id}_results.json"
        with output_path.open("w", encoding="utf-8") as outf:
            json.dump(sorted(results, key=lambda r: r.get("page", 0)), outf, ensure_ascii=False, indent=2)

        logging.info(f"[{task_id}] Saved OCR results to {output_path}")

        # clean up PDF input
        try:
            pdf_p = Path(pdf_path)
            if pdf_p.exists():
                pdf_p.unlink()
                logging.info(f"[{task_id}] Removed PDF file: {pdf_path}")
        except Exception as e:
            logging.warning(f"[{task_id}] Failed to remove temp pdf {pdf_path}: {e}")

        # Final Redis update
        final_progress = {
            "page": total_pages,
            "total": total_pages,
            "percent": 100,
            "status": "done",
            "subject": prev_state.get("subject", ""),
            "filename": prev_state.get("filename", ""),
            "output": str(output_path)
        }
        redis_client.setex(f"task:{task_id}", 24 * 3600, json.dumps(final_progress, ensure_ascii=False))

        # cleanup per-task tmp folder
        try:
            if task_tmp.exists():
                shutil.rmtree(task_tmp)
        except Exception as e:
            logging.debug(f"[{task_id}] Failed to remove task tmp folder: {e}")

        return str(output_path)

    except Exception as e:
        logging.exception(f"[{task_id}] Error in PDF extraction task: {e}")
        error_data = {"status": "error", "error": str(e)}
        try:
            redis_client.setex(f"task:{task_id}", 24 * 3600, json.dumps(error_data, ensure_ascii=False))
        except Exception:
            logging.error(f"[{task_id}] Failed to write error state to redis: {e}")
        raise



# Celery task for MCQ extraction
@celery.task(bind=True, name='extractor.run_all_pages')
def run_all_pages_task(self, input_json_path, out_path, task_id):
    """Celery task for processing all pages and extracting MCQs"""
    try:
        with open(input_json_path, "r", encoding="utf-8") as f:
            pages = json.load(f)

        all_results = []
        out_path= Path(out_path)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        total_pages = len(pages)
        prev_state= get_progress(task_id) or {}

        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
            for idx, page in enumerate(pages, start=1):
                pnum = page.get("page", "?")
                logging.info(f"Starting page {pnum} (index {idx}/{len(pages)})")
                start_ts = time.time()

                future = executor.submit(process_page_obj, page)
                try:
                    validated_mcqs = future.result(timeout=PER_PAGE_TIMEOUT)
                except concurrent.futures.TimeoutError:
                    logging.error(f"Page {pnum} timed out after {PER_PAGE_TIMEOUT}s — skipping.")
                    future.cancel()
                    validated_mcqs = []
                except Exception as e:
                    logging.exception(f"Page {pnum} raised exception: {e}")
                    validated_mcqs = []

                elapsed = time.time() - start_ts
                logging.info(f"Finished page {pnum} in {elapsed:.1f}s — extracted {len(validated_mcqs)} MCQs")

                all_results.append({"page": pnum, "mcqs": validated_mcqs})

                # Update progress in Redis
                progress_data = {
                    "page": idx,
                    "total": total_pages,
                    "percent": round((idx / total_pages) * 100, 2),
                    "subject": prev_state.get("subject", ""),
                    "filename": prev_state.get("filename", ""),
                    "status": "processing"
                }
                redis_client.setex(f"task:{task_id}", 3600, json.dumps(progress_data))

        # Flatten pages -> list of MCQs
        merged = merge_pages(all_results)

        # Final save of merged MCQs list
        with out_path.open("w", encoding="utf-8") as f:
            json.dump(merged, f, ensure_ascii=False, indent=2)

        # Clean up input JSON file
        try:
            if os.path.exists(input_json_path):
                os.remove(input_json_path)
                logging.info(f"Removed input JSON: {input_json_path}")
        except Exception as e:
            logging.warning(f"Failed to remove temp input json {input_json_path}: {e}")

        # Clean up directories with reference to final results
        cleanup_directories(merged)

        # Final progress update
        final_progress = {
            "page": total_pages,
            "total": total_pages,
            "percent": 100,
            "status": "done",
            "subject": prev_state.get("subject", ""),
            "filename": prev_state.get("filename", ""),
            "output": str(out_path),
            "results": merged
        }
        redis_client.setex(f"task:{task_id}", 3600, json.dumps(final_progress))

        logging.info(f"Done. Final saved to {out_path}")
        return str(out_path)
        
    except Exception as e:
        logging.error(f"Error in MCQ extraction task: {e}")
        error_data = {
            "status": "error",
            "error": str(e)
        }
        redis_client.setex(f"task:{task_id}", 3600, json.dumps(error_data))
        raise

# Helper function to get progress from Redis
def get_progress(task_id):
    """Get progress data from Redis"""
    data = redis_client.get(f"task:{task_id}")
    if data:
        return json.loads(data)
    return None

# --- API route: upload file --- 
@extractor_bp.route('/upload', methods=['POST'])
def upload():
    if 'file' not in request.files:
        return jsonify({"error": "No file part"}), 400
    
    if 'subject' not in request.form or not request.form['subject'].strip():
        return jsonify({"error": "Subject is required"}), 400
    

    file = request.files['file']
    subject= request.form['subject'].strip().lower()
    if file.filename == '' or not file.filename.lower().endswith('.pdf'):
        return jsonify({"error": "Invalid file"}), 400

    # Save to TMP_DIR (platform-safe)
    os.makedirs(TMP_DIR, exist_ok=True)
    temp_pdf_path = os.path.join(TMP_DIR, file.filename)
    file.save(temp_pdf_path)

    task_id = str(uuid.uuid4())
    
    # Initialize progress in Redis
    initial_progress = {"percent": 0, "subject": subject, "filename": file.filename.lower(), "status": "starting"}
    redis_client.setex(f"task:{task_id}", 3600, json.dumps(initial_progress))

    # Start Celery task for PDF processing
    extract_text_from_pdf_task.delay(temp_pdf_path, task_id)

    return jsonify({"message": "Upload started", "task_id": task_id}), 200

# --- API route: SSE progress stream ---
@extractor_bp.route('/progress/<task_id>')
@cross_origin()
def progress(task_id):
    if not task_id or task_id == 'undefined':
        return jsonify({"error": "Invalid task ID"}), 400

    def event_stream():
        last_percent = -1
        while True:
            state = get_progress(task_id)
            if not state:
                yield f"data: {json.dumps({'status':'not_found'})}\n\n"
                break
            if state.get("status") == "done" or state.get("status") == "error":
                yield f"data: {json.dumps(state)}\n\n"
                break
            if state.get("percent") != last_percent:
                yield f"data: {json.dumps(state)}\n\n"
                last_percent = state.get("percent")
            time.sleep(1)

    return Response(
        stream_with_context(event_stream()),
        mimetype="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no"
        }
    )

# --- API route: extract MCQs ---
@extractor_bp.route('/extract', methods=['POST'])
def extract():
    """
    Start MCQ extraction pass.
    Accepts either:
      - form field 'upload_task_id' referencing an earlier /upload task (preferred),
      - OR a direct uploaded .json file in form field 'file' (backwards compatible).
    """
    upload_task_id = request.form.get("upload_task_id") or request.json and request.json.get("upload_task_id")
    temp_json_path = None

    # If upload_task_id provided, check progress in Redis
    if upload_task_id:
        upload_state = get_progress(upload_task_id)
        if not upload_state:
            return jsonify({"error": "upload_task_id not found"}), 400
        if upload_state.get("status") != "done" or not upload_state.get("output"):
            return jsonify({"error": "upload task not finished or output missing"}), 400
        temp_json_path = Path(upload_state["output"])
        if not temp_json_path.exists():
            return jsonify({"error": "OCR JSON file not found on disk"}), 404

    # If no upload_task_id, allow direct JSON file upload (backwards compat)
    if temp_json_path is None:
        if 'file' not in request.files:
            return jsonify({"error": "No file part and no upload_task_id provided"}), 400
        file = request.files['file']
        if file.filename == '' or not file.filename.lower().endswith('.json'):
            return jsonify({"error": "Invalid file"}), 400
        # Save in TMP_DIR
        os.makedirs(TMP_DIR, exist_ok=True)
        temp_json_path = Path(os.path.join(TMP_DIR, file.filename))
        file.save(temp_json_path)

    # Prepare output path and start extraction
    extraction_task_id = str(uuid.uuid4())
    out_path = Path(os.path.join(TMP_DIR, f"{temp_json_path.stem}_mcqs.json"))
    
    # Initialize progress in Redis
    initial_progress = {"percent": 0, "status": "starting", "subject": upload_state.get("subject", ""), "filename": upload_state.get("filename", "")}
    redis_client.setex(f"task:{extraction_task_id}", 3600, json.dumps(initial_progress))

    # Start Celery task for MCQ extraction
    run_all_pages_task.delay(str(temp_json_path), str(out_path), extraction_task_id)

    return jsonify({"message": "Extraction started", "task_id": extraction_task_id}), 202

# Other routes (download, results, save-to-db, etc.) remain mostly the same
# but need to be updated to use get_progress() instead of progress_store

@extractor_bp.route('/download/<task_id>', methods=['GET'])
def download(task_id):
    state = get_progress(task_id)
    if not state or state.get("status") != "done":
        return jsonify({"error": "Task not found or not completed"}), 404

    output_file = state.get("output")
    if not output_file or not os.path.exists(output_file):
        return jsonify({"error": "Output file not found"}), 404

    return send_file(output_file, as_attachment=True, download_name=f"extracted_mcqs_{task_id}.json")

@extractor_bp.route('/results/<task_id>', methods=['GET'])
@cross_origin()
def get_results(task_id):
    try:
        state = get_progress(task_id)
        if not state:
            return jsonify({"error": "Task not found"}), 404
        
        if state.get("status") != "done":
            return jsonify({"error": "Task not completed yet", "status": state.get("status")}), 202
        
        # Return the results directly from Redis if available
        if "results" in state:
            return jsonify(state["results"])
        
        # Fallback: read from file if results not in Redis
        output_file = state.get("output")
        if not output_file:
            return jsonify({"error": "Output file path not specified"}), 404
        
        output_path = Path(output_file)
        if not output_path.exists():
            return jsonify({"error": "Output file not found"}), 404
        
        try:
            with open(output_path, 'r', encoding='utf-8') as f:
                results = json.load(f)
            
            # Update Redis with results
            state["results"] = results
            redis_client.setex(f"task:{task_id}", 3600, json.dumps(state))
            
            return jsonify(results)
            
        except json.JSONDecodeError as e:
            logging.error(f"Failed to parse output file {output_file}: {e}")
            return jsonify({"error": "Invalid JSON in output file"}), 500
        except Exception as e:
            logging.error(f"Error reading output file {output_file}: {e}")
            return jsonify({"error": "Failed to read output file"}), 500
            
    except Exception as e:
        logging.error(f"Error in get_results for task {task_id}: {e}")
        return jsonify({"error": "Internal server error"}), 500

# Similarly update the save-to-db route to use get_progress()
@extractor_bp.route('/save-to-db/<task_id>', methods=['POST'])
@cross_origin()
def save_to_database(task_id):
    try:
        # Get Celery task result/state from Redis
        state = get_progress(task_id)
        
         
        if not state:
            return jsonify({"error": "No results found"}), 404

        # Some tasks return dict with `results`, some return just the list
        mcq_data = state.get("results") if isinstance(state, dict) else state

        if not isinstance(mcq_data, list):
            return jsonify({"error": "Invalid data format"}), 400

        saved_count = 0
        skipped_count = 0
        errors = []

        for index, mcq_item in enumerate(mcq_data):
            try:
                # Validate required fields
                if not all(key in mcq_item for key in ['Question', 'Options', 'Answer']):
                    errors.append(f"Item {index}: Missing required fields")
                    skipped_count += 1
                    continue

                # Extract options
                options = mcq_item['Options']
                option_a = option_b = option_c = option_d = ""

                if isinstance(options, dict):
                    option_a = options.get('A', '') or options.get('a', '')
                    option_b = options.get('B', '') or options.get('b', '')
                    option_c = options.get('C', '') or options.get('c', '')
                    option_d = options.get('D', '') or options.get('d', '')
                elif isinstance(options, list):
                    if len(options) > 0: option_a = str(options[0])
                    if len(options) > 1: option_b = str(options[1])
                    if len(options) > 2: option_c = str(options[2])
                    if len(options) > 3: option_d = str(options[3])

                if not option_a or not option_b:
                    errors.append(f"Item {index}: Insufficient options")
                    skipped_count += 1
                    continue

                # Handle ambiguous math
                ambiguous_math = mcq_item.get('AmbiguousMath', False)
                if ambiguous_math:
                    question = mcq_item.get('QuestionLatex') or mcq_item['Question']
                    option_a = mcq_item.get('OptionsLatex', {}).get('A') or option_a
                    option_b = mcq_item.get('OptionsLatex', {}).get('B') or option_b
                    option_c = mcq_item.get('OptionsLatex', {}).get('C') or option_c
                    option_d = mcq_item.get('OptionsLatex', {}).get('D') or option_d
                else:
                    question = mcq_item['Question']

                # Handle figure refs
                figure_refs = mcq_item.get('FigureRefs', [])
                figure_ref = json.dumps(figure_refs) if figure_refs else None

                # Normalize answer
                correct_answer = mcq_item['Answer'].upper().strip()
                if correct_answer and correct_answer[0] in ['A', 'B', 'C', 'D']:
                    correct_answer = correct_answer[0]
                else:
                    answer_text = mcq_item['Answer'].lower()
                    if answer_text in option_a.lower(): correct_answer = 'A'
                    elif answer_text in option_b.lower(): correct_answer = 'B'
                    elif answer_text in option_c.lower(): correct_answer = 'C'
                    elif answer_text in option_d.lower(): correct_answer = 'D'
                    else:
                        errors.append(f"Item {index}: Could not determine correct answer")
                        skipped_count += 1
                        continue

                # Create DB doc
                mcq_doc = final_mcq(
                    filename=state.get('filename', '') if isinstance(state, dict) else '',
                    subject= state.get('subject', 'general') if isinstance(state, dict) else 'general',
                    question=question.strip(),
                    option_a=option_a.strip(),
                    option_b=option_b.strip(),
                    option_c=option_c.strip() if option_c else "",
                    option_d=option_d.strip() if option_d else "",
                    figure_ref=figure_ref,
                    correct_answer=correct_answer,
                    explanation=mcq_item.get('Explanation', '').strip(),
                    created_at=datetime.utcnow()
                )
                mcq_doc.save()
                saved_count += 1

                # Clean up
                cleanup_temp_files(task_id)

            except Exception as e:
                errors.append(f"Item {index}: {str(e)}")
                skipped_count += 1
                continue

        return jsonify({
            "message": f"Successfully saved {saved_count} MCQs to database",
            "saved_count": saved_count,
            "skipped_count": skipped_count,
            "errors": errors if errors else None
        }), 200

    except Exception as e:
        logging.error(f"Error saving to database for task {task_id}: {e}")
        return jsonify({"error": "Internal server error"}), 500


@extractor_bp.route('/extractor', methods=['GET'])
def extraction():
    if 'user' not in session:
        return redirect(url_for('user.user_login'))
    return render_template('user/mcq-extractor.html')