Files
aitbc/tests/test_api_utils.py
aitbc f4688aefbd
Some checks failed
Cross-Node Transaction Testing / transaction-test (push) Has been cancelled
Deploy to Testnet / deploy-testnet (push) Has been cancelled
Documentation Validation / validate-docs (push) Has been cancelled
Documentation Validation / validate-policies-strict (push) Has been cancelled
Integration Tests / test-service-integration (push) Has been cancelled
Multi-Node Stress Testing / stress-test (push) Has been cancelled
Python Tests / test-python (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled
refactor: improve imports, fix datetime usage, and reorganize cross-chain services
- Added logger initialization to EventRouter in events.py
- Fixed datetime.timedelta references to use timedelta directly in security_hardening.py
- Fixed StateTransition timestamp default_factory to use lambda in state.py
- Fixed StateValidator.validate_transitions to only check source states exist
- Moved cross_chain_bridge_enhanced.py to cross_chain/bridge_enhanced.py
- Updated import paths in global_marketplace
2026-05-12 20:49:01 +02:00

528 lines
18 KiB
Python

"""
Tests for API utilities
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import Mock
from aitbc.api_utils import (
APIResponse,
PaginatedResponse,
success_response,
error_response,
not_found_response,
unauthorized_response,
forbidden_response,
validation_error_response,
conflict_response,
internal_error_response,
PaginationParams,
paginate_items,
build_paginated_response,
RateLimitHeaders,
build_cors_headers,
build_standard_headers,
validate_sort_field,
validate_sort_order,
build_sort_params,
filter_fields,
exclude_fields,
sanitize_response,
merge_responses,
get_client_ip,
get_user_agent,
build_request_metadata,
)
class TestAPIResponse:
"""Tests for APIResponse"""
def test_api_response_creation(self):
"""Test APIResponse creation"""
response = APIResponse(
success=True,
message="Test message",
data={"key": "value"}
)
assert response.success is True
assert response.message == "Test message"
assert response.data == {"key": "value"}
assert response.timestamp is not None
def test_api_response_default_timestamp(self):
"""Test APIResponse auto-generates timestamp"""
response = APIResponse(success=True, message="Test")
assert response.timestamp is not None
# Verify it's a valid ISO format timestamp
datetime.fromisoformat(response.timestamp)
class TestPaginatedResponse:
"""Tests for PaginatedResponse"""
def test_paginated_response_creation(self):
"""Test PaginatedResponse creation"""
response = PaginatedResponse(
success=True,
message="Success",
data=[1, 2, 3],
pagination={"page": 1, "total": 10}
)
assert response.success is True
assert response.data == [1, 2, 3]
assert response.pagination == {"page": 1, "total": 10}
assert response.timestamp is not None
class TestResponseBuilders:
"""Tests for response builder functions"""
def test_success_response(self):
"""Test success_response function"""
response = success_response("Operation successful", {"id": 1})
assert response.success is True
assert response.message == "Operation successful"
assert response.data == {"id": 1}
def test_success_response_no_data(self):
"""Test success_response without data"""
response = success_response("Success")
assert response.success is True
assert response.message == "Success"
assert response.data is None
def test_error_response(self):
"""Test error_response function"""
response = error_response("Error occurred", "ERROR_CODE", 400)
assert response.status_code == 400
assert response.detail["success"] is False
assert response.detail["message"] == "Error occurred"
assert response.detail["error"] == "ERROR_CODE"
def test_not_found_response(self):
"""Test not_found_response function"""
response = not_found_response("User")
assert response.status_code == 404
assert "User not found" in response.detail["message"]
assert response.detail["error"] == "NOT_FOUND"
def test_unauthorized_response(self):
"""Test unauthorized_response function"""
response = unauthorized_response("Access denied")
assert response.status_code == 401
assert response.detail["message"] == "Access denied"
assert response.detail["error"] == "UNAUTHORIZED"
def test_forbidden_response(self):
"""Test forbidden_response function"""
response = forbidden_response("Forbidden")
assert response.status_code == 403
assert response.detail["message"] == "Forbidden"
assert response.detail["error"] == "FORBIDDEN"
def test_validation_error_response(self):
"""Test validation_error_response function"""
response = validation_error_response(["Field required", "Invalid format"])
assert response.status_code == 422
assert response.detail["error"] == "VALIDATION_ERROR"
def test_conflict_response(self):
"""Test conflict_response function"""
response = conflict_response("Resource already exists")
assert response.status_code == 409
assert response.detail["message"] == "Resource already exists"
assert response.detail["error"] == "CONFLICT"
def test_internal_error_response(self):
"""Test internal_error_response function"""
response = internal_error_response("Server error")
assert response.status_code == 500
assert response.detail["error"] == "INTERNAL_ERROR"
class TestPaginationParams:
"""Tests for PaginationParams"""
def test_pagination_params_defaults(self):
"""Test PaginationParams with defaults"""
params = PaginationParams()
assert params.page == 1
assert params.page_size == 10
assert params.offset == 0
def test_pagination_params_custom(self):
"""Test PaginationParams with custom values"""
params = PaginationParams(page=2, page_size=20)
assert params.page == 2
assert params.page_size == 20
assert params.offset == 20
def test_pagination_params_page_minimum(self):
"""Test PaginationParams enforces minimum page"""
params = PaginationParams(page=0)
assert params.page == 1
def test_pagination_params_page_size_minimum(self):
"""Test PaginationParams enforces minimum page_size"""
params = PaginationParams(page_size=0)
assert params.page_size == 1
def test_pagination_params_page_size_maximum(self):
"""Test PaginationParams enforces maximum page_size"""
params = PaginationParams(page_size=200, max_page_size=100)
assert params.page_size == 100
def test_get_limit(self):
"""Test get_limit method"""
params = PaginationParams(page_size=25)
assert params.get_limit() == 25
def test_get_offset(self):
"""Test get_offset method"""
params = PaginationParams(page=3, page_size=10)
assert params.get_offset() == 20
class TestPaginateItems:
"""Tests for paginate_items function"""
def test_paginate_items_basic(self):
"""Test basic pagination"""
items = list(range(25))
result = paginate_items(items, page=1, page_size=10)
assert len(result["items"]) == 10
assert result["items"] == list(range(10))
assert result["pagination"]["page"] == 1
assert result["pagination"]["total"] == 25
assert result["pagination"]["total_pages"] == 3
assert result["pagination"]["has_next"] is True
assert result["pagination"]["has_prev"] is False
def test_paginate_items_second_page(self):
"""Test pagination second page"""
items = list(range(25))
result = paginate_items(items, page=2, page_size=10)
assert result["items"] == list(range(10, 20))
assert result["pagination"]["has_next"] is True
assert result["pagination"]["has_prev"] is True
def test_paginate_items_last_page(self):
"""Test pagination last page"""
items = list(range(25))
result = paginate_items(items, page=3, page_size=10)
assert result["items"] == list(range(20, 25))
assert result["pagination"]["has_next"] is False
assert result["pagination"]["has_prev"] is True
def test_paginate_items_empty_list(self):
"""Test pagination with empty list"""
result = paginate_items([], page=1, page_size=10)
assert result["items"] == []
assert result["pagination"]["total"] == 0
assert result["pagination"]["total_pages"] == 0
def test_build_paginated_response(self):
"""Test build_paginated_response function"""
items = list(range(15))
response = build_paginated_response(items, page=1, page_size=10)
assert isinstance(response, PaginatedResponse)
assert response.success is True
assert len(response.data) == 10
assert response.pagination["total"] == 15
class TestRateLimitHeaders:
"""Tests for RateLimitHeaders"""
def test_get_headers(self):
"""Test get_headers method"""
headers = RateLimitHeaders.get_headers(limit=100, remaining=50, reset=3600, window=60)
assert headers["X-RateLimit-Limit"] == "100"
assert headers["X-RateLimit-Remaining"] == "50"
assert headers["X-RateLimit-Reset"] == "3600"
assert headers["X-RateLimit-Window"] == "60"
def test_get_retry_after(self):
"""Test get_retry_after method"""
headers = RateLimitHeaders.get_retry_after(30)
assert headers["Retry-After"] == "30"
class TestHeaderBuilders:
"""Tests for header builder functions"""
def test_build_cors_headers_defaults(self):
"""Test build_cors_headers with defaults"""
headers = build_cors_headers()
assert "Access-Control-Allow-Origin" in headers
assert "Access-Control-Allow-Methods" in headers
assert "Access-Control-Allow-Headers" in headers
assert "Access-Control-Max-Age" in headers
def test_build_cors_headers_custom(self):
"""Test build_cors_headers with custom values"""
headers = build_cors_headers(
allowed_origins=["http://localhost:3000"],
allowed_methods=["GET", "POST"],
max_age=7200
)
assert "http://localhost:3000" in headers["Access-Control-Allow-Origin"]
assert "GET, POST" in headers["Access-Control-Allow-Methods"]
assert headers["Access-Control-Max-Age"] == "7200"
def test_build_standard_headers_defaults(self):
"""Test build_standard_headers with defaults"""
headers = build_standard_headers()
assert headers["Content-Type"] == "application/json"
assert "Cache-Control" not in headers
assert "X-Request-ID" not in headers
def test_build_standard_headers_with_options(self):
"""Test build_standard_headers with options"""
headers = build_standard_headers(
content_type="application/xml",
cache_control="no-cache",
x_request_id="req-123"
)
assert headers["Content-Type"] == "application/xml"
assert headers["Cache-Control"] == "no-cache"
assert headers["X-Request-ID"] == "req-123"
class TestSortValidation:
"""Tests for sort validation functions"""
def test_validate_sort_field_valid(self):
"""Test validate_sort_field with valid field"""
field = validate_sort_field("name", ["name", "email", "age"])
assert field == "name"
def test_validate_sort_field_invalid(self):
"""Test validate_sort_field with invalid field"""
with pytest.raises(ValueError) as exc_info:
validate_sort_field("invalid", ["name", "email"])
assert "Invalid sort field" in str(exc_info.value)
def test_validate_sort_order_asc(self):
"""Test validate_sort_order with ASC"""
order = validate_sort_order("asc")
assert order == "ASC"
def test_validate_sort_order_desc(self):
"""Test validate_sort_order with DESC"""
order = validate_sort_order("desc")
assert order == "DESC"
def test_validate_sort_order_invalid(self):
"""Test validate_sort_order with invalid order"""
with pytest.raises(ValueError) as exc_info:
validate_sort_order("invalid")
assert "Invalid sort order" in str(exc_info.value)
def test_build_sort_params_valid(self):
"""Test build_sort_params with valid parameters"""
params = build_sort_params(
sort_by="name",
sort_order="ASC",
allowed_fields=["name", "email"]
)
assert params == {"sort_by": "name", "sort_order": "ASC"}
def test_build_sort_params_no_sort(self):
"""Test build_sort_params without sort_by"""
params = build_sort_params(sort_by=None, allowed_fields=["name"])
assert params == {}
def test_build_sort_params_no_allowed_fields(self):
"""Test build_sort_params without allowed_fields"""
params = build_sort_params(sort_by="name", allowed_fields=None)
assert params == {}
class TestFieldFiltering:
"""Tests for field filtering functions"""
def test_filter_fields(self):
"""Test filter_fields function"""
data = {"name": "John", "email": "john@example.com", "age": 30}
result = filter_fields(data, ["name", "email"])
assert result == {"name": "John", "email": "john@example.com"}
def test_exclude_fields(self):
"""Test exclude_fields function"""
data = {"name": "John", "email": "john@example.com", "age": 30}
result = exclude_fields(data, ["age"])
assert result == {"name": "John", "email": "john@example.com"}
class TestSanitizeResponse:
"""Tests for sanitize_response function"""
def test_sanitize_response_dict(self):
"""Test sanitize_response with dictionary"""
data = {"username": "john", "password": "secret123", "email": "john@example.com"}
result = sanitize_response(data)
assert result["username"] == "john"
assert result["password"] == "***"
assert result["email"] == "john@example.com"
def test_sanitize_response_list(self):
"""Test sanitize_response with list"""
data = [
{"username": "john", "token": "abc123"},
{"username": "jane", "token": "xyz789"}
]
result = sanitize_response(data)
assert result[0]["username"] == "john"
assert result[0]["token"] == "***"
assert result[1]["username"] == "jane"
assert result[1]["token"] == "***"
def test_sanitize_response_custom_fields(self):
"""Test sanitize_response with custom sensitive fields"""
data = {"username": "john", "api_key": "secret", "email": "john@example.com"}
result = sanitize_response(data, sensitive_fields=["api_key"])
assert result["username"] == "john"
assert result["api_key"] == "***"
assert result["email"] == "john@example.com"
def test_sanitize_response_nested(self):
"""Test sanitize_response with nested structure"""
data = {"user": {"username": "john", "password": "secret"}}
result = sanitize_response(data)
assert result["user"]["username"] == "john"
assert result["user"]["password"] == "***"
class TestMergeResponses:
"""Tests for merge_responses function"""
def test_merge_responses_api_response(self):
"""Test merge_responses with APIResponse objects"""
response1 = success_response("Success1", {"key1": "value1"})
response2 = success_response("Success2", {"key2": "value2"})
result = merge_responses(response1, response2)
assert result["data"]["key1"] == "value1"
assert result["data"]["key2"] == "value2"
def test_merge_responses_dict(self):
"""Test merge_responses with dict objects"""
response1 = {"data": {"key1": "value1"}}
response2 = {"data": {"key2": "value2"}}
result = merge_responses(response1, response2)
assert result["data"]["key1"] == "value1"
assert result["data"]["key2"] == "value2"
def test_merge_responses_mixed(self):
"""Test merge_responses with mixed types"""
response1 = success_response("Success1", {"key1": "value1"})
response2 = {"data": {"key2": "value2"}}
result = merge_responses(response1, response2)
assert result["data"]["key1"] == "value1"
assert result["data"]["key2"] == "value2"
def test_merge_responses_empty(self):
"""Test merge_responses with no responses"""
result = merge_responses()
assert result == {"data": {}}
class TestRequestHelpers:
"""Tests for request helper functions"""
def test_get_client_ip_forwarded(self):
"""Test get_client_ip with X-Forwarded-For header"""
request = Mock()
request.headers = {"X-Forwarded-For": "192.168.1.1, 10.0.0.1"}
request.client = Mock()
ip = get_client_ip(request)
assert ip == "192.168.1.1"
def test_get_client_ip_real_ip(self):
"""Test get_client_ip with X-Real-IP header"""
request = Mock()
request.headers = {"X-Real-IP": "192.168.1.2"}
request.client = Mock()
ip = get_client_ip(request)
assert ip == "192.168.1.2"
def test_get_client_ip_from_client(self):
"""Test get_client_ip from request.client"""
request = Mock()
request.headers = {}
request.client = Mock()
request.client.host = "192.168.1.3"
ip = get_client_ip(request)
assert ip == "192.168.1.3"
def test_get_client_ip_unknown(self):
"""Test get_client_ip when no IP available"""
request = Mock()
request.headers = {}
request.client = None
ip = get_client_ip(request)
assert ip == "unknown"
def test_get_user_agent(self):
"""Test get_user_agent function"""
request = Mock()
request.headers = {"User-Agent": "Mozilla/5.0"}
ua = get_user_agent(request)
assert ua == "Mozilla/5.0"
def test_get_user_agent_unknown(self):
"""Test get_user_agent when header missing"""
request = Mock()
request.headers = {}
ua = get_user_agent(request)
assert ua == "unknown"
def test_build_request_metadata(self):
"""Test build_request_metadata function"""
request = Mock()
request.headers = {
"X-Forwarded-For": "192.168.1.1",
"User-Agent": "Mozilla/5.0",
"X-Request-ID": "req-123"
}
request.client = Mock()
request.client.host = "192.168.1.1"
metadata = build_request_metadata(request)
assert metadata["client_ip"] == "192.168.1.1"
assert metadata["user_agent"] == "Mozilla/5.0"
assert metadata["request_id"] == "req-123"
assert metadata["timestamp"] is not None