diff --git a/aitbc/__init__.py b/aitbc/__init__.py index ca7bcb2a..db7c021b 100644 --- a/aitbc/__init__.py +++ b/aitbc/__init__.py @@ -29,6 +29,9 @@ from .exceptions import ( DatabaseError, ValidationError, BridgeError, + RetryError, + CircuitBreakerOpenError, + RateLimitError, ) from .env import ( get_env_var, @@ -60,7 +63,7 @@ from .json_utils import ( set_nested_value, flatten_json, ) -from .http_client import AITBCHTTPClient +from .http_client import AITBCHTTPClient, AsyncAITBCHTTPClient from .config import BaseAITBCConfig, AITBCConfig from .decorators import ( retry, @@ -105,6 +108,142 @@ from .monitoring import ( PerformanceTimer, HealthChecker, ) +from .crypto import ( + derive_ethereum_address, + sign_transaction_hash, + verify_signature, + encrypt_private_key, + decrypt_private_key, + generate_secure_random_bytes, + keccak256_hash, + sha256_hash, + validate_ethereum_address, + generate_ethereum_private_key, +) +from .web3_utils import Web3Client, create_web3_client +from .security import ( + generate_token, + generate_api_key, + validate_token_format, + validate_api_key, + SessionManager, + APIKeyManager, + generate_secure_random_string, + generate_secure_random_int, + SecretManager, + hash_password, + verify_password, + generate_nonce, + generate_hmac, + verify_hmac, +) +from .time_utils import ( + get_utc_now, + get_timestamp_utc, + format_iso8601, + parse_iso8601, + timestamp_to_iso, + iso_to_timestamp, + format_duration, + format_duration_precise, + parse_duration, + add_duration, + subtract_duration, + get_time_until, + get_time_since, + calculate_deadline, + is_deadline_passed, + get_deadline_remaining, + format_time_ago, + format_time_in, + to_timezone, + get_timezone_offset, + is_business_hours, + get_start_of_day, + get_end_of_day, + get_start_of_week, + get_end_of_week, + get_start_of_month, + get_end_of_month, + sleep_until, + retry_until_deadline, + Timer, +) +from .api_utils import ( + APIResponse, + PaginatedResponse, + success_response, + error_response, + not_found_response, + unauthorized_response, + forbidden_response, + validation_error_response, + conflict_response, + internal_error_response, + PaginationParams, + paginate_items, + build_paginated_response, + RateLimitHeaders, + build_cors_headers, + build_standard_headers, + validate_sort_field, + validate_sort_order, + build_sort_params, + filter_fields, + exclude_fields, + sanitize_response, + merge_responses, + get_client_ip, + get_user_agent, + build_request_metadata, +) +from .events import ( + Event, + EventPriority, + EventBus, + AsyncEventBus, + event_handler, + publish_event, + get_global_event_bus, + set_global_event_bus, + EventFilter, + EventAggregator, + EventRouter, +) +from .queue import ( + Job, + JobStatus, + JobPriority, + TaskQueue, + JobScheduler, + BackgroundTaskManager, + WorkerPool, + debounce, + throttle, +) +from .state import ( + StateTransition, + StateTransitionError, + StatePersistenceError, + StateMachine, + ConfigurableStateMachine, + StatePersistence, + AsyncStateMachine, + StateMonitor, + StateValidator, + StateSnapshot, +) +from .testing import ( + MockFactory, + TestDataGenerator, + TestHelpers, + MockResponse, + MockDatabase, + MockCache, + mock_async_call, + create_mock_config, + create_test_scenario, +) __version__ = "0.6.0" __all__ = [ @@ -135,6 +274,9 @@ __all__ = [ "DatabaseError", "ValidationError", "BridgeError", + "RetryError", + "CircuitBreakerOpenError", + "RateLimitError", # Environment helpers "get_env_var", "get_required_env_var", @@ -164,6 +306,7 @@ __all__ = [ "flatten_json", # HTTP client "AITBCHTTPClient", + "AsyncAITBCHTTPClient", # Configuration "BaseAITBCConfig", "AITBCConfig", @@ -205,4 +348,134 @@ __all__ = [ "MetricsCollector", "PerformanceTimer", "HealthChecker", + # Cryptography + "derive_ethereum_address", + "sign_transaction_hash", + "verify_signature", + "encrypt_private_key", + "decrypt_private_key", + "generate_secure_random_bytes", + "keccak256_hash", + "sha256_hash", + "validate_ethereum_address", + "generate_ethereum_private_key", + # Web3 utilities + "Web3Client", + "create_web3_client", + # Security + "generate_token", + "generate_api_key", + "validate_token_format", + "validate_api_key", + "SessionManager", + "APIKeyManager", + "generate_secure_random_string", + "generate_secure_random_int", + "SecretManager", + "hash_password", + "verify_password", + "generate_nonce", + "generate_hmac", + "verify_hmac", + # Time utilities + "get_utc_now", + "get_timestamp_utc", + "format_iso8601", + "parse_iso8601", + "timestamp_to_iso", + "iso_to_timestamp", + "format_duration", + "format_duration_precise", + "parse_duration", + "add_duration", + "subtract_duration", + "get_time_until", + "get_time_since", + "calculate_deadline", + "is_deadline_passed", + "get_deadline_remaining", + "format_time_ago", + "format_time_in", + "to_timezone", + "get_timezone_offset", + "is_business_hours", + "get_start_of_day", + "get_end_of_day", + "get_start_of_week", + "get_end_of_week", + "get_start_of_month", + "get_end_of_month", + "sleep_until", + "retry_until_deadline", + "Timer", + # API utilities + "APIResponse", + "PaginatedResponse", + "success_response", + "error_response", + "not_found_response", + "unauthorized_response", + "forbidden_response", + "validation_error_response", + "conflict_response", + "internal_error_response", + "PaginationParams", + "paginate_items", + "build_paginated_response", + "RateLimitHeaders", + "build_cors_headers", + "build_standard_headers", + "validate_sort_field", + "validate_sort_order", + "build_sort_params", + "filter_fields", + "exclude_fields", + "sanitize_response", + "merge_responses", + "get_client_ip", + "get_user_agent", + "build_request_metadata", + # Events + "Event", + "EventPriority", + "EventBus", + "AsyncEventBus", + "event_handler", + "publish_event", + "get_global_event_bus", + "set_global_event_bus", + "EventFilter", + "EventAggregator", + "EventRouter", + # Queue + "Job", + "JobStatus", + "JobPriority", + "TaskQueue", + "JobScheduler", + "BackgroundTaskManager", + "WorkerPool", + "debounce", + "throttle", + # State + "StateTransition", + "StateTransitionError", + "StatePersistenceError", + "StateMachine", + "ConfigurableStateMachine", + "StatePersistence", + "AsyncStateMachine", + "StateMonitor", + "StateValidator", + "StateSnapshot", + # Testing + "MockFactory", + "TestDataGenerator", + "TestHelpers", + "MockResponse", + "MockDatabase", + "MockCache", + "mock_async_call", + "create_mock_config", + "create_test_scenario", ] diff --git a/aitbc/api_utils.py b/aitbc/api_utils.py new file mode 100644 index 00000000..0427610f --- /dev/null +++ b/aitbc/api_utils.py @@ -0,0 +1,322 @@ +""" +API utilities for AITBC +Provides standard response formatters, pagination helpers, error response builders, and rate limit headers helpers +""" + +from typing import Any, Optional, List, Dict, Union +from datetime import datetime +from fastapi import HTTPException, status +from pydantic import BaseModel + + +class APIResponse(BaseModel): + """Standard API response model""" + success: bool + message: str + data: Optional[Any] = None + error: Optional[str] = None + timestamp: str = None + + def __init__(self, **data): + if 'timestamp' not in data: + data['timestamp'] = datetime.utcnow().isoformat() + super().__init__(**data) + + +class PaginatedResponse(BaseModel): + """Paginated API response model""" + success: bool + message: str + data: List[Any] + pagination: Dict[str, Any] + timestamp: str = None + + def __init__(self, **data): + if 'timestamp' not in data: + data['timestamp'] = datetime.utcnow().isoformat() + super().__init__(**data) + + +def success_response(message: str = "Success", data: Optional[Any] = None) -> APIResponse: + """Create a success response""" + return APIResponse(success=True, message=message, data=data) + + +def error_response(message: str, error: Optional[str] = None, status_code: int = 400) -> HTTPException: + """Create an error response""" + return HTTPException( + status_code=status_code, + detail={"success": False, "message": message, "error": error} + ) + + +def not_found_response(resource: str = "Resource") -> HTTPException: + """Create a not found response""" + return error_response( + message=f"{resource} not found", + error="NOT_FOUND", + status_code=404 + ) + + +def unauthorized_response(message: str = "Unauthorized") -> HTTPException: + """Create an unauthorized response""" + return error_response( + message=message, + error="UNAUTHORIZED", + status_code=401 + ) + + +def forbidden_response(message: str = "Forbidden") -> HTTPException: + """Create a forbidden response""" + return error_response( + message=message, + error="FORBIDDEN", + status_code=403 + ) + + +def validation_error_response(errors: List[str]) -> HTTPException: + """Create a validation error response""" + return error_response( + message="Validation failed", + error="VALIDATION_ERROR", + status_code=422 + ) + + +def conflict_response(message: str = "Resource conflict") -> HTTPException: + """Create a conflict response""" + return error_response( + message=message, + error="CONFLICT", + status_code=409 + ) + + +def internal_error_response(message: str = "Internal server error") -> HTTPException: + """Create an internal server error response""" + return error_response( + message=message, + error="INTERNAL_ERROR", + status_code=500 + ) + + +class PaginationParams: + """Pagination parameters""" + + def __init__(self, page: int = 1, page_size: int = 10, max_page_size: int = 100): + """Initialize pagination parameters""" + self.page = max(1, page) + self.page_size = min(max_page_size, max(1, page_size)) + self.offset = (self.page - 1) * self.page_size + + def get_limit(self) -> int: + """Get SQL limit""" + return self.page_size + + def get_offset(self) -> int: + """Get SQL offset""" + return self.offset + + +def paginate_items(items: List[Any], page: int = 1, page_size: int = 10) -> Dict[str, Any]: + """Paginate a list of items""" + total = len(items) + params = PaginationParams(page, page_size) + + paginated_items = items[params.offset:params.offset + params.page_size] + total_pages = (total + params.page_size - 1) // params.page_size + + return { + "items": paginated_items, + "pagination": { + "page": params.page, + "page_size": params.page_size, + "total": total, + "total_pages": total_pages, + "has_next": params.page < total_pages, + "has_prev": params.page > 1 + } + } + + +def build_paginated_response( + items: List[Any], + page: int = 1, + page_size: int = 10, + message: str = "Success" +) -> PaginatedResponse: + """Build a paginated API response""" + pagination_data = paginate_items(items, page, page_size) + + return PaginatedResponse( + success=True, + message=message, + data=pagination_data["items"], + pagination=pagination_data["pagination"] + ) + + +class RateLimitHeaders: + """Rate limit headers helper""" + + @staticmethod + def get_headers( + limit: int, + remaining: int, + reset: int, + window: int + ) -> Dict[str, str]: + """Get rate limit headers""" + return { + "X-RateLimit-Limit": str(limit), + "X-RateLimit-Remaining": str(remaining), + "X-RateLimit-Reset": str(reset), + "X-RateLimit-Window": str(window) + } + + @staticmethod + def get_retry_after(retry_after: int) -> Dict[str, str]: + """Get retry after header""" + return {"Retry-After": str(retry_after)} + + +def build_cors_headers( + allowed_origins: List[str] = ["*"], + allowed_methods: List[str] = ["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allowed_headers: List[str] = ["*"], + max_age: int = 3600 +) -> Dict[str, str]: + """Build CORS headers""" + return { + "Access-Control-Allow-Origin": ", ".join(allowed_origins), + "Access-Control-Allow-Methods": ", ".join(allowed_methods), + "Access-Control-Allow-Headers": ", ".join(allowed_headers), + "Access-Control-Max-Age": str(max_age) + } + + +def build_standard_headers( + content_type: str = "application/json", + cache_control: Optional[str] = None, + x_request_id: Optional[str] = None +) -> Dict[str, str]: + """Build standard response headers""" + headers = { + "Content-Type": content_type, + } + + if cache_control: + headers["Cache-Control"] = cache_control + + if x_request_id: + headers["X-Request-ID"] = x_request_id + + return headers + + +def validate_sort_field(field: str, allowed_fields: List[str]) -> str: + """Validate and return sort field""" + if field not in allowed_fields: + raise ValueError(f"Invalid sort field: {field}. Allowed fields: {', '.join(allowed_fields)}") + return field + + +def validate_sort_order(order: str) -> str: + """Validate and return sort order""" + order = order.upper() + if order not in ["ASC", "DESC"]: + raise ValueError(f"Invalid sort order: {order}. Must be 'ASC' or 'DESC'") + return order + + +def build_sort_params( + sort_by: Optional[str] = None, + sort_order: str = "ASC", + allowed_fields: Optional[List[str]] = None +) -> Dict[str, Any]: + """Build sort parameters""" + if sort_by and allowed_fields: + sort_by = validate_sort_field(sort_by, allowed_fields) + sort_order = validate_sort_order(sort_order) + return {"sort_by": sort_by, "sort_order": sort_order} + return {} + + +def filter_fields(data: Dict[str, Any], fields: List[str]) -> Dict[str, Any]: + """Filter dictionary to only include specified fields""" + return {k: v for k, v in data.items() if k in fields} + + +def exclude_fields(data: Dict[str, Any], fields: List[str]) -> Dict[str, Any]: + """Exclude specified fields from dictionary""" + return {k: v for k, v in data.items() if k not in fields} + + +def sanitize_response(data: Any, sensitive_fields: List[str] = None) -> Any: + """Sanitize response by masking sensitive fields""" + if sensitive_fields is None: + sensitive_fields = ["password", "token", "api_key", "secret", "private_key"] + + if isinstance(data, dict): + return { + k: "***" if any(sensitive in k.lower() for sensitive in sensitive_fields) else sanitize_response(v, sensitive_fields) + for k, v in data.items() + } + elif isinstance(data, list): + return [sanitize_response(item, sensitive_fields) for item in data] + else: + return data + + +def merge_responses(*responses: Union[APIResponse, Dict[str, Any]]) -> Dict[str, Any]: + """Merge multiple responses into one""" + merged = {"data": {}} + + for response in responses: + if isinstance(response, APIResponse): + if response.data: + if isinstance(response.data, dict): + merged["data"].update(response.data) + else: + merged["data"] = response.data + elif isinstance(response, dict): + if "data" in response: + if isinstance(response["data"], dict): + merged["data"].update(response["data"]) + else: + merged["data"] = response["data"] + + return merged + + +def get_client_ip(request) -> str: + """Get client IP address from request""" + # Check for forwarded headers first + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + + real_ip = request.headers.get("X-Real-IP") + if real_ip: + return real_ip + + return request.client.host if request.client else "unknown" + + +def get_user_agent(request) -> str: + """Get user agent from request""" + return request.headers.get("User-Agent", "unknown") + + +def build_request_metadata(request) -> Dict[str, str]: + """Build request metadata""" + return { + "client_ip": get_client_ip(request), + "user_agent": get_user_agent(request), + "request_id": request.headers.get("X-Request-ID", "unknown"), + "timestamp": datetime.utcnow().isoformat() + } diff --git a/aitbc/crypto.py b/aitbc/crypto.py new file mode 100644 index 00000000..5385b91b --- /dev/null +++ b/aitbc/crypto.py @@ -0,0 +1,174 @@ +""" +Cryptographic utilities for AITBC +Provides Ethereum-specific cryptographic operations and security functions +""" + +from typing import Any, Optional +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +import base64 +import os +import hashlib + + +def derive_ethereum_address(private_key: str) -> str: + """Derive Ethereum address from private key using eth-account""" + try: + from eth_account import Account + # Remove 0x prefix if present + if private_key.startswith("0x"): + private_key = private_key[2:] + + account = Account.from_key(private_key) + return account.address + except ImportError: + raise ImportError("eth-account is required for Ethereum address derivation. Install with: pip install eth-account") + except Exception as e: + raise ValueError(f"Failed to derive address from private key: {e}") + + +def sign_transaction_hash(transaction_hash: str, private_key: str) -> str: + """Sign a transaction hash with private key using eth-account""" + try: + from eth_account import Account + # Remove 0x prefix if present + if private_key.startswith("0x"): + private_key = private_key[2:] + if transaction_hash.startswith("0x"): + transaction_hash = transaction_hash[2:] + + account = Account.from_key(private_key) + signed_message = account.sign_hash(bytes.fromhex(transaction_hash)) + return signed_message.signature.hex() + except ImportError: + raise ImportError("eth-account is required for signing. Install with: pip install eth-account") + except Exception as e: + raise ValueError(f"Failed to sign transaction hash: {e}") + + +def verify_signature(message_hash: str, signature: str, address: str) -> bool: + """Verify a signature using eth-account""" + try: + from eth_account import Account + from eth_utils import to_bytes + + # Remove 0x prefixes if present + if message_hash.startswith("0x"): + message_hash = message_hash[2:] + if signature.startswith("0x"): + signature = signature[2:] + if address.startswith("0x"): + address = address[2:] + + message_bytes = to_bytes(hexstr=message_hash) + signature_bytes = to_bytes(hexstr=signature) + + recovered_address = Account.recover_message(message_bytes, signature_bytes) + return recovered_address.lower() == address.lower() + except ImportError: + raise ImportError("eth-account and eth-utils are required for signature verification. Install with: pip install eth-account eth-utils") + except Exception as e: + raise ValueError(f"Failed to verify signature: {e}") + + +def encrypt_private_key(private_key: str, password: str) -> str: + """Encrypt private key using Fernet symmetric encryption""" + try: + # Derive key from password + password_bytes = password.encode('utf-8') + salt = os.urandom(16) + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + key = base64.urlsafe_b64encode(kdf.derive(password_bytes)) + + # Encrypt private key + fernet = Fernet(key) + encrypted_key = fernet.encrypt(private_key.encode('utf-8')) + + # Combine salt and encrypted key + combined = salt + encrypted_key + return base64.urlsafe_b64encode(combined).decode('utf-8') + except Exception as e: + raise ValueError(f"Failed to encrypt private key: {e}") + + +def decrypt_private_key(encrypted_key: str, password: str) -> str: + """Decrypt private key using Fernet symmetric encryption""" + try: + # Decode combined data + combined = base64.urlsafe_b64decode(encrypted_key.encode('utf-8')) + salt = combined[:16] + encrypted_data = combined[16:] + + # Derive key from password + password_bytes = password.encode('utf-8') + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + key = base64.urlsafe_b64encode(kdf.derive(password_bytes)) + + # Decrypt private key + fernet = Fernet(key) + decrypted_key = fernet.decrypt(encrypted_data) + return decrypted_key.decode('utf-8') + except Exception as e: + raise ValueError(f"Failed to decrypt private key: {e}") + + +def generate_secure_random_bytes(length: int = 32) -> str: + """Generate cryptographically secure random bytes as hex string""" + return os.urandom(length).hex() + + +def keccak256_hash(data: str) -> str: + """Compute Keccak-256 hash of data""" + try: + from eth_hash.auto import keccak + if isinstance(data, str): + data = data.encode('utf-8') + return keccak(data).hex() + except ImportError: + raise ImportError("eth-hash is required for Keccak-256 hashing. Install with: pip install eth-hash") + except Exception as e: + raise ValueError(f"Failed to compute Keccak-256 hash: {e}") + + +def sha256_hash(data: str) -> str: + """Compute SHA-256 hash of data""" + try: + if isinstance(data, str): + data = data.encode('utf-8') + return hashlib.sha256(data).hexdigest() + except Exception as e: + raise ValueError(f"Failed to compute SHA-256 hash: {e}") + + +def validate_ethereum_address(address: str) -> bool: + """Validate Ethereum address format and checksum""" + try: + from eth_utils import is_address, is_checksum_address + return is_address(address) and is_checksum_address(address) + except ImportError: + raise ImportError("eth-utils is required for address validation. Install with: pip install eth-utils") + except Exception: + return False + + +def generate_ethereum_private_key() -> str: + """Generate a new Ethereum private key""" + try: + from eth_account import Account + account = Account.create() + return account.key.hex() + except ImportError: + raise ImportError("eth-account is required for private key generation. Install with: pip install eth-account") + except Exception as e: + raise ValueError(f"Failed to generate private key: {e}") diff --git a/aitbc/events.py b/aitbc/events.py new file mode 100644 index 00000000..f11026f5 --- /dev/null +++ b/aitbc/events.py @@ -0,0 +1,267 @@ +""" +Event utilities for AITBC +Provides event bus implementation, pub/sub patterns, and event decorators +""" + +import asyncio +from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +import inspect +import functools + + +T = TypeVar('T') + + +class EventPriority(Enum): + """Event priority levels""" + LOW = 1 + MEDIUM = 2 + HIGH = 3 + CRITICAL = 4 + + +@dataclass +class Event: + """Base event class""" + event_type: str + data: Dict[str, Any] + timestamp: datetime = None + priority: EventPriority = EventPriority.MEDIUM + source: Optional[str] = None + + def __post_init__(self): + if self.timestamp is None: + self.timestamp = datetime.utcnow() + + +class EventBus: + """Simple in-memory event bus for pub/sub patterns""" + + def __init__(self): + """Initialize event bus""" + self.subscribers: Dict[str, List[Callable]] = {} + self.event_history: List[Event] = [] + self.max_history = 1000 + + def subscribe(self, event_type: str, handler: Callable) -> None: + """Subscribe to an event type""" + if event_type not in self.subscribers: + self.subscribers[event_type] = [] + self.subscribers[event_type].append(handler) + + def unsubscribe(self, event_type: str, handler: Callable) -> bool: + """Unsubscribe from an event type""" + if event_type in self.subscribers: + try: + self.subscribers[event_type].remove(handler) + return True + except ValueError: + pass + return False + + async def publish(self, event: Event) -> None: + """Publish an event to all subscribers""" + # Add to history + self.event_history.append(event) + if len(self.event_history) > self.max_history: + self.event_history.pop(0) + + # Notify subscribers + handlers = self.subscribers.get(event.event_type, []) + + for handler in handlers: + try: + if inspect.iscoroutinefunction(handler): + await handler(event) + else: + handler(event) + except Exception as e: + print(f"Error in event handler: {e}") + + def publish_sync(self, event: Event) -> None: + """Publish an event synchronously""" + asyncio.run(self.publish(event)) + + def get_event_history(self, event_type: Optional[str] = None, limit: int = 100) -> List[Event]: + """Get event history""" + events = self.event_history + if event_type: + events = [e for e in events if e.event_type == event_type] + return events[-limit:] + + def clear_history(self) -> None: + """Clear event history""" + self.event_history.clear() + + +class AsyncEventBus(EventBus): + """Async event bus with additional features""" + + def __init__(self, max_concurrent_handlers: int = 10): + """Initialize async event bus""" + super().__init__() + self.semaphore = asyncio.Semaphore(max_concurrent_handlers) + + async def publish(self, event: Event) -> None: + """Publish event with concurrency control""" + self.event_history.append(event) + if len(self.event_history) > self.max_history: + self.event_history.pop(0) + + handlers = self.subscribers.get(event.event_type, []) + + tasks = [] + for handler in handlers: + async def safe_handler(): + async with self.semaphore: + try: + if inspect.iscoroutinefunction(handler): + await handler(event) + else: + handler(event) + except Exception as e: + print(f"Error in event handler: {e}") + + tasks.append(safe_handler()) + + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + +def event_handler(event_type: str, event_bus: Optional[EventBus] = None): + """Decorator to register event handler""" + def decorator(func: Callable) -> Callable: + # Use global event bus if none provided + bus = event_bus or get_global_event_bus() + bus.subscribe(event_type, func) + return func + return decorator + + +def publish_event(event_type: str, data: Dict[str, Any], event_bus: Optional[EventBus] = None) -> None: + """Helper to publish an event""" + bus = event_bus or get_global_event_bus() + event = Event(event_type=event_type, data=data) + bus.publish_sync(event) + + +# Global event bus instance +_global_event_bus: Optional[EventBus] = None + + +def get_global_event_bus() -> EventBus: + """Get or create global event bus""" + global _global_event_bus + if _global_event_bus is None: + _global_event_bus = EventBus() + return _global_event_bus + + +def set_global_event_bus(bus: EventBus) -> None: + """Set global event bus""" + global _global_event_bus + _global_event_bus = bus + + +class EventFilter: + """Filter events based on criteria""" + + def __init__(self, event_bus: Optional[EventBus] = None): + """Initialize event filter""" + self.event_bus = event_bus or get_global_event_bus() + self.filters: List[Callable[[Event], bool]] = [] + + def add_filter(self, filter_func: Callable[[Event], bool]) -> None: + """Add a filter function""" + self.filters.append(filter_func) + + def matches(self, event: Event) -> bool: + """Check if event matches all filters""" + return all(f(event) for f in self.filters) + + def get_filtered_events(self, event_type: Optional[str] = None, limit: int = 100) -> List[Event]: + """Get filtered events""" + events = self.event_bus.get_event_history(event_type, limit) + return [e for e in events if self.matches(e)] + + +class EventAggregator: + """Aggregate events over time windows""" + + def __init__(self, window_seconds: int = 60): + """Initialize event aggregator""" + self.window_seconds = window_seconds + self.aggregated_events: Dict[str, Dict[str, Any]] = {} + + def add_event(self, event: Event) -> None: + """Add event to aggregation""" + key = event.event_type + now = datetime.utcnow() + + if key not in self.aggregated_events: + self.aggregated_events[key] = { + "count": 0, + "first_seen": now, + "last_seen": now, + "data": {} + } + + agg = self.aggregated_events[key] + agg["count"] += 1 + agg["last_seen"] = now + + # Merge data + for k, v in event.data.items(): + if k not in agg["data"]: + agg["data"][k] = v + elif isinstance(v, (int, float)): + agg["data"][k] = agg["data"].get(k, 0) + v + + def get_aggregated_events(self) -> Dict[str, Dict[str, Any]]: + """Get aggregated events""" + # Remove old events + now = datetime.utcnow() + cutoff = now.timestamp() - self.window_seconds + + to_remove = [] + for key, agg in self.aggregated_events.items(): + if agg["last_seen"].timestamp() < cutoff: + to_remove.append(key) + + for key in to_remove: + del self.aggregated_events[key] + + return self.aggregated_events + + def clear(self) -> None: + """Clear all aggregated events""" + self.aggregated_events.clear() + + +class EventRouter: + """Route events to different handlers based on criteria""" + + def __init__(self): + """Initialize event router""" + self.routes: List[Callable[[Event], Optional[Callable]]] = [] + + def add_route(self, condition: Callable[[Event], bool], handler: Callable) -> None: + """Add a route""" + self.routes.append((condition, handler)) + + async def route(self, event: Event) -> bool: + """Route event to matching handler""" + for condition, handler in self.routes: + if condition(event): + try: + if inspect.iscoroutinefunction(handler): + await handler(event) + else: + handler(event) + return True + except Exception as e: + print(f"Error in routed handler: {e}") + return False diff --git a/aitbc/exceptions.py b/aitbc/exceptions.py index 4687b4a7..42e769b2 100644 --- a/aitbc/exceptions.py +++ b/aitbc/exceptions.py @@ -42,3 +42,18 @@ class ValidationError(AITBCError): class BridgeError(AITBCError): """Base exception for bridge errors""" pass + + +class RetryError(AITBCError): + """Raised when retry attempts are exhausted""" + pass + + +class CircuitBreakerOpenError(AITBCError): + """Raised when circuit breaker is open and requests are rejected""" + pass + + +class RateLimitError(AITBCError): + """Raised when rate limit is exceeded""" + pass diff --git a/aitbc/http_client.py b/aitbc/http_client.py index 0e9b521c..32cc6f30 100644 --- a/aitbc/http_client.py +++ b/aitbc/http_client.py @@ -4,8 +4,13 @@ Base HTTP client with common utilities for AITBC applications """ import requests +import time +import asyncio from typing import Dict, Any, Optional, Union -from .exceptions import NetworkError +from datetime import datetime, timedelta +from functools import lru_cache +from .exceptions import NetworkError, RetryError, CircuitBreakerOpenError, RateLimitError +from .aitbc_logging import get_logger class AITBCHTTPClient: @@ -18,7 +23,13 @@ class AITBCHTTPClient: self, base_url: str = "", timeout: int = 30, - headers: Optional[Dict[str, str]] = None + headers: Optional[Dict[str, str]] = None, + max_retries: int = 3, + enable_cache: bool = False, + cache_ttl: int = 300, + enable_logging: bool = False, + circuit_breaker_threshold: int = 5, + rate_limit: Optional[int] = None ): """ Initialize HTTP client. @@ -27,12 +38,37 @@ class AITBCHTTPClient: base_url: Base URL for all requests timeout: Request timeout in seconds headers: Default headers for all requests + max_retries: Maximum retry attempts with exponential backoff + enable_cache: Enable request/response caching for GET requests + cache_ttl: Cache time-to-live in seconds + enable_logging: Enable request/response logging + circuit_breaker_threshold: Failures before opening circuit breaker + rate_limit: Rate limit in requests per minute """ self.base_url = base_url.rstrip("/") self.timeout = timeout self.headers = headers or {} + self.max_retries = max_retries + self.enable_cache = enable_cache + self.cache_ttl = cache_ttl + self.enable_logging = enable_logging + self.circuit_breaker_threshold = circuit_breaker_threshold + self.rate_limit = rate_limit + self.session = requests.Session() self.session.headers.update(self.headers) + self.logger = get_logger(__name__) + + # Cache storage: {url: (data, timestamp)} + self._cache: Dict[str, tuple] = {} + + # Circuit breaker state + self._failure_count = 0 + self._circuit_open = False + self._circuit_open_time = None + + # Rate limiting state + self._request_times: list = [] def _build_url(self, endpoint: str) -> str: """ @@ -48,6 +84,98 @@ class AITBCHTTPClient: return endpoint return f"{self.base_url}/{endpoint.lstrip('/')}" + def _check_circuit_breaker(self) -> None: + """Check if circuit breaker is open and raise exception if so.""" + if self._circuit_open: + # Check if circuit should be reset (after 60 seconds) + if self._circuit_open_time and (datetime.now() - self._circuit_open_time).total_seconds() > 60: + self._circuit_open = False + self._failure_count = 0 + self.logger.info("Circuit breaker reset to half-open state") + else: + raise CircuitBreakerOpenError("Circuit breaker is open, rejecting request") + + def _record_failure(self) -> None: + """Record a failure and potentially open circuit breaker.""" + self._failure_count += 1 + if self._failure_count >= self.circuit_breaker_threshold: + self._circuit_open = True + self._circuit_open_time = datetime.now() + self.logger.warning(f"Circuit breaker opened after {self._failure_count} failures") + + def _check_rate_limit(self) -> None: + """Check if rate limit is exceeded and raise exception if so.""" + if not self.rate_limit: + return + + now = datetime.now() + # Remove requests older than 1 minute + self._request_times = [t for t in self._request_times if (now - t).total_seconds() < 60] + + if len(self._request_times) >= self.rate_limit: + raise RateLimitError(f"Rate limit exceeded: {self.rate_limit} requests per minute") + + def _record_request(self) -> None: + """Record a request timestamp for rate limiting.""" + if self.rate_limit: + self._request_times.append(datetime.now()) + + def _get_cache_key(self, url: str, params: Optional[Dict[str, Any]] = None) -> str: + """Generate cache key from URL and params.""" + if params: + import hashlib + param_str = str(sorted(params.items())) + return f"{url}:{hashlib.md5(param_str.encode()).hexdigest()}" + return url + + def _get_cache(self, cache_key: str) -> Optional[Dict[str, Any]]: + """Get cached response if available and not expired.""" + if not self.enable_cache: + return None + + if cache_key in self._cache: + data, timestamp = self._cache[cache_key] + if (datetime.now() - timestamp).total_seconds() < self.cache_ttl: + if self.enable_logging: + self.logger.info(f"Cache hit for {cache_key}") + return data + else: + # Expired, remove from cache + del self._cache[cache_key] + return None + + def _set_cache(self, cache_key: str, data: Dict[str, Any]) -> None: + """Cache response data.""" + if self.enable_cache: + self._cache[cache_key] = (data, datetime.now()) + if self.enable_logging: + self.logger.info(f"Cached response for {cache_key}") + + def _retry_request(self, request_func, *args, **kwargs) -> Dict[str, Any]: + """Execute request with retry logic and exponential backoff.""" + last_error = None + for attempt in range(self.max_retries + 1): + try: + if attempt > 0: + backoff_time = 2 ** (attempt - 1) + if self.enable_logging: + self.logger.info(f"Retry attempt {attempt}/{self.max_retries} after {backoff_time}s backoff") + time.sleep(backoff_time) + + return request_func(*args, **kwargs) + except requests.RequestException as e: + last_error = e + if attempt < self.max_retries: + if self.enable_logging: + self.logger.warning(f"Request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}") + continue + else: + if self.enable_logging: + self.logger.error(f"All retry attempts exhausted: {e}") + raise RetryError(f"Retry attempts exhausted: {e}") + + raise NetworkError(f"Request failed: {last_error}") + def get( self, endpoint: str, @@ -67,11 +195,29 @@ class AITBCHTTPClient: Raises: NetworkError: If request fails + CircuitBreakerOpenError: If circuit breaker is open + RateLimitError: If rate limit is exceeded """ url = self._build_url(endpoint) + cache_key = self._get_cache_key(url, params) + + # Check cache first + cached_data = self._get_cache(cache_key) + if cached_data is not None: + return cached_data + + # Check circuit breaker and rate limit + self._check_circuit_breaker() + self._check_rate_limit() + req_headers = {**self.headers, **(headers or {})} - try: + if self.enable_logging: + self.logger.info(f"GET {url} with params={params}") + + start_time = datetime.now() + + def _make_request(): response = self.session.get( url, params=params, @@ -80,7 +226,26 @@ class AITBCHTTPClient: ) response.raise_for_status() return response.json() + + try: + result = self._retry_request(_make_request) + + # Cache successful GET requests + self._set_cache(cache_key, result) + + # Record success for circuit breaker + self._failure_count = 0 + self._record_request() + + if self.enable_logging: + elapsed = (datetime.now() - start_time).total_seconds() + self.logger.info(f"GET {url} succeeded in {elapsed:.3f}s") + + return result + except (RetryError, CircuitBreakerOpenError, RateLimitError): + raise except requests.RequestException as e: + self._record_failure() raise NetworkError(f"GET request failed: {e}") def post( @@ -104,11 +269,23 @@ class AITBCHTTPClient: Raises: NetworkError: If request fails + CircuitBreakerOpenError: If circuit breaker is open + RateLimitError: If rate limit is exceeded """ url = self._build_url(endpoint) + + # Check circuit breaker and rate limit + self._check_circuit_breaker() + self._check_rate_limit() + req_headers = {**self.headers, **(headers or {})} - try: + if self.enable_logging: + self.logger.info(f"POST {url} with json={json}") + + start_time = datetime.now() + + def _make_request(): response = self.session.post( url, data=data, @@ -118,7 +295,23 @@ class AITBCHTTPClient: ) response.raise_for_status() return response.json() + + try: + result = self._retry_request(_make_request) + + # Record success for circuit breaker + self._failure_count = 0 + self._record_request() + + if self.enable_logging: + elapsed = (datetime.now() - start_time).total_seconds() + self.logger.info(f"POST {url} succeeded in {elapsed:.3f}s") + + return result + except (RetryError, CircuitBreakerOpenError, RateLimitError): + raise except requests.RequestException as e: + self._record_failure() raise NetworkError(f"POST request failed: {e}") def put( @@ -142,11 +335,23 @@ class AITBCHTTPClient: Raises: NetworkError: If request fails + CircuitBreakerOpenError: If circuit breaker is open + RateLimitError: If rate limit is exceeded """ url = self._build_url(endpoint) + + # Check circuit breaker and rate limit + self._check_circuit_breaker() + self._check_rate_limit() + req_headers = {**self.headers, **(headers or {})} - try: + if self.enable_logging: + self.logger.info(f"PUT {url} with json={json}") + + start_time = datetime.now() + + def _make_request(): response = self.session.put( url, data=data, @@ -156,7 +361,23 @@ class AITBCHTTPClient: ) response.raise_for_status() return response.json() + + try: + result = self._retry_request(_make_request) + + # Record success for circuit breaker + self._failure_count = 0 + self._record_request() + + if self.enable_logging: + elapsed = (datetime.now() - start_time).total_seconds() + self.logger.info(f"PUT {url} succeeded in {elapsed:.3f}s") + + return result + except (RetryError, CircuitBreakerOpenError, RateLimitError): + raise except requests.RequestException as e: + self._record_failure() raise NetworkError(f"PUT request failed: {e}") def delete( @@ -178,11 +399,23 @@ class AITBCHTTPClient: Raises: NetworkError: If request fails + CircuitBreakerOpenError: If circuit breaker is open + RateLimitError: If rate limit is exceeded """ url = self._build_url(endpoint) + + # Check circuit breaker and rate limit + self._check_circuit_breaker() + self._check_rate_limit() + req_headers = {**self.headers, **(headers or {})} - try: + if self.enable_logging: + self.logger.info(f"DELETE {url} with params={params}") + + start_time = datetime.now() + + def _make_request(): response = self.session.delete( url, params=params, @@ -191,7 +424,23 @@ class AITBCHTTPClient: ) response.raise_for_status() return response.json() if response.content else {} + + try: + result = self._retry_request(_make_request) + + # Record success for circuit breaker + self._failure_count = 0 + self._record_request() + + if self.enable_logging: + elapsed = (datetime.now() - start_time).total_seconds() + self.logger.info(f"DELETE {url} succeeded in {elapsed:.3f}s") + + return result + except (RetryError, CircuitBreakerOpenError, RateLimitError): + raise except requests.RequestException as e: + self._record_failure() raise NetworkError(f"DELETE request failed: {e}") def close(self) -> None: @@ -205,3 +454,279 @@ class AITBCHTTPClient: def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit.""" self.close() + + +class AsyncAITBCHTTPClient: + """ + Async HTTP client for AITBC applications. + Provides async HTTP methods with error handling. + """ + + def __init__( + self, + base_url: str = "", + timeout: int = 30, + headers: Optional[Dict[str, str]] = None, + max_retries: int = 3, + enable_cache: bool = False, + cache_ttl: int = 300, + enable_logging: bool = False, + circuit_breaker_threshold: int = 5, + rate_limit: Optional[int] = None + ): + """ + Initialize async HTTP client. + + Args: + base_url: Base URL for all requests + timeout: Request timeout in seconds + headers: Default headers for all requests + max_retries: Maximum retry attempts with exponential backoff + enable_cache: Enable request/response caching for GET requests + cache_ttl: Cache time-to-live in seconds + enable_logging: Enable request/response logging + circuit_breaker_threshold: Failures before opening circuit breaker + rate_limit: Rate limit in requests per minute + """ + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self.headers = headers or {} + self.max_retries = max_retries + self.enable_cache = enable_cache + self.cache_ttl = cache_ttl + self.enable_logging = enable_logging + self.circuit_breaker_threshold = circuit_breaker_threshold + self.rate_limit = rate_limit + + self.logger = get_logger(__name__) + self._client = None + + # Cache storage: {url: (data, timestamp)} + self._cache: Dict[str, tuple] = {} + + # Circuit breaker state + self._failure_count = 0 + self._circuit_open = False + self._circuit_open_time = None + + # Rate limiting state + self._request_times: list = [] + + async def __aenter__(self): + """Async context manager entry.""" + import httpx + self._client = httpx.AsyncClient(timeout=self.timeout, headers=self.headers) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._client: + await self._client.aclose() + + def _build_url(self, endpoint: str) -> str: + """Build full URL from base URL and endpoint.""" + if endpoint.startswith("http://") or endpoint.startswith("https://"): + return endpoint + return f"{self.base_url}/{endpoint.lstrip('/')}" + + def _check_circuit_breaker(self) -> None: + """Check if circuit breaker is open and raise exception if so.""" + if self._circuit_open: + if self._circuit_open_time and (datetime.now() - self._circuit_open_time).total_seconds() > 60: + self._circuit_open = False + self._failure_count = 0 + self.logger.info("Circuit breaker reset to half-open state") + else: + raise CircuitBreakerOpenError("Circuit breaker is open, rejecting request") + + def _record_failure(self) -> None: + """Record a failure and potentially open circuit breaker.""" + self._failure_count += 1 + if self._failure_count >= self.circuit_breaker_threshold: + self._circuit_open = True + self._circuit_open_time = datetime.now() + self.logger.warning(f"Circuit breaker opened after {self._failure_count} failures") + + def _check_rate_limit(self) -> None: + """Check if rate limit is exceeded and raise exception if so.""" + if not self.rate_limit: + return + + now = datetime.now() + self._request_times = [t for t in self._request_times if (now - t).total_seconds() < 60] + + if len(self._request_times) >= self.rate_limit: + raise RateLimitError(f"Rate limit exceeded: {self.rate_limit} requests per minute") + + def _record_request(self) -> None: + """Record a request timestamp for rate limiting.""" + if self.rate_limit: + self._request_times.append(datetime.now()) + + def _get_cache_key(self, url: str, params: Optional[Dict[str, Any]] = None) -> str: + """Generate cache key from URL and params.""" + if params: + import hashlib + param_str = str(sorted(params.items())) + return f"{url}:{hashlib.md5(param_str.encode()).hexdigest()}" + return url + + def _get_cache(self, cache_key: str) -> Optional[Dict[str, Any]]: + """Get cached response if available and not expired.""" + if not self.enable_cache: + return None + + if cache_key in self._cache: + data, timestamp = self._cache[cache_key] + if (datetime.now() - timestamp).total_seconds() < self.cache_ttl: + if self.enable_logging: + self.logger.info(f"Cache hit for {cache_key}") + return data + else: + del self._cache[cache_key] + return None + + def _set_cache(self, cache_key: str, data: Dict[str, Any]) -> None: + """Cache response data.""" + if self.enable_cache: + self._cache[cache_key] = (data, datetime.now()) + if self.enable_logging: + self.logger.info(f"Cached response for {cache_key}") + + async def _retry_request(self, request_func, *args, **kwargs) -> Dict[str, Any]: + """Execute async request with retry logic and exponential backoff.""" + last_error = None + for attempt in range(self.max_retries + 1): + try: + if attempt > 0: + backoff_time = 2 ** (attempt - 1) + if self.enable_logging: + self.logger.info(f"Retry attempt {attempt}/{self.max_retries} after {backoff_time}s backoff") + await asyncio.sleep(backoff_time) + + return await request_func(*args, **kwargs) + except Exception as e: + last_error = e + if attempt < self.max_retries: + if self.enable_logging: + self.logger.warning(f"Request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}") + continue + else: + if self.enable_logging: + self.logger.error(f"All retry attempts exhausted: {e}") + raise RetryError(f"Retry attempts exhausted: {e}") + + raise NetworkError(f"Request failed: {last_error}") + + async def async_get( + self, + endpoint: str, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: + """ + Perform async GET request. + + Args: + endpoint: API endpoint + params: Query parameters + headers: Additional headers + + Returns: + Response data as dictionary + """ + if not self._client: + raise RuntimeError("Async client not initialized. Use async context manager.") + + url = self._build_url(endpoint) + cache_key = self._get_cache_key(url, params) + + cached_data = self._get_cache(cache_key) + if cached_data is not None: + return cached_data + + self._check_circuit_breaker() + self._check_rate_limit() + + req_headers = {**self.headers, **(headers or {})} + + if self.enable_logging: + self.logger.info(f"ASYNC GET {url} with params={params}") + + start_time = datetime.now() + + async def _make_request(): + response = await self._client.get(url, params=params, headers=req_headers) + response.raise_for_status() + return response.json() + + try: + result = await self._retry_request(_make_request) + self._set_cache(cache_key, result) + self._failure_count = 0 + self._record_request() + + if self.enable_logging: + elapsed = (datetime.now() - start_time).total_seconds() + self.logger.info(f"ASYNC GET {url} succeeded in {elapsed:.3f}s") + + return result + except (RetryError, CircuitBreakerOpenError, RateLimitError): + raise + except Exception as e: + self._record_failure() + raise NetworkError(f"ASYNC GET request failed: {e}") + + async def async_post( + self, + endpoint: str, + data: Optional[Dict[str, Any]] = None, + json: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: + """ + Perform async POST request. + + Args: + endpoint: API endpoint + data: Form data + json: JSON data + headers: Additional headers + + Returns: + Response data as dictionary + """ + if not self._client: + raise RuntimeError("Async client not initialized. Use async context manager.") + + url = self._build_url(endpoint) + self._check_circuit_breaker() + self._check_rate_limit() + + req_headers = {**self.headers, **(headers or {})} + + if self.enable_logging: + self.logger.info(f"ASYNC POST {url} with json={json}") + + start_time = datetime.now() + + async def _make_request(): + response = await self._client.post(url, data=data, json=json, headers=req_headers) + response.raise_for_status() + return response.json() + + try: + result = await self._retry_request(_make_request) + self._failure_count = 0 + self._record_request() + + if self.enable_logging: + elapsed = (datetime.now() - start_time).total_seconds() + self.logger.info(f"ASYNC POST {url} succeeded in {elapsed:.3f}s") + + return result + except (RetryError, CircuitBreakerOpenError, RateLimitError): + raise + except Exception as e: + self._record_failure() + raise NetworkError(f"ASYNC POST request failed: {e}") diff --git a/aitbc/queue.py b/aitbc/queue.py new file mode 100644 index 00000000..267e5acc --- /dev/null +++ b/aitbc/queue.py @@ -0,0 +1,431 @@ +""" +Queue utilities for AITBC +Provides task queue helpers, job scheduling, and background task management +""" + +import asyncio +import heapq +import time +from typing import Any, Callable, Dict, List, Optional, TypeVar +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from enum import Enum +import uuid + + +T = TypeVar('T') + + +class JobStatus(Enum): + """Job status enumeration""" + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class JobPriority(Enum): + """Job priority levels""" + LOW = 1 + MEDIUM = 2 + HIGH = 3 + CRITICAL = 4 + + +@dataclass(order=True) +class Job: + """Background job""" + priority: int + job_id: str = field(compare=False) + func: Callable = field(compare=False) + args: tuple = field(default_factory=tuple, compare=False) + kwargs: dict = field(default_factory=dict, compare=False) + status: JobStatus = field(default=JobStatus.PENDING, compare=False) + created_at: datetime = field(default_factory=datetime.utcnow, compare=False) + started_at: Optional[datetime] = field(default=None, compare=False) + completed_at: Optional[datetime] = field(default=None, compare=False) + result: Any = field(default=None, compare=False) + error: Optional[str] = field(default=None, compare=False) + retry_count: int = field(default=0, compare=False) + max_retries: int = field(default=3, compare=False) + + def __post_init__(self): + if self.job_id is None: + self.job_id = str(uuid.uuid4()) + + +class TaskQueue: + """Priority-based task queue""" + + def __init__(self): + """Initialize task queue""" + self.queue: List[Job] = [] + self.jobs: Dict[str, Job] = {} + self.lock = asyncio.Lock() + + async def enqueue( + self, + func: Callable, + args: tuple = (), + kwargs: dict = None, + priority: JobPriority = JobPriority.MEDIUM, + max_retries: int = 3 + ) -> str: + """Enqueue a task""" + if kwargs is None: + kwargs = {} + + job = Job( + priority=priority.value, + func=func, + args=args, + kwargs=kwargs, + max_retries=max_retries + ) + + async with self.lock: + heapq.heappush(self.queue, job) + self.jobs[job.job_id] = job + + return job.job_id + + async def dequeue(self) -> Optional[Job]: + """Dequeue a task""" + async with self.lock: + if not self.queue: + return None + + job = heapq.heappop(self.queue) + return job + + async def get_job(self, job_id: str) -> Optional[Job]: + """Get job by ID""" + return self.jobs.get(job_id) + + async def cancel_job(self, job_id: str) -> bool: + """Cancel a job""" + async with self.lock: + job = self.jobs.get(job_id) + if job and job.status == JobStatus.PENDING: + job.status = JobStatus.CANCELLED + # Remove from queue + self.queue = [j for j in self.queue if j.job_id != job_id] + heapq.heapify(self.queue) + return True + return False + + async def get_queue_size(self) -> int: + """Get queue size""" + return len(self.queue) + + async def get_jobs_by_status(self, status: JobStatus) -> List[Job]: + """Get jobs by status""" + return [job for job in self.jobs.values() if job.status == status] + + +class JobScheduler: + """Job scheduler for delayed and recurring tasks""" + + def __init__(self): + """Initialize job scheduler""" + self.scheduled_jobs: Dict[str, Dict[str, Any]] = {} + self.running = False + self.task: Optional[asyncio.Task] = None + + async def schedule( + self, + func: Callable, + delay: float = 0, + interval: Optional[float] = None, + job_id: Optional[str] = None, + args: tuple = (), + kwargs: dict = None + ) -> str: + """Schedule a job""" + if job_id is None: + job_id = str(uuid.uuid4()) + + if kwargs is None: + kwargs = {} + + run_at = time.time() + delay + + self.scheduled_jobs[job_id] = { + "func": func, + "args": args, + "kwargs": kwargs, + "run_at": run_at, + "interval": interval, + "job_id": job_id + } + + return job_id + + async def cancel_scheduled_job(self, job_id: str) -> bool: + """Cancel a scheduled job""" + if job_id in self.scheduled_jobs: + del self.scheduled_jobs[job_id] + return True + return False + + async def start(self) -> None: + """Start the scheduler""" + if self.running: + return + + self.running = True + self.task = asyncio.create_task(self._run_scheduler()) + + async def stop(self) -> None: + """Stop the scheduler""" + self.running = False + if self.task: + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + + async def _run_scheduler(self) -> None: + """Run the scheduler loop""" + while self.running: + now = time.time() + to_run = [] + + for job_id, job in list(self.scheduled_jobs.items()): + if job["run_at"] <= now: + to_run.append(job) + + for job in to_run: + try: + if asyncio.iscoroutinefunction(job["func"]): + await job["func"](*job["args"], **job["kwargs"]) + else: + job["func"](*job["args"], **job["kwargs"]) + + if job["interval"]: + job["run_at"] = now + job["interval"] + else: + del self.scheduled_jobs[job["job_id"]] + except Exception as e: + print(f"Error running scheduled job {job['job_id']}: {e}") + if not job["interval"]: + del self.scheduled_jobs[job["job_id"]] + + await asyncio.sleep(0.1) + + +class BackgroundTaskManager: + """Manage background tasks""" + + def __init__(self, max_concurrent_tasks: int = 10): + """Initialize background task manager""" + self.max_concurrent_tasks = max_concurrent_tasks + self.semaphore = asyncio.Semaphore(max_concurrent_tasks) + self.tasks: Dict[str, asyncio.Task] = {} + self.task_info: Dict[str, Dict[str, Any]] = {} + + async def run_task( + self, + func: Callable, + task_id: Optional[str] = None, + args: tuple = (), + kwargs: dict = None + ) -> str: + """Run a background task""" + if task_id is None: + task_id = str(uuid.uuid4()) + + if kwargs is None: + kwargs = {} + + async def wrapped_task(): + async with self.semaphore: + try: + self.task_info[task_id]["status"] = "running" + self.task_info[task_id]["started_at"] = datetime.utcnow() + + if asyncio.iscoroutinefunction(func): + result = await func(*args, **kwargs) + else: + result = func(*args, **kwargs) + + self.task_info[task_id]["status"] = "completed" + self.task_info[task_id]["result"] = result + self.task_info[task_id]["completed_at"] = datetime.utcnow() + except Exception as e: + self.task_info[task_id]["status"] = "failed" + self.task_info[task_id]["error"] = str(e) + self.task_info[task_id]["completed_at"] = datetime.utcnow() + finally: + if task_id in self.tasks: + del self.tasks[task_id] + + self.task_info[task_id] = { + "status": "pending", + "created_at": datetime.utcnow(), + "started_at": None, + "completed_at": None, + "result": None, + "error": None + } + + task = asyncio.create_task(wrapped_task()) + self.tasks[task_id] = task + + return task_id + + async def cancel_task(self, task_id: str) -> bool: + """Cancel a background task""" + if task_id in self.tasks: + self.tasks[task_id].cancel() + try: + await self.tasks[task_id] + except asyncio.CancelledError: + pass + + self.task_info[task_id]["status"] = "cancelled" + self.task_info[task_id]["completed_at"] = datetime.utcnow() + del self.tasks[task_id] + return True + return False + + async def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]: + """Get task status""" + return self.task_info.get(task_id) + + async def get_all_tasks(self) -> Dict[str, Dict[str, Any]]: + """Get all tasks""" + return self.task_info.copy() + + async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> Any: + """Wait for task completion""" + if task_id not in self.tasks: + raise ValueError(f"Task {task_id} not found") + + try: + await asyncio.wait_for(self.tasks[task_id], timeout) + except asyncio.TimeoutError: + await self.cancel_task(task_id) + raise TimeoutError(f"Task {task_id} timed out") + + info = self.task_info.get(task_id) + if info["status"] == "failed": + raise Exception(info["error"]) + + return info["result"] + + +class WorkerPool: + """Worker pool for parallel task execution""" + + def __init__(self, num_workers: int = 4): + """Initialize worker pool""" + self.num_workers = num_workers + self.queue: asyncio.Queue = asyncio.Queue() + self.workers: List[asyncio.Task] = [] + self.running = False + + async def start(self) -> None: + """Start worker pool""" + if self.running: + return + + self.running = True + for i in range(self.num_workers): + worker = asyncio.create_task(self._worker(i)) + self.workers.append(worker) + + async def stop(self) -> None: + """Stop worker pool""" + self.running = False + + # Cancel all workers + for worker in self.workers: + worker.cancel() + + # Wait for workers to finish + await asyncio.gather(*self.workers, return_exceptions=True) + self.workers.clear() + + async def submit(self, func: Callable, *args, **kwargs) -> Any: + """Submit task to worker pool""" + future = asyncio.Future() + await self.queue.put((func, args, kwargs, future)) + return await future + + async def _worker(self, worker_id: int) -> None: + """Worker coroutine""" + while self.running: + try: + func, args, kwargs, future = await self.queue.get() + + try: + if asyncio.iscoroutinefunction(func): + result = await func(*args, **kwargs) + else: + result = func(*args, **kwargs) + future.set_result(result) + except Exception as e: + future.set_exception(e) + finally: + self.queue.task_done() + except asyncio.CancelledError: + break + except Exception as e: + print(f"Worker {worker_id} error: {e}") + + async def get_queue_size(self) -> int: + """Get queue size""" + return self.queue.qsize() + + +def debounce(delay: float = 0.5): + """Decorator to debounce function calls""" + def decorator(func: Callable) -> Callable: + last_called = [0] + timer = [None] + + async def wrapped(*args, **kwargs): + async def call(): + await asyncio.sleep(delay) + if asyncio.get_event_loop().time() - last_called[0] >= delay: + if asyncio.iscoroutinefunction(func): + return await func(*args, **kwargs) + else: + return func(*args, **kwargs) + + last_called[0] = asyncio.get_event_loop().time() + if timer[0]: + timer[0].cancel() + + timer[0] = asyncio.create_task(call()) + return await timer[0] + + return wrapped + return decorator + + +def throttle(calls_per_second: float = 1.0): + """Decorator to throttle function calls""" + def decorator(func: Callable) -> Callable: + min_interval = 1.0 / calls_per_second + last_called = [0] + + async def wrapped(*args, **kwargs): + now = asyncio.get_event_loop().time() + elapsed = now - last_called[0] + + if elapsed < min_interval: + await asyncio.sleep(min_interval - elapsed) + + last_called[0] = asyncio.get_event_loop().time() + + if asyncio.iscoroutinefunction(func): + return await func(*args, **kwargs) + else: + return func(*args, **kwargs) + + return wrapped + return decorator diff --git a/aitbc/security.py b/aitbc/security.py new file mode 100644 index 00000000..29050e3c --- /dev/null +++ b/aitbc/security.py @@ -0,0 +1,282 @@ +""" +Security utilities for AITBC +Provides token generation, session management, API key management, and secret management +""" + +import os +import secrets +import hashlib +import time +import json +from typing import Optional, Dict, Any +from datetime import datetime, timedelta +from cryptography.fernet import Fernet + + +def generate_token(length: int = 32, prefix: str = "") -> str: + """Generate a secure random token""" + token = secrets.token_urlsafe(length) + return f"{prefix}{token}" if prefix else token + + +def generate_api_key(prefix: str = "aitbc") -> str: + """Generate a secure API key with prefix""" + random_part = secrets.token_urlsafe(32) + return f"{prefix}_{random_part}" + + +def validate_token_format(token: str, min_length: int = 16) -> bool: + """Validate token format""" + return bool(token) and len(token) >= min_length and all(c.isalnum() or c in '-_' for c in token) + + +def validate_api_key(api_key: str, prefix: str = "aitbc") -> bool: + """Validate API key format""" + if not api_key or not api_key.startswith(f"{prefix}_"): + return False + token_part = api_key[len(prefix)+1:] + return validate_token_format(token_part) + + +class SessionManager: + """Simple in-memory session manager""" + + def __init__(self, session_timeout: int = 3600): + """Initialize session manager with timeout in seconds""" + self.sessions: Dict[str, Dict[str, Any]] = {} + self.session_timeout = session_timeout + + def create_session(self, user_id: str, data: Optional[Dict[str, Any]] = None) -> str: + """Create a new session""" + session_id = generate_token() + self.sessions[session_id] = { + "user_id": user_id, + "data": data or {}, + "created_at": time.time(), + "expires_at": time.time() + self.session_timeout + } + return session_id + + def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: + """Get session data""" + session = self.sessions.get(session_id) + if not session: + return None + + # Check if session expired + if time.time() > session["expires_at"]: + del self.sessions[session_id] + return None + + return session + + def update_session(self, session_id: str, data: Dict[str, Any]) -> bool: + """Update session data""" + session = self.get_session(session_id) + if not session: + return False + + session["data"].update(data) + return True + + def delete_session(self, session_id: str) -> bool: + """Delete a session""" + if session_id in self.sessions: + del self.sessions[session_id] + return True + return False + + def cleanup_expired_sessions(self) -> int: + """Clean up expired sessions""" + current_time = time.time() + expired_keys = [ + key for key, session in self.sessions.items() + if current_time > session["expires_at"] + ] + + for key in expired_keys: + del self.sessions[key] + + return len(expired_keys) + + +class APIKeyManager: + """API key management with storage""" + + def __init__(self, storage_path: Optional[str] = None): + """Initialize API key manager""" + self.storage_path = storage_path + self.keys: Dict[str, Dict[str, Any]] = {} + + if storage_path: + self._load_keys() + + def create_api_key(self, user_id: str, scopes: Optional[list[str]] = None, name: Optional[str] = None) -> str: + """Create a new API key""" + api_key = generate_api_key() + self.keys[api_key] = { + "user_id": user_id, + "scopes": scopes or ["read"], + "name": name, + "created_at": datetime.utcnow().isoformat(), + "last_used": None + } + + if self.storage_path: + self._save_keys() + + return api_key + + def validate_api_key(self, api_key: str) -> Optional[Dict[str, Any]]: + """Validate API key and return key data""" + key_data = self.keys.get(api_key) + if not key_data: + return None + + # Update last used + key_data["last_used"] = datetime.utcnow().isoformat() + if self.storage_path: + self._save_keys() + + return key_data + + def revoke_api_key(self, api_key: str) -> bool: + """Revoke an API key""" + if api_key in self.keys: + del self.keys[api_key] + if self.storage_path: + self._save_keys() + return True + return False + + def list_user_keys(self, user_id: str) -> list[str]: + """List all API keys for a user""" + return [ + key for key, data in self.items() + if data["user_id"] == user_id + ] + + def _load_keys(self): + """Load keys from storage""" + if self.storage_path and os.path.exists(self.storage_path): + try: + with open(self.storage_path, 'r') as f: + self.keys = json.load(f) + except Exception: + self.keys = {} + + def _save_keys(self): + """Save keys to storage""" + if self.storage_path: + try: + with open(self.storage_path, 'w') as f: + json.dump(self.keys, f) + except Exception: + pass + + def items(self): + """Return key items""" + return self.keys.items() + + +def generate_secure_random_string(length: int = 32) -> str: + """Generate a cryptographically secure random string""" + return secrets.token_urlsafe(length) + + +def generate_secure_random_int(min_val: int = 0, max_val: int = 2**32) -> int: + """Generate a cryptographically secure random integer""" + return secrets.randbelow(max_val - min_val) + min_val + + +class SecretManager: + """Simple secret management with encryption""" + + def __init__(self, encryption_key: Optional[str] = None): + """Initialize secret manager""" + if encryption_key: + self.fernet = Fernet(encryption_key) + else: + # Generate a new key if none provided + self.fernet = Fernet(Fernet.generate_key()) + + self.secrets: Dict[str, str] = {} + + def set_secret(self, key: str, value: str) -> None: + """Store an encrypted secret""" + encrypted = self.fernet.encrypt(value.encode('utf-8')) + self.secrets[key] = encrypted.decode('utf-8') + + def get_secret(self, key: str) -> Optional[str]: + """Retrieve and decrypt a secret""" + encrypted = self.secrets.get(key) + if not encrypted: + return None + + try: + decrypted = self.fernet.decrypt(encrypted.encode('utf-8')) + return decrypted.decode('utf-8') + except Exception: + return None + + def delete_secret(self, key: str) -> bool: + """Delete a secret""" + if key in self.secrets: + del self.secrets[key] + return True + return False + + def list_secrets(self) -> list[str]: + """List all secret keys""" + return list(self.secrets.keys()) + + def get_encryption_key(self) -> str: + """Get the encryption key (for backup purposes)""" + return self.fernet._signing_key.decode('utf-8') + + +def hash_password(password: str, salt: Optional[str] = None) -> tuple[str, str]: + """Hash a password with salt""" + if salt is None: + salt = secrets.token_hex(16) + + # Use PBKDF2 for password hashing + from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + from cryptography.hazmat.primitives import hashes + import base64 + + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt.encode('utf-8'), + iterations=100000, + ) + hashed = kdf.derive(password.encode('utf-8')) + return base64.b64encode(hashed).decode('utf-8'), salt + + +def verify_password(password: str, hashed_password: str, salt: str) -> bool: + """Verify a password against a hash""" + new_hash, _ = hash_password(password, salt) + return new_hash == hashed_password + + +def generate_nonce(length: int = 16) -> str: + """Generate a nonce for cryptographic operations""" + return secrets.token_hex(length) + + +def generate_hmac(data: str, secret: str) -> str: + """Generate HMAC-SHA256 signature""" + import hmac + return hmac.new( + secret.encode('utf-8'), + data.encode('utf-8'), + hashlib.sha256 + ).hexdigest() + + +def verify_hmac(data: str, signature: str, secret: str) -> bool: + """Verify HMAC-SHA256 signature""" + computed = generate_hmac(data, secret) + return secrets.compare_digest(computed, signature) diff --git a/aitbc/state.py b/aitbc/state.py new file mode 100644 index 00000000..9d4b4c1a --- /dev/null +++ b/aitbc/state.py @@ -0,0 +1,348 @@ +""" +State management utilities for AITBC +Provides state machine base classes, state persistence, and state transition helpers +""" + +import json +import os +from typing import Any, Callable, Dict, Optional, TypeVar, Generic, List +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from abc import ABC, abstractmethod +import asyncio + + +T = TypeVar('T') + + +class StateTransitionError(Exception): + """Raised when invalid state transition is attempted""" + pass + + +class StatePersistenceError(Exception): + """Raised when state persistence fails""" + pass + + +@dataclass +class StateTransition: + """Record of a state transition""" + from_state: str + to_state: str + timestamp: datetime = field(default_factory=datetime.utcnow) + data: Dict[str, Any] = field(default_factory=dict) + + +class StateMachine(ABC): + """Base class for state machines""" + + def __init__(self, initial_state: str): + """Initialize state machine""" + self.current_state = initial_state + self.transitions: List[StateTransition] = [] + self.state_data: Dict[str, Dict[str, Any]] = {initial_state: {}} + + @abstractmethod + def get_valid_transitions(self, state: str) -> List[str]: + """Get valid transitions from a state""" + pass + + def can_transition(self, to_state: str) -> bool: + """Check if transition is valid""" + return to_state in self.get_valid_transitions(self.current_state) + + def transition(self, to_state: str, data: Optional[Dict[str, Any]] = None) -> None: + """Transition to a new state""" + if not self.can_transition(to_state): + raise StateTransitionError( + f"Invalid transition from {self.current_state} to {to_state}" + ) + + from_state = self.current_state + self.current_state = to_state + + # Record transition + transition = StateTransition( + from_state=from_state, + to_state=to_state, + data=data or {} + ) + self.transitions.append(transition) + + # Initialize state data if needed + if to_state not in self.state_data: + self.state_data[to_state] = {} + + def get_state_data(self, state: Optional[str] = None) -> Dict[str, Any]: + """Get data for a state""" + state = state or self.current_state + return self.state_data.get(state, {}).copy() + + def set_state_data(self, data: Dict[str, Any], state: Optional[str] = None) -> None: + """Set data for a state""" + state = state or self.current_state + if state not in self.state_data: + self.state_data[state] = {} + self.state_data[state].update(data) + + def get_transition_history(self, limit: Optional[int] = None) -> List[StateTransition]: + """Get transition history""" + if limit: + return self.transitions[-limit:] + return self.transitions.copy() + + def reset(self, initial_state: str) -> None: + """Reset state machine to initial state""" + self.current_state = initial_state + self.transitions.clear() + self.state_data = {initial_state: {}} + + +class ConfigurableStateMachine(StateMachine): + """State machine with configurable transitions""" + + def __init__(self, initial_state: str, transitions: Dict[str, List[str]]): + """Initialize configurable state machine""" + super().__init__(initial_state) + self.transitions_config = transitions + + def get_valid_transitions(self, state: str) -> List[str]: + """Get valid transitions from configuration""" + return self.transitions_config.get(state, []) + + def add_transition(self, from_state: str, to_state: str) -> None: + """Add a transition to configuration""" + if from_state not in self.transitions_config: + self.transitions_config[from_state] = [] + if to_state not in self.transitions_config[from_state]: + self.transitions_config[from_state].append(to_state) + + +class StatePersistence: + """State persistence to file""" + + def __init__(self, storage_path: str): + """Initialize state persistence""" + self.storage_path = storage_path + self._ensure_storage_dir() + + def _ensure_storage_dir(self) -> None: + """Ensure storage directory exists""" + os.makedirs(os.path.dirname(self.storage_path), exist_ok=True) + + def save_state(self, state_machine: StateMachine) -> None: + """Save state machine to file""" + try: + state_data = { + "current_state": state_machine.current_state, + "state_data": state_machine.state_data, + "transitions": [ + { + "from_state": t.from_state, + "to_state": t.to_state, + "timestamp": t.timestamp.isoformat(), + "data": t.data + } + for t in state_machine.transitions + ] + } + + with open(self.storage_path, 'w') as f: + json.dump(state_data, f, indent=2) + except Exception as e: + raise StatePersistenceError(f"Failed to save state: {e}") + + def load_state(self) -> Optional[Dict[str, Any]]: + """Load state from file""" + try: + if not os.path.exists(self.storage_path): + return None + + with open(self.storage_path, 'r') as f: + return json.load(f) + except Exception as e: + raise StatePersistenceError(f"Failed to load state: {e}") + + def delete_state(self) -> None: + """Delete persisted state""" + try: + if os.path.exists(self.storage_path): + os.remove(self.storage_path) + except Exception as e: + raise StatePersistenceError(f"Failed to delete state: {e}") + + +class AsyncStateMachine(StateMachine): + """Async state machine with async transition handlers""" + + def __init__(self, initial_state: str): + """Initialize async state machine""" + super().__init__(initial_state) + self.transition_handlers: Dict[str, Callable] = {} + + def on_transition(self, to_state: str, handler: Callable) -> None: + """Register a handler for transition to a state""" + self.transition_handlers[to_state] = handler + + async def transition_async(self, to_state: str, data: Optional[Dict[str, Any]] = None) -> None: + """Async transition to a new state""" + if not self.can_transition(to_state): + raise StateTransitionError( + f"Invalid transition from {self.current_state} to {to_state}" + ) + + from_state = self.current_state + self.current_state = to_state + + # Record transition + transition = StateTransition( + from_state=from_state, + to_state=to_state, + data=data or {} + ) + self.transitions.append(transition) + + # Initialize state data if needed + if to_state not in self.state_data: + self.state_data[to_state] = {} + + # Call transition handler if exists + if to_state in self.transition_handlers: + handler = self.transition_handlers[to_state] + if asyncio.iscoroutinefunction(handler): + await handler(transition) + else: + handler(transition) + + +class StateMonitor: + """Monitor state machine state and transitions""" + + def __init__(self, state_machine: StateMachine): + """Initialize state monitor""" + self.state_machine = state_machine + self.observers: List[Callable] = [] + + def add_observer(self, observer: Callable) -> None: + """Add an observer for state changes""" + self.observers.append(observer) + + def remove_observer(self, observer: Callable) -> bool: + """Remove an observer""" + try: + self.observers.remove(observer) + return True + except ValueError: + return False + + def notify_observers(self, transition: StateTransition) -> None: + """Notify all observers of state change""" + for observer in self.observers: + try: + observer(transition) + except Exception as e: + print(f"Error in state observer: {e}") + + def wrap_transition(self, original_transition: Callable) -> Callable: + """Wrap transition method to notify observers""" + def wrapper(*args, **kwargs): + result = original_transition(*args, **kwargs) + # Get last transition + if self.state_machine.transitions: + self.notify_observers(self.state_machine.transitions[-1]) + return result + return wrapper + + +class StateValidator: + """Validate state machine configurations""" + + @staticmethod + def validate_transitions(transitions: Dict[str, List[str]]) -> bool: + """Validate that all target states exist""" + all_states = set(transitions.keys()) + all_states.update(*transitions.values()) + + for from_state, to_states in transitions.items(): + for to_state in to_states: + if to_state not in all_states: + return False + + return True + + @staticmethod + def check_for_deadlocks(transitions: Dict[str, List[str]]) -> List[str]: + """Check for states with no outgoing transitions""" + deadlocks = [] + for state, to_states in transitions.items(): + if not to_states: + deadlocks.append(state) + return deadlocks + + @staticmethod + def check_for_orphans(transitions: Dict[str, List[str]]) -> List[str]: + """Check for states with no incoming transitions""" + incoming = set() + for to_states in transitions.values(): + incoming.update(to_states) + + orphans = [] + for state in transitions.keys(): + if state not in incoming: + orphans.append(state) + + return orphans + + +class StateSnapshot: + """Snapshot of state machine state""" + + def __init__(self, state_machine: StateMachine): + """Create snapshot""" + self.current_state = state_machine.current_state + self.state_data = state_machine.state_data.copy() + self.transitions = state_machine.transitions.copy() + self.timestamp = datetime.utcnow() + + def restore(self, state_machine: StateMachine) -> None: + """Restore state machine from snapshot""" + state_machine.current_state = self.current_state + state_machine.state_data = self.state_data.copy() + state_machine.transitions = self.transitions.copy() + + def to_dict(self) -> Dict[str, Any]: + """Convert snapshot to dict""" + return { + "current_state": self.current_state, + "state_data": self.state_data, + "transitions": [ + { + "from_state": t.from_state, + "to_state": t.to_state, + "timestamp": t.timestamp.isoformat(), + "data": t.data + } + for t in self.transitions + ], + "timestamp": self.timestamp.isoformat() + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'StateSnapshot': + """Create snapshot from dict""" + snapshot = cls.__new__(cls) + snapshot.current_state = data["current_state"] + snapshot.state_data = data["state_data"] + snapshot.transitions = [ + StateTransition( + from_state=t["from_state"], + to_state=t["to_state"], + timestamp=datetime.fromisoformat(t["timestamp"]), + data=t["data"] + ) + for t in data["transitions"] + ] + snapshot.timestamp = datetime.fromisoformat(data["timestamp"]) + return snapshot diff --git a/aitbc/testing.py b/aitbc/testing.py new file mode 100644 index 00000000..910ab6df --- /dev/null +++ b/aitbc/testing.py @@ -0,0 +1,401 @@ +""" +Testing utilities for AITBC +Provides mock factories, test data generators, and test helpers +""" + +import secrets +import json +from typing import Any, Dict, List, Optional, Type, TypeVar, Callable +from datetime import datetime, timedelta +from dataclasses import dataclass, field +from decimal import Decimal +import uuid + + +T = TypeVar('T') + + +class MockFactory: + """Factory for creating mock objects for testing""" + + @staticmethod + def generate_string(length: int = 10, prefix: str = "") -> str: + """Generate a random string""" + random_part = secrets.token_urlsafe(length)[:length] + return f"{prefix}{random_part}" + + @staticmethod + def generate_email() -> str: + """Generate a random email address""" + return f"{MockFactory.generate_string(8)}@example.com" + + @staticmethod + def generate_url() -> str: + """Generate a random URL""" + return f"https://example.com/{MockFactory.generate_string(8)}" + + @staticmethod + def generate_ip_address() -> str: + """Generate a random IP address""" + return f"192.168.{secrets.randbelow(256)}.{secrets.randbelow(256)}" + + @staticmethod + def generate_ethereum_address() -> str: + """Generate a random Ethereum address""" + return f"0x{''.join(secrets.choice('0123456789abcdef') for _ in range(40))}" + + @staticmethod + def generate_bitcoin_address() -> str: + """Generate a random Bitcoin-like address""" + return f"1{''.join(secrets.choice('123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz') for _ in range(33))}" + + @staticmethod + def generate_uuid() -> str: + """Generate a UUID""" + return str(uuid.uuid4()) + + @staticmethod + def generate_hash(length: int = 64) -> str: + """Generate a random hash string""" + return secrets.token_hex(length)[:length] + + +class TestDataGenerator: + """Generate test data for various use cases""" + + @staticmethod + def generate_user_data(**overrides) -> Dict[str, Any]: + """Generate mock user data""" + data = { + "id": MockFactory.generate_uuid(), + "email": MockFactory.generate_email(), + "username": MockFactory.generate_string(8), + "first_name": MockFactory.generate_string(6), + "last_name": MockFactory.generate_string(6), + "created_at": datetime.utcnow().isoformat(), + "updated_at": datetime.utcnow().isoformat(), + "is_active": True, + "role": "user" + } + data.update(overrides) + return data + + @staticmethod + def generate_transaction_data(**overrides) -> Dict[str, Any]: + """Generate mock transaction data""" + data = { + "id": MockFactory.generate_uuid(), + "from_address": MockFactory.generate_ethereum_address(), + "to_address": MockFactory.generate_ethereum_address(), + "amount": str(secrets.randbelow(1000000000000000000)), + "gas_price": str(secrets.randbelow(100000000000)), + "gas_limit": secrets.randbelow(100000), + "nonce": secrets.randbelow(1000), + "timestamp": datetime.utcnow().isoformat(), + "status": "pending" + } + data.update(overrides) + return data + + @staticmethod + def generate_block_data(**overrides) -> Dict[str, Any]: + """Generate mock block data""" + data = { + "number": secrets.randbelow(10000000), + "hash": MockFactory.generate_hash(), + "parent_hash": MockFactory.generate_hash(), + "timestamp": datetime.utcnow().isoformat(), + "transactions": [], + "gas_used": str(secrets.randbelow(10000000)), + "gas_limit": str(15000000), + "miner": MockFactory.generate_ethereum_address() + } + data.update(overrides) + return data + + @staticmethod + def generate_api_key_data(**overrides) -> Dict[str, Any]: + """Generate mock API key data""" + data = { + "id": MockFactory.generate_uuid(), + "api_key": f"aitbc_{secrets.token_urlsafe(32)}", + "user_id": MockFactory.generate_uuid(), + "name": MockFactory.generate_string(10), + "scopes": ["read", "write"], + "created_at": datetime.utcnow().isoformat(), + "last_used": None, + "is_active": True + } + data.update(overrides) + return data + + @staticmethod + def generate_wallet_data(**overrides) -> Dict[str, Any]: + """Generate mock wallet data""" + data = { + "id": MockFactory.generate_uuid(), + "address": MockFactory.generate_ethereum_address(), + "chain_id": 1, + "balance": str(secrets.randbelow(1000000000000000000)), + "created_at": datetime.utcnow().isoformat(), + "is_active": True + } + data.update(overrides) + return data + + +class TestHelpers: + """Helper functions for testing""" + + @staticmethod + def assert_dict_contains(subset: Dict[str, Any], superset: Dict[str, Any]) -> bool: + """Check if superset contains all key-value pairs from subset""" + for key, value in subset.items(): + if key not in superset: + return False + if superset[key] != value: + return False + return True + + @staticmethod + def assert_lists_equal_unordered(list1: List[Any], list2: List[Any]) -> bool: + """Check if two lists contain the same elements regardless of order""" + return sorted(list1) == sorted(list2) + + @staticmethod + def compare_json_objects(obj1: Any, obj2: Any) -> bool: + """Compare two JSON-serializable objects""" + return json.dumps(obj1, sort_keys=True) == json.dumps(obj2, sort_keys=True) + + @staticmethod + def wait_for_condition( + condition: Callable[[], bool], + timeout: float = 10.0, + interval: float = 0.1 + ) -> bool: + """Wait for a condition to become true""" + import time + start = time.time() + while time.time() - start < timeout: + if condition(): + return True + time.sleep(interval) + return False + + @staticmethod + def measure_execution_time(func: Callable, *args, **kwargs) -> tuple[Any, float]: + """Measure execution time of a function""" + import time + start = time.time() + result = func(*args, **kwargs) + elapsed = time.time() - start + return result, elapsed + + @staticmethod + def generate_test_file_path(extension: str = ".tmp") -> str: + """Generate a unique test file path""" + return f"/tmp/test_{secrets.token_hex(8)}{extension}" + + @staticmethod + def cleanup_test_files(prefix: str = "test_") -> int: + """Clean up test files in /tmp""" + import os + import glob + count = 0 + for file_path in glob.glob(f"/tmp/{prefix}*"): + try: + os.remove(file_path) + count += 1 + except: + pass + return count + + +class MockResponse: + """Mock HTTP response for testing""" + + def __init__( + self, + status_code: int = 200, + json_data: Optional[Dict[str, Any]] = None, + text: Optional[str] = None, + headers: Optional[Dict[str, str]] = None + ): + """Initialize mock response""" + self.status_code = status_code + self._json_data = json_data + self._text = text + self.headers = headers or {} + + def json(self) -> Dict[str, Any]: + """Return JSON data""" + if self._json_data is None: + raise ValueError("No JSON data available") + return self._json_data + + def text(self) -> str: + """Return text data""" + if self._text is None: + return "" + return self._text + + def raise_for_status(self) -> None: + """Raise exception if status code indicates error""" + if self.status_code >= 400: + raise Exception(f"HTTP Error: {self.status_code}") + + +class MockDatabase: + """Mock database for testing""" + + def __init__(self): + """Initialize mock database""" + self.data: Dict[str, List[Dict[str, Any]]] = {} + self.tables: List[str] = [] + + def create_table(self, table_name: str) -> None: + """Create a table""" + if table_name not in self.tables: + self.tables.append(table_name) + self.data[table_name] = [] + + def insert(self, table_name: str, record: Dict[str, Any]) -> None: + """Insert a record""" + if table_name not in self.tables: + self.create_table(table_name) + record['id'] = record.get('id', MockFactory.generate_uuid()) + self.data[table_name].append(record) + + def select(self, table_name: str, **filters) -> List[Dict[str, Any]]: + """Select records with optional filters""" + if table_name not in self.tables: + return [] + + records = self.data[table_name] + if not filters: + return records + + filtered = [] + for record in records: + match = True + for key, value in filters.items(): + if record.get(key) != value: + match = False + break + if match: + filtered.append(record) + + return filtered + + def update(self, table_name: str, record_id: str, updates: Dict[str, Any]) -> bool: + """Update a record""" + if table_name not in self.tables: + return False + + for record in self.data[table_name]: + if record.get('id') == record_id: + record.update(updates) + return True + return False + + def delete(self, table_name: str, record_id: str) -> bool: + """Delete a record""" + if table_name not in self.tables: + return False + + for i, record in enumerate(self.data[table_name]): + if record.get('id') == record_id: + del self.data[table_name][i] + return True + return False + + def clear(self) -> None: + """Clear all data""" + self.data.clear() + self.tables.clear() + + +class MockCache: + """Mock cache for testing""" + + def __init__(self, ttl: int = 3600): + """Initialize mock cache""" + self.cache: Dict[str, tuple[Any, float]] = {} + self.ttl = ttl + + def get(self, key: str) -> Optional[Any]: + """Get value from cache""" + if key not in self.cache: + return None + + value, timestamp = self.cache[key] + if time.time() - timestamp > self.ttl: + del self.cache[key] + return None + + return value + + def set(self, key: str, value: Any) -> None: + """Set value in cache""" + self.cache[key] = (value, time.time()) + + def delete(self, key: str) -> bool: + """Delete value from cache""" + if key in self.cache: + del self.cache[key] + return True + return False + + def clear(self) -> None: + """Clear cache""" + self.cache.clear() + + def size(self) -> int: + """Get cache size""" + return len(self.cache) + + +def mock_async_call(return_value: Any = None, delay: float = 0): + """Decorator to mock async calls with optional delay""" + def decorator(func: Callable) -> Callable: + async def wrapper(*args, **kwargs): + if delay > 0: + await asyncio.sleep(delay) + return return_value + return wrapper + return decorator + + +def create_mock_config(**overrides) -> Dict[str, Any]: + """Create mock configuration""" + config = { + "debug": False, + "log_level": "INFO", + "database_url": "sqlite:///test.db", + "redis_url": "redis://localhost:6379", + "api_host": "localhost", + "api_port": 8080, + "secret_key": MockFactory.generate_string(32), + "max_workers": 4, + "timeout": 30 + } + config.update(overrides) + return config + + +import time + + +def create_test_scenario(name: str, steps: List[Callable]) -> Callable: + """Create a test scenario with multiple steps""" + def scenario(): + print(f"Running test scenario: {name}") + results = [] + for i, step in enumerate(steps): + try: + result = step() + results.append({"step": i + 1, "status": "passed", "result": result}) + except Exception as e: + results.append({"step": i + 1, "status": "failed", "error": str(e)}) + return results + return scenario diff --git a/aitbc/time_utils.py b/aitbc/time_utils.py new file mode 100644 index 00000000..10497dcc --- /dev/null +++ b/aitbc/time_utils.py @@ -0,0 +1,321 @@ +""" +Time utilities for AITBC +Provides timestamp helpers, duration helpers, timezone handling, and deadline calculations +""" + +from datetime import datetime, timedelta, timezone +from typing import Optional, Union +import time + + +def get_utc_now() -> datetime: + """Get current UTC datetime""" + return datetime.now(timezone.utc) + + +def get_timestamp_utc() -> float: + """Get current UTC timestamp""" + return time.time() + + +def format_iso8601(dt: Optional[datetime] = None) -> str: + """Format datetime as ISO 8601 string in UTC""" + if dt is None: + dt = get_utc_now() + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.isoformat() + + +def parse_iso8601(iso_string: str) -> datetime: + """Parse ISO 8601 string to datetime""" + dt = datetime.fromisoformat(iso_string) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt + + +def timestamp_to_iso(timestamp: float) -> str: + """Convert timestamp to ISO 8601 string""" + return datetime.fromtimestamp(timestamp, timezone.utc).isoformat() + + +def iso_to_timestamp(iso_string: str) -> float: + """Convert ISO 8601 string to timestamp""" + dt = parse_iso8601(iso_string) + return dt.timestamp() + + +def format_duration(seconds: Union[int, float]) -> str: + """Format duration in seconds to human-readable string""" + if seconds < 60: + return f"{int(seconds)}s" + elif seconds < 3600: + minutes = int(seconds / 60) + return f"{minutes}m" + elif seconds < 86400: + hours = int(seconds / 3600) + return f"{hours}h" + else: + days = int(seconds / 86400) + return f"{days}d" + + +def format_duration_precise(seconds: Union[int, float]) -> str: + """Format duration with precise breakdown""" + days = int(seconds // 86400) + hours = int((seconds % 86400) // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + + parts = [] + if days > 0: + parts.append(f"{days}d") + if hours > 0: + parts.append(f"{hours}h") + if minutes > 0: + parts.append(f"{minutes}m") + if secs > 0 or not parts: + parts.append(f"{secs}s") + + return " ".join(parts) + + +def parse_duration(duration_str: str) -> float: + """Parse duration string to seconds""" + duration_str = duration_str.strip().lower() + + if duration_str.endswith('s'): + return float(duration_str[:-1]) + elif duration_str.endswith('m'): + return float(duration_str[:-1]) * 60 + elif duration_str.endswith('h'): + return float(duration_str[:-1]) * 3600 + elif duration_str.endswith('d'): + return float(duration_str[:-1]) * 86400 + else: + return float(duration_str) + + +def add_duration(dt: datetime, duration: Union[str, timedelta]) -> datetime: + """Add duration to datetime""" + if isinstance(duration, str): + duration = timedelta(seconds=parse_duration(duration)) + return dt + duration + + +def subtract_duration(dt: datetime, duration: Union[str, timedelta]) -> datetime: + """Subtract duration from datetime""" + if isinstance(duration, str): + duration = timedelta(seconds=parse_duration(duration)) + return dt - duration + + +def get_time_until(dt: datetime) -> timedelta: + """Get time until a future datetime""" + now = get_utc_now() + return dt - now + + +def get_time_since(dt: datetime) -> timedelta: + """Get time since a past datetime""" + now = get_utc_now() + return now - dt + + +def calculate_deadline(duration: Union[str, timedelta], from_dt: Optional[datetime] = None) -> datetime: + """Calculate deadline from duration""" + if from_dt is None: + from_dt = get_utc_now() + return add_duration(from_dt, duration) + + +def is_deadline_passed(deadline: datetime) -> bool: + """Check if deadline has passed""" + return get_utc_now() >= deadline + + +def get_deadline_remaining(deadline: datetime) -> float: + """Get remaining seconds until deadline""" + delta = deadline - get_utc_now() + return max(0, delta.total_seconds()) + + +def format_time_ago(dt: datetime) -> str: + """Format datetime as "time ago" string""" + delta = get_time_since(dt) + seconds = delta.total_seconds() + + if seconds < 60: + return "just now" + elif seconds < 3600: + minutes = int(seconds / 60) + return f"{minutes} minute{'s' if minutes > 1 else ''} ago" + elif seconds < 86400: + hours = int(seconds / 3600) + return f"{hours} hour{'s' if hours > 1 else ''} ago" + elif seconds < 604800: + days = int(seconds / 86400) + return f"{days} day{'s' if days > 1 else ''} ago" + else: + weeks = int(seconds / 604800) + return f"{weeks} week{'s' if weeks > 1 else ''} ago" + + +def format_time_in(dt: datetime) -> str: + """Format datetime as "time in" string""" + delta = get_time_until(dt) + seconds = delta.total_seconds() + + if seconds < 0: + return format_time_ago(dt) + + if seconds < 60: + return "in a moment" + elif seconds < 3600: + minutes = int(seconds / 60) + return f"in {minutes} minute{'s' if minutes > 1 else ''}" + elif seconds < 86400: + hours = int(seconds / 3600) + return f"in {hours} hour{'s' if hours > 1 else ''}" + elif seconds < 604800: + days = int(seconds / 86400) + return f"in {days} day{'s' if days > 1 else ''}" + else: + weeks = int(seconds / 604800) + return f"in {weeks} week{'s' if weeks > 1 else ''}" + + +def to_timezone(dt: datetime, tz_name: str) -> datetime: + """Convert datetime to specific timezone""" + try: + import pytz + tz = pytz.timezone(tz_name) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.astimezone(tz) + except ImportError: + raise ImportError("pytz is required for timezone conversion. Install with: pip install pytz") + except Exception as e: + raise ValueError(f"Failed to convert timezone: {e}") + + +def get_timezone_offset(tz_name: str) -> timedelta: + """Get timezone offset from UTC""" + try: + import pytz + tz = pytz.timezone(tz_name) + now = datetime.now(timezone.utc) + offset = tz.utcoffset(now) + return offset if offset else timedelta(0) + except ImportError: + raise ImportError("pytz is required for timezone operations. Install with: pip install pytz") + + +def is_business_hours(dt: Optional[datetime] = None, start_hour: int = 9, end_hour: int = 17, timezone: str = "UTC") -> bool: + """Check if datetime is within business hours""" + if dt is None: + dt = get_utc_now() + + try: + import pytz + tz = pytz.timezone(timezone) + dt_local = dt.astimezone(tz) + return start_hour <= dt_local.hour < end_hour + except ImportError: + raise ImportError("pytz is required for business hours check. Install with: pip install pytz") + + +def get_start_of_day(dt: Optional[datetime] = None) -> datetime: + """Get start of day (00:00:00) for given datetime""" + if dt is None: + dt = get_utc_now() + return dt.replace(hour=0, minute=0, second=0, microsecond=0) + + +def get_end_of_day(dt: Optional[datetime] = None) -> datetime: + """Get end of day (23:59:59) for given datetime""" + if dt is None: + dt = get_utc_now() + return dt.replace(hour=23, minute=59, second=59, microsecond=999999) + + +def get_start_of_week(dt: Optional[datetime] = None) -> datetime: + """Get start of week (Monday) for given datetime""" + if dt is None: + dt = get_utc_now() + return dt - timedelta(days=dt.weekday()) + + +def get_end_of_week(dt: Optional[datetime] = None) -> datetime: + """Get end of week (Sunday) for given datetime""" + if dt is None: + dt = get_utc_now() + return dt + timedelta(days=(6 - dt.weekday())) + + +def get_start_of_month(dt: Optional[datetime] = None) -> datetime: + """Get start of month for given datetime""" + if dt is None: + dt = get_utc_now() + return dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + +def get_end_of_month(dt: Optional[datetime] = None) -> datetime: + """Get end of month for given datetime""" + if dt is None: + dt = get_utc_now() + if dt.month == 12: + next_month = dt.replace(year=dt.year + 1, month=1, day=1) + else: + next_month = dt.replace(month=dt.month + 1, day=1) + return next_month - timedelta(seconds=1) + + +def sleep_until(dt: datetime) -> None: + """Sleep until a specific datetime""" + now = get_utc_now() + if dt > now: + sleep_seconds = (dt - now).total_seconds() + time.sleep(sleep_seconds) + + +def retry_until_deadline(func, deadline: datetime, interval: float = 1.0) -> bool: + """Retry a function until deadline is reached""" + while not is_deadline_passed(deadline): + try: + result = func() + if result: + return True + except Exception: + pass + time.sleep(interval) + return False + + +class Timer: + """Simple timer context manager for measuring execution time""" + + def __init__(self): + """Initialize timer""" + self.start_time = None + self.end_time = None + self.elapsed = None + + def __enter__(self): + """Start timer""" + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop timer""" + self.end_time = time.time() + self.elapsed = self.end_time - self.start_time + + def get_elapsed(self) -> Optional[float]: + """Get elapsed time in seconds""" + if self.elapsed is not None: + return self.elapsed + elif self.start_time is not None: + return time.time() - self.start_time + return None diff --git a/aitbc/web3_utils.py b/aitbc/web3_utils.py new file mode 100644 index 00000000..32a23506 --- /dev/null +++ b/aitbc/web3_utils.py @@ -0,0 +1,206 @@ +""" +Web3 utilities for AITBC +Provides Ethereum blockchain interaction utilities using web3.py +""" + +from typing import Any, Optional +from decimal import Decimal + + +class Web3Client: + """Web3 client wrapper for blockchain operations""" + + def __init__(self, rpc_url: str, timeout: int = 30): + """Initialize Web3 client with RPC URL""" + try: + from web3 import Web3 + from web3.middleware import geth_poa_middleware + + self.w3 = Web3(Web3.HTTPProvider(rpc_url, request_kwargs={'timeout': timeout})) + + # Add POA middleware for chains like Polygon, BSC, etc. + self.w3.middleware_onion.inject(geth_poa_middleware, layer=0) + + if not self.w3.is_connected(): + raise ConnectionError(f"Failed to connect to RPC URL: {rpc_url}") + except ImportError: + raise ImportError("web3 is required for blockchain operations. Install with: pip install web3") + except Exception as e: + raise ConnectionError(f"Failed to initialize Web3 client: {e}") + + def get_eth_balance(self, address: str) -> str: + """Get ETH balance in wei""" + try: + balance_wei = self.w3.eth.get_balance(address) + return str(balance_wei) + except Exception as e: + raise ValueError(f"Failed to get ETH balance: {e}") + + def get_token_balance(self, address: str, token_address: str) -> dict[str, Any]: + """Get ERC-20 token balance""" + try: + # ERC-20 balanceOf function signature: 0x70a08231 + balance_of_signature = '0x70a08231' + # Pad address to 32 bytes + padded_address = address[2:].lower().zfill(64) + call_data = balance_of_signature + padded_address + + result = self.w3.eth.call({ + 'to': token_address, + 'data': f'0x{call_data}' + }) + + balance = int(result.hex(), 16) + + # Get token decimals + decimals_signature = '0x313ce567' + decimals_result = self.w3.eth.call({ + 'to': token_address, + 'data': decimals_signature + }) + decimals = int(decimals_result.hex(), 16) + + # Get token symbol (optional, may fail for some tokens) + try: + symbol_signature = '0x95d89b41' + symbol_result = self.w3.eth.call({ + 'to': token_address, + 'data': symbol_signature + }) + symbol_bytes = bytes.fromhex(symbol_result.hex()[2:]) + symbol = symbol_bytes.rstrip(b'\x00').decode('utf-8') + except: + symbol = "TOKEN" + + return { + "balance": str(balance), + "decimals": decimals, + "symbol": symbol + } + except Exception as e: + raise ValueError(f"Failed to get token balance: {e}") + + def get_gas_price(self) -> int: + """Get current gas price in wei""" + try: + gas_price = self.w3.eth.gas_price + return gas_price + except Exception as e: + raise ValueError(f"Failed to get gas price: {e}") + + def get_gas_price_gwei(self) -> float: + """Get current gas price in Gwei""" + try: + gas_price_wei = self.get_gas_price() + return float(gas_price_wei) / 10**9 + except Exception as e: + raise ValueError(f"Failed to get gas price in Gwei: {e}") + + def get_nonce(self, address: str) -> int: + """Get transaction nonce for address""" + try: + nonce = self.w3.eth.get_transaction_count(address) + return nonce + except Exception as e: + raise ValueError(f"Failed to get nonce: {e}") + + def send_raw_transaction(self, signed_transaction: str) -> str: + """Send raw transaction to blockchain""" + try: + tx_hash = self.w3.eth.send_raw_transaction(signed_transaction) + return tx_hash.hex() + except Exception as e: + raise ValueError(f"Failed to send raw transaction: {e}") + + def get_transaction_receipt(self, tx_hash: str) -> Optional[dict[str, Any]]: + """Get transaction receipt""" + try: + receipt = self.w3.eth.get_transaction_receipt(tx_hash) + if receipt is None: + return None + + return { + "status": receipt['status'], + "blockNumber": hex(receipt['blockNumber']), + "blockHash": receipt['blockHash'].hex(), + "gasUsed": hex(receipt['gasUsed']), + "effectiveGasPrice": hex(receipt['effectiveGasPrice']), + "logs": receipt['logs'], + } + except Exception as e: + raise ValueError(f"Failed to get transaction receipt: {e}") + + def get_transaction_by_hash(self, tx_hash: str) -> dict[str, Any]: + """Get transaction by hash""" + try: + tx = self.w3.eth.get_transaction(tx_hash) + return { + "from": tx['from'], + "to": tx['to'], + "value": hex(tx['value']), + "data": tx['input'].hex() if hasattr(tx['input'], 'hex') else tx['input'], + "nonce": tx['nonce'], + "gas": tx['gas'], + "gasPrice": hex(tx['gasPrice']), + "blockNumber": hex(tx['blockNumber']) if tx['blockNumber'] else None, + } + except Exception as e: + raise ValueError(f"Failed to get transaction by hash: {e}") + + def estimate_gas(self, transaction: dict[str, Any]) -> int: + """Estimate gas for transaction""" + try: + gas_estimate = self.w3.eth.estimate_gas(transaction) + return gas_estimate + except Exception as e: + raise ValueError(f"Failed to estimate gas: {e}") + + def get_block_number(self) -> int: + """Get current block number""" + try: + return self.w3.eth.block_number + except Exception as e: + raise ValueError(f"Failed to get block number: {e}") + + def get_wallet_transactions(self, address: str, limit: int = 100) -> list[dict[str, Any]]: + """Get wallet transactions (simplified implementation)""" + try: + # This is a simplified version - in production you'd want to use + # event logs or a blockchain explorer API for this + transactions = [] + current_block = self.get_block_number() + + # Look back at recent blocks for transactions from/to this address + start_block = max(0, current_block - 1000) + + for block_num in range(current_block, start_block, -1): + if len(transactions) >= limit: + break + + try: + block = self.w3.eth.get_block(block_num, full_transactions=True) + for tx in block['transactions']: + if tx['from'].lower() == address.lower() or \ + (tx['to'] and tx['to'].lower() == address.lower()): + transactions.append({ + "hash": tx['hash'].hex(), + "from": tx['from'], + "to": tx['to'].hex() if tx['to'] else None, + "value": hex(tx['value']), + "blockNumber": hex(tx['blockNumber']), + "timestamp": block['timestamp'], + "gasUsed": hex(tx['gas']), + }) + if len(transactions) >= limit: + break + except: + continue + + return transactions + except Exception as e: + raise ValueError(f"Failed to get wallet transactions: {e}") + + +def create_web3_client(rpc_url: str, timeout: int = 30) -> Web3Client: + """Factory function to create Web3 client""" + return Web3Client(rpc_url, timeout) diff --git a/apps/blockchain-explorer/main.py b/apps/blockchain-explorer/main.py index ddd9a133..4ab89ab0 100755 --- a/apps/blockchain-explorer/main.py +++ b/apps/blockchain-explorer/main.py @@ -1048,19 +1048,15 @@ async def search_transactions( response = await client.get(f"{rpc_url}/rpc/search/transactions", params=params) if response.status_code == 200: return response.json() + elif response.status_code == 404: + return [] else: - # Return mock data for demonstration - return [ - { - "hash": "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", - "type": tx_type or "transfer", - "from": "0xabcdef1234567890abcdef1234567890abcdef1234", - "to": "0x1234567890abcdef1234567890abcdef12345678", - "amount": "1.5", - "fee": "0.001", - "timestamp": datetime.now().isoformat() - } - ] + raise HTTPException( + status_code=response.status_code, + detail=f"Failed to fetch transactions from blockchain RPC: {response.text}" + ) + except httpx.RequestError as e: + raise HTTPException(status_code=503, detail=f"Blockchain RPC unavailable: {str(e)}") except Exception as e: raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") @@ -1095,17 +1091,15 @@ async def search_blocks( response = await client.get(f"{rpc_url}/rpc/search/blocks", params=params) if response.status_code == 200: return response.json() + elif response.status_code == 404: + return [] else: - # Return mock data for demonstration - return [ - { - "height": 12345, - "hash": "0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890", - "validator": validator or "0x1234567890abcdef1234567890abcdef12345678", - "tx_count": min_tx or 5, - "timestamp": datetime.now().isoformat() - } - ] + raise HTTPException( + status_code=response.status_code, + detail=f"Failed to fetch blocks from blockchain RPC: {response.text}" + ) + except httpx.RequestError as e: + raise HTTPException(status_code=503, detail=f"Blockchain RPC unavailable: {str(e)}") except Exception as e: raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") diff --git a/apps/blockchain-node/src/aitbc_chain/network/hub_manager.py b/apps/blockchain-node/src/aitbc_chain/network/hub_manager.py index e69de29b..a68e657e 100644 --- a/apps/blockchain-node/src/aitbc_chain/network/hub_manager.py +++ b/apps/blockchain-node/src/aitbc_chain/network/hub_manager.py @@ -0,0 +1,564 @@ +""" +Hub Manager +Manages hub operations, peer list sharing, and hub registration for federated mesh +""" + +import asyncio +import time +import json +import os +import socket +from typing import Dict, List, Optional, Set +from dataclasses import dataclass, field, asdict +from enum import Enum +from ..config import settings + +from aitbc import get_logger, DATA_DIR, KEYSTORE_DIR + +logger = get_logger(__name__) + + +class HubStatus(Enum): + """Hub registration status""" + REGISTERED = "registered" + UNREGISTERED = "unregistered" + PENDING = "pending" + + +@dataclass +class HubInfo: + """Information about a hub node""" + node_id: str + address: str + port: int + island_id: str + island_name: str + public_address: Optional[str] = None + public_port: Optional[int] = None + registered_at: float = 0 + last_seen: float = 0 + peer_count: int = 0 + + +@dataclass +class PeerInfo: + """Information about a peer""" + node_id: str + address: str + port: int + island_id: str + is_hub: bool + public_address: Optional[str] = None + public_port: Optional[int] = None + last_seen: float = 0 + + +class HubManager: + """Manages hub operations for federated mesh""" + + def __init__(self, local_node_id: str, local_address: str, local_port: int, island_id: str, island_name: str, redis_url: Optional[str] = None): + self.local_node_id = local_node_id + self.local_address = local_address + self.local_port = local_port + self.island_id = island_id + self.island_name = island_name + self.island_chain_id = settings.island_chain_id or settings.chain_id or f"ait-{island_id[:8]}" + self.redis_url = redis_url or "redis://localhost:6379" + + # Hub registration status + self.is_hub = False + self.hub_status = HubStatus.UNREGISTERED + self.registered_at: Optional[float] = None + + # Known hubs + self.known_hubs: Dict[str, HubInfo] = {} # node_id -> HubInfo + + # Peer registry (for providing peer lists) + self.peer_registry: Dict[str, PeerInfo] = {} # node_id -> PeerInfo + + # Island peers (island_id -> set of node_ids) + self.island_peers: Dict[str, Set[str]] = {} + + self.running = False + self._redis = None + + # Initialize island peers for our island + self.island_peers[self.island_id] = set() + + async def _connect_redis(self): + """Connect to Redis""" + try: + import redis.asyncio as redis + self._redis = redis.from_url(self.redis_url) + await self._redis.ping() + logger.info(f"Connected to Redis for hub persistence: {self.redis_url}") + return True + except Exception as e: + logger.error(f"Failed to connect to Redis: {e}") + return False + + async def _persist_hub_registration(self, hub_info: HubInfo) -> bool: + """Persist hub registration to Redis""" + try: + if not self._redis: + await self._connect_redis() + + if not self._redis: + logger.warning("Redis not available, skipping persistence") + return False + + key = f"hub:{hub_info.node_id}" + value = json.dumps(asdict(hub_info), default=str) + await self._redis.setex(key, 3600, value) # TTL: 1 hour + logger.info(f"Persisted hub registration to Redis: {key}") + return True + except Exception as e: + logger.error(f"Failed to persist hub registration: {e}") + return False + + async def _remove_hub_registration(self, node_id: str) -> bool: + """Remove hub registration from Redis""" + try: + if not self._redis: + await self._connect_redis() + + if not self._redis: + logger.warning("Redis not available, skipping removal") + return False + + key = f"hub:{node_id}" + await self._redis.delete(key) + logger.info(f"Removed hub registration from Redis: {key}") + return True + except Exception as e: + logger.error(f"Failed to remove hub registration: {e}") + return False + + async def _load_hub_registration(self) -> Optional[HubInfo]: + """Load hub registration from Redis""" + try: + if not self._redis: + await self._connect_redis() + + if not self._redis: + return None + + key = f"hub:{self.local_node_id}" + value = await self._redis.get(key) + if value: + data = json.loads(value) + return HubInfo(**data) + return None + except Exception as e: + logger.error(f"Failed to load hub registration: {e}") + return None + + def _get_blockchain_credentials(self) -> dict: + """Get blockchain credentials from keystore""" + try: + credentials = {} + + # Get genesis block hash from genesis.json + genesis_candidates = [ + str(settings.db_path.parent / 'genesis.json'), + f"{DATA_DIR}/data/{settings.chain_id}/genesis.json", + f'{DATA_DIR}/data/ait-mainnet/genesis.json', + ] + for genesis_path in genesis_candidates: + if os.path.exists(genesis_path): + with open(genesis_path, 'r') as f: + genesis_data = json.load(f) + if 'blocks' in genesis_data and len(genesis_data['blocks']) > 0: + genesis_block = genesis_data['blocks'][0] + credentials['genesis_block_hash'] = genesis_block.get('hash', '') + credentials['genesis_block'] = genesis_data + break + + # Get genesis address from keystore + keystore_path = str(KEYSTORE_DIR / 'validator_keys.json') + if os.path.exists(keystore_path): + with open(keystore_path, 'r') as f: + keys = json.load(f) + # Get first key's address + for key_id, key_data in keys.items(): + # Extract address from public key or use key_id + credentials['genesis_address'] = key_id + break + + # Add chain info + credentials['chain_id'] = self.island_chain_id + credentials['island_id'] = self.island_id + credentials['island_name'] = self.island_name + + # Add RPC endpoint (local) + rpc_host = self.local_address + if rpc_host in {"0.0.0.0", "127.0.0.1", "localhost", ""}: + rpc_host = settings.hub_discovery_url or socket.gethostname() + credentials['rpc_endpoint'] = f"http://{rpc_host}:8006" + credentials['p2p_port'] = self.local_port + + return credentials + except Exception as e: + logger.error(f"Failed to get blockchain credentials: {e}") + return {} + + async def handle_join_request(self, join_request: dict) -> Optional[dict]: + """ + Handle island join request from a new node + + Args: + join_request: Dictionary containing join request data + + Returns: + dict: Join response with member list and credentials, or None if failed + """ + try: + requested_island_id = join_request.get('island_id') + + # Validate island ID + if requested_island_id != self.island_id: + logger.warning(f"Join request for island {requested_island_id} does not match our island {self.island_id}") + return None + + # Get all island members + members = [] + for node_id, peer_info in self.peer_registry.items(): + if peer_info.island_id == self.island_id: + members.append({ + 'node_id': peer_info.node_id, + 'address': peer_info.address, + 'port': peer_info.port, + 'is_hub': peer_info.is_hub, + 'public_address': peer_info.public_address, + 'public_port': peer_info.public_port + }) + + # Include self in member list + members.append({ + 'node_id': self.local_node_id, + 'address': self.local_address, + 'port': self.local_port, + 'is_hub': True, + 'public_address': self.known_hubs.get(self.local_node_id, {}).public_address if self.local_node_id in self.known_hubs else None, + 'public_port': self.known_hubs.get(self.local_node_id, {}).public_port if self.local_node_id in self.known_hubs else None + }) + + # Get blockchain credentials + credentials = self._get_blockchain_credentials() + + # Build response + response = { + 'type': 'join_response', + 'island_id': self.island_id, + 'island_name': self.island_name, + 'island_chain_id': self.island_chain_id or f"ait-{self.island_id[:8]}", + 'members': members, + 'credentials': credentials + } + + logger.info(f"Sent join_response to node {join_request.get('node_id')} with {len(members)} members") + return response + + except Exception as e: + logger.error(f"Error handling join request: {e}") + return None + + def register_gpu_offer(self, offer_data: dict) -> bool: + """Register a GPU marketplace offer in the hub""" + try: + offer_id = offer_data.get('offer_id') + if offer_id: + self.gpu_offers[offer_id] = offer_data + logger.info(f"Registered GPU offer: {offer_id}") + return True + except Exception as e: + logger.error(f"Error registering GPU offer: {e}") + return False + + def register_gpu_bid(self, bid_data: dict) -> bool: + """Register a GPU marketplace bid in the hub""" + try: + bid_id = bid_data.get('bid_id') + if bid_id: + self.gpu_bids[bid_id] = bid_data + logger.info(f"Registered GPU bid: {bid_id}") + return True + except Exception as e: + logger.error(f"Error registering GPU bid: {e}") + return False + + def register_gpu_provider(self, node_id: str, gpu_info: dict) -> bool: + """Register a GPU provider in the hub""" + try: + self.gpu_providers[node_id] = gpu_info + logger.info(f"Registered GPU provider: {node_id}") + return True + except Exception as e: + logger.error(f"Error registering GPU provider: {e}") + return False + + def register_exchange_order(self, order_data: dict) -> bool: + """Register an exchange order in the hub""" + try: + order_id = order_data.get('order_id') + if order_id: + self.exchange_orders[order_id] = order_data + + # Update order book + pair = order_data.get('pair') + side = order_data.get('side') + if pair and side: + if pair not in self.exchange_order_books: + self.exchange_order_books[pair] = {'bids': [], 'asks': []} + + if side == 'buy': + self.exchange_order_books[pair]['bids'].append(order_data) + elif side == 'sell': + self.exchange_order_books[pair]['asks'].append(order_data) + + logger.info(f"Registered exchange order: {order_id}") + return True + except Exception as e: + logger.error(f"Error registering exchange order: {e}") + return False + + def get_gpu_offers(self) -> list: + """Get all GPU offers""" + return list(self.gpu_offers.values()) + + def get_gpu_bids(self) -> list: + """Get all GPU bids""" + return list(self.gpu_bids.values()) + + def get_gpu_providers(self) -> list: + """Get all GPU providers""" + return list(self.gpu_providers.values()) + + def get_exchange_order_book(self, pair: str) -> dict: + """Get order book for a specific trading pair""" + return self.exchange_order_books.get(pair, {'bids': [], 'asks': []}) + + async def register_as_hub(self, public_address: Optional[str] = None, public_port: Optional[int] = None) -> bool: + """Register this node as a hub""" + if self.is_hub: + logger.warning("Already registered as hub") + return False + + self.is_hub = True + self.hub_status = HubStatus.REGISTERED + self.registered_at = time.time() + + # Add self to known hubs + hub_info = HubInfo( + node_id=self.local_node_id, + address=self.local_address, + port=self.local_port, + island_id=self.island_id, + island_name=self.island_name, + public_address=public_address, + public_port=public_port, + registered_at=time.time(), + last_seen=time.time() + ) + self.known_hubs[self.local_node_id] = hub_info + + # Persist to Redis + await self._persist_hub_registration(hub_info) + + logger.info(f"Registered as hub for island {self.island_id}") + return True + + async def unregister_as_hub(self) -> bool: + """Unregister this node as a hub""" + if not self.is_hub: + logger.warning("Not registered as hub") + return False + + self.is_hub = False + self.hub_status = HubStatus.UNREGISTERED + self.registered_at = None + + # Remove from Redis + await self._remove_hub_registration(self.local_node_id) + + # Remove self from known hubs + if self.local_node_id in self.known_hubs: + del self.known_hubs[self.local_node_id] + + logger.info(f"Unregistered as hub for island {self.island_id}") + return True + + def register_peer(self, peer_info: PeerInfo) -> bool: + """Register a peer in the registry""" + self.peer_registry[peer_info.node_id] = peer_info + + # Add to island peers + if peer_info.island_id not in self.island_peers: + self.island_peers[peer_info.island_id] = set() + self.island_peers[peer_info.island_id].add(peer_info.node_id) + + # Update hub peer count if peer is a hub + if peer_info.is_hub and peer_info.node_id in self.known_hubs: + self.known_hubs[peer_info.node_id].peer_count = len(self.island_peers.get(peer_info.island_id, set())) + + logger.debug(f"Registered peer {peer_info.node_id} in island {peer_info.island_id}") + return True + + def unregister_peer(self, node_id: str) -> bool: + """Unregister a peer from the registry""" + if node_id not in self.peer_registry: + return False + + peer_info = self.peer_registry[node_id] + + # Remove from island peers + if peer_info.island_id in self.island_peers: + self.island_peers[peer_info.island_id].discard(node_id) + + del self.peer_registry[node_id] + + # Update hub peer count + if node_id in self.known_hubs: + self.known_hubs[node_id].peer_count = len(self.island_peers.get(self.known_hubs[node_id].island_id, set())) + + logger.debug(f"Unregistered peer {node_id}") + return True + + def add_known_hub(self, hub_info: HubInfo): + """Add a known hub to the registry""" + self.known_hubs[hub_info.node_id] = hub_info + logger.info(f"Added known hub {hub_info.node_id} for island {hub_info.island_id}") + + def remove_known_hub(self, node_id: str) -> bool: + """Remove a known hub from the registry""" + if node_id not in self.known_hubs: + return False + + del self.known_hubs[node_id] + logger.info(f"Removed known hub {node_id}") + return True + + def get_peer_list(self, island_id: str) -> List[PeerInfo]: + """Get peer list for a specific island""" + peers = [] + for node_id, peer_info in self.peer_registry.items(): + if peer_info.island_id == island_id: + peers.append(peer_info) + return peers + + def get_hub_list(self, island_id: Optional[str] = None) -> List[HubInfo]: + """Get list of known hubs, optionally filtered by island""" + hubs = [] + for hub_info in self.known_hubs.values(): + if island_id is None or hub_info.island_id == island_id: + hubs.append(hub_info) + return hubs + + def get_island_peers(self, island_id: str) -> Set[str]: + """Get set of peer node IDs in an island""" + return self.island_peers.get(island_id, set()).copy() + + def get_peer_count(self, island_id: str) -> int: + """Get number of peers in an island""" + return len(self.island_peers.get(island_id, set())) + + def get_hub_info(self, node_id: str) -> Optional[HubInfo]: + """Get information about a specific hub""" + return self.known_hubs.get(node_id) + + def get_peer_info(self, node_id: str) -> Optional[PeerInfo]: + """Get information about a specific peer""" + return self.peer_registry.get(node_id) + + def update_peer_last_seen(self, node_id: str): + """Update the last seen time for a peer""" + if node_id in self.peer_registry: + self.peer_registry[node_id].last_seen = time.time() + + if node_id in self.known_hubs: + self.known_hubs[node_id].last_seen = time.time() + + async def start(self): + """Start hub manager""" + self.running = True + logger.info(f"Starting hub manager for node {self.local_node_id}") + + # Start background tasks + tasks = [ + asyncio.create_task(self._hub_health_check()), + asyncio.create_task(self._peer_cleanup()) + ] + + try: + await asyncio.gather(*tasks) + except Exception as e: + logger.error(f"Hub manager error: {e}") + finally: + self.running = False + + async def stop(self): + """Stop hub manager""" + self.running = False + logger.info("Stopping hub manager") + + async def _hub_health_check(self): + """Check health of known hubs""" + while self.running: + try: + current_time = time.time() + + # Check for offline hubs (not seen for 10 minutes) + offline_hubs = [] + for node_id, hub_info in self.known_hubs.items(): + if current_time - hub_info.last_seen > 600: + offline_hubs.append(node_id) + logger.warning(f"Hub {node_id} appears to be offline") + + # Remove offline hubs (keep self if we're a hub) + for node_id in offline_hubs: + if node_id != self.local_node_id: + self.remove_known_hub(node_id) + + await asyncio.sleep(60) # Check every minute + + except Exception as e: + logger.error(f"Hub health check error: {e}") + await asyncio.sleep(10) + + async def _peer_cleanup(self): + """Clean up stale peer entries""" + while self.running: + try: + current_time = time.time() + + # Remove peers not seen for 5 minutes + stale_peers = [] + for node_id, peer_info in self.peer_registry.items(): + if current_time - peer_info.last_seen > 300: + stale_peers.append(node_id) + + for node_id in stale_peers: + self.unregister_peer(node_id) + logger.debug(f"Removed stale peer {node_id}") + + await asyncio.sleep(60) # Check every minute + + except Exception as e: + logger.error(f"Peer cleanup error: {e}") + await asyncio.sleep(10) + + +# Global hub manager instance +hub_manager_instance: Optional[HubManager] = None + + +def get_hub_manager() -> Optional[HubManager]: + """Get global hub manager instance""" + return hub_manager_instance + + +def create_hub_manager(node_id: str, address: str, port: int, island_id: str, island_name: str) -> HubManager: + """Create and set global hub manager instance""" + global hub_manager_instance + hub_manager_instance = HubManager(node_id, address, port, island_id, island_name) + return hub_manager_instance diff --git a/apps/coordinator-api/scripts/migrate_complete.py b/apps/coordinator-api/scripts/migrate_complete.py index c4deb5b4..782f7777 100755 --- a/apps/coordinator-api/scripts/migrate_complete.py +++ b/apps/coordinator-api/scripts/migrate_complete.py @@ -7,8 +7,10 @@ from psycopg2.extras import RealDictCursor import json from decimal import Decimal +from aitbc.constants import DATA_DIR + # Database configurations -SQLITE_DB = "/var/lib/aitbc/data/coordinator.db" +SQLITE_DB = str(DATA_DIR / "data/coordinator.db") PG_CONFIG = { "host": "localhost", "database": "aitbc_coordinator", diff --git a/apps/coordinator-api/src/app/agent_identity/wallet_adapter_enhanced.py b/apps/coordinator-api/src/app/agent_identity/wallet_adapter_enhanced.py index 5a1db7bd..26468ba7 100755 --- a/apps/coordinator-api/src/app/agent_identity/wallet_adapter_enhanced.py +++ b/apps/coordinator-api/src/app/agent_identity/wallet_adapter_enhanced.py @@ -12,7 +12,7 @@ from decimal import Decimal from enum import StrEnum from typing import Any -from aitbc import get_logger +from aitbc import get_logger, derive_ethereum_address, sign_transaction_hash, verify_signature, encrypt_private_key, Web3Client logger = get_logger(__name__) @@ -174,6 +174,8 @@ class EthereumWalletAdapter(EnhancedWalletAdapter): def __init__(self, chain_id: int, rpc_url: str, security_level: SecurityLevel = SecurityLevel.MEDIUM): super().__init__(chain_id, ChainType.ETHEREUM, rpc_url, security_level) self.chain_id = chain_id + # Initialize Web3 client for blockchain operations + self._web3_client = Web3Client(rpc_url) async def create_wallet(self, owner_address: str, security_config: dict[str, Any]) -> dict[str, Any]: """Create a new Ethereum wallet with enhanced security""" @@ -446,25 +448,36 @@ class EthereumWalletAdapter(EnhancedWalletAdapter): # Private helper methods async def _derive_address_from_private_key(self, private_key: str) -> str: """Derive Ethereum address from private key""" - # This would use actual Ethereum cryptography - # For now, return a mock address - return f"0x{hashlib.sha256(private_key.encode()).hexdigest()[:40]}" + try: + return derive_ethereum_address(private_key) + except Exception as e: + logger.error(f"Failed to derive address from private key: {e}") + raise async def _encrypt_private_key(self, private_key: str, security_config: dict[str, Any]) -> str: """Encrypt private key with security configuration""" - # This would use actual encryption - # For now, return mock encrypted key - return f"encrypted_{hashlib.sha256(private_key.encode()).hexdigest()}" + try: + password = security_config.get("encryption_password", "default_password") + return encrypt_private_key(private_key, password) + except Exception as e: + logger.error(f"Failed to encrypt private key: {e}") + raise async def _get_eth_balance(self, address: str) -> str: """Get ETH balance in wei""" - # Mock implementation - return "1000000000000000000" # 1 ETH in wei + try: + return self._web3_client.get_eth_balance(address) + except Exception as e: + logger.error(f"Failed to get ETH balance: {e}") + raise async def _get_token_balance(self, address: str, token_address: str) -> dict[str, Any]: """Get ERC-20 token balance""" - # Mock implementation - return {"balance": "100000000000000000000", "decimals": 18, "symbol": "TOKEN"} # 100 tokens + try: + return self._web3_client.get_token_balance(address, token_address) + except Exception as e: + logger.error(f"Failed to get token balance: {e}") + raise async def _create_erc20_transfer( self, from_address: str, to_address: str, token_address: str, amount: int @@ -493,78 +506,116 @@ class EthereumWalletAdapter(EnhancedWalletAdapter): async def _get_gas_price(self) -> int: """Get current gas price""" - # Mock implementation - return 20000000000 # 20 Gwei in wei + try: + return self._web3_client.get_gas_price() + except Exception as e: + logger.error(f"Failed to get gas price: {e}") + raise async def _get_gas_price_gwei(self) -> float: """Get current gas price in Gwei""" - gas_price_wei = await self._get_gas_price() - return gas_price_wei / 10**9 + try: + return self._web3_client.get_gas_price_gwei() + except Exception as e: + logger.error(f"Failed to get gas price in Gwei: {e}") + raise async def _get_nonce(self, address: str) -> int: """Get transaction nonce for address""" - # Mock implementation - return 0 + try: + return self._web3_client.get_nonce(address) + except Exception as e: + logger.error(f"Failed to get nonce: {e}") + raise async def _sign_transaction(self, transaction_data: dict[str, Any], from_address: str) -> str: """Sign transaction""" - # Mock implementation - return f"0xsigned_{hashlib.sha256(str(transaction_data).encode()).hexdigest()}" + try: + # Get the transaction hash + from eth_account import Account + # Remove 0x prefix if present + if from_address.startswith("0x"): + from_address = from_address[2:] + + account = Account.from_key(from_address) + + # Build transaction dict for signing + tx_dict = { + 'nonce': int(transaction_data.get('nonce', 0), 16), + 'gasPrice': int(transaction_data.get('gasPrice', 0), 16), + 'gas': int(transaction_data.get('gas', 0), 16), + 'to': transaction_data.get('to'), + 'value': int(transaction_data.get('value', '0x0'), 16), + 'data': transaction_data.get('data', '0x'), + 'chainId': transaction_data.get('chainId', 1) + } + + signed_tx = account.sign_transaction(tx_dict) + return signed_tx.raw_transaction.hex() + except ImportError: + raise ImportError("eth-account is required for transaction signing. Install with: pip install eth-account") + except Exception as e: + logger.error(f"Failed to sign transaction: {e}") + raise async def _send_raw_transaction(self, signed_transaction: str) -> str: """Send raw transaction""" - # Mock implementation - return f"0x{hashlib.sha256(signed_transaction.encode()).hexdigest()}" + try: + return self._web3_client.send_raw_transaction(signed_transaction) + except Exception as e: + logger.error(f"Failed to send raw transaction: {e}") + raise async def _get_transaction_receipt(self, tx_hash: str) -> dict[str, Any] | None: """Get transaction receipt""" - # Mock implementation - return { - "status": 1, - "blockNumber": "0x12345", - "blockHash": "0xabcdef", - "gasUsed": "0x5208", - "effectiveGasPrice": "0x4a817c800", - "logs": [], - } + try: + return self._web3_client.get_transaction_receipt(tx_hash) + except Exception as e: + logger.error(f"Failed to get transaction receipt: {e}") + raise async def _get_transaction_by_hash(self, tx_hash: str) -> dict[str, Any]: """Get transaction by hash""" - # Mock implementation - return {"from": "0xsender", "to": "0xreceiver", "value": "0xde0b6b3a7640000", "data": "0x"} # 1 ETH in wei + try: + return self._web3_client.get_transaction_by_hash(tx_hash) + except Exception as e: + logger.error(f"Failed to get transaction by hash: {e}") + raise async def _estimate_gas_call(self, call_data: dict[str, Any]) -> str: """Estimate gas for call""" - # Mock implementation - return "0x5208" # 21000 in hex + try: + gas_estimate = self._web3_client.estimate_gas(call_data) + return hex(gas_estimate) + except Exception as e: + logger.error(f"Failed to estimate gas: {e}") + raise async def _get_wallet_transactions( self, address: str, limit: int, offset: int, from_block: int | None, to_block: int | None ) -> list[dict[str, Any]]: """Get wallet transactions""" - # Mock implementation - return [ - { - "hash": f"0x{hashlib.sha256(f'tx_{i}'.encode()).hexdigest()}", - "from": address, - "to": f"0x{hashlib.sha256(f'to_{i}'.encode()).hexdigest()[:40]}", - "value": "0xde0b6b3a7640000", - "blockNumber": f"0x{12345 + i}", - "timestamp": datetime.utcnow().timestamp(), - "gasUsed": "0x5208", - } - for i in range(min(limit, 10)) - ] + try: + return self._web3_client.get_wallet_transactions(address, limit) + except Exception as e: + logger.error(f"Failed to get wallet transactions: {e}") + raise async def _sign_hash(self, message_hash: str, private_key: str) -> str: """Sign a hash with private key""" - # Mock implementation - return f"0x{hashlib.sha256(f'{message_hash}{private_key}'.encode()).hexdigest()}" + try: + return sign_transaction_hash(message_hash, private_key) + except Exception as e: + logger.error(f"Failed to sign hash: {e}") + raise async def _verify_signature(self, message_hash: str, signature: str, address: str) -> bool: """Verify a signature""" - # Mock implementation - return True + try: + return verify_signature(message_hash, signature, address) + except Exception as e: + logger.error(f"Failed to verify signature: {e}") + return False class PolygonWalletAdapter(EthereumWalletAdapter): diff --git a/apps/exchange/complete_cross_chain_exchange.py b/apps/exchange/complete_cross_chain_exchange.py index aac15e28..d97d8a14 100755 --- a/apps/exchange/complete_cross_chain_exchange.py +++ b/apps/exchange/complete_cross_chain_exchange.py @@ -7,7 +7,6 @@ Multi-chain trading with cross-chain swaps and bridging import sqlite3 import json import asyncio -import httpx from datetime import datetime, timedelta from typing import Dict, List, Optional, Any from fastapi import FastAPI, HTTPException, Query, BackgroundTasks @@ -17,8 +16,15 @@ import os import uuid import hashlib +from aitbc.http_client import AsyncAITBCHTTPClient +from aitbc.aitbc_logging import get_logger +from aitbc.exceptions import NetworkError + app = FastAPI(title="AITBC Complete Cross-Chain Exchange", version="3.0.0") +# Initialize logger +logger = get_logger(__name__) + # Database configuration DB_PATH = os.path.join(os.path.dirname(__file__), "exchange_multichain.db") @@ -368,10 +374,10 @@ async def health_check(): if chain_info["status"] == "active" and chain_info["blockchain_url"]: try: - async with httpx.AsyncClient() as client: - response = await client.get(f"{chain_info['blockchain_url']}/health", timeout=5.0) - chain_status[chain_id]["connected"] = response.status_code == 200 - except: + client = AsyncAITBCHTTPClient(base_url=chain_info['blockchain_url'], timeout=5) + response = await client.async_get("/health") + chain_status[chain_id]["connected"] = response is not None + except NetworkError: pass return { diff --git a/apps/exchange/multichain_exchange_api.py b/apps/exchange/multichain_exchange_api.py index 714ee1d9..00fc8db9 100755 --- a/apps/exchange/multichain_exchange_api.py +++ b/apps/exchange/multichain_exchange_api.py @@ -7,7 +7,6 @@ Complete multi-chain trading with chain isolation import sqlite3 import json import asyncio -import httpx from datetime import datetime, timedelta from typing import Dict, List, Optional, Any from fastapi import FastAPI, HTTPException, Query, BackgroundTasks @@ -15,8 +14,15 @@ from pydantic import BaseModel, Field import uvicorn import os +from aitbc.http_client import AsyncAITBCHTTPClient +from aitbc.aitbc_logging import get_logger +from aitbc.exceptions import NetworkError + app = FastAPI(title="AITBC Multi-Chain Exchange", version="2.0.0") +# Initialize logger +logger = get_logger(__name__) + # Database configuration DB_PATH = os.path.join(os.path.dirname(__file__), "exchange_multichain.db") @@ -145,10 +151,10 @@ async def verify_chain_transaction(chain_id: str, tx_hash: str) -> bool: return False try: - async with httpx.AsyncClient() as client: - response = await client.get(f"{chain_info['blockchain_url']}/api/v1/transactions/{tx_hash}") - return response.status_code == 200 - except: + client = AsyncAITBCHTTPClient(base_url=chain_info['blockchain_url'], timeout=5) + response = await client.async_get(f"/api/v1/transactions/{tx_hash}") + return response is not None + except NetworkError: return False async def submit_chain_transaction(chain_id: str, order_data: Dict) -> Optional[str]: @@ -161,16 +167,13 @@ async def submit_chain_transaction(chain_id: str, order_data: Dict) -> Optional[ return None try: - async with httpx.AsyncClient() as client: - response = await client.post( - f"{chain_info['blockchain_url']}/api/v1/transactions", - json=order_data - ) - if response.status_code == 200: - return response.json().get("tx_hash") - except Exception as e: - print(f"Chain transaction error: {e}") - + client = AsyncAITBCHTTPClient(base_url=chain_info['blockchain_url'], timeout=10) + response = await client.async_post("/api/v1/transactions", json=order_data) + if response: + return response.get("tx_hash") + except NetworkError as e: + logger.error(f"Chain transaction error: {e}") + return None # API Endpoints @@ -188,10 +191,10 @@ async def health_check(): if chain_info["status"] == "active" and chain_info["blockchain_url"]: try: - async with httpx.AsyncClient() as client: - response = await client.get(f"{chain_info['blockchain_url']}/health", timeout=5.0) - chain_status[chain_id]["connected"] = response.status_code == 200 - except: + client = AsyncAITBCHTTPClient(base_url=chain_info['blockchain_url'], timeout=5) + response = await client.async_get("/health") + chain_status[chain_id]["connected"] = response is not None + except NetworkError: pass return { diff --git a/apps/simple-explorer/main.py b/apps/simple-explorer/main.py index 21c064ab..25902180 100644 --- a/apps/simple-explorer/main.py +++ b/apps/simple-explorer/main.py @@ -4,7 +4,6 @@ Simple AITBC Blockchain Explorer - Demonstrating the issues described in the ana """ import asyncio -import httpx import re from datetime import datetime from typing import Dict, Any, Optional @@ -12,8 +11,15 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse import uvicorn +from aitbc.http_client import AsyncAITBCHTTPClient +from aitbc.aitbc_logging import get_logger +from aitbc.exceptions import NetworkError + app = FastAPI(title="Simple AITBC Explorer", version="0.1.0") +# Initialize logger +logger = get_logger(__name__) + # Configuration BLOCKCHAIN_RPC_URL = "http://localhost:8025" @@ -174,12 +180,12 @@ HTML_TEMPLATE = """ async def get_chain_head(): """Get current chain head""" try: - async with httpx.AsyncClient() as client: - response = await client.get(f"{BLOCKCHAIN_RPC_URL}/rpc/head") - if response.status_code == 200: - return response.json() - except Exception as e: - print(f"Error getting chain head: {e}") + client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10) + response = await client.async_get("/rpc/head") + if response: + return response + except NetworkError as e: + logger.error(f"Error getting chain head: {e}") return {"height": 0, "hash": "", "timestamp": None} @app.get("/api/blocks/{height}") @@ -189,12 +195,12 @@ async def get_block(height: int): if height < 0 or height > 10000000: return {"height": height, "hash": "", "timestamp": None, "transactions": []} try: - async with httpx.AsyncClient() as client: - response = await client.get(f"{BLOCKCHAIN_RPC_URL}/rpc/blocks/{height}") - if response.status_code == 200: - return response.json() - except Exception as e: - print(f"Error getting block {height}: {e}") + client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10) + response = await client.async_get(f"/rpc/blocks/{height}") + if response: + return response + except NetworkError as e: + logger.error(f"Error getting block: {e}") return {"height": height, "hash": "", "timestamp": None, "transactions": []} @app.get("/api/transactions/{tx_hash}") @@ -203,26 +209,21 @@ async def get_transaction(tx_hash: str): if not validate_tx_hash(tx_hash): return {"hash": tx_hash, "from": "unknown", "to": "unknown", "amount": 0, "timestamp": None} try: - async with httpx.AsyncClient() as client: - response = await client.get(f"{BLOCKCHAIN_RPC_URL}/rpc/tx/{tx_hash}") - if response.status_code == 200: - tx_data = response.json() - # Problem 2: Map RPC schema to UI schema - return { - "hash": tx_data.get("tx_hash", tx_hash), # tx_hash -> hash - "from": tx_data.get("sender", "unknown"), # sender -> from - "to": tx_data.get("recipient", "unknown"), # recipient -> to - "amount": tx_data.get("payload", {}).get("value", "0"), # payload.value -> amount - "fee": tx_data.get("payload", {}).get("fee", "0"), # payload.fee -> fee - "timestamp": tx_data.get("created_at"), # created_at -> timestamp - "block_height": tx_data.get("block_height", "pending") - } - elif response.status_code == 404: - raise HTTPException(status_code=404, detail="Transaction not found") - except HTTPException: - raise - except Exception as e: - print(f"Error getting transaction {tx_hash}: {e}") + client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10) + response = await client.async_get(f"/rpc/tx/{tx_hash}") + if response: + # Problem 2: Map RPC schema to UI schema + return { + "hash": response.get("tx_hash", tx_hash), # tx_hash -> hash + "from": response.get("sender", "unknown"), # sender -> from + "to": response.get("recipient", "unknown"), # recipient -> to + "amount": response.get("payload", {}).get("value", "0"), # payload.value -> amount + "fee": response.get("payload", {}).get("fee", "0"), # payload.fee -> fee + "timestamp": response.get("created_at"), # created_at -> timestamp + "block_height": response.get("block_height", "pending") + } + except NetworkError as e: + logger.error(f"Error getting transaction {tx_hash}: {e}") raise HTTPException(status_code=500, detail=f"Failed to fetch transaction: {str(e)}") # Missing: @app.get("/api/transactions/{tx_hash}") - THIS IS THE PROBLEM diff --git a/apps/wallet/simple_daemon.py b/apps/wallet/simple_daemon.py index 95f34f14..21f70aea 100755 --- a/apps/wallet/simple_daemon.py +++ b/apps/wallet/simple_daemon.py @@ -16,6 +16,8 @@ from pathlib import Path import os import sys +from aitbc.constants import KEYSTORE_DIR + # Add CLI utils to path sys.path.insert(0, '/opt/aitbc/cli') @@ -23,7 +25,7 @@ sys.path.insert(0, '/opt/aitbc/cli') app = FastAPI(title="AITBC Wallet Daemon", debug=False) # Configuration -KEYSTORE_PATH = Path("/var/lib/aitbc/keystore") +KEYSTORE_PATH = KEYSTORE_DIR BLOCKCHAIN_RPC_URL = "http://localhost:8006" CHAIN_ID = "ait-mainnet" diff --git a/cli/aitbc_cli.py b/cli/aitbc_cli.py index 4c44a808..6dd7c250 100755 --- a/cli/aitbc_cli.py +++ b/cli/aitbc_cli.py @@ -42,6 +42,7 @@ from aitbc import ( AITBCHTTPClient, NetworkError, ValidationError, ConfigurationError, get_logger, get_keystore_path, ensure_dir, validate_address, validate_url ) +from aitbc.paths import get_blockchain_data_path, get_data_path # Initialize logger logger = get_logger(__name__) @@ -1188,8 +1189,9 @@ def agent_operations(action: str, **kwargs) -> Optional[Dict]: sys.path.insert(0, "/opt/aitbc/apps/blockchain-node/src") from sqlmodel import create_engine, Session, select from aitbc_chain.models import Transaction - - engine = create_engine("sqlite:////var/lib/aitbc/data/ait-mainnet/chain.db") + + chain_db_path = get_blockchain_data_path("ait-mainnet") / "chain.db" + engine = create_engine(f"sqlite:///{chain_db_path}") with Session(engine) as session: # Query transactions where recipient is the agent txs = session.exec( @@ -2525,11 +2527,13 @@ def legacy_main(): daemon_url = getattr(args, 'daemon_url', DEFAULT_WALLET_DAEMON_URL) if args.wallet_action == "backup": print(f"Wallet backup: {args.name}") - print(f" Backup created: /var/lib/aitbc/backups/{args.name}_$(date +%Y%m%d).json") + backup_path = get_data_path("backups") + print(f" Backup created: {backup_path}/{args.name}_$(date +%Y%m%d).json") print(f" Status: completed") elif args.wallet_action == "export": print(f"Wallet export: {args.name}") - print(f" Export file: /var/lib/aitbc/exports/{args.name}_private.json") + export_path = get_data_path("exports") + print(f" Export file: {export_path}/{args.name}_private.json") print(f" Status: completed") elif args.wallet_action == "sync": if args.all: @@ -2602,12 +2606,14 @@ def legacy_main(): elif args.command == "wallet-backup": print(f"Wallet backup: {args.name}") - print(f" Backup created: /var/lib/aitbc/backups/{args.name}_$(date +%Y%m%d).json") + backup_path = get_data_path("backups") + print(f" Backup created: {backup_path}/{args.name}_$(date +%Y%m%d).json") print(f" Status: completed") - + elif args.command == "wallet-export": print(f"Wallet export: {args.name}") - print(f" Export file: /var/lib/aitbc/exports/{args.name}_private.json") + export_path = get_data_path("exports") + print(f" Export file: {export_path}/{args.name}_private.json") print(f" Status: completed") elif args.command == "wallet-sync": diff --git a/cli/enterprise_cli.py b/cli/enterprise_cli.py index 63b2f82a..c1f36b1c 100755 --- a/cli/enterprise_cli.py +++ b/cli/enterprise_cli.py @@ -12,9 +12,10 @@ from pathlib import Path from typing import Optional, Dict, Any, List import requests import getpass +from aitbc.paths import get_keystore_path # Default paths -DEFAULT_KEYSTORE_DIR = Path("/var/lib/aitbc/keystore") +DEFAULT_KEYSTORE_DIR = get_keystore_path() DEFAULT_RPC_URL = "http://localhost:8006" def get_password(password_arg: str = None, password_file: str = None) -> str: diff --git a/cli/extended_features.py b/cli/extended_features.py index 00cbcb4d..f5defdac 100644 --- a/cli/extended_features.py +++ b/cli/extended_features.py @@ -2,8 +2,9 @@ import json import os import time import uuid +from aitbc.paths import get_data_path -STATE_FILE = "/var/lib/aitbc/data/cli_extended_state.json" +STATE_FILE = str(get_data_path("data/cli_extended_state.json")) def load_state(): if os.path.exists(STATE_FILE): @@ -82,10 +83,10 @@ def handle_extended_command(command, args, kwargs): result["nodes_reached"] = 2 elif command == "wallet_backup": - result["path"] = f"/var/lib/aitbc/backups/{kwargs.get('name')}.backup" - + result["path"] = f"{get_data_path('backups')}/{kwargs.get('name')}.backup" + elif command == "wallet_export": - result["path"] = f"/var/lib/aitbc/exports/{kwargs.get('name')}.key" + result["path"] = f"{get_data_path('exports')}/{kwargs.get('name')}.key" elif command == "wallet_sync": result["status"] = "Wallets synchronized" @@ -275,7 +276,7 @@ def handle_extended_command(command, args, kwargs): elif command == "compliance_report": result["format"] = kwargs.get("format") - result["path"] = "/var/lib/aitbc/reports/compliance.pdf" + result["path"] = f"{get_data_path('reports')}/compliance.pdf" elif command == "script_run": result["file"] = kwargs.get("file") diff --git a/cli/handlers/wallet.py b/cli/handlers/wallet.py index fa2d373d..ee96ae60 100644 --- a/cli/handlers/wallet.py +++ b/cli/handlers/wallet.py @@ -2,6 +2,7 @@ import json import sys +from aitbc.paths import get_data_path def handle_wallet_create(args, create_wallet, read_password, first): @@ -140,7 +141,8 @@ def handle_wallet_backup(args, first): print("Error: Wallet name is required") sys.exit(1) print(f"Wallet backup: {wallet_name}") - print(f" Backup created: /var/lib/aitbc/backups/{wallet_name}_$(date +%Y%m%d).json") + backup_path = get_data_path("backups") + print(f" Backup created: {backup_path}/{wallet_name}_$(date +%Y%m%d).json") print(" Status: completed") diff --git a/cli/keystore_auth.py b/cli/keystore_auth.py index 9a5bf1f7..0415c326 100644 --- a/cli/keystore_auth.py +++ b/cli/keystore_auth.py @@ -42,8 +42,10 @@ def decrypt_private_key(keystore_data: Dict[str, Any], password: str) -> str: return decrypted.decode() -def load_keystore(address: str, keystore_dir: Path | str = "/var/lib/aitbc/keystore") -> Dict[str, Any]: +def load_keystore(address: str, keystore_dir: Path | str = None) -> Dict[str, Any]: """Load keystore file for a given address.""" + if keystore_dir is None: + keystore_dir = get_keystore_path() keystore_dir = Path(keystore_dir) keystore_file = keystore_dir / f"{address}.json"