"""
security_enhancements.py

Purpose:
  Advanced security enhancement functions that implement protection against sophisticated
  attack vectors identified during red-team security auditing. This module provides
  timing attack protection, input validation hardening, and security-focused error
  handling to maintain the highest security standards across the application.

Security Enhancements:
  - Timing attack resistant authentication checks (username enumeration prevention)
  - Geographic coordinate validation with proper bounds checking
  - Security-focused error message sanitization to prevent information disclosure
  - Enhanced input validation for all user-controllable parameters
  - Cryptographic helper functions for secure token generation and validation

Implementation Philosophy:
  All functions follow defense-in-depth principles with multiple validation layers.
  Error handling is designed to prevent information leakage while maintaining
  system functionality. Security measures are transparent to legitimate users
  but effectively block malicious attempts.
"""

import time
import bcrypt
import secrets
import hashlib
import hmac
from typing import Tuple, Optional, Dict, Any
from functools import wraps
from flask import current_app


def secure_password_check(username: str, password: str, user_row: Optional[Dict]) -> bool:
    """
    Timing-attack resistant password verification.
    
    Prevents username enumeration by ensuring consistent timing regardless of whether
    the username exists in the database. Always performs bcrypt operation to maintain
    constant time complexity.
    
    Args:
        username: Username being authenticated
        password: Password to verify
        user_row: User database row (None if user doesn't exist)
    
    Returns:
        bool: True if authentication successful, False otherwise
    
    Security Features:
        - Constant time execution regardless of username validity
        - Secure password hashing verification
        - No information leakage through timing differences
    """
    # Dummy hash for timing consistency when user doesn't exist
    # This ensures bcrypt operation always runs, preventing timing attacks
    dummy_hash = "$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/LewiyTuTpbsEf.ug."
    
    if user_row and user_row['password_hash']:
        # User exists - verify actual password
        target_hash = user_row['password_hash'].encode('utf-8')
        result = bcrypt.checkpw(password.encode('utf-8'), target_hash)
    else:
        # User doesn't exist - perform dummy bcrypt to maintain timing
        bcrypt.checkpw(password.encode('utf-8'), dummy_hash.encode('utf-8'))
        result = False
    
    return result


def validate_geographic_coordinates(lat: Any, lng: Any) -> Tuple[bool, Optional[str], Optional[Tuple[float, float]]]:
    """
    Validate geographic coordinates with proper bounds checking.
    
    Prevents injection of invalid coordinate data that could pollute the database
    or cause application errors. Enforces strict geographic bounds and type validation.
    
    Args:
        lat: Latitude value (any type, will be validated)
        lng: Longitude value (any type, will be validated)
    
    Returns:
        Tuple[bool, Optional[str], Optional[Tuple[float, float]]]:
            - success: Whether validation passed
            - error_message: Error description if validation failed
            - coordinates: Validated (lat, lng) tuple if successful
    
    Security Features:
        - Strict type validation prevents injection attacks
        - Geographic bounds enforcement prevents invalid data
        - Sanitized error messages prevent information disclosure
    """
    try:
        # Type validation and conversion
        lat_float = float(lat)
        lng_float = float(lng)
        
        # Geographic bounds validation
        if not (-90.0 <= lat_float <= 90.0):
            return False, "Neispravna geografska širina.", None
        
        if not (-180.0 <= lng_float <= 180.0):
            return False, "Neispravna geografska dužina.", None
        
        # Additional sanity checks
        if lat_float != lat_float or lng_float != lng_float:  # NaN check
            return False, "Neispravne koordinate.", None
        
        return True, None, (lat_float, lng_float)
        
    except (ValueError, TypeError, OverflowError):
        return False, "Neispravne koordinate.", None


