# docai_math_pipeline.py
import os
import time
import json
import uuid
from pathlib import Path
from typing import List, Dict, Any
from pypdf import PdfReader, PdfWriter
from google.cloud import documentai_v1 as documentai
from google.cloud import storage
from dotenv import load_dotenv

load_dotenv()

# -------------------------
# CONFIG - fill these
# -------------------------
PROJECT_ID = "isentropic-card-468114-f6"
LOCATION = "us"         # e.g. "us" or "us-central1"
PROCESSOR_ID = "7e408abd889e028f" # Document AI processor with Math OCR add-on
GCS_BUCKET = "docai004"
INPUT_PDF = Path("QuantitativeAptitudeVOL1.pdf")     # path to local PDF
CHUNK_PAGES = 20                 # split PDF into 20-page chunks
OUTPUT_JSON = Path("docai_merged_output.json")
TMP_DIR = Path("tmp_docai_chunks")
TMP_DIR.mkdir(exist_ok=True)

# -------------------------
# Clients
# -------------------------
docai_client = documentai.DocumentProcessorServiceClient()
storage_client = storage.Client()

# Helper to build processor name
def processor_name(project: str, location: str, processor: str) -> str:
    return docai_client.processor_path(project, location, processor)

# -------------------------
# 1) Split PDF into N-page chunks
# -------------------------
def split_pdf_to_chunks(input_pdf: Path, chunk_pages: int, out_dir: Path) -> List[Path]:
    reader = PdfReader(str(input_pdf))
    total = len(reader.pages)
    chunk_files = []
    start = 0
    idx = 0
    while start < total:
        end = min(start + chunk_pages, total)
        writer = PdfWriter()
        for p in range(start, end):
            writer.add_page(reader.pages[p])
        out_path = out_dir / f"{input_pdf.stem}_chunk_{idx+1:03d}.pdf"
        with open(out_path, "wb") as f:
            writer.write(f)
        chunk_files.append(out_path)
        idx += 1
        start = end
    return chunk_files

# -------------------------
# 2) Upload local file to GCS
# -------------------------
def upload_to_gcs(bucket_name: str, local_path: Path, dest_blob_name: str) -> str:
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(dest_blob_name)
    blob.upload_from_filename(str(local_path))
    gcs_uri = f"gs://{bucket_name}/{dest_blob_name}"
    return gcs_uri

# -------------------------
# 3) Call Document AI batch_process_documents for a single chunk (GCS input -> GCS output prefix)
# -------------------------
def batch_process_gcs_pdf(gcs_input_uri: str, gcs_output_prefix: str, project: str, location: str, processor_id: str, mime_type: str = "application/pdf") -> str:
    """
    Submits a batch_process_documents job. Returns the operation response (will block until done).
    Document AI will write output JSON files into gcs_output_prefix.
    """
    name = processor_name(project, location, processor_id)

    # Build request objects
    gcs_document = documentai.GcsDocument(gcs_uri=gcs_input_uri, mime_type=mime_type)
    gcs_docs = documentai.GcsDocuments(documents=[gcs_document])
    input_config = documentai.BatchDocumentsInputConfig(gcs_documents=gcs_docs)

    gcs_output_config = documentai.DocumentOutputConfig.GcsOutputConfig(gcs_uri=gcs_output_prefix)
    document_output_config = documentai.DocumentOutputConfig(gcs_output_config=gcs_output_config)

    request = {
        "name": name,
        "input_documents": input_config,
        "document_output_config": document_output_config,
    }

    operation = docai_client.batch_process_documents(request=request)
    print(f"Submitted Document AI batch job for {gcs_input_uri}; waiting for completion...")
    operation.result(timeout=1800)  # waits up to 30 minutes; adjust as needed
    print("Batch process completed.")
    return gcs_output_prefix

# -------------------------
# 4) List and download output JSONs from the GCS prefix
# -------------------------
def list_output_blobs_and_download(bucket_name: str, prefix: str, local_out_dir: Path) -> List[Path]:
    bucket = storage_client.bucket(bucket_name)
    blobs = list(storage_client.list_blobs(bucket, prefix=prefix))
    local_out_dir.mkdir(parents=True, exist_ok=True)
    downloaded = []
    for blob in blobs:
        # skip "directory" objects
        if not blob.name.endswith(".json") and not blob.name.endswith(".json.gz"):
            continue
        local_path = local_out_dir / Path(blob.name).name
        blob.download_to_filename(str(local_path))
        downloaded.append(local_path)
    return downloaded

