"""
test_security_enhancements.py

Purpose:
  Unit tests for the security enhancement functions that implement protection against
  sophisticated attack vectors. These tests validate that security functions work
  correctly under normal conditions and properly block malicious inputs while
  maintaining usability for legitimate users.

Test Coverage:
  - Timing attack protection validation
  - Geographic coordinate validation with bounds checking
  - Error message sanitization to prevent information disclosure
  - Filename security validation including path traversal prevention
  - Cryptographic token generation and verification
  - API input validation with security pattern detection
  - Security headers middleware functionality

Security Testing Philosophy:
  Each test validates both positive cases (legitimate inputs work correctly) and
  negative cases (malicious inputs are properly blocked). Tests ensure security
  measures are transparent to legitimate users while effectively blocking attacks.
"""

import pytest
import time
import hmac
import hashlib
from app_modules.security_enhancements import (
    secure_password_check,
    validate_geographic_coordinates,
    sanitize_error_message,
    validate_filename_security,
    generate_secure_token,
    verify_hmac_token,
    rate_limit_with_exponential_backoff,
    validate_api_input,
    security_headers_middleware
)


@pytest.mark.security
def test_secure_password_check_timing_consistency():
    """
    Test that password checking maintains consistent timing regardless of username validity.
    
    This prevents timing-based username enumeration attacks by ensuring bcrypt
    operations run in both valid and invalid username scenarios.
    """
    import bcrypt
    
    # Create a legitimate user row with proper password hash
    password = "testpassword123"
    password_hash = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
    
    valid_user_row = {
        'username': 'testuser',
        'password_hash': password_hash
    }
    
    # Test 1: Valid user with correct password
    result = secure_password_check('testuser', password, valid_user_row)
    assert result is True, "Valid password should authenticate successfully"
    
    # Test 2: Valid user with incorrect password
    result = secure_password_check('testuser', 'wrongpassword', valid_user_row)
    assert result is False, "Invalid password should fail authentication"
    
    # Test 3: Non-existent user (None user_row)
    result = secure_password_check('nonexistent', password, None)
    assert result is False, "Non-existent user should fail authentication"
    
    # Test 4: Timing consistency validation
    valid_times = []
    invalid_times = []
    
    # Measure timing for valid user with wrong password
    for _ in range(10):
        start = time.perf_counter()
        secure_password_check('testuser', 'wrongpassword', valid_user_row)
        end = time.perf_counter()
        valid_times.append(end - start)
    
    # Measure timing for non-existent user
    for _ in range(10):
        start = time.perf_counter()
        secure_password_check('nonexistent', 'wrongpassword', None)
        end = time.perf_counter()
        invalid_times.append(end - start)
    
    # Calculate timing difference
    valid_avg = sum(valid_times) / len(valid_times)
    invalid_avg = sum(invalid_times) / len(invalid_times)
    timing_diff = abs(valid_avg - invalid_avg)
    
    # Timing difference should be minimal (under 50ms for realistic consistency in test environments)
    assert timing_diff < 0.05, f"Timing difference too large: {timing_diff:.4f}s - potential username enumeration vulnerability"


