# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Sequence, TypedDict, Union

from typing_extensions import TypeAlias, overload

from ..image_utils import is_pil_image
from ..utils import is_vision_available, requires_backends
from .base import Pipeline


if is_vision_available():
    from PIL import Image

    from ..image_utils import load_image


ImagePair: TypeAlias = Sequence[Union["Image.Image", str]]

Keypoint = TypedDict("Keypoint", {"x": float, "y": float})
Match = TypedDict("Match", {"keypoint_image_0": Keypoint, "keypoint_image_1": Keypoint, "score": float})


def validate_image_pairs(images: Any) -> Sequence[Sequence[ImagePair]]:
    error_message = (
        "Input images must be a one of the following :",
        " - A pair of images.",
        " - A list of pairs of images.",
    )

    def _is_valid_image(image):
        """images is a PIL Image or a string."""
        return is_pil_image(image) or isinstance(image, str)

    if isinstance(images, Sequence):
        if len(images) == 2 and all((_is_valid_image(image)) for image in images):
            return [images]
        if all(
            isinstance(image_pair, Sequence)
            and len(image_pair) == 2
            and all(_is_valid_image(image) for image in image_pair)
            for image_pair in images
        ):
            return images
    raise ValueError(error_message)


class KeypointMatchingPipeline(Pipeline):
    """
    Keypoint matching pipeline using any `AutoModelForKeypointMatching`. This pipeline matches keypoints between two images.
    """

    _load_processor = False
    _load_image_processor = True
    _load_feature_extractor = False
    _load_tokenizer = False

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        requires_backends(self, "vision")
        if self.framework != "pt":
            raise ValueError("Keypoint matching pipeline only supports PyTorch (framework='pt').")

    def _sanitize_parameters(self, threshold=None, timeout=None):
        preprocess_params = {}
        if timeout is not None:
            preprocess_params["timeout"] = timeout
        postprocess_params = {}
        if threshold is not None:
            postprocess_params["threshold"] = threshold
        return preprocess_params, {}, postprocess_params

    @overload
    def __call__(self, inputs: ImagePair, threshold: float = 0.0, **kwargs: Any) -> list[Match]: ...

    @overload
    def __call__(self, inputs: list[ImagePair], threshold: float = 0.0, **kwargs: Any) -> list[list[Match]]: ...

    def __call__(
        self,
        inputs: Union[list[ImagePair], ImagePair],
        threshold: float = 0.0,
        **kwargs: Any,
    ) -> Union[list[Match], list[list[Match]]]:
        """
        Find matches between keypoints in two images.

        Args:
            inputs (`str`, `list[str]`, `PIL.Image` or `list[PIL.Image]`):
                The pipeline handles three types of images:

                - A string containing a http link pointing to an image
                - A string containing a local path to an image
                - An image loaded in PIL directly

                The pipeline accepts either a single pair of images or a batch of image pairs, which must then be passed as a string.
                Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
                images.

            threshold (`float`, *optional*, defaults to 0.0):
                The threshold to use for keypoint matching. Keypoints matched with a lower matching score will be filtered out.
                A value of 0 means that all matched keypoints will be returned.

            kwargs:
                `timeout (`float`, *optional*, defaults to None)`
                    The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
                    the call may block forever.

        Return:
            Union[list[Match], list[list[Match]]]:
                A list of matches or a list if a single image pair is provided, or of lists of matches if a batch
                of image pairs is provided. Each match is a dictionary containing the following keys:

                - **keypoint_image_0** (`Keypoint`): The keypoint in the first image (x, y coordinates).
                - **keypoint_image_1** (`Keypoint`): The keypoint in the second image (x, y coordinates).
                - **score** (`float`): The matching score between the two keypoints.
        """
        if inputs is None:
            raise ValueError("Cannot call the keypoint-matching pipeline without an inputs argument!")
        formatted_inputs = validate_image_pairs(inputs)
        outputs = super().__call__(formatted_inputs, threshold=threshold, **kwargs)
        if len(formatted_inputs) == 1:
            return outputs[0]
        return outputs

    def preprocess(self, images, timeout=None):
        images = [load_image(image, timeout=timeout) for image in images]
        model_inputs = self.image_processor(images=images, return_tensors=self.framework)
        model_inputs = model_inputs.to(self.torch_dtype)
        target_sizes = [image.size for image in images]
        preprocess_outputs = {"model_inputs": model_inputs, "target_sizes": target_sizes}
        return preprocess_outputs

    def _forward(self, preprocess_outputs):
        model_inputs = preprocess_outputs["model_inputs"]
        model_outputs = self.model(**model_inputs)
        forward_outputs = {"model_outputs": model_outputs, "target_sizes": [preprocess_outputs["target_sizes"]]}
        return forward_outputs

    def postprocess(self, forward_outputs, threshold=0.0) -> list[Match]:
        model_outputs = forward_outputs["model_outputs"]
        target_sizes = forward_outputs["target_sizes"]
        postprocess_outputs = self.image_processor.post_process_keypoint_matching(
            model_outputs, target_sizes=target_sizes, threshold=threshold
        )
        postprocess_outputs = postprocess_outputs[0]
        pair_result = []
        for kp_0, kp_1, score in zip(
            postprocess_outputs["keypoints0"],
            postprocess_outputs["keypoints1"],
            postprocess_outputs["matching_scores"],
        ):
            kp_0 = Keypoint(x=kp_0[0].item(), y=kp_0[1].item())
            kp_1 = Keypoint(x=kp_1[0].item(), y=kp_1[1].item())
            pair_result.append(Match(keypoint_image_0=kp_0, keypoint_image_1=kp_1, score=score.item()))
        pair_result = sorted(pair_result, key=lambda x: x["score"], reverse=True)
        return pair_result
