from flask import request, Blueprint, jsonify
from flask_socketio import emit
import base64
import os
import io
import time
import threading
from pydub import AudioSegment  # type: ignore
from app.core import (
    socketio, client,
    audio_buffers, sectionIds, characterIds, last_ping_time, metadata,
    buttonsStates, character_actions, distance_objects, action_modes,
    get_available_stt, release_stt, STT_INSTANCES,
    get_available_tts, release_tts, TTS_INSTANCES, CHUNK_SIZE
)
from Ai_Agents.services.process_requests.process_request_narriative_design import process_request_1
from Ai_Agents.services.process_requests.process_request_knowledge_bank     import process_request
from Ai_Agents.services.llm_functions.get_action_response                  import get_action_response
# …import whisper, io, os, base64, time, AudioSegment, etc.
from app.database.fetch_data import get_character_by_id, update_token_usage, get_mongo_client
from TTS.api import TTS  # type: ignore
thread_lock = threading.Lock()
background_thread = None

# Periodically check for inactive clients
def check_inactive_users():
    while True:
        current_time = time.time()
        for sid in list(last_ping_time.keys()):  # Iterate over a copy of keys
            if current_time - last_ping_time[sid] > 30:  # 30-second timeout
                print(f"⏳ Disconnecting inactive client: {sid}")
                socketio.emit('force_disconnect', {}, room=sid)  # Notify client
                onUserDisconnect(sid)  # Disconnect from server
        socketio.sleep(10)  # Run this check every 10 seconds


@socketio.on("connect")
def handle_connect():
    global background_thread
    with thread_lock:
        if background_thread is None:
            # At this point, socketio.server is definitely initialized
            background_thread = socketio.start_background_task(check_inactive_users)
    emit("message", {"data": "Connected to server!"})

@socketio.on('disconnect')
def handle_disconnect():
    sid = request.sid  # Get user's session ID
    print(f"❌ Client disconnected: {sid}")
    # Remove user from all dictionaries safely
    onUserDisconnect(sid)
    print(f"🗑 Cleaned up data for {sid}")

@socketio.on("toggle_action_mode")
def handle_toggle_action_mode(data):
    sid = request.sid
    enabled = data.get("enabled", False)
    action_modes[sid] = enabled
    emit("action_mode_toggled", {"enabled": enabled}, room=sid)

@socketio.on('ping_from_client')
def handle_ping():
    sid = request.sid
    print(f"🔄 Received ping from {sid}")
    last_ping_time[sid] = time.time()  # Update last ping time
    socketio.emit('pong_from_server', {}, room=sid)  # Send pong back

def onUserDisconnect(sid):
    audio_buffers.pop(sid, None)
    characterIds.pop(sid, None)
    sectionIds.pop(sid, None)
    last_ping_time.pop(sid, None)
    buttonsStates.pop(sid, None)
    character_actions.pop(sid, None)
    distance_objects.pop(sid, None)
    action_modes.pop(sid, None)
    metadata.pop(sid, None)


@socketio.on("update_Character_id")
def updateCharacterid(data):
    characterid = data.get("character_id", "")
    characterIds[request.sid] = characterid
    sectionIds[request.sid] = '$'
    metadata[request.sid] = {"rag_chain":None, "memory":{}}
    data = get_character_by_id(characterid)
    buttonsStates[request.sid] = data.get("toggle_button", False)

    # Update the client's selected_character_id in MongoDB
    db = client['character_test']
    clients_collection = db['clients']
    # You may need to get client_id from session or headers; here we assume it's sent in data
    client_id = data.get('client_id')
    if client_id:
        clients_collection.update_one(
            {"client_id": client_id},
            {"$set": {"selected_character_id": characterid}}
        )
    # Emit success message back to the client
    socketio.emit("update_success", {"message": "Character ID updated successfully"}, room=request.sid)


# 🎤 **Speech-to-Text (STT) Streaming with Byte Arrays**
#region streaming audio input
@socketio.on("start_audio_stream")
def start_audio_stream():
    """Initialize a new buffer for a client when they start sending audio."""
    audio_buffers[request.sid] = io.BytesIO()

@socketio.on("audio_chunk")
def handle_audio_chunk(chunk):
    """Receive audio byte chunks and write to buffer."""
    if request.sid not in audio_buffers:
        audio_buffers[request.sid] = io.BytesIO()
    
    # Make sure the incoming chunk is bytes, not base64
    if isinstance(chunk, str):
        chunk = base64.b64decode(chunk)

    audio_buffers[request.sid].write(chunk)  # Append raw bytes