@pytest.mark.security
def test_validate_geographic_coordinates():
    """Test geographic coordinate validation with proper bounds checking."""
    
    # Test 1: Valid coordinates
    success, error, coords = validate_geographic_coordinates(45.815399, 15.981919)
    assert success is True, "Valid coordinates should pass validation"
    assert error is None, "Valid coordinates should not have error message"
    assert coords == (45.815399, 15.981919), "Valid coordinates should be returned unchanged"
    
    # Test 2: Boundary values (should be valid)
    boundary_tests = [
        (90.0, 180.0),    # Max valid values
        (-90.0, -180.0),  # Min valid values
        (0.0, 0.0),       # Zero values
        (45.0, 15.5),     # Typical values
    ]
    
    for lat, lng in boundary_tests:
        success, error, coords = validate_geographic_coordinates(lat, lng)
        assert success is True, f"Boundary coordinates ({lat}, {lng}) should be valid"
        assert coords == (lat, lng), f"Boundary coordinates should be returned unchanged"
    
    # Test 3: Invalid latitude values
    invalid_lat_tests = [
        (91.0, 15.0),     # Latitude too high
        (-91.0, 15.0),    # Latitude too low
        (180.0, 15.0),    # Way too high
        (-180.0, 15.0),   # Way too low
    ]
    
    for lat, lng in invalid_lat_tests:
        success, error, coords = validate_geographic_coordinates(lat, lng)
        assert success is False, f"Invalid latitude {lat} should be rejected"
        assert "geografska širina" in error.lower(), f"Should have latitude error message for {lat}"
        assert coords is None, "Invalid coordinates should return None"
    
    # Test 4: Invalid longitude values
    invalid_lng_tests = [
        (45.0, 181.0),    # Longitude too high
        (45.0, -181.0),   # Longitude too low
        (45.0, 360.0),    # Way too high
        (45.0, -360.0),   # Way too low
    ]
    
    for lat, lng in invalid_lng_tests:
        success, error, coords = validate_geographic_coordinates(lat, lng)
        assert success is False, f"Invalid longitude {lng} should be rejected"
        assert "geografska dužina" in error.lower(), f"Should have longitude error message for {lng}"
        assert coords is None, "Invalid coordinates should return None"
    
    # Test 5: Type conversion and validation
    string_coordinate_tests = [
        ("45.815399", "15.981919"),  # String numbers (should convert)
        ("45", "15"),                # Integer strings
        ("45.0", "15.0"),           # Float strings
    ]
    
    for lat, lng in string_coordinate_tests:
        success, error, coords = validate_geographic_coordinates(lat, lng)
        assert success is True, f"String coordinates ({lat}, {lng}) should convert and validate"
        assert isinstance(coords[0], float), "Converted latitude should be float"
        assert isinstance(coords[1], float), "Converted longitude should be float"
    
    # Test 6: Invalid types and malicious inputs
    invalid_inputs = [
        ("invalid", "15.0"),                    # Non-numeric string
        ("45.0", "invalid"),                    # Non-numeric string
        (None, "15.0"),                         # None value
        ("45.0", None),                         # None value
        ("'; DROP TABLE coords; --", "15.0"),  # SQL injection attempt
        ("<script>alert('xss')</script>", "15.0"),  # XSS attempt
        (float('inf'), 15.0),                   # Infinity
        (45.0, float('inf')),                   # Infinity
        (float('nan'), 15.0),                   # NaN
        (45.0, float('nan')),                   # NaN
    ]
    
    for lat, lng in invalid_inputs:
        success, error, coords = validate_geographic_coordinates(lat, lng)
        assert success is False, f"Invalid input ({lat}, {lng}) should be rejected"
        assert error is not None, "Invalid input should have error message"
        assert coords is None, "Invalid input should return None coordinates"


@pytest.mark.security
def test_sanitize_error_message():
    """Test error message sanitization to prevent information disclosure."""
    
    # Test 1: Database errors should be sanitized
    database_errors = [
        Exception("UNIQUE constraint failed: users.username"),
        Exception("sqlite3.OperationalError: table users has no column named password"),
        Exception("Database disk image is malformed"),
        Exception("FOREIGN KEY constraint failed"),
    ]
    
    for error in database_errors:
        sanitized = sanitize_error_message(error, "database")
        assert "sqlite" not in sanitized.lower(), f"Database error leaked info: {sanitized}"
        assert "table" not in sanitized.lower(), f"Database error leaked info: {sanitized}"
        assert "column" not in sanitized.lower(), f"Database error leaked info: {sanitized}"
        assert "constraint" not in sanitized.lower(), f"Database error leaked info: {sanitized}"
        assert "greška pri obradi podataka" in sanitized.lower(), "Should have generic database error message"
    
    # Test 2: File system errors should be sanitized
    filesystem_errors = [
        Exception("Permission denied: /etc/passwd"),
        Exception("No such file or directory: /var/www/secret.txt"),
        Exception("Access is denied to C:\\Windows\\System32\\config"),
        Exception("File not found: /home/user/.ssh/id_rsa"),
    ]
    
    for error in filesystem_errors:
        sanitized = sanitize_error_message(error, "filesystem")
        assert "/etc/" not in sanitized, f"File path leaked: {sanitized}"
        assert "c:\\" not in sanitized.lower(), f"File path leaked: {sanitized}"
        assert ".ssh" not in sanitized, f"File path leaked: {sanitized}"
        assert "greška pri pristupu datoteci" in sanitized.lower(), "Should have generic file error message"
    
    # Test 3: Network errors should be sanitized
    network_errors = [
        Exception("Connection timeout to internal-server:3306"),
        Exception("Network unreachable: 192.168.1.100"),
        Exception("Socket connection failed"),
    ]
    
    for error in network_errors:
        sanitized = sanitize_error_message(error, "network")
        assert "192.168" not in sanitized, f"IP address leaked: {sanitized}"
        assert ":3306" not in sanitized, f"Port number leaked: {sanitized}"
        assert "greška mreže" in sanitized.lower(), "Should have generic network error message"
    
    # Test 4: Generic errors should get generic message
    generic_error = Exception("Some random error message")
    sanitized = sanitize_error_message(generic_error, "generic")
    assert "došlo je do greške" in sanitized.lower(), "Should have generic error message"


