import os
import cv2
import torch
import numpy as np
from .s3fd import S3FD

# Initialize the model only once globally
device = 'cuda' if torch.cuda.is_available() else 'cpu'
detector_model = S3FD()

# Load pretrained weights for S3FD face detector
weights_path = os.path.join(os.path.dirname(__file__), 's3fd.pth')
if not os.path.exists(weights_path):
    raise FileNotFoundError(f"Missing face detector weights: {weights_path}")

state_dict = torch.load(weights_path, map_location=device)
detector_model.load_state_dict(state_dict)
detector_model = detector_model.to(device).eval()


def detect_faces(image):
    """Detect faces in an image and return bounding boxes"""
    with torch.no_grad():
        bboxes = detector_model.detect_faces(image, conf_th=0.9, scales=[1.0])
    return bboxes


class FaceDetector:
    def __init__(self):
        self.detector = detector_model

    def detect(self, image, conf_th=0.9):
        with torch.no_grad():
            return self.detector.detect_faces(image, conf_th=conf_th, scales=[1.0])
