# Content_creator.py
"""
Presentation generator (fixed + image_suggestion + slide_no + expanded bullets)

Features:
- Each slide has: heading, bullets (1-2 sentence items), speaker_notes, image_suggestion, slide_no
- Bullets will be expanded to 1-2 sentences if they appear too short
- Ensures first slide is "Introduction" and last slide is "Conclusion"
- Uses "\n" correctly for PDF extraction and logs
- Retries with deterministic temperature=0.0 if initial parse fails
Dependencies: langchain_openai (ChatOpenAI), pydantic, python-dotenv, PyPDF2
"""

from typing import List, Tuple, Dict, Any, Optional
from pydantic import BaseModel, Field, ValidationError
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.output_parsers import PydanticOutputParser
import json
import math
import logging
from pathlib import Path
import os
from dotenv import load_dotenv
from PyPDF2 import PdfReader
from celery_app import celery

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
root_dir = Path(__file__).parent.parent
dotenv_path = root_dir / ".env"
load_dotenv(dotenv_path=dotenv_path)

# ---------------------------
# Pydantic schema (with image_suggestion + slide_no)
# ---------------------------
class SlideModel(BaseModel):
    heading: str = Field(..., description="Slide heading/title")
    bullets: List[str] = Field(..., description="List of short bullet points (1-2 sentence each)")
    speaker_notes: str = Field(..., description="Elaborative speaker notes (~8-20 sentences)")
    image_suggestion: Optional[str] = Field(None, description="Short prompt suggestion for an image for this slide")
    slide_no: Optional[int] = Field(None, description="Slide index (1-based). Will be normalized in post-processing")


class PresentationModel(BaseModel):
    title: str = Field(..., description="Presentation title")
    slides: List[SlideModel] = Field(..., description="List of slides")

# ---------- small helper (fallback for bullet expansion) ----------
def _expand_short_bullets(bullets: List[str], heading: str, context_excerpt: str, model_name: str = "gpt-4o-mini") -> List[str]:
    """
    If you already have a stronger _expand_short_bullets, keep that.
    This fallback simply returns bullets unchanged but ensures they are non-empty strings.
    You can replace this to call a lightweight deterministic LLM to expand tiny bullets.
    """
    out = []
    for b in bullets:
        b = (b or "").strip()
        if not b:
            continue
        # simple rule: if bullet short (<20 chars) append brief phrase to help memory
        if len(b) < 20:
            out.append(b + " — key point")
        else:
            out.append(b)
    return out


# ---------- synthesize speaker notes fallback ----------
def _synthesize_speaker_notes_from_bullets(heading: str, bullets: List[str], context_excerpt: str) -> str:
    """
    Quick deterministic fallback to generate 3-5 sentence speaker notes from bullets.
    Not as rich as an LLM but safe and guaranteed to produce text.
    """
    if not bullets:
        # generic fallback using heading and context excerpt
        snippet = (context_excerpt or "").strip()
        return f"This slide introduces {heading}. {snippet[:240]}".strip()

    # Compose sentences from bullets
    sentences = []
    for b in bullets[:5]:
        # make it into a sentence (if already short phrase)
        s = b.rstrip(".")
        # simple connector heuristics
        sentences.append(f"{s}.")
    # Add a short intro and closing
    intro = f"On the slide titled '{heading}', we will cover the following points."
    closing = "These points summarize the core ideas and give you takeaways to remember."
    notes = " ".join([intro] + sentences + [closing])
    return notes