def sanitize_error_message(error: Exception, context: str = "operation") -> str:
    """
    Sanitize error messages to prevent information disclosure.
    
    Converts detailed system errors into generic user-friendly messages that don't
    reveal sensitive system information or internal application structure.
    
    Args:
        error: The original exception
        context: Context of the operation for logging purposes
    
    Returns:
        str: Sanitized error message safe for user display
    
    Security Features:
        - Removes database-specific error details
        - Hides file system paths and internal structure
        - Prevents stack trace information leakage
        - Maintains user experience with helpful generic messages
    """
    error_str = str(error).lower()
    
    # Database-related errors
    if any(keyword in error_str for keyword in ['sqlite', 'database', 'table', 'column', 'constraint']):
        return "Greška pri obradi podataka."
    
    # File system errors  
    if any(keyword in error_str for keyword in ['permission', 'access', 'file', 'directory', 'path']):
        return "Greška pri pristupu datoteci."
    
    # Network/connection errors
    if any(keyword in error_str for keyword in ['connection', 'timeout', 'network', 'socket']):
        return "Greška mreže."
    
    # Authentication/authorization errors
    if any(keyword in error_str for keyword in ['unauthorized', 'forbidden', 'access denied']):
        return "Nemate dozvolu za ovu operaciju."
    
    # Generic server errors
    return "Došlo je do greške. Molimo pokušajte ponovno."


def validate_filename_security(filename: str) -> Tuple[bool, Optional[str]]:
    """
    Validate filename for security issues including path traversal and malicious extensions.
    
    Prevents path traversal attacks, executable file uploads, and other filename-based
    security vulnerabilities.
    
    Args:
        filename: Filename to validate
    
    Returns:
        Tuple[bool, Optional[str]]: (is_valid, error_message)
    
    Security Features:
        - Path traversal prevention (../, ..\\, etc.)
        - Malicious extension detection
        - Special character filtering
        - Length validation
    """
    if not filename or not isinstance(filename, str):
        return False, "Nedostaje ime datoteke."
    
    # Length validation
    if len(filename) > 255:
        return False, "Ime datoteke je predugačko."
    
    # Path traversal detection
    dangerous_patterns = [
        '..',
        '/',
        '\\',
        ':',
        '|',
        '<',
        '>',
        '*',
        '?',
        '"'
    ]
    
    for pattern in dangerous_patterns:
        if pattern in filename:
            return False, "Neispravno ime datoteke."
    
    # Malicious extension detection
    dangerous_extensions = [
        '.php', '.php3', '.php4', '.php5', '.phtml',
        '.asp', '.aspx', '.jsp', '.jspx',
        '.py', '.pl', '.rb', '.sh', '.bat', '.cmd',
        '.exe', '.scr', '.com', '.pif',
        '.htaccess', '.htpasswd',
        '.config', '.ini', '.cfg'
    ]
    
    filename_lower = filename.lower()
    for ext in dangerous_extensions:
        if filename_lower.endswith(ext):
            return False, "Nedozvoljena vrsta datoteke."
    
    # Null byte injection prevention
    if '\x00' in filename:
        return False, "Neispravno ime datoteke."
    
    return True, None


def generate_secure_token(length: int = 32) -> str:
    """
    Generate cryptographically secure random token.
    
    Creates tokens with high entropy suitable for CSRF protection, session tokens,
    and other security-critical applications.
    
    Args:
        length: Desired token length in bytes (default 32)
    
    Returns:
        str: URL-safe base64 encoded secure token
    
    Security Features:
        - Cryptographically secure random number generation
        - High entropy output suitable for security tokens
        - URL-safe encoding for web application compatibility
    """
    return secrets.token_urlsafe(length)


def verify_hmac_token(data: str, token: str, secret_key: str) -> bool:
    """
    Verify HMAC-based token for integrity and authenticity.
    
    Provides secure verification of tokens used in media URLs, API authentication,
    and other security-sensitive contexts.
    
    Args:
        data: Original data that was signed
        token: HMAC token to verify
        secret_key: Secret key used for signing
    
    Returns:
        bool: True if token is valid, False otherwise
    
    Security Features:
        - Timing-attack resistant comparison
        - Strong HMAC-SHA256 verification
        - Prevents token manipulation attacks
    """
    try:
        expected_token = hmac.new(
            secret_key.encode('utf-8'),
            data.encode('utf-8'),
            hashlib.sha256
        ).hexdigest()
        
        # Timing-attack resistant comparison
        return hmac.compare_digest(token, expected_token)
        
    except Exception:
        return False