@pytest.mark.security
def test_validate_filename_security():
    """Test filename validation for security issues including path traversal."""
    
    # Test 1: Valid filenames should pass
    valid_filenames = [
        "image.jpg",
        "camera_photo_123.png",
        "PICT_20231201_120000_123456789012.jpg",
        "test_file.jpeg",
        "document.pdf",
    ]
    
    for filename in valid_filenames:
        is_valid, error = validate_filename_security(filename)
        assert is_valid is True, f"Valid filename should pass: {filename}"
        assert error is None, f"Valid filename should not have error: {filename}"
    
    # Test 2: Path traversal attempts should be blocked
    path_traversal_filenames = [
        "../../../etc/passwd",
        "..\\..\\..\\windows\\system32\\config\\sam",
        "....//....//....//etc/passwd",
        "file/../../../sensitive.txt",
        "..%2f..%2f..%2fetc%2fpasswd",  # URL encoded
    ]
    
    for filename in path_traversal_filenames:
        is_valid, error = validate_filename_security(filename)
        assert is_valid is False, f"Path traversal should be blocked: {filename}"
        assert error is not None, f"Path traversal should have error message: {filename}"
        assert "neispravno ime datoteke" in error.lower(), f"Should have generic error message for: {filename}"
    
    # Test 3: Dangerous extensions should be blocked
    dangerous_filenames = [
        "malicious.php",
        "script.php3",
        "backdoor.asp",
        "shell.jsp",
        "trojan.exe",
        "virus.bat",
        "config.htaccess",
        "secrets.ini",
        "exploit.py",
    ]
    
    for filename in dangerous_filenames:
        is_valid, error = validate_filename_security(filename)
        assert is_valid is False, f"Dangerous extension should be blocked: {filename}"
        assert error is not None, f"Dangerous extension should have error: {filename}"
        assert "nedozvoljena vrsta datoteke" in error.lower(), f"Should have file type error for: {filename}"
    
    # Test 4: Special characters should be blocked
    special_char_filenames = [
        "file<script>.jpg",
        "file>redirect.jpg",
        "file|pipe.jpg",
        "file:stream.jpg",
        "file*wildcard.jpg",
        "file?query.jpg",
        'file"quote.jpg',
    ]
    
    for filename in special_char_filenames:
        is_valid, error = validate_filename_security(filename)
        assert is_valid is False, f"Special characters should be blocked: {filename}"
        assert error is not None, f"Special characters should have error: {filename}"
    
    # Test 5: Length validation
    long_filename = "a" * 300 + ".jpg"
    is_valid, error = validate_filename_security(long_filename)
    assert is_valid is False, "Overly long filename should be rejected"
    assert "predugačko" in error.lower(), "Should have length error message"
    
    # Test 6: Null byte injection
    null_byte_filename = "file\x00.jpg"
    is_valid, error = validate_filename_security(null_byte_filename)
    assert is_valid is False, "Null byte injection should be blocked"
    assert error is not None, "Null byte injection should have error message"
    
    # Test 7: Empty or None filename
    is_valid, error = validate_filename_security("")
    assert is_valid is False, "Empty filename should be rejected"
    
    is_valid, error = validate_filename_security(None)
    assert is_valid is False, "None filename should be rejected"


@pytest.mark.security
def test_generate_secure_token():
    """Test secure token generation for cryptographic strength."""
    
    # Test 1: Token generation and uniqueness
    tokens = [generate_secure_token() for _ in range(100)]
    
    # All tokens should be unique
    assert len(set(tokens)) == 100, "All generated tokens should be unique"
    
    # All tokens should have reasonable length
    for token in tokens[:10]:  # Check first 10
        assert len(token) >= 40, f"Token too short: {len(token)} chars"  # Base64 encoding increases length
        assert len(token) <= 50, f"Token too long: {len(token)} chars"
    
    # Test 2: Custom length tokens
    short_token = generate_secure_token(16)
    long_token = generate_secure_token(64)
    
    assert len(short_token) >= 20, "Short token should be encoded properly"
    assert len(long_token) >= 80, "Long token should be encoded properly"
    
    # Test 3: Character set validation (URL-safe base64)
    import string
    valid_chars = set(string.ascii_letters + string.digits + '-_')
    
    for token in tokens[:10]:
        token_chars = set(token.rstrip('='))  # Remove padding
        invalid_chars = token_chars - valid_chars
        assert len(invalid_chars) == 0, f"Token contains invalid characters: {invalid_chars}"


