import cv2
import numpy as np
from pathlib import Path
from PIL import Image
from typing import List, Optional, Tuple, Union
from ultralytics import YOLO
from PIL import Image
import uuid

class Prediction:
    def __init__(
        self,
        model_path: str,
        output_dir: str = "detected_crops",
        conf_thresh: float = 0.9,
        allowed_labels: Optional[List[str]] = None,
    ):
        """
        model_path: path to your YOLO .pt
        output_dir: where crops + blacked image are saved
        conf_thresh: minimum confidence threshold for saving crops
        allowed_labels: list of class names to accept (e.g. ["Figure","Table"]); None => accept all
        """
        self.model = YOLO(model_path)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.conf_thresh = float(conf_thresh)
        self.allowed_labels = allowed_labels

    def _pil_to_bgr(self, pil_img: Image.Image) -> np.ndarray:
        """Convert PIL image to OpenCV BGR numpy array."""
        rgb = np.array(pil_img.convert("RGB"))
        bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
        return bgr

    def _read_image_bgr(self, img_input: Union[str, Path, Image.Image]) -> Tuple[np.ndarray, str]:
        """
        Reads input which can be:
         - a PIL Image object
         - a path (str or Path) to an image file
        Returns: (img_bgr, base_name) where base_name is used for saved filenames.
        """
        if isinstance(img_input, Image.Image):
            img_bgr = self._pil_to_bgr(img_input)
            base_name = f"page_{uuid.uuid4().hex[:8]}"  # random if no explicit name
            return img_bgr, base_name

        # treat as path
        path = Path(img_input)
        if not path.exists():
            raise FileNotFoundError(f"Image path not found: {path}")
        # read with PIL to support many formats (and avoid OpenCV ppm issues)
        pil = Image.open(path)
        img_bgr = self._pil_to_bgr(pil)
        return img_bgr, path.stem

    def _predict(self, image_rgb: np.ndarray):
        """
        Run YOLO on an RGB numpy array and return first result (ultralytics).
        We pass the numpy array directly to keep coordinate system consistent.
        """
        results = self.model.predict(source=image_rgb, verbose=False)
        return results[0] if len(results) else None

    def generate(self, img_input: Union[str, Path, Image.Image], img_name: Optional[str] = None) -> Tuple[str, List[str]]:
        """
        img_input: PIL.Image or path to image
        img_name: optional base name for saved files (overrides path stem / random)
        Returns: (blacked_image_path, [crop_path1, crop_path2, ...])
        """
        # 1. read image into BGR (OpenCV) and decide base file stem
        img_bgr, auto_stem = self._read_image_bgr(img_input)
        base_stem = img_name if img_name else auto_stem
        h, w = img_bgr.shape[:2]

        # 2. prepare blacked image and RGB copy for model
        #blacked = img_bgr.copy()
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

        # 3. model inference
        result = self._predict(img_rgb)

        # If no detections, save the unchanged (but still 'blacked') image and return
        if result is None or not hasattr(result, "boxes") or len(result.boxes) == 0:
            pil_img= Image.fromarray(img_rgb)
            return pil_img, []

        crop_paths: List[str] = []
        crop_idx = 0

        # 4. iterate through boxes
        for box in result.boxes:
            # XYXY coords
            try:
                coords = box.xyxy[0].tolist()  # [x1,y1,x2,y2]
            except Exception:
                # fallback if different API
                coords = [float(x) for x in box.xyxy[0]]

            x1, y1, x2, y2 = [int(round(c)) for c in coords]

            # clamp to image bounds
            x1 = max(0, min(w - 1, x1))
            x2 = max(0, min(w - 1, x2))
            y1 = max(0, min(h - 1, y1))
            y2 = max(0, min(h - 1, y2))
            if x2 <= x1 or y2 <= y1:
                continue

            # confidence & label
            conf = 0.0
            label = "obj"
            try:
                conf = float(box.conf[0].item())
            except Exception:
                try:
                    conf = float(box.conf[0])
                except Exception:
                    conf = 0.0

            try:
                class_id = int(box.cls[0].item())
                label = result.names[class_id]
            except Exception:
                # fallback if .cls not available
                label = "obj"

            # threshold & label filtering
            if conf < self.conf_thresh:
                continue
            if self.allowed_labels is not None and label not in self.allowed_labels:
                continue

            crop_padding = 10  # pixels to shrink from each side

            # clamp to image bounds with padding
            x1 = max(0, min(w - 1, x1 + crop_padding))
            x2 = max(0, min(w - 1, x2 - crop_padding))
            y1 = max(0, min(h - 1, y1 + 2*crop_padding))
            y2 = max(0, min(h - 1, y2 - 2*crop_padding))
            if x2 <= x1 or y2 <= y1:
                continue

            # crop and save (from original BGR)
            crop = img_bgr[y1:y2, x1:x2].copy()
            crop_idx += 1
            crop_name = f"{base_stem}_crop_{crop_idx}_{label}_{int(conf*100)}.png"
            crop_path = self.output_dir / crop_name
            cv2.imwrite(str(crop_path), crop)
            crop_paths.append(str(crop_path))

            # black out region in the blacked image
            #blacked[y1:y2, x1:x2] = (0, 0, 0)

        # 5. save the blacked original (first returned item)
        #blacked_name = f"{base_stem}_blacked.png"
        #blacked_path = self.output_dir / blacked_name
        #cv2.imwrite(str(blacked_path), blacked)
        #cv_img_rgb= cv2.cvtColor(blacked, cv2.COLOR_BGR2RGB)
        pil_img= Image.fromarray(img_rgb)
        return pil_img, crop_paths


# ------------------------------
# Example usage:
# ------------------------------
if __name__ == "__main__":
    # model .pt
    model_file = "layout_detector_model.pt"

    # Example 1: using a PIL image (from pdf2image)
    from pdf2image import convert_from_path

    # Convert PDF to images
    pages = convert_from_path(
        "book.pdf",
        dpi=200,
        poppler_path=r"C:/poppler-23.05.0/Library/bin"
    )

    # Initialize Prediction
    p = Prediction(
        model_file,
        output_dir="crops_out",
        conf_thresh=0.9,
        allowed_labels=["Figure", "Table"]
    )

    # Process first page
    page_img = pages[0]  # PIL Image
    blacked_img, crops = p.generate(page_img, img_name="page_001")

    # Save the blacked image if you want
    blacked_img.save("crops_out/page_001_blacked.png")

    print("Blacked image saved as:", "crops_out/page_001_blacked.png")
    print("Crops:", crops)

    # Example 2: using an image path
    # blacked_img2, crops2 = p.generate("page_002.png")
    # blacked_img2.save("crops_out/page_002_blacked.png")
    # print("Blacked image saved as:", "crops_out/page_002_blacked.png")
    # print("Crops:", crops2)