# -------------------------
# 5) Parse Document AI output JSON -> per-page text & math (heuristic)
# -------------------------
def extract_pages_from_docai_json(doc_json: dict) -> List[Dict[str, Any]]:
    """
    doc_json is the parsed JSON file created by Document AI batch output.
    We extract:
      - document['text'] entire text
      - pages: for each page, assemble page text using the page.layout.textAnchor.textSegments ranges
    Additionally, we attempt to collect any 'latex' fields (math OCR add-on may store LaTeX in custom fields).
    """
    results = []
    document = doc_json.get("document") or doc_json  # sometimes doc_json is already the document
    full_text = document.get("text", "")

    pages = document.get("pages", [])
    for page in pages:
        page_number = page.get("pageNumber", None)
        # Retrieve text segments from the page layout
        page_text = ""
        layout = page.get("layout", {})
        text_anchor = layout.get("textAnchor", {})
        segments = text_anchor.get("textSegments", []) if text_anchor else []
        if segments:
            # Concatenate all text segments for the page
            parts = []
            for seg in segments:
                start = int(seg.get("startIndex", 0))
                end = seg.get("endIndex", 0)
                # protect against invalid indices
                if start >= 0 and end and end > start and end <= len(full_text):
                    parts.append(full_text[start:end])
            page_text = "".join(parts).strip()
        else:
            # fallback: try using layout/blocks text field
            # Document AI may also include 'blocks' or 'paragraphs' entries; we ignore for brevity
            page_text = ""

        # collect any math-latex occurrences (heuristic: search for keys named 'latex' in the page dict)
        math_entries = []

        def find_latex(obj):
            if isinstance(obj, dict):
                for k, v in obj.items():
                    if k.lower() == "latex" and isinstance(v, str):
                        math_entries.append({"latex": v})
                    else:
                        find_latex(v)
            elif isinstance(obj, list):
                for el in obj:
                    find_latex(el)

        find_latex(page)

        results.append({
            "pageNumber": page_number,
            "page_text": page_text,
            "math": math_entries
        })
    return results

# -------------------------
# 6) Full orchestrator
# -------------------------
def run_documentai_pipeline():
    # split PDF
    chunk_files = split_pdf_to_chunks(INPUT_PDF, CHUNK_PAGES, TMP_DIR)
    print(f"Created {len(chunk_files)} chunk files in {TMP_DIR}")

    merged_results = []   # will contain dicts per page across all chunks
    bucket = GCS_BUCKET

    for i, chunk in enumerate(chunk_files, start=1):
        print(f"\n--- Processing chunk {i}/{len(chunk_files)}: {chunk.name} ---")
        blob_name = f"Input/{INPUT_PDF.stem}/{chunk.name}"
        gcs_input = upload_to_gcs(bucket, chunk, blob_name)
        print(f"Uploaded to {gcs_input}")

        # set output prefix for this chunk (Document AI will create JSONs under this prefix)
        unique_prefix = f"Output/{INPUT_PDF.stem}/{chunk.stem}-{uuid.uuid4().hex}/"
        gcs_output_prefix = f"gs://{bucket}/{unique_prefix}"

        # submit batch job (blocks until finished)
        batch_process_gcs_pdf(gcs_input, gcs_output_prefix, PROJECT_ID, LOCATION, PROCESSOR_ID)

        # list and download output JSONs (they are under the prefix in the bucket)
        local_out_dir = TMP_DIR / f"outputs_{chunk.stem}"
        downloaded = list_output_blobs_and_download(bucket, prefix=unique_prefix, local_out_dir=local_out_dir)
        print(f"Downloaded {len(downloaded)} output JSON(s) for chunk {chunk.name}")

        # parse each downloaded json and append per-page results
        for jl in downloaded:
            with open(jl, "r", encoding="utf-8") as f:
                doc_json = json.load(f)
            pages_data = extract_pages_from_docai_json(doc_json)
            for p in pages_data:
                # calculate absolute page number: compute base from chunk index
                # chunk i corresponds to pages ((i-1)*CHUNK_PAGES + 1) .. etc.
                base_page = (i-1) * CHUNK_PAGES
                page_num_in_chunk = p.get("pageNumber")
                if page_num_in_chunk is not None:
                    absolute_page_num = base_page + int(page_num_in_chunk)
                else:
                    absolute_page_num = None
                merged_results.append({
                    "page": absolute_page_num,
                    "page_text": p.get("page_text", ""),
                    "math": p.get("math", [])
                })

        # optionally delete input chunk from GCS or local chunk if desired
        # storage_client.bucket(bucket).blob(blob_name).delete()
        # chunk.unlink()  # if you want to remove local chunk

    # sort results by page
    merged_sorted = sorted(merged_results, key=lambda x: (x["page"] if x["page"] is not None else 0))
    # save final json
    with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
        json.dump(merged_sorted, f, ensure_ascii=False, indent=2)

    print(f"\nDone. Merged output written to {OUTPUT_JSON} (pages: {len(merged_sorted)})")


if __name__ == "__main__":
    run_documentai_pipeline()