@pytest.mark.security
def test_verify_hmac_token():
    """Test HMAC token verification for integrity and authenticity."""
    
    # Test 1: Valid token verification
    secret_key = "test_secret_key_12345"
    data = "user_data_to_sign"
    
    # Generate valid token
    valid_token = hmac.new(
        secret_key.encode('utf-8'),
        data.encode('utf-8'),
        hashlib.sha256
    ).hexdigest()
    
    # Should verify successfully
    is_valid = verify_hmac_token(data, valid_token, secret_key)
    assert is_valid is True, "Valid HMAC token should verify successfully"
    
    # Test 2: Invalid token should fail
    invalid_token = "invalid_token_12345"
    is_valid = verify_hmac_token(data, invalid_token, secret_key)
    assert is_valid is False, "Invalid HMAC token should fail verification"
    
    # Test 3: Modified data should fail
    modified_data = "modified_user_data"
    is_valid = verify_hmac_token(modified_data, valid_token, secret_key)
    assert is_valid is False, "Modified data should fail HMAC verification"
    
    # Test 4: Wrong secret key should fail
    wrong_secret = "wrong_secret_key"
    is_valid = verify_hmac_token(data, valid_token, wrong_secret)
    assert is_valid is False, "Wrong secret key should fail verification"
    
    # Test 5: Malformed inputs should fail gracefully
    malformed_tests = [
        ("", valid_token, secret_key),
        (data, "", secret_key),
        (data, valid_token, ""),
        (None, valid_token, secret_key),
        (data, None, secret_key),
    ]
    
    for test_data, test_token, test_secret in malformed_tests:
        try:
            is_valid = verify_hmac_token(test_data, test_token, test_secret)
            assert is_valid is False, f"Malformed input should fail: {test_data}, {test_token}, {test_secret}"
        except Exception:
            # Exception is acceptable for malformed input
            pass


@pytest.mark.security
def test_rate_limit_with_exponential_backoff():
    """Test exponential backoff calculation for rate limiting."""
    
    # Test 1: First few attempts should have reasonable delays
    delays = [rate_limit_with_exponential_backoff(i) for i in range(1, 11)]
    
    # Should be increasing
    for i in range(1, len(delays)):
        assert delays[i] >= delays[i-1], f"Delay should increase: {delays[i-1]} -> {delays[i]}"
    
    # First attempt should have base delay
    assert delays[0] == 5, f"First attempt should have 5 second delay, got {delays[0]}"
    
    # Test 2: Zero or negative attempts should return zero delay
    assert rate_limit_with_exponential_backoff(0) == 0, "Zero attempts should have zero delay"
    assert rate_limit_with_exponential_backoff(-1) == 0, "Negative attempts should have zero delay"
    
    # Test 3: Very high attempt counts should be capped
    high_delay = rate_limit_with_exponential_backoff(50)
    assert high_delay <= 3600, f"Delay should be capped at 1 hour, got {high_delay}"
    
    # Test 4: Exponential growth verification
    delay1 = rate_limit_with_exponential_backoff(1)
    delay2 = rate_limit_with_exponential_backoff(2)
    delay3 = rate_limit_with_exponential_backoff(3)
    
    assert delay2 == delay1 * 2, f"Delay should double: {delay1} -> {delay2}"
    assert delay3 == delay1 * 4, f"Delay should quadruple: {delay1} -> {delay3}"


