import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class S3FD(nn.Module):
    """Simplified S3FD model (used for face detection in Wav2Lip)"""
    def __init__(self):
        super(S3FD, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        return x

    def detect_faces(self, image, conf_th=0.9, scales=[1.0]):
        """Fake bounding box return (center of frame)."""
        h, w, _ = image.shape
        return [[w * 0.25, h * 0.25, w * 0.75, h * 0.75]]

