from typing import Any

from langchain_core.callbacks import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.retrievers import BaseRetriever, RetrieverLike
from pydantic import ConfigDict


class ContextualCompressionRetriever(BaseRetriever):
    """Retriever that wraps a base retriever and compresses the results."""

    base_compressor: BaseDocumentCompressor
    """Compressor for compressing retrieved documents."""

    base_retriever: RetrieverLike
    """Base Retriever to use for getting relevant documents."""

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
    )

    def _get_relevant_documents(
        self,
        query: str,
        *,
        run_manager: CallbackManagerForRetrieverRun,
        **kwargs: Any,
    ) -> list[Document]:
        """Get documents relevant for a query.

        Args:
            query: string to find relevant documents for

        Returns:
            Sequence of relevant documents
        """
        docs = self.base_retriever.invoke(
            query,
            config={"callbacks": run_manager.get_child()},
            **kwargs,
        )
        if docs:
            compressed_docs = self.base_compressor.compress_documents(
                docs,
                query,
                callbacks=run_manager.get_child(),
            )
            return list(compressed_docs)
        return []

    async def _aget_relevant_documents(
        self,
        query: str,
        *,
        run_manager: AsyncCallbackManagerForRetrieverRun,
        **kwargs: Any,
    ) -> list[Document]:
        """Get documents relevant for a query.

        Args:
            query: string to find relevant documents for

        Returns:
            List of relevant documents
        """
        docs = await self.base_retriever.ainvoke(
            query,
            config={"callbacks": run_manager.get_child()},
            **kwargs,
        )
        if docs:
            compressed_docs = await self.base_compressor.acompress_documents(
                docs,
                query,
                callbacks=run_manager.get_child(),
            )
            return list(compressed_docs)
        return []
