"""
rate_limit.py

Purpose:
  Implements robust, database-backed login rate limiting and lockout to
  mitigate credential stuffing and brute-force attacks. Tracks failed login
  attempts per (username, ip) pair with a sliding window and enforces a
  temporary lock after too many failures.

How it works:
  - On each login POST, `is_login_allowed(username, ip)` is called.
    - If an active lock exists (locked_until > now), it returns (False, retry_s).
  - On password failure, `record_login_failure(username, ip)`:
    - Resets the window if last_failed_at is outside the window.
    - Increments fail_count and, when threshold is reached, sets locked_until.
  - On success, `record_login_success(username, ip)` resets counters.

Defense-in-depth:
  - Also enforces an IP-level limiter (`is_ip_allowed(ip)`) to avoid trivial
    bypass by switching usernames. Both username+ip and ip-only checks must pass.

Configuration (env vars):
  - LOGIN_MAX_FAILS: max failures in the window before lock (default: 5)
  - LOGIN_WINDOW_SECONDS: sliding window duration in seconds (default: 900)
  - LOGIN_LOCK_SECONDS: lock duration after threshold exceeded (default: 900)
"""

from __future__ import annotations

import os
import time
from typing import Tuple

from .db import get_db


def _now_ts() -> int:
    return int(time.time())


def _is_localhost_ip(ip: str) -> bool:
    """Check if the IP address is localhost and should be exempt from rate limiting."""
    if not ip:
        return False
    
    # Handle common localhost addresses
    localhost_ips = {
        '127.0.0.1',    # IPv4 localhost
        '::1',          # IPv6 localhost
        '0.0.0.0',      # Sometimes used in tests
        'localhost'     # In case hostname is passed
    }
    
    if ip in localhost_ips:
        return True
    
    # Check for any IP in 127.x.x.x range
    if ip.startswith('127.'):
        return True
    
    return False


def _cfg_int(name: str, default: int) -> int:
    try:
        return int(os.getenv(name, str(default)))
    except Exception:
        return default


LOGIN_MAX_FAILS = _cfg_int('LOGIN_MAX_FAILS', 10)
LOGIN_WINDOW_SECONDS = _cfg_int('LOGIN_WINDOW_SECONDS', 300)
LOGIN_LOCK_SECONDS = _cfg_int('LOGIN_LOCK_SECONDS', 300)
LOCK_IP_ON_ACCOUNT_LOCK = os.getenv('LOGIN_LOCK_IP_ON_ACCOUNT_LOCK', '1') == '1'


def _ensure_schema():
    db = get_db()
    db.execute(
        """
        CREATE TABLE IF NOT EXISTS login_attempts (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            username TEXT NOT NULL,
            ip TEXT NOT NULL,
            fail_count INTEGER NOT NULL DEFAULT 0,
            last_failed_at INTEGER NOT NULL DEFAULT 0,
            locked_until INTEGER NOT NULL DEFAULT 0,
            UNIQUE(username, ip)
        )
        """
    )
    db.commit()

    # IP-only attempts table
    db.execute(
        """
        CREATE TABLE IF NOT EXISTS login_ip_attempts (
            ip TEXT PRIMARY KEY,
            fail_count INTEGER NOT NULL DEFAULT 0,
            last_failed_at INTEGER NOT NULL DEFAULT 0,
            locked_until INTEGER NOT NULL DEFAULT 0
        )
        """
    )
    db.commit()


def _get_row(username: str, ip: str):
    db = get_db()
    cur = db.execute(
        'SELECT fail_count, last_failed_at, locked_until FROM login_attempts WHERE username = ? AND ip = ?',
        (username, ip),
    )
    return cur.fetchone()


def _get_ip_row(ip: str):
    db = get_db()
    cur = db.execute(
        'SELECT fail_count, last_failed_at, locked_until FROM login_ip_attempts WHERE ip = ?',
        (ip,),
    )
    return cur.fetchone()


def _upsert_row(username: str, ip: str, fail_count: int, last_failed_at: int, locked_until: int):
    db = get_db()
    db.execute(
        """
        INSERT INTO login_attempts (username, ip, fail_count, last_failed_at, locked_until)
        VALUES (?, ?, ?, ?, ?)
        ON CONFLICT(username, ip) DO UPDATE SET
            fail_count=excluded.fail_count,
            last_failed_at=excluded.last_failed_at,
            locked_until=excluded.locked_until
        """,
        (username, ip, fail_count, last_failed_at, locked_until),
    )
    db.commit()


def _upsert_ip_row(ip: str, fail_count: int, last_failed_at: int, locked_until: int):
    db = get_db()
    db.execute(
        """
        INSERT INTO login_ip_attempts (ip, fail_count, last_failed_at, locked_until)
        VALUES (?, ?, ?, ?)
        ON CONFLICT(ip) DO UPDATE SET
            fail_count=excluded.fail_count,
            last_failed_at=excluded.last_failed_at,
            locked_until=excluded.locked_until
        """,
        (ip, fail_count, last_failed_at, locked_until),
    )
    db.commit()