@pytest.mark.security
def test_validate_api_input():
    """Test API input validation with security pattern detection."""
    
    # Test schema for validation
    test_schema = {
        'username': {
            'type': str,
            'required': True,
            'min_length': 3,
            'max_length': 32
        },
        'camera_id': {
            'type': str,
            'required': True,
            'min_length': 12,
            'max_length': 12
        },
        'camera_name': {
            'type': str,
            'required': True,
            'min_length': 1,
            'max_length': 100
        }
    }
    
    # Test 1: Valid input should pass
    valid_data = {
        'username': 'testuser123',
        'camera_id': '123456789012',
        'camera_name': 'Test Camera'
    }
    
    is_valid, error, sanitized = validate_api_input(valid_data, test_schema)
    assert is_valid is True, "Valid input should pass validation"
    assert error is None, "Valid input should not have error"
    assert sanitized == valid_data, "Valid input should be returned unchanged"
    
    # Test 2: Missing required fields should fail
    incomplete_data = {
        'username': 'testuser123',
        # Missing camera_id and camera_name
    }
    
    is_valid, error, sanitized = validate_api_input(incomplete_data, test_schema)
    assert is_valid is False, "Missing required fields should fail"
    assert error is not None, "Missing fields should have error message"
    assert "obavezno" in error.lower(), "Should have required field error message"
    
    # Test 3: Length validation
    length_tests = [
        ({'username': 'xy', 'camera_id': '123456789012', 'camera_name': 'Test'}, 'prekratak'),  # Too short
        ({'username': 'x' * 50, 'camera_id': '123456789012', 'camera_name': 'Test'}, 'predugačak'),  # Too long
        ({'username': 'testuser', 'camera_id': '12345', 'camera_name': 'Test'}, 'prekratak'),  # Camera ID too short
        ({'username': 'testuser', 'camera_id': '1234567890123', 'camera_name': 'Test'}, 'predugačak'),  # Camera ID too long
    ]
    
    for test_data, expected_error in length_tests:
        is_valid, error, sanitized = validate_api_input(test_data, test_schema)
        assert is_valid is False, f"Length validation should fail for: {test_data}"
        assert expected_error in error.lower(), f"Should have length error for: {test_data}"
    
    # Test 4: Security pattern detection
    malicious_inputs = [
        {'username': 'testuser', 'camera_id': '123456789012', 'camera_name': '<script>alert("xss")</script>'},
        {'username': 'testuser', 'camera_id': '123456789012', 'camera_name': "'; DROP TABLE cameras; --"},
        {'username': 'testuser', 'camera_id': '123456789012', 'camera_name': 'UNION SELECT password FROM users'},
        {'username': 'testuser', 'camera_id': '123456789012', 'camera_name': 'javascript:alert("xss")'},
        {'username': 'testuser', 'camera_id': '123456789012', 'camera_name': 'onload=alert("xss")'},
    ]
    
    for malicious_data in malicious_inputs:
        is_valid, error, sanitized = validate_api_input(malicious_data, test_schema)
        assert is_valid is False, f"Malicious input should be blocked: {malicious_data}"
        assert error is not None, f"Malicious input should have error: {malicious_data}"
        assert "neispravna vrijednost" in error.lower(), f"Should have invalid value error for: {malicious_data}"
    
    # Test 5: Type conversion
    type_conversion_data = {
        'username': 'testuser123',
        'camera_id': 123456789012,  # Integer instead of string
        'camera_name': 'Test Camera'
    }
    
    is_valid, error, sanitized = validate_api_input(type_conversion_data, test_schema)
    assert is_valid is True, "Type conversion should work for compatible types"
    assert isinstance(sanitized['camera_id'], str), "Camera ID should be converted to string"
    
    # Test 6: Invalid data type
    invalid_data = "not a dictionary"
    is_valid, error, sanitized = validate_api_input(invalid_data, test_schema)
    assert is_valid is False, "Non-dict input should be rejected"
    assert "neispravni podaci" in error.lower(), "Should have invalid data error"


@pytest.mark.security
def test_security_headers_middleware():
    """Test security headers middleware functionality."""
    
    # Mock response object
    class MockResponse:
        def __init__(self):
            self.headers = {}
    
    response = MockResponse()
    
    # Apply security headers
    enhanced_response = security_headers_middleware(response)
    
    # Test 1: Required security headers should be present
    required_headers = [
        'Content-Security-Policy',
        'X-Frame-Options',
        'X-Content-Type-Options',
        'X-XSS-Protection',
        'Referrer-Policy',
        'Permissions-Policy'
    ]
    
    for header in required_headers:
        assert header in enhanced_response.headers, f"Required security header missing: {header}"
    
    # Test 2: Content Security Policy should be comprehensive
    csp = enhanced_response.headers['Content-Security-Policy']
    
    csp_directives = [
        "default-src 'self'",
        "script-src 'self'",
        "style-src 'self'",
        "img-src 'self'",
        "font-src 'self'",
        "connect-src 'self'",
        "frame-ancestors 'none'",
        "base-uri 'self'",
        "form-action 'self'"
    ]
    
    for directive in csp_directives:
        assert directive in csp, f"CSP directive missing: {directive}"
    
    # Test 3: Security header values should be secure
    assert enhanced_response.headers['X-Frame-Options'] == 'DENY', "X-Frame-Options should deny framing"
    assert enhanced_response.headers['X-Content-Type-Options'] == 'nosniff', "Should prevent MIME sniffing"
    assert 'mode=block' in enhanced_response.headers['X-XSS-Protection'], "XSS protection should block attacks"
    
    # Test 4: Server header should be removed if present
    response_with_server = MockResponse()
    response_with_server.headers['Server'] = 'Flask/2.0.1 Werkzeug/2.0.1'
    
    enhanced = security_headers_middleware(response_with_server)
    assert 'Server' not in enhanced.headers, "Server header should be removed"
