from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from langchain.memory import ConversationSummaryBufferMemory
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_openai import OpenAIEmbeddings

from app.database.fetch_data import get_character_by_id
from Ai_Agents.services.llm_functions.count_tokens import count_tokens
from app.core import metadata  # dict keyed by session id
import os

def create_retriever(collection, k=10, search_type="similarity"):
    vector_store = MongoDBAtlasVectorSearch(
        collection=collection,
        embedding=OpenAIEmbeddings(),
        index_name="vector_search_index",
    )
    return vector_store.as_retriever(search_type=search_type, search_kwargs={'k': k})

def create_rag_chain(sid, retriever):
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.2)
    memory = ConversationSummaryBufferMemory(
        llm=llm, memory_key="chat_history", max_token_limit=2048
    )
    metadata[sid]["memory"] = memory

    # Build a history‑aware retriever that takes one input variable: "input"
    contextualize = ChatPromptTemplate.from_messages([
        MessagesPlaceholder("chat_history"),
        ("user", "{input}"),
    ])

    history_aware_retriever = create_history_aware_retriever(
        llm, retriever, contextualize
    )

    answer_prompt = ChatPromptTemplate.from_messages([
        ("system", "Use the context below:\n{context}"),
        MessagesPlaceholder("chat_history"),
        ("user", "{input}"),
    ])

    combine_chain = create_stuff_documents_chain(llm=llm, prompt=answer_prompt)
    return create_retrieval_chain(
        history_aware_retriever, combine_chain
    )

from langchain_core.messages import HumanMessage, AIMessage

def process_request(client, data, sid):
    character_id = data["character_id"]
    question = data["message"]

    character = get_character_by_id(character_id)
    name = character.get("character_name", "Unknown")
    gender = character.get("gender", "Unknown")
    backstory = character.get("backstory", "")

    metadata.setdefault(sid, {"memory": None, "rag_chain": None})
    if metadata[sid]["rag_chain"] is None:
        collection = client["Knowledge_Bank"][character_id]
        retriever = create_retriever(collection)
        metadata[sid]["rag_chain"] = create_rag_chain(sid, retriever)

    rag_chain = metadata[sid]["rag_chain"]
    memory = metadata[sid]["memory"]
    
    # Get chat history from memory and convert to proper message format
    chat_history = []
    if memory and memory.chat_memory:
        for msg in memory.chat_memory.messages:
            if msg.type == "human":
                chat_history.append(HumanMessage(content=msg.content))
            elif msg.type == "ai":
                chat_history.append(AIMessage(content=msg.content))
    
    llm_input = {
        "input": f"You are {name}, a {gender}. Your backstory- {backstory}.\nUser: {question}",
        "chat_history": chat_history  # Now properly formatted as list of messages
    }
    
    response = rag_chain.invoke(llm_input)

    token_usage = count_tokens(question + response["answer"])
    return {
        "response": response["answer"].replace("\n", "<br>"),
        "section_id": data.get("section_id"),
        "token_usage": token_usage
    }