@socketio.on("end_audio_stream")
def end_audio_stream():
    """Process the collected audio when streaming ends."""
    if request.sid not in audio_buffers:
        emit("transcription", {"error": "No audio received."})
        return
    
    audio_data = audio_buffers.pop(request.sid)
    audio_data.seek(0)
    temp_wav_path = f"temp_audio_{request.sid}.wav"
    
    stt_index = get_available_stt()
    stt_model = STT_INSTANCES[stt_index]
    
    try:
        with open(temp_wav_path, "wb") as f:
            f.write(audio_data.getvalue())
        
        result = stt_model.transcribe(temp_wav_path)
        transcription_text = result["text"]
        emit("user_audio_text", {"text": transcription_text})
        handleChat(transcription_text)
    except Exception as e:
        emit("error", {"message": str(e)})
    finally:
        if os.path.exists(temp_wav_path):
            os.remove(temp_wav_path)
        release_stt(stt_index)

#endregion

def handleChat(transcription_text):
            # ✅ **Auto-send to chat function**
        chat_data = {          
            "message": transcription_text,
            "section_id": sectionIds[request.sid],  # Replace with actual section ID
            "character_id": characterIds[request.sid]  # Replace with actual character ID
        }

        # if action mode is enabled the add the actions list and objects list to the chat data
        if request.sid in action_modes and action_modes[request.sid]:
            # chat_data["Actions"] = character_actions.get(request.sid, [])
            # chat_data["Distances"] = distance_objects.get(request.sid, [])
            print("action mode is enables and actions and objects are added to the chat data")
            
            actions = character_actions.get(request.sid, [])
            distances = distance_objects.get(request.sid, [])

            chat_response = get_action_response(transcription_text, actions, distances)
            print("action response : ", chat_response)
            chat_response["section_id"] = sectionIds[request.sid]
        else:    
            # 🔹 **Process chat request based on button state**
            if(buttonsStates[request.sid]):
                print("Narriative design is enabled")
                chat_response = process_request_1(client,chat_data, request.sid)
            else:
                print("Knowledge bank is enabled")
                chat_response = process_request(client,chat_data, request.sid)

        # chat_response = process_request_1(chat_data) if bool(bttonsStates[request.sid]) else process_request(chat_data)

        print(f"💬 Chat Response: {chat_response}")
        print(f"🔵 New section ID: {chat_response['section_id']}")
        sectionIds[request.sid]=chat_response["section_id"]

        # ✅ **Emit chat response to the client**
        emit("chat_response", {"text": chat_response["response"]})
        update_token_usage(client,character_id=characterIds[request.sid],
                           token_usage=chat_response["token_usage"])
        # ✅ **Convert chat response to audio & stream to client**
        synthesize_audio({"text": chat_response["response"], "gender": "female"})

@socketio.on("synthesize_audio")
def synthesize_audio(data):
    """Convert text to speech and stream audio in real-time."""
    text = data.get("text", "")
    if not text:
        emit("error", {"message": "Missing text"})
        return
    
    tts_index = get_available_tts()
    tts_model = TTS_INSTANCES[tts_index]
    
    try:
        audio_buffer = io.BytesIO()
        tts_model.tts_to_file(text=text, file_path=audio_buffer)
        audio_buffer.seek(0)
        
        audio = AudioSegment.from_file(audio_buffer, format="wav")
        audio = audio.set_frame_rate(16000).set_channels(1).set_sample_width(2)
        
        audio_bytes_io = io.BytesIO()
        audio.export(audio_bytes_io, format="wav", codec="pcm_s16le")
        audio_bytes = audio_bytes_io.getvalue()
        
        total_chunks = (len(audio_bytes) + CHUNK_SIZE - 1) // CHUNK_SIZE
        for i in range(total_chunks):
            chunk = audio_bytes[i * CHUNK_SIZE:(i + 1) * CHUNK_SIZE]
            encoded_chunk = base64.b64encode(chunk).decode("utf-8")
            emit("audio_chunk", {"chunk": encoded_chunk, "index": i, "total_chunks": total_chunks})
        
        emit("audio_complete", {"message": "Audio synthesis complete"})
    except Exception as e:
        emit("error", {"message": str(e)})
    finally:
        release_tts(tts_index)


@socketio.on("text_chat")
def text_chat(data):
    handleChat(data.get("userInput"))

# REST API Blueprint for chat
chat_rest_bp = Blueprint("chat_rest", __name__, url_prefix="/api")

@chat_rest_bp.route("/chat", methods=["POST"])
def chat_with_character():
    data = request.get_json()
    prompt = data.get("prompt")
    character_id = data.get("character_id")
    if not prompt or not character_id:
        return jsonify({"error": "Missing prompt or character_id"}), 400

    # Prepare chat_data for processing
    chat_data = {
        "message": prompt,
        "character_id": character_id,
        "section_id": "$"  # Default section_id; adjust as needed
    }
    # Use the knowledge bank process_request by default
    response = process_request(None, chat_data, None)  # Adjust args if needed

    return jsonify({"response": response.get("response", "No response")})