import os
import cv2
import torch
import numpy as np
import librosa
from tqdm import tqdm
from pathlib import Path

# --- Core imports
from wav2lip_core.models import Wav2Lip
from wav2lip_core.audio import melspectrogram
from wav2lip_core.face_detection.detection.sfd.sfd_detector import SFDDetector



# -------------------------------------------------------------
#  DEVICE SETUP
# -------------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INIT] Using device: {device.upper()}")


# -------------------------------------------------------------
#  LOAD WAV2LIP MODEL
# -------------------------------------------------------------
def load_model(path):
    print(f"[INFO] Loading Wav2Lip model from {path} ...")
    model = Wav2Lip()
    checkpoint = torch.load(path, map_location=device, weights_only=False)
    state_dict = checkpoint["state_dict"]
    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict)
    model = model.to(device).eval()
    print("[INFO] Model loaded successfully!")
    return model


# -------------------------------------------------------------
#  AUDIO → MEL CHUNKS
# -------------------------------------------------------------
def prepare_mel_chunks(audio_path, fps):
    wav, sr = librosa.load(audio_path, sr=16000)
    mel = melspectrogram(wav)
    mel_step_size = 16
    mel_idx_multiplier = 80.0 / fps
    mel_chunks = []

    i = 0
    while True:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + mel_step_size > mel.shape[1]:
            break
        mel_chunks.append(mel[:, start_idx:start_idx + mel_step_size])
        i += 1

    return mel_chunks


# -------------------------------------------------------------
#  LIP-SYNC GENERATION FUNCTION
# -------------------------------------------------------------
def generate_lip_sync(face_path, audio_path, output_path):
    """Generate lip-synced video using a face video and a separate input audio."""
    try:
        model_path = "wav2lip_core/checkpoints/wav2lip.pth"
        model = load_model(model_path)

        print("[INFO] Loading SFD face detector...")
        face_detector = SFDDetector(device=device)

        # --- Step 1: Remove original audio from video (mute)
        muted_video = face_path.replace(".mp4", "_muted.mp4")
        os.system(f"ffmpeg -y -i {face_path} -an -vcodec copy {muted_video}")
        face_path = muted_video

        # --- Step 2: Load frames from muted video
        cap = cv2.VideoCapture(face_path)
        fps = cap.get(cv2.CAP_PROP_FPS) or 25
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame)
        cap.release()
        print(f"[INFO] Video loaded: {len(frames)} frames @ {fps} fps")

        # --- Step 3: Compute durations
        audio_duration = librosa.get_duration(filename=audio_path)
        video_duration = len(frames) / fps

        # --- Step 4: Loop video to match audio length
        if audio_duration > video_duration:
            loop_count = int(np.ceil(audio_duration / video_duration))
            frames = (frames * loop_count)[: int(audio_duration * fps)]
            print(f"[INFO] Looping video {loop_count}x to match {audio_duration:.2f}s audio duration.")

        # --- Step 5: Prepare mel chunks
        mel_chunks = prepare_mel_chunks(audio_path, fps)
        print(f"[INFO] Mel chunks: {len(mel_chunks)}")

        min_len = min(len(frames), len(mel_chunks))
        frames = frames[:min_len]
        mel_chunks = mel_chunks[:min_len]

        # --- Step 6: Detect faces frame-by-frame
        print("[INFO] Detecting faces...")
        batch_faces = []
        for frame in tqdm(frames, desc="Detecting"):
            bboxes = face_detector.detect_from_image(frame)
            if len(bboxes) == 0:
                batch_faces.append(None)
                continue
            x1, y1, x2, y2 = map(int, bboxes[0][:4])
            face = frame[y1:y2, x1:x2]
            batch_faces.append((face, (x1, y1, x2, y2)))

        # --- Step 7: Lip-sync inference
        print("[INFO] Performing lip-sync inference...")
        out_frames = []
        for i, face_data in enumerate(tqdm(batch_faces, desc="Lip-sync")):
            if face_data is None:
                out_frames.append(frames[i])
                continue
            face, (x1, y1, x2, y2) = face_data
            face = cv2.resize(face, (96, 96)) / 255.0
            masked_face = face.copy()
            masked_face[48:, :] = 0

            face_input = np.concatenate([masked_face, face], axis=2)
            face_tensor = torch.FloatTensor(face_input).permute(2, 0, 1).unsqueeze(0).to(device)
            mel_tensor = torch.FloatTensor(mel_chunks[i]).unsqueeze(0).unsqueeze(0).to(device)

            with torch.no_grad():
                pred = model(mel_tensor, face_tensor)

            pred = (pred.squeeze().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
            pred_resized = cv2.resize(pred, (x2 - x1, y2 - y1))
            frame = frames[i].copy()
            frame[y1:y2, x1:x2] = pred_resized
            out_frames.append(frame)

        # --- Step 8: Save temporary silent video
        print("[INFO] Saving generated video (without audio)...")
        Path(os.path.dirname(output_path)).mkdir(parents=True, exist_ok=True)
        temp_video_path = output_path.replace(".mp4", "_no_audio.mp4")

        h, w = out_frames[0].shape[:2]
        out = cv2.VideoWriter(temp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
        for f in out_frames:
            out.write(f)
        out.release()

        # --- Step 9: Add input audio and finalize
        print("[INFO] Adding uploaded audio to the video...")
        os.system(f"ffmpeg -y -i {temp_video_path} -i {audio_path} -c:v copy -c:a aac -shortest {output_path}")

        print(f"[DONE] Final lip-sync video generated: {output_path}")

    except Exception as e:
        print(f"[ERROR] Lip-sync generation failed: {e}")
        raise e