# ---------------------------
# LLM + parser setup
# ---------------------------
def _build_chain(
    model_name: str = "gpt-4o-mini", temperature: float = 0.2, max_tokens: int = 1400
) -> Tuple[LLMChain, PydanticOutputParser]:
    parser = PydanticOutputParser(pydantic_object=PresentationModel)

    prompt = """You are an assistant that generates a presentation JSON object precisely matching the schema.
Return JSON ONLY (no explanations).

IMPORTANT: Use the CONTEXT below as the authoritative source/topic for the entire presentation.
Do NOT change the topic or invent a different subject. Base headings, bullets, speaker_notes and image_suggestion on the CONTEXT.

Follow these strict rules:
- Output MUST be a single JSON object with keys exactly: "title" and "slides".
- "title" is a short presentation title (string).
- "slides" is an array of slide objects; each slide object must include these keys:
    - "heading" (string)
    - "bullets" (array of short strings) — each bullet should be a concise 1–2 sentence explanation (useful as a memory cue for students)
    - "speaker_notes" (string, elaborative about 6-12 sentences, natural narration style, include one brief example/analogy if appropriate)
    - "image_suggestion" (string or null) — a one-line prompt describing an image for the slide
    - "slide_no" (int or null) — optional; if omitted we'll set it sequentially during post-processing
- The FIRST slide must be an "Introduction" slide with heading exactly "Introduction".
- The LAST slide must be a "Conclusion" slide with heading exactly "Conclusion".
- Bullets should be 1–2 short sentences. Provide 5–8 bullets per slide.
- speaker_notes should be comprehensive and suitable for narration; avoid lists inside speaker_notes.
- If the CONTEXT includes a title or heading line, prefer that for the "title" field; otherwise synthesize a concise, accurate title.
- Do NOT return any extra keys, comments, or markup — only the JSON object matching the schema.

Here are the format instructions you must follow:
{format_instructions}

CONTEXT (use this to build the slides - base the presentation on this):
{context}

Audience: {audience}
Tone: {tone}

Produce JSON only.
"""

    prompt_template = PromptTemplate(
        template=prompt,
        input_variables=["context", "audience", "tone"],
        partial_variables={"format_instructions": parser.get_format_instructions()},
    )

    llm = ChatOpenAI(
        model_name=model_name,
        temperature=temperature,
        max_tokens=max_tokens,
        openai_api_key=os.getenv("OPENAI_API_KEY"),
    )
    chain = LLMChain(llm=llm, prompt=prompt_template)
    return chain, parser

def _normalize_raw_obj(obj: Dict[str, Any], context_excerpt: str) -> Dict[str, Any]:
    """
    Ensure 'title' exists and every slide has: heading, bullets(list), speaker_notes(str), image_suggestion (opt), slide_no (int).
    Returns normalized dict suitable for Pydantic model construction or direct use.
    """
    title = obj.get("title") or ""
    if not title and "slides" in obj and isinstance(obj["slides"], list) and len(obj["slides"]) > 0:
        # try take first slide heading or first 80 chars of first slide speaker_notes
        first = obj["slides"][0]
        title = first.get("heading") or (first.get("speaker_notes") or "")[:80]
    title = (title or "").strip() or "Presentation"

    slides_raw = obj.get("slides") or []
    normalized_slides = []
    for idx, s in enumerate(slides_raw, start=1):
        if not isinstance(s, dict):
            # if it's a Pydantic model, convert
            try:
                s = dict(s)
            except Exception:
                s = {}

        heading = (s.get("heading") or "").strip() or f"Slide {idx}"
        bullets = s.get("bullets") or []
        # Ensure bullet list of strings
        if isinstance(bullets, str):
            # sometimes bullets come as single string with newline separated
            bullets = [ln.strip() for ln in bullets.splitlines() if ln.strip()]
        bullets = [str(b).strip() for b in bullets if str(b).strip()]
        if not bullets:
            bullets = ["Key point"]

        speaker_notes = (s.get("speaker_notes") or "").strip()
        image_suggestion = s.get("image_suggestion") or None
        slide_no = s.get("slide_no") or idx

        normalized_slides.append({
            "heading": heading,
            "bullets": bullets,
            "speaker_notes": speaker_notes,
            "image_suggestion": image_suggestion,
            "slide_no": slide_no
        })

    # Ensure intro/conclusion present
    if not normalized_slides or normalized_slides[0]["heading"].strip().lower() != "introduction":
        intro = {
            "heading": "Introduction",
            "bullets": ["Overview of the topic", "Why it matters", "What you'll learn"],
            "speaker_notes": f"Welcome. This introduction summarizes the topic briefly: {context_excerpt[:300]}",
            "image_suggestion": None,
            "slide_no": 1
        }
        normalized_slides.insert(0, intro)
        # re-index slide_no below

    if normalized_slides[-1]["heading"].strip().lower() != "conclusion":
        conclusion = {
            "heading": "Conclusion",
            "bullets": ["Key takeaways", "Next steps", "Further resources"],
            "speaker_notes": "In conclusion, we've summarized the main points and recommended next steps. Review the takeaways and explore resources to deepen your understanding.",
            "image_suggestion": None,
            "slide_no": len(normalized_slides) + 1
        }
        normalized_slides.append(conclusion)

    # Reassign sequential slide_no and ensure required fields
    out_slides = []
    for i, s in enumerate(normalized_slides, start=1):
        heading = s["heading"]
        bullets = _expand_short_bullets(s["bullets"], heading=heading, context_excerpt=context_excerpt, model_name="gpt-4o-mini")
        speaker_notes = s["speaker_notes"] or _synthesize_speaker_notes_from_bullets(heading, bullets, context_excerpt)
        img = s.get("image_suggestion")
        out_slides.append({
            "heading": heading,
            "bullets": bullets,
            "speaker_notes": speaker_notes,
            "image_suggestion": img,
            "slide_no": i
        })

    return {"title": title, "slides": out_slides}