def rate_limit_with_exponential_backoff(attempt_count: int, base_delay: int = 5) -> int:
    """
    Calculate exponential backoff delay for rate limiting.
    
    Implements increasingly longer delays for repeated failed attempts,
    making brute force attacks impractical while allowing legitimate users
    to retry after reasonable delays.
    
    Args:
        attempt_count: Number of failed attempts
        base_delay: Base delay in seconds (default 5)
    
    Returns:
        int: Delay in seconds before next attempt allowed
    
    Security Features:
        - Exponential increase in delay times
        - Maximum cap to prevent excessive delays
        - Transparent to legitimate users with minimal attempts
    """
    if attempt_count <= 0:
        return 0
    
    # Exponential backoff with maximum cap
    delay = min(base_delay * (2 ** (attempt_count - 1)), 3600)  # Max 1 hour
    return delay


def validate_api_input(data: Dict[str, Any], schema: Dict[str, Dict]) -> Tuple[bool, Optional[str], Dict[str, Any]]:
    """
    Validate API input against schema with security-focused validation.
    
    Provides comprehensive input validation with type checking, length limits,
    and security pattern detection to prevent injection attacks.
    
    Args:
        data: Input data to validate
        schema: Validation schema with field definitions
    
    Returns:
        Tuple[bool, Optional[str], Dict[str, Any]]: (is_valid, error_message, sanitized_data)
    
    Security Features:
        - SQL injection pattern detection
        - XSS prevention through input sanitization
        - Type and length validation
        - Required field enforcement
    """
    if not isinstance(data, dict):
        return False, "Neispravni podaci.", {}
    
    sanitized_data = {}
    
    for field_name, field_schema in schema.items():
        value = data.get(field_name)
        
        # Required field check
        if field_schema.get('required', False) and (value is None or value == ''):
            return False, f"Polje '{field_name}' je obavezno.", {}
        
        if value is not None:
            # Type validation
            expected_type = field_schema.get('type', str)
            if not isinstance(value, expected_type):
                try:
                    value = expected_type(value)
                except (ValueError, TypeError):
                    return False, f"Neispravna vrijednost za '{field_name}'.", {}
            
            # String-specific validations
            if isinstance(value, str):
                # Length validation
                min_length = field_schema.get('min_length', 0)
                max_length = field_schema.get('max_length', 10000)
                
                if len(value) < min_length:
                    return False, f"'{field_name}' prekratak.", {}
                
                if len(value) > max_length:
                    return False, f"'{field_name}' predugačak.", {}
                
                # Security pattern detection
                dangerous_patterns = [
                    'script',
                    'javascript',
                    'vbscript',
                    'onload',
                    'onerror',
                    'onclick',
                    'select',
                    'union',
                    'drop',
                    'insert',
                    'update',
                    'delete',
                    '--',
                    ';'
                ]
                
                value_lower = value.lower()
                for pattern in dangerous_patterns:
                    if pattern in value_lower:
                        return False, f"Neispravna vrijednost za '{field_name}'.", {}
            
            sanitized_data[field_name] = value
    
    return True, None, sanitized_data


def security_headers_middleware(response):
    """
    Add comprehensive security headers to HTTP responses.
    
    Implements defense-in-depth security headers to protect against various
    web application attacks including XSS, clickjacking, and data injection.
    
    Args:
        response: Flask response object
    
    Returns:
        Response object with security headers added
    
    Security Features:
        - Content Security Policy (CSP) enforcement
        - XSS protection headers
        - Clickjacking prevention
        - MIME type sniffing prevention
        - Referrer policy enforcement
    """
    # Enhanced Content Security Policy
    csp_directives = [
        "default-src 'self'",
        "script-src 'self' 'unsafe-inline' https://unpkg.com",
        "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com https://unpkg.com",
        "img-src 'self' data: https://*.tile.openstreetmap.org https://unpkg.com",
        "font-src 'self' https://fonts.gstatic.com data:",
        "connect-src 'self'",
        "frame-ancestors 'none'",
        "base-uri 'self'",
        "form-action 'self'"
    ]
    
    response.headers['Content-Security-Policy'] = '; '.join(csp_directives)
    
    # Security headers
    response.headers['X-Frame-Options'] = 'DENY'
    response.headers['X-Content-Type-Options'] = 'nosniff'
    response.headers['X-XSS-Protection'] = '1; mode=block'
    response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin'
    response.headers['Permissions-Policy'] = 'geolocation=(self), microphone=(), camera=()'
    
    # Remove server information
    response.headers.pop('Server', None)
    
    return response
