import jwt
import os
from datetime import datetime, timedelta
from flask import request
from functools import wraps

# Load your secret key from env or set a default for dev
SECRET_KEY = os.getenv("SECRET_KEY", "secret123")

def generate_token(payload, expires_in_minutes=60*24*30):
    """
    Generate a JWT token for a given payload.
    Default expiry = 30 days.
    """
    payload_copy = payload.copy()
    payload_copy["exp"] = datetime.utcnow() + timedelta(minutes=expires_in_minutes)

    token = jwt.encode(payload_copy, SECRET_KEY, algorithm="HS256")
    return token

def decode_token(token):
    """
    Decode and verify a JWT token.
    Returns decoded payload if valid, None if invalid or expired.
    """
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
        return payload
    except jwt.ExpiredSignatureError:
        return None
    except jwt.InvalidTokenError:
        return None


def jwt_required(f):
    @wraps(f)
    def decorated(*args, **kwargs):
        auth = request.headers.get("Authorization", None)
        if not auth:
            return {"status": False, "message": "Token missing"}, 401

        try:
            token = auth.split(" ")[1]
            payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
            request.user_id = payload["id"]
        except jwt.ExpiredSignatureError:
            return {"status": False, "message": "Token expired"}, 401
        except jwt.InvalidTokenError:
            return {"status": False, "message": "Invalid token"}, 401

        return f(*args, **kwargs)
    return decorated