# ---------------------------
# Public API
# ---------------------------
@celery.task(name="Content_creator.generate_presentation_json")
def generate_presentation_json(
    data: dict,
    audience: str = "general",
    tone: str = "informative",
    model_name: str = "gpt-4o-mini",
    temperature: float = 0.18,
    max_tokens: int = 1400,
) -> Dict[str, Any]:
    """
    Robust wrapper that calls LLM, parses output, and ensures the final JSON is valid and complete.
    Returns {"json": presentation_out, "id": session_id}
    """

    session_id = data.get("id")
    context = data.get("text", "")
    if not context or not context.strip():
        raise ValueError("Empty context: provide text (e.g., from a PDF) as the 'text' parameter.")

    context_for_prompt = context if len(context) <= 32000 else context[:32000]
    logger.info("Context preview (first 600 chars):\n%s", context_for_prompt[:600].replace("\n", " "))

    chain, parser = _build_chain(model_name=model_name, temperature=temperature, max_tokens=max_tokens)

    raw = chain.run(
        {
            "context": context_for_prompt,
            "audience": audience,
            "tone": tone
        }
    )
    logger.debug("LLM raw output (truncated): %s", (raw[:2000] if raw else "<empty>"))

    # Try parse with pydantic parser; on failure, fallback to JSON load + normalization
    presentation_obj = None
    try:
        presentation_model = parser.parse(raw)
        # convert into Python dict structure for downstream pipeline
        presentation_obj = {"title": presentation_model.title, "slides": [s.dict() if hasattr(s, "dict") else dict(s) for s in presentation_model.slides]}
        logger.info("LLM parsed into PresentationModel successfully.")
    except Exception as e:
        logger.warning("Pydantic parser failed: %s", e)
        # attempt deterministic retry with temperature=0.0
        try:
            chain_det, parser_det = _build_chain(model_name=model_name, temperature=0.0, max_tokens=max_tokens)
            raw2 = chain_det.run({"context": context_for_prompt, "audience": audience, "tone": tone})
            logger.debug("LLM retry raw (truncated): %s", (raw2[:2000] if raw2 else "<empty>"))
            # try parse again
            try:
                presentation_model = parser_det.parse(raw2)
                presentation_obj = {"title": presentation_model.title, "slides": [s.dict() if hasattr(s, "dict") else dict(s) for s in presentation_model.slides]}
                raw = raw2
                logger.info("Deterministic parse succeeded.")
            except Exception:
                # no luck — attempt to json.loads and normalize
                logger.warning("Deterministic parse failed; attempting json.loads normalization.")
                candidate = None
                try:
                    candidate = json.loads(raw2)
                except Exception:
                    try:
                        candidate = json.loads(raw)
                    except Exception:
                        candidate = None

                if candidate and isinstance(candidate, dict):
                    presentation_obj = _normalize_raw_obj(candidate, context_for_prompt)
                else:
                    # final fallback: synthesize a very minimal presentation from context
                    logger.error("LLM outputs could not be parsed as JSON; synthesizing minimal slides from context.")
                    presentation_obj = {
                        "title": (context_for_prompt.splitlines()[0][:80]) or "Presentation",
                        "slides": [
                            {
                                "heading": "Introduction",
                                "bullets": ["Overview"],
                                "speaker_notes": context_for_prompt[:300],
                                "image_suggestion": None,
                                "slide_no": 1
                            },
                            {
                                "heading": "Conclusion",
                                "bullets": ["Key takeaways"],
                                "speaker_notes": "Summary notes.",
                                "image_suggestion": None,
                                "slide_no": 2
                            }
                        ]
                    }
        except Exception as e2:
            logger.exception("Retry chain failed: %s", e2)
            raise RuntimeError("Failed to parse LLM output into PresentationModel schema.") from e2

    # If we have a presentation_obj, ensure normalization and fill missing fields
    if presentation_obj is None:
        raise RuntimeError("Unable to produce a presentation object from LLM output.")

    # Normalize (ensures speaker_notes/bullets/image_suggestion/slide_no present)
    normalized = _normalize_raw_obj(presentation_obj, context_for_prompt)

    presentation_out = {
        "title": normalized["title"],
        "slides": normalized["slides"]
    }

    logger.info("Generated presentation with %d slides (title: %s).", len(presentation_out["slides"]), presentation_out["title"])
    return {"json": presentation_out, "id": session_id}