def is_login_allowed(username: str, ip: str) -> Tuple[bool, int]:
    """Return (allowed, retry_after_seconds)."""
    # Exempt localhost IPs from rate limiting (for testing and local development)
    if _is_localhost_ip(ip):
        return True, 0
    
    _ensure_schema()
    row = _get_row(username, ip)
    now = _now_ts()
    if row:
        locked_until = int(row['locked_until'] or 0)
        if locked_until > now:
            return False, locked_until - now
        # Extra guard: if window has not elapsed and failures at/over threshold, block
        if int(row['fail_count'] or 0) >= LOGIN_MAX_FAILS and int(row['last_failed_at'] or 0) + LOGIN_WINDOW_SECONDS > now:
            # Normalize: set a lock if missing
            lock_until = now + LOGIN_LOCK_SECONDS
            _upsert_row(username, ip, int(row['fail_count'] or 0), int(row['last_failed_at'] or 0), lock_until)
            # Optionally escalate to IP-level lock
            if LOCK_IP_ON_ACCOUNT_LOCK:
                ip_row = _get_ip_row(ip)
                existing = int((ip_row['locked_until'] if ip_row else 0) or 0)
                if existing < lock_until:
                    _upsert_ip_row(ip, int((ip_row['fail_count'] if ip_row else 0) or 0), now, lock_until)
            return False, LOGIN_LOCK_SECONDS
    return True, 0


def is_ip_allowed(ip: str) -> Tuple[bool, int]:
    # Exempt localhost IPs from rate limiting (for testing and local development)
    if _is_localhost_ip(ip):
        return True, 0
    
    _ensure_schema()
    row = _get_ip_row(ip)
    now = _now_ts()
    if row:
        locked_until = int(row['locked_until'] or 0)
        if locked_until > now:
            return False, locked_until - now
        if int(row['fail_count'] or 0) >= LOGIN_MAX_FAILS and int(row['last_failed_at'] or 0) + LOGIN_WINDOW_SECONDS > now:
            _upsert_ip_row(ip, int(row['fail_count'] or 0), int(row['last_failed_at'] or 0), now + LOGIN_LOCK_SECONDS)
            return False, LOGIN_LOCK_SECONDS
    return True, 0


def record_ip_failure(ip: str) -> None:
    """Increment only the IP-level counters. Used for attempts blocked before
    credential verification to escalate to IP lock across usernames."""
    # Don't record failures for localhost IPs
    if _is_localhost_ip(ip):
        return
    
    _ensure_schema()
    now = _now_ts()
    row = _get_ip_row(ip)
    if not row:
        _upsert_ip_row(ip, 1, now, 0)
        return
    fail = int(row['fail_count'] or 0)
    last = int(row['last_failed_at'] or 0)
    lock = int(row['locked_until'] or 0)
    if last + LOGIN_WINDOW_SECONDS < now:
        fail = 0
    fail += 1
    last = now
    if fail >= LOGIN_MAX_FAILS:
        lock = now + LOGIN_LOCK_SECONDS
    _upsert_ip_row(ip, fail, last, lock)


def record_login_failure(username: str, ip: str) -> None:
    # Don't record failures for localhost IPs
    if _is_localhost_ip(ip):
        return
    
    _ensure_schema()
    now = _now_ts()
    row = _get_row(username, ip)
    if not row:
        _upsert_row(username, ip, 1, now, 0)
    else:
        fail_count = int(row['fail_count'] or 0)
        last_failed_at = int(row['last_failed_at'] or 0)
        locked_until = int(row['locked_until'] or 0)

        # Reset window if outside
        if last_failed_at + LOGIN_WINDOW_SECONDS < now:
            fail_count = 0

        fail_count += 1
        last_failed_at = now

        if fail_count >= LOGIN_MAX_FAILS:
            locked_until = now + LOGIN_LOCK_SECONDS

        _upsert_row(username, ip, fail_count, last_failed_at, locked_until)

    # IP-only failure tracking
    ip_row = _get_ip_row(ip)
    if not ip_row:
        _upsert_ip_row(ip, 1, now, 0)
        return

    ip_fail = int(ip_row['fail_count'] or 0)
    ip_last = int(ip_row['last_failed_at'] or 0)
    ip_lock = int(ip_row['locked_until'] or 0)
    if ip_last + LOGIN_WINDOW_SECONDS < now:
        ip_fail = 0
    ip_fail += 1
    ip_last = now
    if ip_fail >= LOGIN_MAX_FAILS:
        ip_lock = now + LOGIN_LOCK_SECONDS
    _upsert_ip_row(ip, ip_fail, ip_last, ip_lock)


def record_login_success(username: str, ip: str) -> None:
    _ensure_schema()
    db = get_db()
    # Reset only the (username, ip) counters. Keep IP-level state to avoid
    # trivial bypass by switching usernames after abuse.
    db.execute('DELETE FROM login_attempts WHERE username = ? AND ip = ?', (username, ip))
    db.commit()