def save_presentation_json(presentation: Dict[str, Any], path: str = "presentation.json") -> str:
    with open(path, "w", encoding="utf-8") as f:
        json.dump(presentation, f, ensure_ascii=False, indent=2)
    return path


def build_narration_script_from_presentation(presentation: Dict[str, Any]) -> str:
    lines: List[str] = []
    title = presentation.get("title", "")
    lines.append(f"TITLE: {title}\n")
    for s in presentation.get("slides", []):
        lines.append(f"SLIDE {s.get('slide_no', '?')}: {s.get('heading','')}")
        bullets = s.get("bullets", [])
        if bullets:
            lines.append("Bullets:")
            for b in bullets:
                lines.append(f"- {b}")
        lines.append("")  # blank line
        lines.append(s.get("speaker_notes", ""))
        lines.append("\n")
    return "\n".join(lines)


# ---------------------------
# PDF loader helper (fixed newlines)
# ---------------------------
def load_text_from_pdf(pdf_path: str) -> str:
    """
    Extract text from a PDF using PyPDF2. Returns a single string with proper newlines.
    NOTE: scanned PDFs require OCR (tesseract) — PyPDF2 extracts embedded text only.
    """
    reader=""
    if not os.path.exists(pdf_path):
        raise FileNotFoundError(f"PDF not found: {pdf_path}")

    file_size = os.path.getsize(pdf_path)
    print("file size (bytes):", file_size)
    if file_size == 0:
        raise ValueError("Uploaded file is empty")

    with open(pdf_path, "rb") as f:
        header = f.read(5)
        print("header bytes:", header)
        if header != b"%PDF-":
            raise ValueError("File is not a valid PDF (missing %PDF- header)")
        f.seek(0)
        try:
            reader = PdfReader(f, strict=False)  # <-- important
            print(type(reader),":reader")
            print("pages:", len(reader.pages))
            parts: List[str] = []
            for page in reader.pages:
                try:
                    t = page.extract_text() or ""
                    if t.strip():
                        parts.append(t.strip())
                except Exception as e:
                    logger.warning("Failed to extract text from a page: %s", e)
            return "\n\n".join(parts)
        except errors.PdfReadError as e:
            print("PdfReadError while reading PDF:", e)

    

# ---------------------------
# Example usage (debugging)
# ---------------------------
if __name__ == "__main__":
    pdf_path = "D:/Internship2/StudyBuddy/EduRuby_flask/AI.pdf"
    if os.path.exists(pdf_path):
        ctx = load_text_from_pdf(pdf_path)
    else:
        ctx = (
            "Photosynthesis Basics:\n"
            "Cover: what is photosynthesis, the chemical equation, light and dark reactions, significance to life, "
            "basic experiments and examples."
        )

    presentation = generate_presentation_json(
        context=ctx,
        audience="students",
        tone="conversational",
        model_name="gpt-4o-mini",
        temperature=0.15,
        max_tokens=1200,
    )

    script_dir = Path(__file__).parent
    save_presentation_json(presentation, script_dir / "presentation.json")

# Save narration text
    narration = build_narration_script_from_presentation(presentation)
    with open(script_dir / "narration.txt", "w", encoding="utf-8") as f:
        f.write(narration)

    print(f"Saved presentation.json and narration.txt in {script_dir}")
