Expand aitbc package with new utility modules and enhanced HTTP client

- Add new exception types: RetryError, CircuitBreakerOpenError, RateLimitError
- Enhance AITBCHTTPClient with retry logic, caching, circuit breaker, and rate limiting
- Add AsyncAITBCHTTPClient for async HTTP operations
- Add crypto module with Ethereum key derivation, signing, encryption, and hashing utilities
- Add web3_utils module with Web3Client and create_web3_client
- Add security module with token generation, API key management
This commit is contained in:
aitbc
2026-04-25 07:46:44 +02:00
parent dea9550dc9
commit ad5c147789
25 changed files with 4354 additions and 154 deletions

View File

@@ -29,6 +29,9 @@ from .exceptions import (
DatabaseError,
ValidationError,
BridgeError,
RetryError,
CircuitBreakerOpenError,
RateLimitError,
)
from .env import (
get_env_var,
@@ -60,7 +63,7 @@ from .json_utils import (
set_nested_value,
flatten_json,
)
from .http_client import AITBCHTTPClient
from .http_client import AITBCHTTPClient, AsyncAITBCHTTPClient
from .config import BaseAITBCConfig, AITBCConfig
from .decorators import (
retry,
@@ -105,6 +108,142 @@ from .monitoring import (
PerformanceTimer,
HealthChecker,
)
from .crypto import (
derive_ethereum_address,
sign_transaction_hash,
verify_signature,
encrypt_private_key,
decrypt_private_key,
generate_secure_random_bytes,
keccak256_hash,
sha256_hash,
validate_ethereum_address,
generate_ethereum_private_key,
)
from .web3_utils import Web3Client, create_web3_client
from .security import (
generate_token,
generate_api_key,
validate_token_format,
validate_api_key,
SessionManager,
APIKeyManager,
generate_secure_random_string,
generate_secure_random_int,
SecretManager,
hash_password,
verify_password,
generate_nonce,
generate_hmac,
verify_hmac,
)
from .time_utils import (
get_utc_now,
get_timestamp_utc,
format_iso8601,
parse_iso8601,
timestamp_to_iso,
iso_to_timestamp,
format_duration,
format_duration_precise,
parse_duration,
add_duration,
subtract_duration,
get_time_until,
get_time_since,
calculate_deadline,
is_deadline_passed,
get_deadline_remaining,
format_time_ago,
format_time_in,
to_timezone,
get_timezone_offset,
is_business_hours,
get_start_of_day,
get_end_of_day,
get_start_of_week,
get_end_of_week,
get_start_of_month,
get_end_of_month,
sleep_until,
retry_until_deadline,
Timer,
)
from .api_utils import (
APIResponse,
PaginatedResponse,
success_response,
error_response,
not_found_response,
unauthorized_response,
forbidden_response,
validation_error_response,
conflict_response,
internal_error_response,
PaginationParams,
paginate_items,
build_paginated_response,
RateLimitHeaders,
build_cors_headers,
build_standard_headers,
validate_sort_field,
validate_sort_order,
build_sort_params,
filter_fields,
exclude_fields,
sanitize_response,
merge_responses,
get_client_ip,
get_user_agent,
build_request_metadata,
)
from .events import (
Event,
EventPriority,
EventBus,
AsyncEventBus,
event_handler,
publish_event,
get_global_event_bus,
set_global_event_bus,
EventFilter,
EventAggregator,
EventRouter,
)
from .queue import (
Job,
JobStatus,
JobPriority,
TaskQueue,
JobScheduler,
BackgroundTaskManager,
WorkerPool,
debounce,
throttle,
)
from .state import (
StateTransition,
StateTransitionError,
StatePersistenceError,
StateMachine,
ConfigurableStateMachine,
StatePersistence,
AsyncStateMachine,
StateMonitor,
StateValidator,
StateSnapshot,
)
from .testing import (
MockFactory,
TestDataGenerator,
TestHelpers,
MockResponse,
MockDatabase,
MockCache,
mock_async_call,
create_mock_config,
create_test_scenario,
)
__version__ = "0.6.0"
__all__ = [
@@ -135,6 +274,9 @@ __all__ = [
"DatabaseError",
"ValidationError",
"BridgeError",
"RetryError",
"CircuitBreakerOpenError",
"RateLimitError",
# Environment helpers
"get_env_var",
"get_required_env_var",
@@ -164,6 +306,7 @@ __all__ = [
"flatten_json",
# HTTP client
"AITBCHTTPClient",
"AsyncAITBCHTTPClient",
# Configuration
"BaseAITBCConfig",
"AITBCConfig",
@@ -205,4 +348,134 @@ __all__ = [
"MetricsCollector",
"PerformanceTimer",
"HealthChecker",
# Cryptography
"derive_ethereum_address",
"sign_transaction_hash",
"verify_signature",
"encrypt_private_key",
"decrypt_private_key",
"generate_secure_random_bytes",
"keccak256_hash",
"sha256_hash",
"validate_ethereum_address",
"generate_ethereum_private_key",
# Web3 utilities
"Web3Client",
"create_web3_client",
# Security
"generate_token",
"generate_api_key",
"validate_token_format",
"validate_api_key",
"SessionManager",
"APIKeyManager",
"generate_secure_random_string",
"generate_secure_random_int",
"SecretManager",
"hash_password",
"verify_password",
"generate_nonce",
"generate_hmac",
"verify_hmac",
# Time utilities
"get_utc_now",
"get_timestamp_utc",
"format_iso8601",
"parse_iso8601",
"timestamp_to_iso",
"iso_to_timestamp",
"format_duration",
"format_duration_precise",
"parse_duration",
"add_duration",
"subtract_duration",
"get_time_until",
"get_time_since",
"calculate_deadline",
"is_deadline_passed",
"get_deadline_remaining",
"format_time_ago",
"format_time_in",
"to_timezone",
"get_timezone_offset",
"is_business_hours",
"get_start_of_day",
"get_end_of_day",
"get_start_of_week",
"get_end_of_week",
"get_start_of_month",
"get_end_of_month",
"sleep_until",
"retry_until_deadline",
"Timer",
# API utilities
"APIResponse",
"PaginatedResponse",
"success_response",
"error_response",
"not_found_response",
"unauthorized_response",
"forbidden_response",
"validation_error_response",
"conflict_response",
"internal_error_response",
"PaginationParams",
"paginate_items",
"build_paginated_response",
"RateLimitHeaders",
"build_cors_headers",
"build_standard_headers",
"validate_sort_field",
"validate_sort_order",
"build_sort_params",
"filter_fields",
"exclude_fields",
"sanitize_response",
"merge_responses",
"get_client_ip",
"get_user_agent",
"build_request_metadata",
# Events
"Event",
"EventPriority",
"EventBus",
"AsyncEventBus",
"event_handler",
"publish_event",
"get_global_event_bus",
"set_global_event_bus",
"EventFilter",
"EventAggregator",
"EventRouter",
# Queue
"Job",
"JobStatus",
"JobPriority",
"TaskQueue",
"JobScheduler",
"BackgroundTaskManager",
"WorkerPool",
"debounce",
"throttle",
# State
"StateTransition",
"StateTransitionError",
"StatePersistenceError",
"StateMachine",
"ConfigurableStateMachine",
"StatePersistence",
"AsyncStateMachine",
"StateMonitor",
"StateValidator",
"StateSnapshot",
# Testing
"MockFactory",
"TestDataGenerator",
"TestHelpers",
"MockResponse",
"MockDatabase",
"MockCache",
"mock_async_call",
"create_mock_config",
"create_test_scenario",
]

322
aitbc/api_utils.py Normal file
View File

@@ -0,0 +1,322 @@
"""
API utilities for AITBC
Provides standard response formatters, pagination helpers, error response builders, and rate limit headers helpers
"""
from typing import Any, Optional, List, Dict, Union
from datetime import datetime
from fastapi import HTTPException, status
from pydantic import BaseModel
class APIResponse(BaseModel):
"""Standard API response model"""
success: bool
message: str
data: Optional[Any] = None
error: Optional[str] = None
timestamp: str = None
def __init__(self, **data):
if 'timestamp' not in data:
data['timestamp'] = datetime.utcnow().isoformat()
super().__init__(**data)
class PaginatedResponse(BaseModel):
"""Paginated API response model"""
success: bool
message: str
data: List[Any]
pagination: Dict[str, Any]
timestamp: str = None
def __init__(self, **data):
if 'timestamp' not in data:
data['timestamp'] = datetime.utcnow().isoformat()
super().__init__(**data)
def success_response(message: str = "Success", data: Optional[Any] = None) -> APIResponse:
"""Create a success response"""
return APIResponse(success=True, message=message, data=data)
def error_response(message: str, error: Optional[str] = None, status_code: int = 400) -> HTTPException:
"""Create an error response"""
return HTTPException(
status_code=status_code,
detail={"success": False, "message": message, "error": error}
)
def not_found_response(resource: str = "Resource") -> HTTPException:
"""Create a not found response"""
return error_response(
message=f"{resource} not found",
error="NOT_FOUND",
status_code=404
)
def unauthorized_response(message: str = "Unauthorized") -> HTTPException:
"""Create an unauthorized response"""
return error_response(
message=message,
error="UNAUTHORIZED",
status_code=401
)
def forbidden_response(message: str = "Forbidden") -> HTTPException:
"""Create a forbidden response"""
return error_response(
message=message,
error="FORBIDDEN",
status_code=403
)
def validation_error_response(errors: List[str]) -> HTTPException:
"""Create a validation error response"""
return error_response(
message="Validation failed",
error="VALIDATION_ERROR",
status_code=422
)
def conflict_response(message: str = "Resource conflict") -> HTTPException:
"""Create a conflict response"""
return error_response(
message=message,
error="CONFLICT",
status_code=409
)
def internal_error_response(message: str = "Internal server error") -> HTTPException:
"""Create an internal server error response"""
return error_response(
message=message,
error="INTERNAL_ERROR",
status_code=500
)
class PaginationParams:
"""Pagination parameters"""
def __init__(self, page: int = 1, page_size: int = 10, max_page_size: int = 100):
"""Initialize pagination parameters"""
self.page = max(1, page)
self.page_size = min(max_page_size, max(1, page_size))
self.offset = (self.page - 1) * self.page_size
def get_limit(self) -> int:
"""Get SQL limit"""
return self.page_size
def get_offset(self) -> int:
"""Get SQL offset"""
return self.offset
def paginate_items(items: List[Any], page: int = 1, page_size: int = 10) -> Dict[str, Any]:
"""Paginate a list of items"""
total = len(items)
params = PaginationParams(page, page_size)
paginated_items = items[params.offset:params.offset + params.page_size]
total_pages = (total + params.page_size - 1) // params.page_size
return {
"items": paginated_items,
"pagination": {
"page": params.page,
"page_size": params.page_size,
"total": total,
"total_pages": total_pages,
"has_next": params.page < total_pages,
"has_prev": params.page > 1
}
}
def build_paginated_response(
items: List[Any],
page: int = 1,
page_size: int = 10,
message: str = "Success"
) -> PaginatedResponse:
"""Build a paginated API response"""
pagination_data = paginate_items(items, page, page_size)
return PaginatedResponse(
success=True,
message=message,
data=pagination_data["items"],
pagination=pagination_data["pagination"]
)
class RateLimitHeaders:
"""Rate limit headers helper"""
@staticmethod
def get_headers(
limit: int,
remaining: int,
reset: int,
window: int
) -> Dict[str, str]:
"""Get rate limit headers"""
return {
"X-RateLimit-Limit": str(limit),
"X-RateLimit-Remaining": str(remaining),
"X-RateLimit-Reset": str(reset),
"X-RateLimit-Window": str(window)
}
@staticmethod
def get_retry_after(retry_after: int) -> Dict[str, str]:
"""Get retry after header"""
return {"Retry-After": str(retry_after)}
def build_cors_headers(
allowed_origins: List[str] = ["*"],
allowed_methods: List[str] = ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allowed_headers: List[str] = ["*"],
max_age: int = 3600
) -> Dict[str, str]:
"""Build CORS headers"""
return {
"Access-Control-Allow-Origin": ", ".join(allowed_origins),
"Access-Control-Allow-Methods": ", ".join(allowed_methods),
"Access-Control-Allow-Headers": ", ".join(allowed_headers),
"Access-Control-Max-Age": str(max_age)
}
def build_standard_headers(
content_type: str = "application/json",
cache_control: Optional[str] = None,
x_request_id: Optional[str] = None
) -> Dict[str, str]:
"""Build standard response headers"""
headers = {
"Content-Type": content_type,
}
if cache_control:
headers["Cache-Control"] = cache_control
if x_request_id:
headers["X-Request-ID"] = x_request_id
return headers
def validate_sort_field(field: str, allowed_fields: List[str]) -> str:
"""Validate and return sort field"""
if field not in allowed_fields:
raise ValueError(f"Invalid sort field: {field}. Allowed fields: {', '.join(allowed_fields)}")
return field
def validate_sort_order(order: str) -> str:
"""Validate and return sort order"""
order = order.upper()
if order not in ["ASC", "DESC"]:
raise ValueError(f"Invalid sort order: {order}. Must be 'ASC' or 'DESC'")
return order
def build_sort_params(
sort_by: Optional[str] = None,
sort_order: str = "ASC",
allowed_fields: Optional[List[str]] = None
) -> Dict[str, Any]:
"""Build sort parameters"""
if sort_by and allowed_fields:
sort_by = validate_sort_field(sort_by, allowed_fields)
sort_order = validate_sort_order(sort_order)
return {"sort_by": sort_by, "sort_order": sort_order}
return {}
def filter_fields(data: Dict[str, Any], fields: List[str]) -> Dict[str, Any]:
"""Filter dictionary to only include specified fields"""
return {k: v for k, v in data.items() if k in fields}
def exclude_fields(data: Dict[str, Any], fields: List[str]) -> Dict[str, Any]:
"""Exclude specified fields from dictionary"""
return {k: v for k, v in data.items() if k not in fields}
def sanitize_response(data: Any, sensitive_fields: List[str] = None) -> Any:
"""Sanitize response by masking sensitive fields"""
if sensitive_fields is None:
sensitive_fields = ["password", "token", "api_key", "secret", "private_key"]
if isinstance(data, dict):
return {
k: "***" if any(sensitive in k.lower() for sensitive in sensitive_fields) else sanitize_response(v, sensitive_fields)
for k, v in data.items()
}
elif isinstance(data, list):
return [sanitize_response(item, sensitive_fields) for item in data]
else:
return data
def merge_responses(*responses: Union[APIResponse, Dict[str, Any]]) -> Dict[str, Any]:
"""Merge multiple responses into one"""
merged = {"data": {}}
for response in responses:
if isinstance(response, APIResponse):
if response.data:
if isinstance(response.data, dict):
merged["data"].update(response.data)
else:
merged["data"] = response.data
elif isinstance(response, dict):
if "data" in response:
if isinstance(response["data"], dict):
merged["data"].update(response["data"])
else:
merged["data"] = response["data"]
return merged
def get_client_ip(request) -> str:
"""Get client IP address from request"""
# Check for forwarded headers first
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
return request.client.host if request.client else "unknown"
def get_user_agent(request) -> str:
"""Get user agent from request"""
return request.headers.get("User-Agent", "unknown")
def build_request_metadata(request) -> Dict[str, str]:
"""Build request metadata"""
return {
"client_ip": get_client_ip(request),
"user_agent": get_user_agent(request),
"request_id": request.headers.get("X-Request-ID", "unknown"),
"timestamp": datetime.utcnow().isoformat()
}

174
aitbc/crypto.py Normal file
View File

@@ -0,0 +1,174 @@
"""
Cryptographic utilities for AITBC
Provides Ethereum-specific cryptographic operations and security functions
"""
from typing import Any, Optional
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
import base64
import os
import hashlib
def derive_ethereum_address(private_key: str) -> str:
"""Derive Ethereum address from private key using eth-account"""
try:
from eth_account import Account
# Remove 0x prefix if present
if private_key.startswith("0x"):
private_key = private_key[2:]
account = Account.from_key(private_key)
return account.address
except ImportError:
raise ImportError("eth-account is required for Ethereum address derivation. Install with: pip install eth-account")
except Exception as e:
raise ValueError(f"Failed to derive address from private key: {e}")
def sign_transaction_hash(transaction_hash: str, private_key: str) -> str:
"""Sign a transaction hash with private key using eth-account"""
try:
from eth_account import Account
# Remove 0x prefix if present
if private_key.startswith("0x"):
private_key = private_key[2:]
if transaction_hash.startswith("0x"):
transaction_hash = transaction_hash[2:]
account = Account.from_key(private_key)
signed_message = account.sign_hash(bytes.fromhex(transaction_hash))
return signed_message.signature.hex()
except ImportError:
raise ImportError("eth-account is required for signing. Install with: pip install eth-account")
except Exception as e:
raise ValueError(f"Failed to sign transaction hash: {e}")
def verify_signature(message_hash: str, signature: str, address: str) -> bool:
"""Verify a signature using eth-account"""
try:
from eth_account import Account
from eth_utils import to_bytes
# Remove 0x prefixes if present
if message_hash.startswith("0x"):
message_hash = message_hash[2:]
if signature.startswith("0x"):
signature = signature[2:]
if address.startswith("0x"):
address = address[2:]
message_bytes = to_bytes(hexstr=message_hash)
signature_bytes = to_bytes(hexstr=signature)
recovered_address = Account.recover_message(message_bytes, signature_bytes)
return recovered_address.lower() == address.lower()
except ImportError:
raise ImportError("eth-account and eth-utils are required for signature verification. Install with: pip install eth-account eth-utils")
except Exception as e:
raise ValueError(f"Failed to verify signature: {e}")
def encrypt_private_key(private_key: str, password: str) -> str:
"""Encrypt private key using Fernet symmetric encryption"""
try:
# Derive key from password
password_bytes = password.encode('utf-8')
salt = os.urandom(16)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password_bytes))
# Encrypt private key
fernet = Fernet(key)
encrypted_key = fernet.encrypt(private_key.encode('utf-8'))
# Combine salt and encrypted key
combined = salt + encrypted_key
return base64.urlsafe_b64encode(combined).decode('utf-8')
except Exception as e:
raise ValueError(f"Failed to encrypt private key: {e}")
def decrypt_private_key(encrypted_key: str, password: str) -> str:
"""Decrypt private key using Fernet symmetric encryption"""
try:
# Decode combined data
combined = base64.urlsafe_b64decode(encrypted_key.encode('utf-8'))
salt = combined[:16]
encrypted_data = combined[16:]
# Derive key from password
password_bytes = password.encode('utf-8')
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password_bytes))
# Decrypt private key
fernet = Fernet(key)
decrypted_key = fernet.decrypt(encrypted_data)
return decrypted_key.decode('utf-8')
except Exception as e:
raise ValueError(f"Failed to decrypt private key: {e}")
def generate_secure_random_bytes(length: int = 32) -> str:
"""Generate cryptographically secure random bytes as hex string"""
return os.urandom(length).hex()
def keccak256_hash(data: str) -> str:
"""Compute Keccak-256 hash of data"""
try:
from eth_hash.auto import keccak
if isinstance(data, str):
data = data.encode('utf-8')
return keccak(data).hex()
except ImportError:
raise ImportError("eth-hash is required for Keccak-256 hashing. Install with: pip install eth-hash")
except Exception as e:
raise ValueError(f"Failed to compute Keccak-256 hash: {e}")
def sha256_hash(data: str) -> str:
"""Compute SHA-256 hash of data"""
try:
if isinstance(data, str):
data = data.encode('utf-8')
return hashlib.sha256(data).hexdigest()
except Exception as e:
raise ValueError(f"Failed to compute SHA-256 hash: {e}")
def validate_ethereum_address(address: str) -> bool:
"""Validate Ethereum address format and checksum"""
try:
from eth_utils import is_address, is_checksum_address
return is_address(address) and is_checksum_address(address)
except ImportError:
raise ImportError("eth-utils is required for address validation. Install with: pip install eth-utils")
except Exception:
return False
def generate_ethereum_private_key() -> str:
"""Generate a new Ethereum private key"""
try:
from eth_account import Account
account = Account.create()
return account.key.hex()
except ImportError:
raise ImportError("eth-account is required for private key generation. Install with: pip install eth-account")
except Exception as e:
raise ValueError(f"Failed to generate private key: {e}")

267
aitbc/events.py Normal file
View File

@@ -0,0 +1,267 @@
"""
Event utilities for AITBC
Provides event bus implementation, pub/sub patterns, and event decorators
"""
import asyncio
from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
import inspect
import functools
T = TypeVar('T')
class EventPriority(Enum):
"""Event priority levels"""
LOW = 1
MEDIUM = 2
HIGH = 3
CRITICAL = 4
@dataclass
class Event:
"""Base event class"""
event_type: str
data: Dict[str, Any]
timestamp: datetime = None
priority: EventPriority = EventPriority.MEDIUM
source: Optional[str] = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = datetime.utcnow()
class EventBus:
"""Simple in-memory event bus for pub/sub patterns"""
def __init__(self):
"""Initialize event bus"""
self.subscribers: Dict[str, List[Callable]] = {}
self.event_history: List[Event] = []
self.max_history = 1000
def subscribe(self, event_type: str, handler: Callable) -> None:
"""Subscribe to an event type"""
if event_type not in self.subscribers:
self.subscribers[event_type] = []
self.subscribers[event_type].append(handler)
def unsubscribe(self, event_type: str, handler: Callable) -> bool:
"""Unsubscribe from an event type"""
if event_type in self.subscribers:
try:
self.subscribers[event_type].remove(handler)
return True
except ValueError:
pass
return False
async def publish(self, event: Event) -> None:
"""Publish an event to all subscribers"""
# Add to history
self.event_history.append(event)
if len(self.event_history) > self.max_history:
self.event_history.pop(0)
# Notify subscribers
handlers = self.subscribers.get(event.event_type, [])
for handler in handlers:
try:
if inspect.iscoroutinefunction(handler):
await handler(event)
else:
handler(event)
except Exception as e:
print(f"Error in event handler: {e}")
def publish_sync(self, event: Event) -> None:
"""Publish an event synchronously"""
asyncio.run(self.publish(event))
def get_event_history(self, event_type: Optional[str] = None, limit: int = 100) -> List[Event]:
"""Get event history"""
events = self.event_history
if event_type:
events = [e for e in events if e.event_type == event_type]
return events[-limit:]
def clear_history(self) -> None:
"""Clear event history"""
self.event_history.clear()
class AsyncEventBus(EventBus):
"""Async event bus with additional features"""
def __init__(self, max_concurrent_handlers: int = 10):
"""Initialize async event bus"""
super().__init__()
self.semaphore = asyncio.Semaphore(max_concurrent_handlers)
async def publish(self, event: Event) -> None:
"""Publish event with concurrency control"""
self.event_history.append(event)
if len(self.event_history) > self.max_history:
self.event_history.pop(0)
handlers = self.subscribers.get(event.event_type, [])
tasks = []
for handler in handlers:
async def safe_handler():
async with self.semaphore:
try:
if inspect.iscoroutinefunction(handler):
await handler(event)
else:
handler(event)
except Exception as e:
print(f"Error in event handler: {e}")
tasks.append(safe_handler())
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
def event_handler(event_type: str, event_bus: Optional[EventBus] = None):
"""Decorator to register event handler"""
def decorator(func: Callable) -> Callable:
# Use global event bus if none provided
bus = event_bus or get_global_event_bus()
bus.subscribe(event_type, func)
return func
return decorator
def publish_event(event_type: str, data: Dict[str, Any], event_bus: Optional[EventBus] = None) -> None:
"""Helper to publish an event"""
bus = event_bus or get_global_event_bus()
event = Event(event_type=event_type, data=data)
bus.publish_sync(event)
# Global event bus instance
_global_event_bus: Optional[EventBus] = None
def get_global_event_bus() -> EventBus:
"""Get or create global event bus"""
global _global_event_bus
if _global_event_bus is None:
_global_event_bus = EventBus()
return _global_event_bus
def set_global_event_bus(bus: EventBus) -> None:
"""Set global event bus"""
global _global_event_bus
_global_event_bus = bus
class EventFilter:
"""Filter events based on criteria"""
def __init__(self, event_bus: Optional[EventBus] = None):
"""Initialize event filter"""
self.event_bus = event_bus or get_global_event_bus()
self.filters: List[Callable[[Event], bool]] = []
def add_filter(self, filter_func: Callable[[Event], bool]) -> None:
"""Add a filter function"""
self.filters.append(filter_func)
def matches(self, event: Event) -> bool:
"""Check if event matches all filters"""
return all(f(event) for f in self.filters)
def get_filtered_events(self, event_type: Optional[str] = None, limit: int = 100) -> List[Event]:
"""Get filtered events"""
events = self.event_bus.get_event_history(event_type, limit)
return [e for e in events if self.matches(e)]
class EventAggregator:
"""Aggregate events over time windows"""
def __init__(self, window_seconds: int = 60):
"""Initialize event aggregator"""
self.window_seconds = window_seconds
self.aggregated_events: Dict[str, Dict[str, Any]] = {}
def add_event(self, event: Event) -> None:
"""Add event to aggregation"""
key = event.event_type
now = datetime.utcnow()
if key not in self.aggregated_events:
self.aggregated_events[key] = {
"count": 0,
"first_seen": now,
"last_seen": now,
"data": {}
}
agg = self.aggregated_events[key]
agg["count"] += 1
agg["last_seen"] = now
# Merge data
for k, v in event.data.items():
if k not in agg["data"]:
agg["data"][k] = v
elif isinstance(v, (int, float)):
agg["data"][k] = agg["data"].get(k, 0) + v
def get_aggregated_events(self) -> Dict[str, Dict[str, Any]]:
"""Get aggregated events"""
# Remove old events
now = datetime.utcnow()
cutoff = now.timestamp() - self.window_seconds
to_remove = []
for key, agg in self.aggregated_events.items():
if agg["last_seen"].timestamp() < cutoff:
to_remove.append(key)
for key in to_remove:
del self.aggregated_events[key]
return self.aggregated_events
def clear(self) -> None:
"""Clear all aggregated events"""
self.aggregated_events.clear()
class EventRouter:
"""Route events to different handlers based on criteria"""
def __init__(self):
"""Initialize event router"""
self.routes: List[Callable[[Event], Optional[Callable]]] = []
def add_route(self, condition: Callable[[Event], bool], handler: Callable) -> None:
"""Add a route"""
self.routes.append((condition, handler))
async def route(self, event: Event) -> bool:
"""Route event to matching handler"""
for condition, handler in self.routes:
if condition(event):
try:
if inspect.iscoroutinefunction(handler):
await handler(event)
else:
handler(event)
return True
except Exception as e:
print(f"Error in routed handler: {e}")
return False

View File

@@ -42,3 +42,18 @@ class ValidationError(AITBCError):
class BridgeError(AITBCError):
"""Base exception for bridge errors"""
pass
class RetryError(AITBCError):
"""Raised when retry attempts are exhausted"""
pass
class CircuitBreakerOpenError(AITBCError):
"""Raised when circuit breaker is open and requests are rejected"""
pass
class RateLimitError(AITBCError):
"""Raised when rate limit is exceeded"""
pass

View File

@@ -4,8 +4,13 @@ Base HTTP client with common utilities for AITBC applications
"""
import requests
import time
import asyncio
from typing import Dict, Any, Optional, Union
from .exceptions import NetworkError
from datetime import datetime, timedelta
from functools import lru_cache
from .exceptions import NetworkError, RetryError, CircuitBreakerOpenError, RateLimitError
from .aitbc_logging import get_logger
class AITBCHTTPClient:
@@ -18,7 +23,13 @@ class AITBCHTTPClient:
self,
base_url: str = "",
timeout: int = 30,
headers: Optional[Dict[str, str]] = None
headers: Optional[Dict[str, str]] = None,
max_retries: int = 3,
enable_cache: bool = False,
cache_ttl: int = 300,
enable_logging: bool = False,
circuit_breaker_threshold: int = 5,
rate_limit: Optional[int] = None
):
"""
Initialize HTTP client.
@@ -27,12 +38,37 @@ class AITBCHTTPClient:
base_url: Base URL for all requests
timeout: Request timeout in seconds
headers: Default headers for all requests
max_retries: Maximum retry attempts with exponential backoff
enable_cache: Enable request/response caching for GET requests
cache_ttl: Cache time-to-live in seconds
enable_logging: Enable request/response logging
circuit_breaker_threshold: Failures before opening circuit breaker
rate_limit: Rate limit in requests per minute
"""
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.headers = headers or {}
self.max_retries = max_retries
self.enable_cache = enable_cache
self.cache_ttl = cache_ttl
self.enable_logging = enable_logging
self.circuit_breaker_threshold = circuit_breaker_threshold
self.rate_limit = rate_limit
self.session = requests.Session()
self.session.headers.update(self.headers)
self.logger = get_logger(__name__)
# Cache storage: {url: (data, timestamp)}
self._cache: Dict[str, tuple] = {}
# Circuit breaker state
self._failure_count = 0
self._circuit_open = False
self._circuit_open_time = None
# Rate limiting state
self._request_times: list = []
def _build_url(self, endpoint: str) -> str:
"""
@@ -48,6 +84,98 @@ class AITBCHTTPClient:
return endpoint
return f"{self.base_url}/{endpoint.lstrip('/')}"
def _check_circuit_breaker(self) -> None:
"""Check if circuit breaker is open and raise exception if so."""
if self._circuit_open:
# Check if circuit should be reset (after 60 seconds)
if self._circuit_open_time and (datetime.now() - self._circuit_open_time).total_seconds() > 60:
self._circuit_open = False
self._failure_count = 0
self.logger.info("Circuit breaker reset to half-open state")
else:
raise CircuitBreakerOpenError("Circuit breaker is open, rejecting request")
def _record_failure(self) -> None:
"""Record a failure and potentially open circuit breaker."""
self._failure_count += 1
if self._failure_count >= self.circuit_breaker_threshold:
self._circuit_open = True
self._circuit_open_time = datetime.now()
self.logger.warning(f"Circuit breaker opened after {self._failure_count} failures")
def _check_rate_limit(self) -> None:
"""Check if rate limit is exceeded and raise exception if so."""
if not self.rate_limit:
return
now = datetime.now()
# Remove requests older than 1 minute
self._request_times = [t for t in self._request_times if (now - t).total_seconds() < 60]
if len(self._request_times) >= self.rate_limit:
raise RateLimitError(f"Rate limit exceeded: {self.rate_limit} requests per minute")
def _record_request(self) -> None:
"""Record a request timestamp for rate limiting."""
if self.rate_limit:
self._request_times.append(datetime.now())
def _get_cache_key(self, url: str, params: Optional[Dict[str, Any]] = None) -> str:
"""Generate cache key from URL and params."""
if params:
import hashlib
param_str = str(sorted(params.items()))
return f"{url}:{hashlib.md5(param_str.encode()).hexdigest()}"
return url
def _get_cache(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""Get cached response if available and not expired."""
if not self.enable_cache:
return None
if cache_key in self._cache:
data, timestamp = self._cache[cache_key]
if (datetime.now() - timestamp).total_seconds() < self.cache_ttl:
if self.enable_logging:
self.logger.info(f"Cache hit for {cache_key}")
return data
else:
# Expired, remove from cache
del self._cache[cache_key]
return None
def _set_cache(self, cache_key: str, data: Dict[str, Any]) -> None:
"""Cache response data."""
if self.enable_cache:
self._cache[cache_key] = (data, datetime.now())
if self.enable_logging:
self.logger.info(f"Cached response for {cache_key}")
def _retry_request(self, request_func, *args, **kwargs) -> Dict[str, Any]:
"""Execute request with retry logic and exponential backoff."""
last_error = None
for attempt in range(self.max_retries + 1):
try:
if attempt > 0:
backoff_time = 2 ** (attempt - 1)
if self.enable_logging:
self.logger.info(f"Retry attempt {attempt}/{self.max_retries} after {backoff_time}s backoff")
time.sleep(backoff_time)
return request_func(*args, **kwargs)
except requests.RequestException as e:
last_error = e
if attempt < self.max_retries:
if self.enable_logging:
self.logger.warning(f"Request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}")
continue
else:
if self.enable_logging:
self.logger.error(f"All retry attempts exhausted: {e}")
raise RetryError(f"Retry attempts exhausted: {e}")
raise NetworkError(f"Request failed: {last_error}")
def get(
self,
endpoint: str,
@@ -67,11 +195,29 @@ class AITBCHTTPClient:
Raises:
NetworkError: If request fails
CircuitBreakerOpenError: If circuit breaker is open
RateLimitError: If rate limit is exceeded
"""
url = self._build_url(endpoint)
cache_key = self._get_cache_key(url, params)
# Check cache first
cached_data = self._get_cache(cache_key)
if cached_data is not None:
return cached_data
# Check circuit breaker and rate limit
self._check_circuit_breaker()
self._check_rate_limit()
req_headers = {**self.headers, **(headers or {})}
try:
if self.enable_logging:
self.logger.info(f"GET {url} with params={params}")
start_time = datetime.now()
def _make_request():
response = self.session.get(
url,
params=params,
@@ -80,7 +226,26 @@ class AITBCHTTPClient:
)
response.raise_for_status()
return response.json()
try:
result = self._retry_request(_make_request)
# Cache successful GET requests
self._set_cache(cache_key, result)
# Record success for circuit breaker
self._failure_count = 0
self._record_request()
if self.enable_logging:
elapsed = (datetime.now() - start_time).total_seconds()
self.logger.info(f"GET {url} succeeded in {elapsed:.3f}s")
return result
except (RetryError, CircuitBreakerOpenError, RateLimitError):
raise
except requests.RequestException as e:
self._record_failure()
raise NetworkError(f"GET request failed: {e}")
def post(
@@ -104,11 +269,23 @@ class AITBCHTTPClient:
Raises:
NetworkError: If request fails
CircuitBreakerOpenError: If circuit breaker is open
RateLimitError: If rate limit is exceeded
"""
url = self._build_url(endpoint)
# Check circuit breaker and rate limit
self._check_circuit_breaker()
self._check_rate_limit()
req_headers = {**self.headers, **(headers or {})}
try:
if self.enable_logging:
self.logger.info(f"POST {url} with json={json}")
start_time = datetime.now()
def _make_request():
response = self.session.post(
url,
data=data,
@@ -118,7 +295,23 @@ class AITBCHTTPClient:
)
response.raise_for_status()
return response.json()
try:
result = self._retry_request(_make_request)
# Record success for circuit breaker
self._failure_count = 0
self._record_request()
if self.enable_logging:
elapsed = (datetime.now() - start_time).total_seconds()
self.logger.info(f"POST {url} succeeded in {elapsed:.3f}s")
return result
except (RetryError, CircuitBreakerOpenError, RateLimitError):
raise
except requests.RequestException as e:
self._record_failure()
raise NetworkError(f"POST request failed: {e}")
def put(
@@ -142,11 +335,23 @@ class AITBCHTTPClient:
Raises:
NetworkError: If request fails
CircuitBreakerOpenError: If circuit breaker is open
RateLimitError: If rate limit is exceeded
"""
url = self._build_url(endpoint)
# Check circuit breaker and rate limit
self._check_circuit_breaker()
self._check_rate_limit()
req_headers = {**self.headers, **(headers or {})}
try:
if self.enable_logging:
self.logger.info(f"PUT {url} with json={json}")
start_time = datetime.now()
def _make_request():
response = self.session.put(
url,
data=data,
@@ -156,7 +361,23 @@ class AITBCHTTPClient:
)
response.raise_for_status()
return response.json()
try:
result = self._retry_request(_make_request)
# Record success for circuit breaker
self._failure_count = 0
self._record_request()
if self.enable_logging:
elapsed = (datetime.now() - start_time).total_seconds()
self.logger.info(f"PUT {url} succeeded in {elapsed:.3f}s")
return result
except (RetryError, CircuitBreakerOpenError, RateLimitError):
raise
except requests.RequestException as e:
self._record_failure()
raise NetworkError(f"PUT request failed: {e}")
def delete(
@@ -178,11 +399,23 @@ class AITBCHTTPClient:
Raises:
NetworkError: If request fails
CircuitBreakerOpenError: If circuit breaker is open
RateLimitError: If rate limit is exceeded
"""
url = self._build_url(endpoint)
# Check circuit breaker and rate limit
self._check_circuit_breaker()
self._check_rate_limit()
req_headers = {**self.headers, **(headers or {})}
try:
if self.enable_logging:
self.logger.info(f"DELETE {url} with params={params}")
start_time = datetime.now()
def _make_request():
response = self.session.delete(
url,
params=params,
@@ -191,7 +424,23 @@ class AITBCHTTPClient:
)
response.raise_for_status()
return response.json() if response.content else {}
try:
result = self._retry_request(_make_request)
# Record success for circuit breaker
self._failure_count = 0
self._record_request()
if self.enable_logging:
elapsed = (datetime.now() - start_time).total_seconds()
self.logger.info(f"DELETE {url} succeeded in {elapsed:.3f}s")
return result
except (RetryError, CircuitBreakerOpenError, RateLimitError):
raise
except requests.RequestException as e:
self._record_failure()
raise NetworkError(f"DELETE request failed: {e}")
def close(self) -> None:
@@ -205,3 +454,279 @@ class AITBCHTTPClient:
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.close()
class AsyncAITBCHTTPClient:
"""
Async HTTP client for AITBC applications.
Provides async HTTP methods with error handling.
"""
def __init__(
self,
base_url: str = "",
timeout: int = 30,
headers: Optional[Dict[str, str]] = None,
max_retries: int = 3,
enable_cache: bool = False,
cache_ttl: int = 300,
enable_logging: bool = False,
circuit_breaker_threshold: int = 5,
rate_limit: Optional[int] = None
):
"""
Initialize async HTTP client.
Args:
base_url: Base URL for all requests
timeout: Request timeout in seconds
headers: Default headers for all requests
max_retries: Maximum retry attempts with exponential backoff
enable_cache: Enable request/response caching for GET requests
cache_ttl: Cache time-to-live in seconds
enable_logging: Enable request/response logging
circuit_breaker_threshold: Failures before opening circuit breaker
rate_limit: Rate limit in requests per minute
"""
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.headers = headers or {}
self.max_retries = max_retries
self.enable_cache = enable_cache
self.cache_ttl = cache_ttl
self.enable_logging = enable_logging
self.circuit_breaker_threshold = circuit_breaker_threshold
self.rate_limit = rate_limit
self.logger = get_logger(__name__)
self._client = None
# Cache storage: {url: (data, timestamp)}
self._cache: Dict[str, tuple] = {}
# Circuit breaker state
self._failure_count = 0
self._circuit_open = False
self._circuit_open_time = None
# Rate limiting state
self._request_times: list = []
async def __aenter__(self):
"""Async context manager entry."""
import httpx
self._client = httpx.AsyncClient(timeout=self.timeout, headers=self.headers)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
if self._client:
await self._client.aclose()
def _build_url(self, endpoint: str) -> str:
"""Build full URL from base URL and endpoint."""
if endpoint.startswith("http://") or endpoint.startswith("https://"):
return endpoint
return f"{self.base_url}/{endpoint.lstrip('/')}"
def _check_circuit_breaker(self) -> None:
"""Check if circuit breaker is open and raise exception if so."""
if self._circuit_open:
if self._circuit_open_time and (datetime.now() - self._circuit_open_time).total_seconds() > 60:
self._circuit_open = False
self._failure_count = 0
self.logger.info("Circuit breaker reset to half-open state")
else:
raise CircuitBreakerOpenError("Circuit breaker is open, rejecting request")
def _record_failure(self) -> None:
"""Record a failure and potentially open circuit breaker."""
self._failure_count += 1
if self._failure_count >= self.circuit_breaker_threshold:
self._circuit_open = True
self._circuit_open_time = datetime.now()
self.logger.warning(f"Circuit breaker opened after {self._failure_count} failures")
def _check_rate_limit(self) -> None:
"""Check if rate limit is exceeded and raise exception if so."""
if not self.rate_limit:
return
now = datetime.now()
self._request_times = [t for t in self._request_times if (now - t).total_seconds() < 60]
if len(self._request_times) >= self.rate_limit:
raise RateLimitError(f"Rate limit exceeded: {self.rate_limit} requests per minute")
def _record_request(self) -> None:
"""Record a request timestamp for rate limiting."""
if self.rate_limit:
self._request_times.append(datetime.now())
def _get_cache_key(self, url: str, params: Optional[Dict[str, Any]] = None) -> str:
"""Generate cache key from URL and params."""
if params:
import hashlib
param_str = str(sorted(params.items()))
return f"{url}:{hashlib.md5(param_str.encode()).hexdigest()}"
return url
def _get_cache(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""Get cached response if available and not expired."""
if not self.enable_cache:
return None
if cache_key in self._cache:
data, timestamp = self._cache[cache_key]
if (datetime.now() - timestamp).total_seconds() < self.cache_ttl:
if self.enable_logging:
self.logger.info(f"Cache hit for {cache_key}")
return data
else:
del self._cache[cache_key]
return None
def _set_cache(self, cache_key: str, data: Dict[str, Any]) -> None:
"""Cache response data."""
if self.enable_cache:
self._cache[cache_key] = (data, datetime.now())
if self.enable_logging:
self.logger.info(f"Cached response for {cache_key}")
async def _retry_request(self, request_func, *args, **kwargs) -> Dict[str, Any]:
"""Execute async request with retry logic and exponential backoff."""
last_error = None
for attempt in range(self.max_retries + 1):
try:
if attempt > 0:
backoff_time = 2 ** (attempt - 1)
if self.enable_logging:
self.logger.info(f"Retry attempt {attempt}/{self.max_retries} after {backoff_time}s backoff")
await asyncio.sleep(backoff_time)
return await request_func(*args, **kwargs)
except Exception as e:
last_error = e
if attempt < self.max_retries:
if self.enable_logging:
self.logger.warning(f"Request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}")
continue
else:
if self.enable_logging:
self.logger.error(f"All retry attempts exhausted: {e}")
raise RetryError(f"Retry attempts exhausted: {e}")
raise NetworkError(f"Request failed: {last_error}")
async def async_get(
self,
endpoint: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""
Perform async GET request.
Args:
endpoint: API endpoint
params: Query parameters
headers: Additional headers
Returns:
Response data as dictionary
"""
if not self._client:
raise RuntimeError("Async client not initialized. Use async context manager.")
url = self._build_url(endpoint)
cache_key = self._get_cache_key(url, params)
cached_data = self._get_cache(cache_key)
if cached_data is not None:
return cached_data
self._check_circuit_breaker()
self._check_rate_limit()
req_headers = {**self.headers, **(headers or {})}
if self.enable_logging:
self.logger.info(f"ASYNC GET {url} with params={params}")
start_time = datetime.now()
async def _make_request():
response = await self._client.get(url, params=params, headers=req_headers)
response.raise_for_status()
return response.json()
try:
result = await self._retry_request(_make_request)
self._set_cache(cache_key, result)
self._failure_count = 0
self._record_request()
if self.enable_logging:
elapsed = (datetime.now() - start_time).total_seconds()
self.logger.info(f"ASYNC GET {url} succeeded in {elapsed:.3f}s")
return result
except (RetryError, CircuitBreakerOpenError, RateLimitError):
raise
except Exception as e:
self._record_failure()
raise NetworkError(f"ASYNC GET request failed: {e}")
async def async_post(
self,
endpoint: str,
data: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""
Perform async POST request.
Args:
endpoint: API endpoint
data: Form data
json: JSON data
headers: Additional headers
Returns:
Response data as dictionary
"""
if not self._client:
raise RuntimeError("Async client not initialized. Use async context manager.")
url = self._build_url(endpoint)
self._check_circuit_breaker()
self._check_rate_limit()
req_headers = {**self.headers, **(headers or {})}
if self.enable_logging:
self.logger.info(f"ASYNC POST {url} with json={json}")
start_time = datetime.now()
async def _make_request():
response = await self._client.post(url, data=data, json=json, headers=req_headers)
response.raise_for_status()
return response.json()
try:
result = await self._retry_request(_make_request)
self._failure_count = 0
self._record_request()
if self.enable_logging:
elapsed = (datetime.now() - start_time).total_seconds()
self.logger.info(f"ASYNC POST {url} succeeded in {elapsed:.3f}s")
return result
except (RetryError, CircuitBreakerOpenError, RateLimitError):
raise
except Exception as e:
self._record_failure()
raise NetworkError(f"ASYNC POST request failed: {e}")

431
aitbc/queue.py Normal file
View File

@@ -0,0 +1,431 @@
"""
Queue utilities for AITBC
Provides task queue helpers, job scheduling, and background task management
"""
import asyncio
import heapq
import time
from typing import Any, Callable, Dict, List, Optional, TypeVar
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
import uuid
T = TypeVar('T')
class JobStatus(Enum):
"""Job status enumeration"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class JobPriority(Enum):
"""Job priority levels"""
LOW = 1
MEDIUM = 2
HIGH = 3
CRITICAL = 4
@dataclass(order=True)
class Job:
"""Background job"""
priority: int
job_id: str = field(compare=False)
func: Callable = field(compare=False)
args: tuple = field(default_factory=tuple, compare=False)
kwargs: dict = field(default_factory=dict, compare=False)
status: JobStatus = field(default=JobStatus.PENDING, compare=False)
created_at: datetime = field(default_factory=datetime.utcnow, compare=False)
started_at: Optional[datetime] = field(default=None, compare=False)
completed_at: Optional[datetime] = field(default=None, compare=False)
result: Any = field(default=None, compare=False)
error: Optional[str] = field(default=None, compare=False)
retry_count: int = field(default=0, compare=False)
max_retries: int = field(default=3, compare=False)
def __post_init__(self):
if self.job_id is None:
self.job_id = str(uuid.uuid4())
class TaskQueue:
"""Priority-based task queue"""
def __init__(self):
"""Initialize task queue"""
self.queue: List[Job] = []
self.jobs: Dict[str, Job] = {}
self.lock = asyncio.Lock()
async def enqueue(
self,
func: Callable,
args: tuple = (),
kwargs: dict = None,
priority: JobPriority = JobPriority.MEDIUM,
max_retries: int = 3
) -> str:
"""Enqueue a task"""
if kwargs is None:
kwargs = {}
job = Job(
priority=priority.value,
func=func,
args=args,
kwargs=kwargs,
max_retries=max_retries
)
async with self.lock:
heapq.heappush(self.queue, job)
self.jobs[job.job_id] = job
return job.job_id
async def dequeue(self) -> Optional[Job]:
"""Dequeue a task"""
async with self.lock:
if not self.queue:
return None
job = heapq.heappop(self.queue)
return job
async def get_job(self, job_id: str) -> Optional[Job]:
"""Get job by ID"""
return self.jobs.get(job_id)
async def cancel_job(self, job_id: str) -> bool:
"""Cancel a job"""
async with self.lock:
job = self.jobs.get(job_id)
if job and job.status == JobStatus.PENDING:
job.status = JobStatus.CANCELLED
# Remove from queue
self.queue = [j for j in self.queue if j.job_id != job_id]
heapq.heapify(self.queue)
return True
return False
async def get_queue_size(self) -> int:
"""Get queue size"""
return len(self.queue)
async def get_jobs_by_status(self, status: JobStatus) -> List[Job]:
"""Get jobs by status"""
return [job for job in self.jobs.values() if job.status == status]
class JobScheduler:
"""Job scheduler for delayed and recurring tasks"""
def __init__(self):
"""Initialize job scheduler"""
self.scheduled_jobs: Dict[str, Dict[str, Any]] = {}
self.running = False
self.task: Optional[asyncio.Task] = None
async def schedule(
self,
func: Callable,
delay: float = 0,
interval: Optional[float] = None,
job_id: Optional[str] = None,
args: tuple = (),
kwargs: dict = None
) -> str:
"""Schedule a job"""
if job_id is None:
job_id = str(uuid.uuid4())
if kwargs is None:
kwargs = {}
run_at = time.time() + delay
self.scheduled_jobs[job_id] = {
"func": func,
"args": args,
"kwargs": kwargs,
"run_at": run_at,
"interval": interval,
"job_id": job_id
}
return job_id
async def cancel_scheduled_job(self, job_id: str) -> bool:
"""Cancel a scheduled job"""
if job_id in self.scheduled_jobs:
del self.scheduled_jobs[job_id]
return True
return False
async def start(self) -> None:
"""Start the scheduler"""
if self.running:
return
self.running = True
self.task = asyncio.create_task(self._run_scheduler())
async def stop(self) -> None:
"""Stop the scheduler"""
self.running = False
if self.task:
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
pass
async def _run_scheduler(self) -> None:
"""Run the scheduler loop"""
while self.running:
now = time.time()
to_run = []
for job_id, job in list(self.scheduled_jobs.items()):
if job["run_at"] <= now:
to_run.append(job)
for job in to_run:
try:
if asyncio.iscoroutinefunction(job["func"]):
await job["func"](*job["args"], **job["kwargs"])
else:
job["func"](*job["args"], **job["kwargs"])
if job["interval"]:
job["run_at"] = now + job["interval"]
else:
del self.scheduled_jobs[job["job_id"]]
except Exception as e:
print(f"Error running scheduled job {job['job_id']}: {e}")
if not job["interval"]:
del self.scheduled_jobs[job["job_id"]]
await asyncio.sleep(0.1)
class BackgroundTaskManager:
"""Manage background tasks"""
def __init__(self, max_concurrent_tasks: int = 10):
"""Initialize background task manager"""
self.max_concurrent_tasks = max_concurrent_tasks
self.semaphore = asyncio.Semaphore(max_concurrent_tasks)
self.tasks: Dict[str, asyncio.Task] = {}
self.task_info: Dict[str, Dict[str, Any]] = {}
async def run_task(
self,
func: Callable,
task_id: Optional[str] = None,
args: tuple = (),
kwargs: dict = None
) -> str:
"""Run a background task"""
if task_id is None:
task_id = str(uuid.uuid4())
if kwargs is None:
kwargs = {}
async def wrapped_task():
async with self.semaphore:
try:
self.task_info[task_id]["status"] = "running"
self.task_info[task_id]["started_at"] = datetime.utcnow()
if asyncio.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
self.task_info[task_id]["status"] = "completed"
self.task_info[task_id]["result"] = result
self.task_info[task_id]["completed_at"] = datetime.utcnow()
except Exception as e:
self.task_info[task_id]["status"] = "failed"
self.task_info[task_id]["error"] = str(e)
self.task_info[task_id]["completed_at"] = datetime.utcnow()
finally:
if task_id in self.tasks:
del self.tasks[task_id]
self.task_info[task_id] = {
"status": "pending",
"created_at": datetime.utcnow(),
"started_at": None,
"completed_at": None,
"result": None,
"error": None
}
task = asyncio.create_task(wrapped_task())
self.tasks[task_id] = task
return task_id
async def cancel_task(self, task_id: str) -> bool:
"""Cancel a background task"""
if task_id in self.tasks:
self.tasks[task_id].cancel()
try:
await self.tasks[task_id]
except asyncio.CancelledError:
pass
self.task_info[task_id]["status"] = "cancelled"
self.task_info[task_id]["completed_at"] = datetime.utcnow()
del self.tasks[task_id]
return True
return False
async def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""Get task status"""
return self.task_info.get(task_id)
async def get_all_tasks(self) -> Dict[str, Dict[str, Any]]:
"""Get all tasks"""
return self.task_info.copy()
async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> Any:
"""Wait for task completion"""
if task_id not in self.tasks:
raise ValueError(f"Task {task_id} not found")
try:
await asyncio.wait_for(self.tasks[task_id], timeout)
except asyncio.TimeoutError:
await self.cancel_task(task_id)
raise TimeoutError(f"Task {task_id} timed out")
info = self.task_info.get(task_id)
if info["status"] == "failed":
raise Exception(info["error"])
return info["result"]
class WorkerPool:
"""Worker pool for parallel task execution"""
def __init__(self, num_workers: int = 4):
"""Initialize worker pool"""
self.num_workers = num_workers
self.queue: asyncio.Queue = asyncio.Queue()
self.workers: List[asyncio.Task] = []
self.running = False
async def start(self) -> None:
"""Start worker pool"""
if self.running:
return
self.running = True
for i in range(self.num_workers):
worker = asyncio.create_task(self._worker(i))
self.workers.append(worker)
async def stop(self) -> None:
"""Stop worker pool"""
self.running = False
# Cancel all workers
for worker in self.workers:
worker.cancel()
# Wait for workers to finish
await asyncio.gather(*self.workers, return_exceptions=True)
self.workers.clear()
async def submit(self, func: Callable, *args, **kwargs) -> Any:
"""Submit task to worker pool"""
future = asyncio.Future()
await self.queue.put((func, args, kwargs, future))
return await future
async def _worker(self, worker_id: int) -> None:
"""Worker coroutine"""
while self.running:
try:
func, args, kwargs, future = await self.queue.get()
try:
if asyncio.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
future.set_result(result)
except Exception as e:
future.set_exception(e)
finally:
self.queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
print(f"Worker {worker_id} error: {e}")
async def get_queue_size(self) -> int:
"""Get queue size"""
return self.queue.qsize()
def debounce(delay: float = 0.5):
"""Decorator to debounce function calls"""
def decorator(func: Callable) -> Callable:
last_called = [0]
timer = [None]
async def wrapped(*args, **kwargs):
async def call():
await asyncio.sleep(delay)
if asyncio.get_event_loop().time() - last_called[0] >= delay:
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
last_called[0] = asyncio.get_event_loop().time()
if timer[0]:
timer[0].cancel()
timer[0] = asyncio.create_task(call())
return await timer[0]
return wrapped
return decorator
def throttle(calls_per_second: float = 1.0):
"""Decorator to throttle function calls"""
def decorator(func: Callable) -> Callable:
min_interval = 1.0 / calls_per_second
last_called = [0]
async def wrapped(*args, **kwargs):
now = asyncio.get_event_loop().time()
elapsed = now - last_called[0]
if elapsed < min_interval:
await asyncio.sleep(min_interval - elapsed)
last_called[0] = asyncio.get_event_loop().time()
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
return wrapped
return decorator

282
aitbc/security.py Normal file
View File

@@ -0,0 +1,282 @@
"""
Security utilities for AITBC
Provides token generation, session management, API key management, and secret management
"""
import os
import secrets
import hashlib
import time
import json
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
from cryptography.fernet import Fernet
def generate_token(length: int = 32, prefix: str = "") -> str:
"""Generate a secure random token"""
token = secrets.token_urlsafe(length)
return f"{prefix}{token}" if prefix else token
def generate_api_key(prefix: str = "aitbc") -> str:
"""Generate a secure API key with prefix"""
random_part = secrets.token_urlsafe(32)
return f"{prefix}_{random_part}"
def validate_token_format(token: str, min_length: int = 16) -> bool:
"""Validate token format"""
return bool(token) and len(token) >= min_length and all(c.isalnum() or c in '-_' for c in token)
def validate_api_key(api_key: str, prefix: str = "aitbc") -> bool:
"""Validate API key format"""
if not api_key or not api_key.startswith(f"{prefix}_"):
return False
token_part = api_key[len(prefix)+1:]
return validate_token_format(token_part)
class SessionManager:
"""Simple in-memory session manager"""
def __init__(self, session_timeout: int = 3600):
"""Initialize session manager with timeout in seconds"""
self.sessions: Dict[str, Dict[str, Any]] = {}
self.session_timeout = session_timeout
def create_session(self, user_id: str, data: Optional[Dict[str, Any]] = None) -> str:
"""Create a new session"""
session_id = generate_token()
self.sessions[session_id] = {
"user_id": user_id,
"data": data or {},
"created_at": time.time(),
"expires_at": time.time() + self.session_timeout
}
return session_id
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get session data"""
session = self.sessions.get(session_id)
if not session:
return None
# Check if session expired
if time.time() > session["expires_at"]:
del self.sessions[session_id]
return None
return session
def update_session(self, session_id: str, data: Dict[str, Any]) -> bool:
"""Update session data"""
session = self.get_session(session_id)
if not session:
return False
session["data"].update(data)
return True
def delete_session(self, session_id: str) -> bool:
"""Delete a session"""
if session_id in self.sessions:
del self.sessions[session_id]
return True
return False
def cleanup_expired_sessions(self) -> int:
"""Clean up expired sessions"""
current_time = time.time()
expired_keys = [
key for key, session in self.sessions.items()
if current_time > session["expires_at"]
]
for key in expired_keys:
del self.sessions[key]
return len(expired_keys)
class APIKeyManager:
"""API key management with storage"""
def __init__(self, storage_path: Optional[str] = None):
"""Initialize API key manager"""
self.storage_path = storage_path
self.keys: Dict[str, Dict[str, Any]] = {}
if storage_path:
self._load_keys()
def create_api_key(self, user_id: str, scopes: Optional[list[str]] = None, name: Optional[str] = None) -> str:
"""Create a new API key"""
api_key = generate_api_key()
self.keys[api_key] = {
"user_id": user_id,
"scopes": scopes or ["read"],
"name": name,
"created_at": datetime.utcnow().isoformat(),
"last_used": None
}
if self.storage_path:
self._save_keys()
return api_key
def validate_api_key(self, api_key: str) -> Optional[Dict[str, Any]]:
"""Validate API key and return key data"""
key_data = self.keys.get(api_key)
if not key_data:
return None
# Update last used
key_data["last_used"] = datetime.utcnow().isoformat()
if self.storage_path:
self._save_keys()
return key_data
def revoke_api_key(self, api_key: str) -> bool:
"""Revoke an API key"""
if api_key in self.keys:
del self.keys[api_key]
if self.storage_path:
self._save_keys()
return True
return False
def list_user_keys(self, user_id: str) -> list[str]:
"""List all API keys for a user"""
return [
key for key, data in self.items()
if data["user_id"] == user_id
]
def _load_keys(self):
"""Load keys from storage"""
if self.storage_path and os.path.exists(self.storage_path):
try:
with open(self.storage_path, 'r') as f:
self.keys = json.load(f)
except Exception:
self.keys = {}
def _save_keys(self):
"""Save keys to storage"""
if self.storage_path:
try:
with open(self.storage_path, 'w') as f:
json.dump(self.keys, f)
except Exception:
pass
def items(self):
"""Return key items"""
return self.keys.items()
def generate_secure_random_string(length: int = 32) -> str:
"""Generate a cryptographically secure random string"""
return secrets.token_urlsafe(length)
def generate_secure_random_int(min_val: int = 0, max_val: int = 2**32) -> int:
"""Generate a cryptographically secure random integer"""
return secrets.randbelow(max_val - min_val) + min_val
class SecretManager:
"""Simple secret management with encryption"""
def __init__(self, encryption_key: Optional[str] = None):
"""Initialize secret manager"""
if encryption_key:
self.fernet = Fernet(encryption_key)
else:
# Generate a new key if none provided
self.fernet = Fernet(Fernet.generate_key())
self.secrets: Dict[str, str] = {}
def set_secret(self, key: str, value: str) -> None:
"""Store an encrypted secret"""
encrypted = self.fernet.encrypt(value.encode('utf-8'))
self.secrets[key] = encrypted.decode('utf-8')
def get_secret(self, key: str) -> Optional[str]:
"""Retrieve and decrypt a secret"""
encrypted = self.secrets.get(key)
if not encrypted:
return None
try:
decrypted = self.fernet.decrypt(encrypted.encode('utf-8'))
return decrypted.decode('utf-8')
except Exception:
return None
def delete_secret(self, key: str) -> bool:
"""Delete a secret"""
if key in self.secrets:
del self.secrets[key]
return True
return False
def list_secrets(self) -> list[str]:
"""List all secret keys"""
return list(self.secrets.keys())
def get_encryption_key(self) -> str:
"""Get the encryption key (for backup purposes)"""
return self.fernet._signing_key.decode('utf-8')
def hash_password(password: str, salt: Optional[str] = None) -> tuple[str, str]:
"""Hash a password with salt"""
if salt is None:
salt = secrets.token_hex(16)
# Use PBKDF2 for password hashing
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import hashes
import base64
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt.encode('utf-8'),
iterations=100000,
)
hashed = kdf.derive(password.encode('utf-8'))
return base64.b64encode(hashed).decode('utf-8'), salt
def verify_password(password: str, hashed_password: str, salt: str) -> bool:
"""Verify a password against a hash"""
new_hash, _ = hash_password(password, salt)
return new_hash == hashed_password
def generate_nonce(length: int = 16) -> str:
"""Generate a nonce for cryptographic operations"""
return secrets.token_hex(length)
def generate_hmac(data: str, secret: str) -> str:
"""Generate HMAC-SHA256 signature"""
import hmac
return hmac.new(
secret.encode('utf-8'),
data.encode('utf-8'),
hashlib.sha256
).hexdigest()
def verify_hmac(data: str, signature: str, secret: str) -> bool:
"""Verify HMAC-SHA256 signature"""
computed = generate_hmac(data, secret)
return secrets.compare_digest(computed, signature)

348
aitbc/state.py Normal file
View File

@@ -0,0 +1,348 @@
"""
State management utilities for AITBC
Provides state machine base classes, state persistence, and state transition helpers
"""
import json
import os
from typing import Any, Callable, Dict, Optional, TypeVar, Generic, List
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from abc import ABC, abstractmethod
import asyncio
T = TypeVar('T')
class StateTransitionError(Exception):
"""Raised when invalid state transition is attempted"""
pass
class StatePersistenceError(Exception):
"""Raised when state persistence fails"""
pass
@dataclass
class StateTransition:
"""Record of a state transition"""
from_state: str
to_state: str
timestamp: datetime = field(default_factory=datetime.utcnow)
data: Dict[str, Any] = field(default_factory=dict)
class StateMachine(ABC):
"""Base class for state machines"""
def __init__(self, initial_state: str):
"""Initialize state machine"""
self.current_state = initial_state
self.transitions: List[StateTransition] = []
self.state_data: Dict[str, Dict[str, Any]] = {initial_state: {}}
@abstractmethod
def get_valid_transitions(self, state: str) -> List[str]:
"""Get valid transitions from a state"""
pass
def can_transition(self, to_state: str) -> bool:
"""Check if transition is valid"""
return to_state in self.get_valid_transitions(self.current_state)
def transition(self, to_state: str, data: Optional[Dict[str, Any]] = None) -> None:
"""Transition to a new state"""
if not self.can_transition(to_state):
raise StateTransitionError(
f"Invalid transition from {self.current_state} to {to_state}"
)
from_state = self.current_state
self.current_state = to_state
# Record transition
transition = StateTransition(
from_state=from_state,
to_state=to_state,
data=data or {}
)
self.transitions.append(transition)
# Initialize state data if needed
if to_state not in self.state_data:
self.state_data[to_state] = {}
def get_state_data(self, state: Optional[str] = None) -> Dict[str, Any]:
"""Get data for a state"""
state = state or self.current_state
return self.state_data.get(state, {}).copy()
def set_state_data(self, data: Dict[str, Any], state: Optional[str] = None) -> None:
"""Set data for a state"""
state = state or self.current_state
if state not in self.state_data:
self.state_data[state] = {}
self.state_data[state].update(data)
def get_transition_history(self, limit: Optional[int] = None) -> List[StateTransition]:
"""Get transition history"""
if limit:
return self.transitions[-limit:]
return self.transitions.copy()
def reset(self, initial_state: str) -> None:
"""Reset state machine to initial state"""
self.current_state = initial_state
self.transitions.clear()
self.state_data = {initial_state: {}}
class ConfigurableStateMachine(StateMachine):
"""State machine with configurable transitions"""
def __init__(self, initial_state: str, transitions: Dict[str, List[str]]):
"""Initialize configurable state machine"""
super().__init__(initial_state)
self.transitions_config = transitions
def get_valid_transitions(self, state: str) -> List[str]:
"""Get valid transitions from configuration"""
return self.transitions_config.get(state, [])
def add_transition(self, from_state: str, to_state: str) -> None:
"""Add a transition to configuration"""
if from_state not in self.transitions_config:
self.transitions_config[from_state] = []
if to_state not in self.transitions_config[from_state]:
self.transitions_config[from_state].append(to_state)
class StatePersistence:
"""State persistence to file"""
def __init__(self, storage_path: str):
"""Initialize state persistence"""
self.storage_path = storage_path
self._ensure_storage_dir()
def _ensure_storage_dir(self) -> None:
"""Ensure storage directory exists"""
os.makedirs(os.path.dirname(self.storage_path), exist_ok=True)
def save_state(self, state_machine: StateMachine) -> None:
"""Save state machine to file"""
try:
state_data = {
"current_state": state_machine.current_state,
"state_data": state_machine.state_data,
"transitions": [
{
"from_state": t.from_state,
"to_state": t.to_state,
"timestamp": t.timestamp.isoformat(),
"data": t.data
}
for t in state_machine.transitions
]
}
with open(self.storage_path, 'w') as f:
json.dump(state_data, f, indent=2)
except Exception as e:
raise StatePersistenceError(f"Failed to save state: {e}")
def load_state(self) -> Optional[Dict[str, Any]]:
"""Load state from file"""
try:
if not os.path.exists(self.storage_path):
return None
with open(self.storage_path, 'r') as f:
return json.load(f)
except Exception as e:
raise StatePersistenceError(f"Failed to load state: {e}")
def delete_state(self) -> None:
"""Delete persisted state"""
try:
if os.path.exists(self.storage_path):
os.remove(self.storage_path)
except Exception as e:
raise StatePersistenceError(f"Failed to delete state: {e}")
class AsyncStateMachine(StateMachine):
"""Async state machine with async transition handlers"""
def __init__(self, initial_state: str):
"""Initialize async state machine"""
super().__init__(initial_state)
self.transition_handlers: Dict[str, Callable] = {}
def on_transition(self, to_state: str, handler: Callable) -> None:
"""Register a handler for transition to a state"""
self.transition_handlers[to_state] = handler
async def transition_async(self, to_state: str, data: Optional[Dict[str, Any]] = None) -> None:
"""Async transition to a new state"""
if not self.can_transition(to_state):
raise StateTransitionError(
f"Invalid transition from {self.current_state} to {to_state}"
)
from_state = self.current_state
self.current_state = to_state
# Record transition
transition = StateTransition(
from_state=from_state,
to_state=to_state,
data=data or {}
)
self.transitions.append(transition)
# Initialize state data if needed
if to_state not in self.state_data:
self.state_data[to_state] = {}
# Call transition handler if exists
if to_state in self.transition_handlers:
handler = self.transition_handlers[to_state]
if asyncio.iscoroutinefunction(handler):
await handler(transition)
else:
handler(transition)
class StateMonitor:
"""Monitor state machine state and transitions"""
def __init__(self, state_machine: StateMachine):
"""Initialize state monitor"""
self.state_machine = state_machine
self.observers: List[Callable] = []
def add_observer(self, observer: Callable) -> None:
"""Add an observer for state changes"""
self.observers.append(observer)
def remove_observer(self, observer: Callable) -> bool:
"""Remove an observer"""
try:
self.observers.remove(observer)
return True
except ValueError:
return False
def notify_observers(self, transition: StateTransition) -> None:
"""Notify all observers of state change"""
for observer in self.observers:
try:
observer(transition)
except Exception as e:
print(f"Error in state observer: {e}")
def wrap_transition(self, original_transition: Callable) -> Callable:
"""Wrap transition method to notify observers"""
def wrapper(*args, **kwargs):
result = original_transition(*args, **kwargs)
# Get last transition
if self.state_machine.transitions:
self.notify_observers(self.state_machine.transitions[-1])
return result
return wrapper
class StateValidator:
"""Validate state machine configurations"""
@staticmethod
def validate_transitions(transitions: Dict[str, List[str]]) -> bool:
"""Validate that all target states exist"""
all_states = set(transitions.keys())
all_states.update(*transitions.values())
for from_state, to_states in transitions.items():
for to_state in to_states:
if to_state not in all_states:
return False
return True
@staticmethod
def check_for_deadlocks(transitions: Dict[str, List[str]]) -> List[str]:
"""Check for states with no outgoing transitions"""
deadlocks = []
for state, to_states in transitions.items():
if not to_states:
deadlocks.append(state)
return deadlocks
@staticmethod
def check_for_orphans(transitions: Dict[str, List[str]]) -> List[str]:
"""Check for states with no incoming transitions"""
incoming = set()
for to_states in transitions.values():
incoming.update(to_states)
orphans = []
for state in transitions.keys():
if state not in incoming:
orphans.append(state)
return orphans
class StateSnapshot:
"""Snapshot of state machine state"""
def __init__(self, state_machine: StateMachine):
"""Create snapshot"""
self.current_state = state_machine.current_state
self.state_data = state_machine.state_data.copy()
self.transitions = state_machine.transitions.copy()
self.timestamp = datetime.utcnow()
def restore(self, state_machine: StateMachine) -> None:
"""Restore state machine from snapshot"""
state_machine.current_state = self.current_state
state_machine.state_data = self.state_data.copy()
state_machine.transitions = self.transitions.copy()
def to_dict(self) -> Dict[str, Any]:
"""Convert snapshot to dict"""
return {
"current_state": self.current_state,
"state_data": self.state_data,
"transitions": [
{
"from_state": t.from_state,
"to_state": t.to_state,
"timestamp": t.timestamp.isoformat(),
"data": t.data
}
for t in self.transitions
],
"timestamp": self.timestamp.isoformat()
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'StateSnapshot':
"""Create snapshot from dict"""
snapshot = cls.__new__(cls)
snapshot.current_state = data["current_state"]
snapshot.state_data = data["state_data"]
snapshot.transitions = [
StateTransition(
from_state=t["from_state"],
to_state=t["to_state"],
timestamp=datetime.fromisoformat(t["timestamp"]),
data=t["data"]
)
for t in data["transitions"]
]
snapshot.timestamp = datetime.fromisoformat(data["timestamp"])
return snapshot

401
aitbc/testing.py Normal file
View File

@@ -0,0 +1,401 @@
"""
Testing utilities for AITBC
Provides mock factories, test data generators, and test helpers
"""
import secrets
import json
from typing import Any, Dict, List, Optional, Type, TypeVar, Callable
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from decimal import Decimal
import uuid
T = TypeVar('T')
class MockFactory:
"""Factory for creating mock objects for testing"""
@staticmethod
def generate_string(length: int = 10, prefix: str = "") -> str:
"""Generate a random string"""
random_part = secrets.token_urlsafe(length)[:length]
return f"{prefix}{random_part}"
@staticmethod
def generate_email() -> str:
"""Generate a random email address"""
return f"{MockFactory.generate_string(8)}@example.com"
@staticmethod
def generate_url() -> str:
"""Generate a random URL"""
return f"https://example.com/{MockFactory.generate_string(8)}"
@staticmethod
def generate_ip_address() -> str:
"""Generate a random IP address"""
return f"192.168.{secrets.randbelow(256)}.{secrets.randbelow(256)}"
@staticmethod
def generate_ethereum_address() -> str:
"""Generate a random Ethereum address"""
return f"0x{''.join(secrets.choice('0123456789abcdef') for _ in range(40))}"
@staticmethod
def generate_bitcoin_address() -> str:
"""Generate a random Bitcoin-like address"""
return f"1{''.join(secrets.choice('123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz') for _ in range(33))}"
@staticmethod
def generate_uuid() -> str:
"""Generate a UUID"""
return str(uuid.uuid4())
@staticmethod
def generate_hash(length: int = 64) -> str:
"""Generate a random hash string"""
return secrets.token_hex(length)[:length]
class TestDataGenerator:
"""Generate test data for various use cases"""
@staticmethod
def generate_user_data(**overrides) -> Dict[str, Any]:
"""Generate mock user data"""
data = {
"id": MockFactory.generate_uuid(),
"email": MockFactory.generate_email(),
"username": MockFactory.generate_string(8),
"first_name": MockFactory.generate_string(6),
"last_name": MockFactory.generate_string(6),
"created_at": datetime.utcnow().isoformat(),
"updated_at": datetime.utcnow().isoformat(),
"is_active": True,
"role": "user"
}
data.update(overrides)
return data
@staticmethod
def generate_transaction_data(**overrides) -> Dict[str, Any]:
"""Generate mock transaction data"""
data = {
"id": MockFactory.generate_uuid(),
"from_address": MockFactory.generate_ethereum_address(),
"to_address": MockFactory.generate_ethereum_address(),
"amount": str(secrets.randbelow(1000000000000000000)),
"gas_price": str(secrets.randbelow(100000000000)),
"gas_limit": secrets.randbelow(100000),
"nonce": secrets.randbelow(1000),
"timestamp": datetime.utcnow().isoformat(),
"status": "pending"
}
data.update(overrides)
return data
@staticmethod
def generate_block_data(**overrides) -> Dict[str, Any]:
"""Generate mock block data"""
data = {
"number": secrets.randbelow(10000000),
"hash": MockFactory.generate_hash(),
"parent_hash": MockFactory.generate_hash(),
"timestamp": datetime.utcnow().isoformat(),
"transactions": [],
"gas_used": str(secrets.randbelow(10000000)),
"gas_limit": str(15000000),
"miner": MockFactory.generate_ethereum_address()
}
data.update(overrides)
return data
@staticmethod
def generate_api_key_data(**overrides) -> Dict[str, Any]:
"""Generate mock API key data"""
data = {
"id": MockFactory.generate_uuid(),
"api_key": f"aitbc_{secrets.token_urlsafe(32)}",
"user_id": MockFactory.generate_uuid(),
"name": MockFactory.generate_string(10),
"scopes": ["read", "write"],
"created_at": datetime.utcnow().isoformat(),
"last_used": None,
"is_active": True
}
data.update(overrides)
return data
@staticmethod
def generate_wallet_data(**overrides) -> Dict[str, Any]:
"""Generate mock wallet data"""
data = {
"id": MockFactory.generate_uuid(),
"address": MockFactory.generate_ethereum_address(),
"chain_id": 1,
"balance": str(secrets.randbelow(1000000000000000000)),
"created_at": datetime.utcnow().isoformat(),
"is_active": True
}
data.update(overrides)
return data
class TestHelpers:
"""Helper functions for testing"""
@staticmethod
def assert_dict_contains(subset: Dict[str, Any], superset: Dict[str, Any]) -> bool:
"""Check if superset contains all key-value pairs from subset"""
for key, value in subset.items():
if key not in superset:
return False
if superset[key] != value:
return False
return True
@staticmethod
def assert_lists_equal_unordered(list1: List[Any], list2: List[Any]) -> bool:
"""Check if two lists contain the same elements regardless of order"""
return sorted(list1) == sorted(list2)
@staticmethod
def compare_json_objects(obj1: Any, obj2: Any) -> bool:
"""Compare two JSON-serializable objects"""
return json.dumps(obj1, sort_keys=True) == json.dumps(obj2, sort_keys=True)
@staticmethod
def wait_for_condition(
condition: Callable[[], bool],
timeout: float = 10.0,
interval: float = 0.1
) -> bool:
"""Wait for a condition to become true"""
import time
start = time.time()
while time.time() - start < timeout:
if condition():
return True
time.sleep(interval)
return False
@staticmethod
def measure_execution_time(func: Callable, *args, **kwargs) -> tuple[Any, float]:
"""Measure execution time of a function"""
import time
start = time.time()
result = func(*args, **kwargs)
elapsed = time.time() - start
return result, elapsed
@staticmethod
def generate_test_file_path(extension: str = ".tmp") -> str:
"""Generate a unique test file path"""
return f"/tmp/test_{secrets.token_hex(8)}{extension}"
@staticmethod
def cleanup_test_files(prefix: str = "test_") -> int:
"""Clean up test files in /tmp"""
import os
import glob
count = 0
for file_path in glob.glob(f"/tmp/{prefix}*"):
try:
os.remove(file_path)
count += 1
except:
pass
return count
class MockResponse:
"""Mock HTTP response for testing"""
def __init__(
self,
status_code: int = 200,
json_data: Optional[Dict[str, Any]] = None,
text: Optional[str] = None,
headers: Optional[Dict[str, str]] = None
):
"""Initialize mock response"""
self.status_code = status_code
self._json_data = json_data
self._text = text
self.headers = headers or {}
def json(self) -> Dict[str, Any]:
"""Return JSON data"""
if self._json_data is None:
raise ValueError("No JSON data available")
return self._json_data
def text(self) -> str:
"""Return text data"""
if self._text is None:
return ""
return self._text
def raise_for_status(self) -> None:
"""Raise exception if status code indicates error"""
if self.status_code >= 400:
raise Exception(f"HTTP Error: {self.status_code}")
class MockDatabase:
"""Mock database for testing"""
def __init__(self):
"""Initialize mock database"""
self.data: Dict[str, List[Dict[str, Any]]] = {}
self.tables: List[str] = []
def create_table(self, table_name: str) -> None:
"""Create a table"""
if table_name not in self.tables:
self.tables.append(table_name)
self.data[table_name] = []
def insert(self, table_name: str, record: Dict[str, Any]) -> None:
"""Insert a record"""
if table_name not in self.tables:
self.create_table(table_name)
record['id'] = record.get('id', MockFactory.generate_uuid())
self.data[table_name].append(record)
def select(self, table_name: str, **filters) -> List[Dict[str, Any]]:
"""Select records with optional filters"""
if table_name not in self.tables:
return []
records = self.data[table_name]
if not filters:
return records
filtered = []
for record in records:
match = True
for key, value in filters.items():
if record.get(key) != value:
match = False
break
if match:
filtered.append(record)
return filtered
def update(self, table_name: str, record_id: str, updates: Dict[str, Any]) -> bool:
"""Update a record"""
if table_name not in self.tables:
return False
for record in self.data[table_name]:
if record.get('id') == record_id:
record.update(updates)
return True
return False
def delete(self, table_name: str, record_id: str) -> bool:
"""Delete a record"""
if table_name not in self.tables:
return False
for i, record in enumerate(self.data[table_name]):
if record.get('id') == record_id:
del self.data[table_name][i]
return True
return False
def clear(self) -> None:
"""Clear all data"""
self.data.clear()
self.tables.clear()
class MockCache:
"""Mock cache for testing"""
def __init__(self, ttl: int = 3600):
"""Initialize mock cache"""
self.cache: Dict[str, tuple[Any, float]] = {}
self.ttl = ttl
def get(self, key: str) -> Optional[Any]:
"""Get value from cache"""
if key not in self.cache:
return None
value, timestamp = self.cache[key]
if time.time() - timestamp > self.ttl:
del self.cache[key]
return None
return value
def set(self, key: str, value: Any) -> None:
"""Set value in cache"""
self.cache[key] = (value, time.time())
def delete(self, key: str) -> bool:
"""Delete value from cache"""
if key in self.cache:
del self.cache[key]
return True
return False
def clear(self) -> None:
"""Clear cache"""
self.cache.clear()
def size(self) -> int:
"""Get cache size"""
return len(self.cache)
def mock_async_call(return_value: Any = None, delay: float = 0):
"""Decorator to mock async calls with optional delay"""
def decorator(func: Callable) -> Callable:
async def wrapper(*args, **kwargs):
if delay > 0:
await asyncio.sleep(delay)
return return_value
return wrapper
return decorator
def create_mock_config(**overrides) -> Dict[str, Any]:
"""Create mock configuration"""
config = {
"debug": False,
"log_level": "INFO",
"database_url": "sqlite:///test.db",
"redis_url": "redis://localhost:6379",
"api_host": "localhost",
"api_port": 8080,
"secret_key": MockFactory.generate_string(32),
"max_workers": 4,
"timeout": 30
}
config.update(overrides)
return config
import time
def create_test_scenario(name: str, steps: List[Callable]) -> Callable:
"""Create a test scenario with multiple steps"""
def scenario():
print(f"Running test scenario: {name}")
results = []
for i, step in enumerate(steps):
try:
result = step()
results.append({"step": i + 1, "status": "passed", "result": result})
except Exception as e:
results.append({"step": i + 1, "status": "failed", "error": str(e)})
return results
return scenario

321
aitbc/time_utils.py Normal file
View File

@@ -0,0 +1,321 @@
"""
Time utilities for AITBC
Provides timestamp helpers, duration helpers, timezone handling, and deadline calculations
"""
from datetime import datetime, timedelta, timezone
from typing import Optional, Union
import time
def get_utc_now() -> datetime:
"""Get current UTC datetime"""
return datetime.now(timezone.utc)
def get_timestamp_utc() -> float:
"""Get current UTC timestamp"""
return time.time()
def format_iso8601(dt: Optional[datetime] = None) -> str:
"""Format datetime as ISO 8601 string in UTC"""
if dt is None:
dt = get_utc_now()
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt.isoformat()
def parse_iso8601(iso_string: str) -> datetime:
"""Parse ISO 8601 string to datetime"""
dt = datetime.fromisoformat(iso_string)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
def timestamp_to_iso(timestamp: float) -> str:
"""Convert timestamp to ISO 8601 string"""
return datetime.fromtimestamp(timestamp, timezone.utc).isoformat()
def iso_to_timestamp(iso_string: str) -> float:
"""Convert ISO 8601 string to timestamp"""
dt = parse_iso8601(iso_string)
return dt.timestamp()
def format_duration(seconds: Union[int, float]) -> str:
"""Format duration in seconds to human-readable string"""
if seconds < 60:
return f"{int(seconds)}s"
elif seconds < 3600:
minutes = int(seconds / 60)
return f"{minutes}m"
elif seconds < 86400:
hours = int(seconds / 3600)
return f"{hours}h"
else:
days = int(seconds / 86400)
return f"{days}d"
def format_duration_precise(seconds: Union[int, float]) -> str:
"""Format duration with precise breakdown"""
days = int(seconds // 86400)
hours = int((seconds % 86400) // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
parts = []
if days > 0:
parts.append(f"{days}d")
if hours > 0:
parts.append(f"{hours}h")
if minutes > 0:
parts.append(f"{minutes}m")
if secs > 0 or not parts:
parts.append(f"{secs}s")
return " ".join(parts)
def parse_duration(duration_str: str) -> float:
"""Parse duration string to seconds"""
duration_str = duration_str.strip().lower()
if duration_str.endswith('s'):
return float(duration_str[:-1])
elif duration_str.endswith('m'):
return float(duration_str[:-1]) * 60
elif duration_str.endswith('h'):
return float(duration_str[:-1]) * 3600
elif duration_str.endswith('d'):
return float(duration_str[:-1]) * 86400
else:
return float(duration_str)
def add_duration(dt: datetime, duration: Union[str, timedelta]) -> datetime:
"""Add duration to datetime"""
if isinstance(duration, str):
duration = timedelta(seconds=parse_duration(duration))
return dt + duration
def subtract_duration(dt: datetime, duration: Union[str, timedelta]) -> datetime:
"""Subtract duration from datetime"""
if isinstance(duration, str):
duration = timedelta(seconds=parse_duration(duration))
return dt - duration
def get_time_until(dt: datetime) -> timedelta:
"""Get time until a future datetime"""
now = get_utc_now()
return dt - now
def get_time_since(dt: datetime) -> timedelta:
"""Get time since a past datetime"""
now = get_utc_now()
return now - dt
def calculate_deadline(duration: Union[str, timedelta], from_dt: Optional[datetime] = None) -> datetime:
"""Calculate deadline from duration"""
if from_dt is None:
from_dt = get_utc_now()
return add_duration(from_dt, duration)
def is_deadline_passed(deadline: datetime) -> bool:
"""Check if deadline has passed"""
return get_utc_now() >= deadline
def get_deadline_remaining(deadline: datetime) -> float:
"""Get remaining seconds until deadline"""
delta = deadline - get_utc_now()
return max(0, delta.total_seconds())
def format_time_ago(dt: datetime) -> str:
"""Format datetime as "time ago" string"""
delta = get_time_since(dt)
seconds = delta.total_seconds()
if seconds < 60:
return "just now"
elif seconds < 3600:
minutes = int(seconds / 60)
return f"{minutes} minute{'s' if minutes > 1 else ''} ago"
elif seconds < 86400:
hours = int(seconds / 3600)
return f"{hours} hour{'s' if hours > 1 else ''} ago"
elif seconds < 604800:
days = int(seconds / 86400)
return f"{days} day{'s' if days > 1 else ''} ago"
else:
weeks = int(seconds / 604800)
return f"{weeks} week{'s' if weeks > 1 else ''} ago"
def format_time_in(dt: datetime) -> str:
"""Format datetime as "time in" string"""
delta = get_time_until(dt)
seconds = delta.total_seconds()
if seconds < 0:
return format_time_ago(dt)
if seconds < 60:
return "in a moment"
elif seconds < 3600:
minutes = int(seconds / 60)
return f"in {minutes} minute{'s' if minutes > 1 else ''}"
elif seconds < 86400:
hours = int(seconds / 3600)
return f"in {hours} hour{'s' if hours > 1 else ''}"
elif seconds < 604800:
days = int(seconds / 86400)
return f"in {days} day{'s' if days > 1 else ''}"
else:
weeks = int(seconds / 604800)
return f"in {weeks} week{'s' if weeks > 1 else ''}"
def to_timezone(dt: datetime, tz_name: str) -> datetime:
"""Convert datetime to specific timezone"""
try:
import pytz
tz = pytz.timezone(tz_name)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt.astimezone(tz)
except ImportError:
raise ImportError("pytz is required for timezone conversion. Install with: pip install pytz")
except Exception as e:
raise ValueError(f"Failed to convert timezone: {e}")
def get_timezone_offset(tz_name: str) -> timedelta:
"""Get timezone offset from UTC"""
try:
import pytz
tz = pytz.timezone(tz_name)
now = datetime.now(timezone.utc)
offset = tz.utcoffset(now)
return offset if offset else timedelta(0)
except ImportError:
raise ImportError("pytz is required for timezone operations. Install with: pip install pytz")
def is_business_hours(dt: Optional[datetime] = None, start_hour: int = 9, end_hour: int = 17, timezone: str = "UTC") -> bool:
"""Check if datetime is within business hours"""
if dt is None:
dt = get_utc_now()
try:
import pytz
tz = pytz.timezone(timezone)
dt_local = dt.astimezone(tz)
return start_hour <= dt_local.hour < end_hour
except ImportError:
raise ImportError("pytz is required for business hours check. Install with: pip install pytz")
def get_start_of_day(dt: Optional[datetime] = None) -> datetime:
"""Get start of day (00:00:00) for given datetime"""
if dt is None:
dt = get_utc_now()
return dt.replace(hour=0, minute=0, second=0, microsecond=0)
def get_end_of_day(dt: Optional[datetime] = None) -> datetime:
"""Get end of day (23:59:59) for given datetime"""
if dt is None:
dt = get_utc_now()
return dt.replace(hour=23, minute=59, second=59, microsecond=999999)
def get_start_of_week(dt: Optional[datetime] = None) -> datetime:
"""Get start of week (Monday) for given datetime"""
if dt is None:
dt = get_utc_now()
return dt - timedelta(days=dt.weekday())
def get_end_of_week(dt: Optional[datetime] = None) -> datetime:
"""Get end of week (Sunday) for given datetime"""
if dt is None:
dt = get_utc_now()
return dt + timedelta(days=(6 - dt.weekday()))
def get_start_of_month(dt: Optional[datetime] = None) -> datetime:
"""Get start of month for given datetime"""
if dt is None:
dt = get_utc_now()
return dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
def get_end_of_month(dt: Optional[datetime] = None) -> datetime:
"""Get end of month for given datetime"""
if dt is None:
dt = get_utc_now()
if dt.month == 12:
next_month = dt.replace(year=dt.year + 1, month=1, day=1)
else:
next_month = dt.replace(month=dt.month + 1, day=1)
return next_month - timedelta(seconds=1)
def sleep_until(dt: datetime) -> None:
"""Sleep until a specific datetime"""
now = get_utc_now()
if dt > now:
sleep_seconds = (dt - now).total_seconds()
time.sleep(sleep_seconds)
def retry_until_deadline(func, deadline: datetime, interval: float = 1.0) -> bool:
"""Retry a function until deadline is reached"""
while not is_deadline_passed(deadline):
try:
result = func()
if result:
return True
except Exception:
pass
time.sleep(interval)
return False
class Timer:
"""Simple timer context manager for measuring execution time"""
def __init__(self):
"""Initialize timer"""
self.start_time = None
self.end_time = None
self.elapsed = None
def __enter__(self):
"""Start timer"""
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop timer"""
self.end_time = time.time()
self.elapsed = self.end_time - self.start_time
def get_elapsed(self) -> Optional[float]:
"""Get elapsed time in seconds"""
if self.elapsed is not None:
return self.elapsed
elif self.start_time is not None:
return time.time() - self.start_time
return None

206
aitbc/web3_utils.py Normal file
View File

@@ -0,0 +1,206 @@
"""
Web3 utilities for AITBC
Provides Ethereum blockchain interaction utilities using web3.py
"""
from typing import Any, Optional
from decimal import Decimal
class Web3Client:
"""Web3 client wrapper for blockchain operations"""
def __init__(self, rpc_url: str, timeout: int = 30):
"""Initialize Web3 client with RPC URL"""
try:
from web3 import Web3
from web3.middleware import geth_poa_middleware
self.w3 = Web3(Web3.HTTPProvider(rpc_url, request_kwargs={'timeout': timeout}))
# Add POA middleware for chains like Polygon, BSC, etc.
self.w3.middleware_onion.inject(geth_poa_middleware, layer=0)
if not self.w3.is_connected():
raise ConnectionError(f"Failed to connect to RPC URL: {rpc_url}")
except ImportError:
raise ImportError("web3 is required for blockchain operations. Install with: pip install web3")
except Exception as e:
raise ConnectionError(f"Failed to initialize Web3 client: {e}")
def get_eth_balance(self, address: str) -> str:
"""Get ETH balance in wei"""
try:
balance_wei = self.w3.eth.get_balance(address)
return str(balance_wei)
except Exception as e:
raise ValueError(f"Failed to get ETH balance: {e}")
def get_token_balance(self, address: str, token_address: str) -> dict[str, Any]:
"""Get ERC-20 token balance"""
try:
# ERC-20 balanceOf function signature: 0x70a08231
balance_of_signature = '0x70a08231'
# Pad address to 32 bytes
padded_address = address[2:].lower().zfill(64)
call_data = balance_of_signature + padded_address
result = self.w3.eth.call({
'to': token_address,
'data': f'0x{call_data}'
})
balance = int(result.hex(), 16)
# Get token decimals
decimals_signature = '0x313ce567'
decimals_result = self.w3.eth.call({
'to': token_address,
'data': decimals_signature
})
decimals = int(decimals_result.hex(), 16)
# Get token symbol (optional, may fail for some tokens)
try:
symbol_signature = '0x95d89b41'
symbol_result = self.w3.eth.call({
'to': token_address,
'data': symbol_signature
})
symbol_bytes = bytes.fromhex(symbol_result.hex()[2:])
symbol = symbol_bytes.rstrip(b'\x00').decode('utf-8')
except:
symbol = "TOKEN"
return {
"balance": str(balance),
"decimals": decimals,
"symbol": symbol
}
except Exception as e:
raise ValueError(f"Failed to get token balance: {e}")
def get_gas_price(self) -> int:
"""Get current gas price in wei"""
try:
gas_price = self.w3.eth.gas_price
return gas_price
except Exception as e:
raise ValueError(f"Failed to get gas price: {e}")
def get_gas_price_gwei(self) -> float:
"""Get current gas price in Gwei"""
try:
gas_price_wei = self.get_gas_price()
return float(gas_price_wei) / 10**9
except Exception as e:
raise ValueError(f"Failed to get gas price in Gwei: {e}")
def get_nonce(self, address: str) -> int:
"""Get transaction nonce for address"""
try:
nonce = self.w3.eth.get_transaction_count(address)
return nonce
except Exception as e:
raise ValueError(f"Failed to get nonce: {e}")
def send_raw_transaction(self, signed_transaction: str) -> str:
"""Send raw transaction to blockchain"""
try:
tx_hash = self.w3.eth.send_raw_transaction(signed_transaction)
return tx_hash.hex()
except Exception as e:
raise ValueError(f"Failed to send raw transaction: {e}")
def get_transaction_receipt(self, tx_hash: str) -> Optional[dict[str, Any]]:
"""Get transaction receipt"""
try:
receipt = self.w3.eth.get_transaction_receipt(tx_hash)
if receipt is None:
return None
return {
"status": receipt['status'],
"blockNumber": hex(receipt['blockNumber']),
"blockHash": receipt['blockHash'].hex(),
"gasUsed": hex(receipt['gasUsed']),
"effectiveGasPrice": hex(receipt['effectiveGasPrice']),
"logs": receipt['logs'],
}
except Exception as e:
raise ValueError(f"Failed to get transaction receipt: {e}")
def get_transaction_by_hash(self, tx_hash: str) -> dict[str, Any]:
"""Get transaction by hash"""
try:
tx = self.w3.eth.get_transaction(tx_hash)
return {
"from": tx['from'],
"to": tx['to'],
"value": hex(tx['value']),
"data": tx['input'].hex() if hasattr(tx['input'], 'hex') else tx['input'],
"nonce": tx['nonce'],
"gas": tx['gas'],
"gasPrice": hex(tx['gasPrice']),
"blockNumber": hex(tx['blockNumber']) if tx['blockNumber'] else None,
}
except Exception as e:
raise ValueError(f"Failed to get transaction by hash: {e}")
def estimate_gas(self, transaction: dict[str, Any]) -> int:
"""Estimate gas for transaction"""
try:
gas_estimate = self.w3.eth.estimate_gas(transaction)
return gas_estimate
except Exception as e:
raise ValueError(f"Failed to estimate gas: {e}")
def get_block_number(self) -> int:
"""Get current block number"""
try:
return self.w3.eth.block_number
except Exception as e:
raise ValueError(f"Failed to get block number: {e}")
def get_wallet_transactions(self, address: str, limit: int = 100) -> list[dict[str, Any]]:
"""Get wallet transactions (simplified implementation)"""
try:
# This is a simplified version - in production you'd want to use
# event logs or a blockchain explorer API for this
transactions = []
current_block = self.get_block_number()
# Look back at recent blocks for transactions from/to this address
start_block = max(0, current_block - 1000)
for block_num in range(current_block, start_block, -1):
if len(transactions) >= limit:
break
try:
block = self.w3.eth.get_block(block_num, full_transactions=True)
for tx in block['transactions']:
if tx['from'].lower() == address.lower() or \
(tx['to'] and tx['to'].lower() == address.lower()):
transactions.append({
"hash": tx['hash'].hex(),
"from": tx['from'],
"to": tx['to'].hex() if tx['to'] else None,
"value": hex(tx['value']),
"blockNumber": hex(tx['blockNumber']),
"timestamp": block['timestamp'],
"gasUsed": hex(tx['gas']),
})
if len(transactions) >= limit:
break
except:
continue
return transactions
except Exception as e:
raise ValueError(f"Failed to get wallet transactions: {e}")
def create_web3_client(rpc_url: str, timeout: int = 30) -> Web3Client:
"""Factory function to create Web3 client"""
return Web3Client(rpc_url, timeout)

View File

@@ -1048,19 +1048,15 @@ async def search_transactions(
response = await client.get(f"{rpc_url}/rpc/search/transactions", params=params)
if response.status_code == 200:
return response.json()
elif response.status_code == 404:
return []
else:
# Return mock data for demonstration
return [
{
"hash": "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef",
"type": tx_type or "transfer",
"from": "0xabcdef1234567890abcdef1234567890abcdef1234",
"to": "0x1234567890abcdef1234567890abcdef12345678",
"amount": "1.5",
"fee": "0.001",
"timestamp": datetime.now().isoformat()
}
]
raise HTTPException(
status_code=response.status_code,
detail=f"Failed to fetch transactions from blockchain RPC: {response.text}"
)
except httpx.RequestError as e:
raise HTTPException(status_code=503, detail=f"Blockchain RPC unavailable: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
@@ -1095,17 +1091,15 @@ async def search_blocks(
response = await client.get(f"{rpc_url}/rpc/search/blocks", params=params)
if response.status_code == 200:
return response.json()
elif response.status_code == 404:
return []
else:
# Return mock data for demonstration
return [
{
"height": 12345,
"hash": "0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890",
"validator": validator or "0x1234567890abcdef1234567890abcdef12345678",
"tx_count": min_tx or 5,
"timestamp": datetime.now().isoformat()
}
]
raise HTTPException(
status_code=response.status_code,
detail=f"Failed to fetch blocks from blockchain RPC: {response.text}"
)
except httpx.RequestError as e:
raise HTTPException(status_code=503, detail=f"Blockchain RPC unavailable: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")

View File

@@ -0,0 +1,564 @@
"""
Hub Manager
Manages hub operations, peer list sharing, and hub registration for federated mesh
"""
import asyncio
import time
import json
import os
import socket
from typing import Dict, List, Optional, Set
from dataclasses import dataclass, field, asdict
from enum import Enum
from ..config import settings
from aitbc import get_logger, DATA_DIR, KEYSTORE_DIR
logger = get_logger(__name__)
class HubStatus(Enum):
"""Hub registration status"""
REGISTERED = "registered"
UNREGISTERED = "unregistered"
PENDING = "pending"
@dataclass
class HubInfo:
"""Information about a hub node"""
node_id: str
address: str
port: int
island_id: str
island_name: str
public_address: Optional[str] = None
public_port: Optional[int] = None
registered_at: float = 0
last_seen: float = 0
peer_count: int = 0
@dataclass
class PeerInfo:
"""Information about a peer"""
node_id: str
address: str
port: int
island_id: str
is_hub: bool
public_address: Optional[str] = None
public_port: Optional[int] = None
last_seen: float = 0
class HubManager:
"""Manages hub operations for federated mesh"""
def __init__(self, local_node_id: str, local_address: str, local_port: int, island_id: str, island_name: str, redis_url: Optional[str] = None):
self.local_node_id = local_node_id
self.local_address = local_address
self.local_port = local_port
self.island_id = island_id
self.island_name = island_name
self.island_chain_id = settings.island_chain_id or settings.chain_id or f"ait-{island_id[:8]}"
self.redis_url = redis_url or "redis://localhost:6379"
# Hub registration status
self.is_hub = False
self.hub_status = HubStatus.UNREGISTERED
self.registered_at: Optional[float] = None
# Known hubs
self.known_hubs: Dict[str, HubInfo] = {} # node_id -> HubInfo
# Peer registry (for providing peer lists)
self.peer_registry: Dict[str, PeerInfo] = {} # node_id -> PeerInfo
# Island peers (island_id -> set of node_ids)
self.island_peers: Dict[str, Set[str]] = {}
self.running = False
self._redis = None
# Initialize island peers for our island
self.island_peers[self.island_id] = set()
async def _connect_redis(self):
"""Connect to Redis"""
try:
import redis.asyncio as redis
self._redis = redis.from_url(self.redis_url)
await self._redis.ping()
logger.info(f"Connected to Redis for hub persistence: {self.redis_url}")
return True
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
return False
async def _persist_hub_registration(self, hub_info: HubInfo) -> bool:
"""Persist hub registration to Redis"""
try:
if not self._redis:
await self._connect_redis()
if not self._redis:
logger.warning("Redis not available, skipping persistence")
return False
key = f"hub:{hub_info.node_id}"
value = json.dumps(asdict(hub_info), default=str)
await self._redis.setex(key, 3600, value) # TTL: 1 hour
logger.info(f"Persisted hub registration to Redis: {key}")
return True
except Exception as e:
logger.error(f"Failed to persist hub registration: {e}")
return False
async def _remove_hub_registration(self, node_id: str) -> bool:
"""Remove hub registration from Redis"""
try:
if not self._redis:
await self._connect_redis()
if not self._redis:
logger.warning("Redis not available, skipping removal")
return False
key = f"hub:{node_id}"
await self._redis.delete(key)
logger.info(f"Removed hub registration from Redis: {key}")
return True
except Exception as e:
logger.error(f"Failed to remove hub registration: {e}")
return False
async def _load_hub_registration(self) -> Optional[HubInfo]:
"""Load hub registration from Redis"""
try:
if not self._redis:
await self._connect_redis()
if not self._redis:
return None
key = f"hub:{self.local_node_id}"
value = await self._redis.get(key)
if value:
data = json.loads(value)
return HubInfo(**data)
return None
except Exception as e:
logger.error(f"Failed to load hub registration: {e}")
return None
def _get_blockchain_credentials(self) -> dict:
"""Get blockchain credentials from keystore"""
try:
credentials = {}
# Get genesis block hash from genesis.json
genesis_candidates = [
str(settings.db_path.parent / 'genesis.json'),
f"{DATA_DIR}/data/{settings.chain_id}/genesis.json",
f'{DATA_DIR}/data/ait-mainnet/genesis.json',
]
for genesis_path in genesis_candidates:
if os.path.exists(genesis_path):
with open(genesis_path, 'r') as f:
genesis_data = json.load(f)
if 'blocks' in genesis_data and len(genesis_data['blocks']) > 0:
genesis_block = genesis_data['blocks'][0]
credentials['genesis_block_hash'] = genesis_block.get('hash', '')
credentials['genesis_block'] = genesis_data
break
# Get genesis address from keystore
keystore_path = str(KEYSTORE_DIR / 'validator_keys.json')
if os.path.exists(keystore_path):
with open(keystore_path, 'r') as f:
keys = json.load(f)
# Get first key's address
for key_id, key_data in keys.items():
# Extract address from public key or use key_id
credentials['genesis_address'] = key_id
break
# Add chain info
credentials['chain_id'] = self.island_chain_id
credentials['island_id'] = self.island_id
credentials['island_name'] = self.island_name
# Add RPC endpoint (local)
rpc_host = self.local_address
if rpc_host in {"0.0.0.0", "127.0.0.1", "localhost", ""}:
rpc_host = settings.hub_discovery_url or socket.gethostname()
credentials['rpc_endpoint'] = f"http://{rpc_host}:8006"
credentials['p2p_port'] = self.local_port
return credentials
except Exception as e:
logger.error(f"Failed to get blockchain credentials: {e}")
return {}
async def handle_join_request(self, join_request: dict) -> Optional[dict]:
"""
Handle island join request from a new node
Args:
join_request: Dictionary containing join request data
Returns:
dict: Join response with member list and credentials, or None if failed
"""
try:
requested_island_id = join_request.get('island_id')
# Validate island ID
if requested_island_id != self.island_id:
logger.warning(f"Join request for island {requested_island_id} does not match our island {self.island_id}")
return None
# Get all island members
members = []
for node_id, peer_info in self.peer_registry.items():
if peer_info.island_id == self.island_id:
members.append({
'node_id': peer_info.node_id,
'address': peer_info.address,
'port': peer_info.port,
'is_hub': peer_info.is_hub,
'public_address': peer_info.public_address,
'public_port': peer_info.public_port
})
# Include self in member list
members.append({
'node_id': self.local_node_id,
'address': self.local_address,
'port': self.local_port,
'is_hub': True,
'public_address': self.known_hubs.get(self.local_node_id, {}).public_address if self.local_node_id in self.known_hubs else None,
'public_port': self.known_hubs.get(self.local_node_id, {}).public_port if self.local_node_id in self.known_hubs else None
})
# Get blockchain credentials
credentials = self._get_blockchain_credentials()
# Build response
response = {
'type': 'join_response',
'island_id': self.island_id,
'island_name': self.island_name,
'island_chain_id': self.island_chain_id or f"ait-{self.island_id[:8]}",
'members': members,
'credentials': credentials
}
logger.info(f"Sent join_response to node {join_request.get('node_id')} with {len(members)} members")
return response
except Exception as e:
logger.error(f"Error handling join request: {e}")
return None
def register_gpu_offer(self, offer_data: dict) -> bool:
"""Register a GPU marketplace offer in the hub"""
try:
offer_id = offer_data.get('offer_id')
if offer_id:
self.gpu_offers[offer_id] = offer_data
logger.info(f"Registered GPU offer: {offer_id}")
return True
except Exception as e:
logger.error(f"Error registering GPU offer: {e}")
return False
def register_gpu_bid(self, bid_data: dict) -> bool:
"""Register a GPU marketplace bid in the hub"""
try:
bid_id = bid_data.get('bid_id')
if bid_id:
self.gpu_bids[bid_id] = bid_data
logger.info(f"Registered GPU bid: {bid_id}")
return True
except Exception as e:
logger.error(f"Error registering GPU bid: {e}")
return False
def register_gpu_provider(self, node_id: str, gpu_info: dict) -> bool:
"""Register a GPU provider in the hub"""
try:
self.gpu_providers[node_id] = gpu_info
logger.info(f"Registered GPU provider: {node_id}")
return True
except Exception as e:
logger.error(f"Error registering GPU provider: {e}")
return False
def register_exchange_order(self, order_data: dict) -> bool:
"""Register an exchange order in the hub"""
try:
order_id = order_data.get('order_id')
if order_id:
self.exchange_orders[order_id] = order_data
# Update order book
pair = order_data.get('pair')
side = order_data.get('side')
if pair and side:
if pair not in self.exchange_order_books:
self.exchange_order_books[pair] = {'bids': [], 'asks': []}
if side == 'buy':
self.exchange_order_books[pair]['bids'].append(order_data)
elif side == 'sell':
self.exchange_order_books[pair]['asks'].append(order_data)
logger.info(f"Registered exchange order: {order_id}")
return True
except Exception as e:
logger.error(f"Error registering exchange order: {e}")
return False
def get_gpu_offers(self) -> list:
"""Get all GPU offers"""
return list(self.gpu_offers.values())
def get_gpu_bids(self) -> list:
"""Get all GPU bids"""
return list(self.gpu_bids.values())
def get_gpu_providers(self) -> list:
"""Get all GPU providers"""
return list(self.gpu_providers.values())
def get_exchange_order_book(self, pair: str) -> dict:
"""Get order book for a specific trading pair"""
return self.exchange_order_books.get(pair, {'bids': [], 'asks': []})
async def register_as_hub(self, public_address: Optional[str] = None, public_port: Optional[int] = None) -> bool:
"""Register this node as a hub"""
if self.is_hub:
logger.warning("Already registered as hub")
return False
self.is_hub = True
self.hub_status = HubStatus.REGISTERED
self.registered_at = time.time()
# Add self to known hubs
hub_info = HubInfo(
node_id=self.local_node_id,
address=self.local_address,
port=self.local_port,
island_id=self.island_id,
island_name=self.island_name,
public_address=public_address,
public_port=public_port,
registered_at=time.time(),
last_seen=time.time()
)
self.known_hubs[self.local_node_id] = hub_info
# Persist to Redis
await self._persist_hub_registration(hub_info)
logger.info(f"Registered as hub for island {self.island_id}")
return True
async def unregister_as_hub(self) -> bool:
"""Unregister this node as a hub"""
if not self.is_hub:
logger.warning("Not registered as hub")
return False
self.is_hub = False
self.hub_status = HubStatus.UNREGISTERED
self.registered_at = None
# Remove from Redis
await self._remove_hub_registration(self.local_node_id)
# Remove self from known hubs
if self.local_node_id in self.known_hubs:
del self.known_hubs[self.local_node_id]
logger.info(f"Unregistered as hub for island {self.island_id}")
return True
def register_peer(self, peer_info: PeerInfo) -> bool:
"""Register a peer in the registry"""
self.peer_registry[peer_info.node_id] = peer_info
# Add to island peers
if peer_info.island_id not in self.island_peers:
self.island_peers[peer_info.island_id] = set()
self.island_peers[peer_info.island_id].add(peer_info.node_id)
# Update hub peer count if peer is a hub
if peer_info.is_hub and peer_info.node_id in self.known_hubs:
self.known_hubs[peer_info.node_id].peer_count = len(self.island_peers.get(peer_info.island_id, set()))
logger.debug(f"Registered peer {peer_info.node_id} in island {peer_info.island_id}")
return True
def unregister_peer(self, node_id: str) -> bool:
"""Unregister a peer from the registry"""
if node_id not in self.peer_registry:
return False
peer_info = self.peer_registry[node_id]
# Remove from island peers
if peer_info.island_id in self.island_peers:
self.island_peers[peer_info.island_id].discard(node_id)
del self.peer_registry[node_id]
# Update hub peer count
if node_id in self.known_hubs:
self.known_hubs[node_id].peer_count = len(self.island_peers.get(self.known_hubs[node_id].island_id, set()))
logger.debug(f"Unregistered peer {node_id}")
return True
def add_known_hub(self, hub_info: HubInfo):
"""Add a known hub to the registry"""
self.known_hubs[hub_info.node_id] = hub_info
logger.info(f"Added known hub {hub_info.node_id} for island {hub_info.island_id}")
def remove_known_hub(self, node_id: str) -> bool:
"""Remove a known hub from the registry"""
if node_id not in self.known_hubs:
return False
del self.known_hubs[node_id]
logger.info(f"Removed known hub {node_id}")
return True
def get_peer_list(self, island_id: str) -> List[PeerInfo]:
"""Get peer list for a specific island"""
peers = []
for node_id, peer_info in self.peer_registry.items():
if peer_info.island_id == island_id:
peers.append(peer_info)
return peers
def get_hub_list(self, island_id: Optional[str] = None) -> List[HubInfo]:
"""Get list of known hubs, optionally filtered by island"""
hubs = []
for hub_info in self.known_hubs.values():
if island_id is None or hub_info.island_id == island_id:
hubs.append(hub_info)
return hubs
def get_island_peers(self, island_id: str) -> Set[str]:
"""Get set of peer node IDs in an island"""
return self.island_peers.get(island_id, set()).copy()
def get_peer_count(self, island_id: str) -> int:
"""Get number of peers in an island"""
return len(self.island_peers.get(island_id, set()))
def get_hub_info(self, node_id: str) -> Optional[HubInfo]:
"""Get information about a specific hub"""
return self.known_hubs.get(node_id)
def get_peer_info(self, node_id: str) -> Optional[PeerInfo]:
"""Get information about a specific peer"""
return self.peer_registry.get(node_id)
def update_peer_last_seen(self, node_id: str):
"""Update the last seen time for a peer"""
if node_id in self.peer_registry:
self.peer_registry[node_id].last_seen = time.time()
if node_id in self.known_hubs:
self.known_hubs[node_id].last_seen = time.time()
async def start(self):
"""Start hub manager"""
self.running = True
logger.info(f"Starting hub manager for node {self.local_node_id}")
# Start background tasks
tasks = [
asyncio.create_task(self._hub_health_check()),
asyncio.create_task(self._peer_cleanup())
]
try:
await asyncio.gather(*tasks)
except Exception as e:
logger.error(f"Hub manager error: {e}")
finally:
self.running = False
async def stop(self):
"""Stop hub manager"""
self.running = False
logger.info("Stopping hub manager")
async def _hub_health_check(self):
"""Check health of known hubs"""
while self.running:
try:
current_time = time.time()
# Check for offline hubs (not seen for 10 minutes)
offline_hubs = []
for node_id, hub_info in self.known_hubs.items():
if current_time - hub_info.last_seen > 600:
offline_hubs.append(node_id)
logger.warning(f"Hub {node_id} appears to be offline")
# Remove offline hubs (keep self if we're a hub)
for node_id in offline_hubs:
if node_id != self.local_node_id:
self.remove_known_hub(node_id)
await asyncio.sleep(60) # Check every minute
except Exception as e:
logger.error(f"Hub health check error: {e}")
await asyncio.sleep(10)
async def _peer_cleanup(self):
"""Clean up stale peer entries"""
while self.running:
try:
current_time = time.time()
# Remove peers not seen for 5 minutes
stale_peers = []
for node_id, peer_info in self.peer_registry.items():
if current_time - peer_info.last_seen > 300:
stale_peers.append(node_id)
for node_id in stale_peers:
self.unregister_peer(node_id)
logger.debug(f"Removed stale peer {node_id}")
await asyncio.sleep(60) # Check every minute
except Exception as e:
logger.error(f"Peer cleanup error: {e}")
await asyncio.sleep(10)
# Global hub manager instance
hub_manager_instance: Optional[HubManager] = None
def get_hub_manager() -> Optional[HubManager]:
"""Get global hub manager instance"""
return hub_manager_instance
def create_hub_manager(node_id: str, address: str, port: int, island_id: str, island_name: str) -> HubManager:
"""Create and set global hub manager instance"""
global hub_manager_instance
hub_manager_instance = HubManager(node_id, address, port, island_id, island_name)
return hub_manager_instance

View File

@@ -7,8 +7,10 @@ from psycopg2.extras import RealDictCursor
import json
from decimal import Decimal
from aitbc.constants import DATA_DIR
# Database configurations
SQLITE_DB = "/var/lib/aitbc/data/coordinator.db"
SQLITE_DB = str(DATA_DIR / "data/coordinator.db")
PG_CONFIG = {
"host": "localhost",
"database": "aitbc_coordinator",

View File

@@ -12,7 +12,7 @@ from decimal import Decimal
from enum import StrEnum
from typing import Any
from aitbc import get_logger
from aitbc import get_logger, derive_ethereum_address, sign_transaction_hash, verify_signature, encrypt_private_key, Web3Client
logger = get_logger(__name__)
@@ -174,6 +174,8 @@ class EthereumWalletAdapter(EnhancedWalletAdapter):
def __init__(self, chain_id: int, rpc_url: str, security_level: SecurityLevel = SecurityLevel.MEDIUM):
super().__init__(chain_id, ChainType.ETHEREUM, rpc_url, security_level)
self.chain_id = chain_id
# Initialize Web3 client for blockchain operations
self._web3_client = Web3Client(rpc_url)
async def create_wallet(self, owner_address: str, security_config: dict[str, Any]) -> dict[str, Any]:
"""Create a new Ethereum wallet with enhanced security"""
@@ -446,25 +448,36 @@ class EthereumWalletAdapter(EnhancedWalletAdapter):
# Private helper methods
async def _derive_address_from_private_key(self, private_key: str) -> str:
"""Derive Ethereum address from private key"""
# This would use actual Ethereum cryptography
# For now, return a mock address
return f"0x{hashlib.sha256(private_key.encode()).hexdigest()[:40]}"
try:
return derive_ethereum_address(private_key)
except Exception as e:
logger.error(f"Failed to derive address from private key: {e}")
raise
async def _encrypt_private_key(self, private_key: str, security_config: dict[str, Any]) -> str:
"""Encrypt private key with security configuration"""
# This would use actual encryption
# For now, return mock encrypted key
return f"encrypted_{hashlib.sha256(private_key.encode()).hexdigest()}"
try:
password = security_config.get("encryption_password", "default_password")
return encrypt_private_key(private_key, password)
except Exception as e:
logger.error(f"Failed to encrypt private key: {e}")
raise
async def _get_eth_balance(self, address: str) -> str:
"""Get ETH balance in wei"""
# Mock implementation
return "1000000000000000000" # 1 ETH in wei
try:
return self._web3_client.get_eth_balance(address)
except Exception as e:
logger.error(f"Failed to get ETH balance: {e}")
raise
async def _get_token_balance(self, address: str, token_address: str) -> dict[str, Any]:
"""Get ERC-20 token balance"""
# Mock implementation
return {"balance": "100000000000000000000", "decimals": 18, "symbol": "TOKEN"} # 100 tokens
try:
return self._web3_client.get_token_balance(address, token_address)
except Exception as e:
logger.error(f"Failed to get token balance: {e}")
raise
async def _create_erc20_transfer(
self, from_address: str, to_address: str, token_address: str, amount: int
@@ -493,78 +506,116 @@ class EthereumWalletAdapter(EnhancedWalletAdapter):
async def _get_gas_price(self) -> int:
"""Get current gas price"""
# Mock implementation
return 20000000000 # 20 Gwei in wei
try:
return self._web3_client.get_gas_price()
except Exception as e:
logger.error(f"Failed to get gas price: {e}")
raise
async def _get_gas_price_gwei(self) -> float:
"""Get current gas price in Gwei"""
gas_price_wei = await self._get_gas_price()
return gas_price_wei / 10**9
try:
return self._web3_client.get_gas_price_gwei()
except Exception as e:
logger.error(f"Failed to get gas price in Gwei: {e}")
raise
async def _get_nonce(self, address: str) -> int:
"""Get transaction nonce for address"""
# Mock implementation
return 0
try:
return self._web3_client.get_nonce(address)
except Exception as e:
logger.error(f"Failed to get nonce: {e}")
raise
async def _sign_transaction(self, transaction_data: dict[str, Any], from_address: str) -> str:
"""Sign transaction"""
# Mock implementation
return f"0xsigned_{hashlib.sha256(str(transaction_data).encode()).hexdigest()}"
try:
# Get the transaction hash
from eth_account import Account
# Remove 0x prefix if present
if from_address.startswith("0x"):
from_address = from_address[2:]
account = Account.from_key(from_address)
# Build transaction dict for signing
tx_dict = {
'nonce': int(transaction_data.get('nonce', 0), 16),
'gasPrice': int(transaction_data.get('gasPrice', 0), 16),
'gas': int(transaction_data.get('gas', 0), 16),
'to': transaction_data.get('to'),
'value': int(transaction_data.get('value', '0x0'), 16),
'data': transaction_data.get('data', '0x'),
'chainId': transaction_data.get('chainId', 1)
}
signed_tx = account.sign_transaction(tx_dict)
return signed_tx.raw_transaction.hex()
except ImportError:
raise ImportError("eth-account is required for transaction signing. Install with: pip install eth-account")
except Exception as e:
logger.error(f"Failed to sign transaction: {e}")
raise
async def _send_raw_transaction(self, signed_transaction: str) -> str:
"""Send raw transaction"""
# Mock implementation
return f"0x{hashlib.sha256(signed_transaction.encode()).hexdigest()}"
try:
return self._web3_client.send_raw_transaction(signed_transaction)
except Exception as e:
logger.error(f"Failed to send raw transaction: {e}")
raise
async def _get_transaction_receipt(self, tx_hash: str) -> dict[str, Any] | None:
"""Get transaction receipt"""
# Mock implementation
return {
"status": 1,
"blockNumber": "0x12345",
"blockHash": "0xabcdef",
"gasUsed": "0x5208",
"effectiveGasPrice": "0x4a817c800",
"logs": [],
}
try:
return self._web3_client.get_transaction_receipt(tx_hash)
except Exception as e:
logger.error(f"Failed to get transaction receipt: {e}")
raise
async def _get_transaction_by_hash(self, tx_hash: str) -> dict[str, Any]:
"""Get transaction by hash"""
# Mock implementation
return {"from": "0xsender", "to": "0xreceiver", "value": "0xde0b6b3a7640000", "data": "0x"} # 1 ETH in wei
try:
return self._web3_client.get_transaction_by_hash(tx_hash)
except Exception as e:
logger.error(f"Failed to get transaction by hash: {e}")
raise
async def _estimate_gas_call(self, call_data: dict[str, Any]) -> str:
"""Estimate gas for call"""
# Mock implementation
return "0x5208" # 21000 in hex
try:
gas_estimate = self._web3_client.estimate_gas(call_data)
return hex(gas_estimate)
except Exception as e:
logger.error(f"Failed to estimate gas: {e}")
raise
async def _get_wallet_transactions(
self, address: str, limit: int, offset: int, from_block: int | None, to_block: int | None
) -> list[dict[str, Any]]:
"""Get wallet transactions"""
# Mock implementation
return [
{
"hash": f"0x{hashlib.sha256(f'tx_{i}'.encode()).hexdigest()}",
"from": address,
"to": f"0x{hashlib.sha256(f'to_{i}'.encode()).hexdigest()[:40]}",
"value": "0xde0b6b3a7640000",
"blockNumber": f"0x{12345 + i}",
"timestamp": datetime.utcnow().timestamp(),
"gasUsed": "0x5208",
}
for i in range(min(limit, 10))
]
try:
return self._web3_client.get_wallet_transactions(address, limit)
except Exception as e:
logger.error(f"Failed to get wallet transactions: {e}")
raise
async def _sign_hash(self, message_hash: str, private_key: str) -> str:
"""Sign a hash with private key"""
# Mock implementation
return f"0x{hashlib.sha256(f'{message_hash}{private_key}'.encode()).hexdigest()}"
try:
return sign_transaction_hash(message_hash, private_key)
except Exception as e:
logger.error(f"Failed to sign hash: {e}")
raise
async def _verify_signature(self, message_hash: str, signature: str, address: str) -> bool:
"""Verify a signature"""
# Mock implementation
return True
try:
return verify_signature(message_hash, signature, address)
except Exception as e:
logger.error(f"Failed to verify signature: {e}")
return False
class PolygonWalletAdapter(EthereumWalletAdapter):

View File

@@ -7,7 +7,6 @@ Multi-chain trading with cross-chain swaps and bridging
import sqlite3
import json
import asyncio
import httpx
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from fastapi import FastAPI, HTTPException, Query, BackgroundTasks
@@ -17,8 +16,15 @@ import os
import uuid
import hashlib
from aitbc.http_client import AsyncAITBCHTTPClient
from aitbc.aitbc_logging import get_logger
from aitbc.exceptions import NetworkError
app = FastAPI(title="AITBC Complete Cross-Chain Exchange", version="3.0.0")
# Initialize logger
logger = get_logger(__name__)
# Database configuration
DB_PATH = os.path.join(os.path.dirname(__file__), "exchange_multichain.db")
@@ -368,10 +374,10 @@ async def health_check():
if chain_info["status"] == "active" and chain_info["blockchain_url"]:
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{chain_info['blockchain_url']}/health", timeout=5.0)
chain_status[chain_id]["connected"] = response.status_code == 200
except:
client = AsyncAITBCHTTPClient(base_url=chain_info['blockchain_url'], timeout=5)
response = await client.async_get("/health")
chain_status[chain_id]["connected"] = response is not None
except NetworkError:
pass
return {

View File

@@ -7,7 +7,6 @@ Complete multi-chain trading with chain isolation
import sqlite3
import json
import asyncio
import httpx
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from fastapi import FastAPI, HTTPException, Query, BackgroundTasks
@@ -15,8 +14,15 @@ from pydantic import BaseModel, Field
import uvicorn
import os
from aitbc.http_client import AsyncAITBCHTTPClient
from aitbc.aitbc_logging import get_logger
from aitbc.exceptions import NetworkError
app = FastAPI(title="AITBC Multi-Chain Exchange", version="2.0.0")
# Initialize logger
logger = get_logger(__name__)
# Database configuration
DB_PATH = os.path.join(os.path.dirname(__file__), "exchange_multichain.db")
@@ -145,10 +151,10 @@ async def verify_chain_transaction(chain_id: str, tx_hash: str) -> bool:
return False
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{chain_info['blockchain_url']}/api/v1/transactions/{tx_hash}")
return response.status_code == 200
except:
client = AsyncAITBCHTTPClient(base_url=chain_info['blockchain_url'], timeout=5)
response = await client.async_get(f"/api/v1/transactions/{tx_hash}")
return response is not None
except NetworkError:
return False
async def submit_chain_transaction(chain_id: str, order_data: Dict) -> Optional[str]:
@@ -161,15 +167,12 @@ async def submit_chain_transaction(chain_id: str, order_data: Dict) -> Optional[
return None
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{chain_info['blockchain_url']}/api/v1/transactions",
json=order_data
)
if response.status_code == 200:
return response.json().get("tx_hash")
except Exception as e:
print(f"Chain transaction error: {e}")
client = AsyncAITBCHTTPClient(base_url=chain_info['blockchain_url'], timeout=10)
response = await client.async_post("/api/v1/transactions", json=order_data)
if response:
return response.get("tx_hash")
except NetworkError as e:
logger.error(f"Chain transaction error: {e}")
return None
@@ -188,10 +191,10 @@ async def health_check():
if chain_info["status"] == "active" and chain_info["blockchain_url"]:
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{chain_info['blockchain_url']}/health", timeout=5.0)
chain_status[chain_id]["connected"] = response.status_code == 200
except:
client = AsyncAITBCHTTPClient(base_url=chain_info['blockchain_url'], timeout=5)
response = await client.async_get("/health")
chain_status[chain_id]["connected"] = response is not None
except NetworkError:
pass
return {

View File

@@ -4,7 +4,6 @@ Simple AITBC Blockchain Explorer - Demonstrating the issues described in the ana
"""
import asyncio
import httpx
import re
from datetime import datetime
from typing import Dict, Any, Optional
@@ -12,8 +11,15 @@ from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
import uvicorn
from aitbc.http_client import AsyncAITBCHTTPClient
from aitbc.aitbc_logging import get_logger
from aitbc.exceptions import NetworkError
app = FastAPI(title="Simple AITBC Explorer", version="0.1.0")
# Initialize logger
logger = get_logger(__name__)
# Configuration
BLOCKCHAIN_RPC_URL = "http://localhost:8025"
@@ -174,12 +180,12 @@ HTML_TEMPLATE = """
async def get_chain_head():
"""Get current chain head"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{BLOCKCHAIN_RPC_URL}/rpc/head")
if response.status_code == 200:
return response.json()
except Exception as e:
print(f"Error getting chain head: {e}")
client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10)
response = await client.async_get("/rpc/head")
if response:
return response
except NetworkError as e:
logger.error(f"Error getting chain head: {e}")
return {"height": 0, "hash": "", "timestamp": None}
@app.get("/api/blocks/{height}")
@@ -189,12 +195,12 @@ async def get_block(height: int):
if height < 0 or height > 10000000:
return {"height": height, "hash": "", "timestamp": None, "transactions": []}
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{BLOCKCHAIN_RPC_URL}/rpc/blocks/{height}")
if response.status_code == 200:
return response.json()
except Exception as e:
print(f"Error getting block {height}: {e}")
client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10)
response = await client.async_get(f"/rpc/blocks/{height}")
if response:
return response
except NetworkError as e:
logger.error(f"Error getting block: {e}")
return {"height": height, "hash": "", "timestamp": None, "transactions": []}
@app.get("/api/transactions/{tx_hash}")
@@ -203,26 +209,21 @@ async def get_transaction(tx_hash: str):
if not validate_tx_hash(tx_hash):
return {"hash": tx_hash, "from": "unknown", "to": "unknown", "amount": 0, "timestamp": None}
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{BLOCKCHAIN_RPC_URL}/rpc/tx/{tx_hash}")
if response.status_code == 200:
tx_data = response.json()
client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10)
response = await client.async_get(f"/rpc/tx/{tx_hash}")
if response:
# Problem 2: Map RPC schema to UI schema
return {
"hash": tx_data.get("tx_hash", tx_hash), # tx_hash -> hash
"from": tx_data.get("sender", "unknown"), # sender -> from
"to": tx_data.get("recipient", "unknown"), # recipient -> to
"amount": tx_data.get("payload", {}).get("value", "0"), # payload.value -> amount
"fee": tx_data.get("payload", {}).get("fee", "0"), # payload.fee -> fee
"timestamp": tx_data.get("created_at"), # created_at -> timestamp
"block_height": tx_data.get("block_height", "pending")
"hash": response.get("tx_hash", tx_hash), # tx_hash -> hash
"from": response.get("sender", "unknown"), # sender -> from
"to": response.get("recipient", "unknown"), # recipient -> to
"amount": response.get("payload", {}).get("value", "0"), # payload.value -> amount
"fee": response.get("payload", {}).get("fee", "0"), # payload.fee -> fee
"timestamp": response.get("created_at"), # created_at -> timestamp
"block_height": response.get("block_height", "pending")
}
elif response.status_code == 404:
raise HTTPException(status_code=404, detail="Transaction not found")
except HTTPException:
raise
except Exception as e:
print(f"Error getting transaction {tx_hash}: {e}")
except NetworkError as e:
logger.error(f"Error getting transaction {tx_hash}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to fetch transaction: {str(e)}")
# Missing: @app.get("/api/transactions/{tx_hash}") - THIS IS THE PROBLEM

View File

@@ -16,6 +16,8 @@ from pathlib import Path
import os
import sys
from aitbc.constants import KEYSTORE_DIR
# Add CLI utils to path
sys.path.insert(0, '/opt/aitbc/cli')
@@ -23,7 +25,7 @@ sys.path.insert(0, '/opt/aitbc/cli')
app = FastAPI(title="AITBC Wallet Daemon", debug=False)
# Configuration
KEYSTORE_PATH = Path("/var/lib/aitbc/keystore")
KEYSTORE_PATH = KEYSTORE_DIR
BLOCKCHAIN_RPC_URL = "http://localhost:8006"
CHAIN_ID = "ait-mainnet"

View File

@@ -42,6 +42,7 @@ from aitbc import (
AITBCHTTPClient, NetworkError, ValidationError, ConfigurationError,
get_logger, get_keystore_path, ensure_dir, validate_address, validate_url
)
from aitbc.paths import get_blockchain_data_path, get_data_path
# Initialize logger
logger = get_logger(__name__)
@@ -1189,7 +1190,8 @@ def agent_operations(action: str, **kwargs) -> Optional[Dict]:
from sqlmodel import create_engine, Session, select
from aitbc_chain.models import Transaction
engine = create_engine("sqlite:////var/lib/aitbc/data/ait-mainnet/chain.db")
chain_db_path = get_blockchain_data_path("ait-mainnet") / "chain.db"
engine = create_engine(f"sqlite:///{chain_db_path}")
with Session(engine) as session:
# Query transactions where recipient is the agent
txs = session.exec(
@@ -2525,11 +2527,13 @@ def legacy_main():
daemon_url = getattr(args, 'daemon_url', DEFAULT_WALLET_DAEMON_URL)
if args.wallet_action == "backup":
print(f"Wallet backup: {args.name}")
print(f" Backup created: /var/lib/aitbc/backups/{args.name}_$(date +%Y%m%d).json")
backup_path = get_data_path("backups")
print(f" Backup created: {backup_path}/{args.name}_$(date +%Y%m%d).json")
print(f" Status: completed")
elif args.wallet_action == "export":
print(f"Wallet export: {args.name}")
print(f" Export file: /var/lib/aitbc/exports/{args.name}_private.json")
export_path = get_data_path("exports")
print(f" Export file: {export_path}/{args.name}_private.json")
print(f" Status: completed")
elif args.wallet_action == "sync":
if args.all:
@@ -2602,12 +2606,14 @@ def legacy_main():
elif args.command == "wallet-backup":
print(f"Wallet backup: {args.name}")
print(f" Backup created: /var/lib/aitbc/backups/{args.name}_$(date +%Y%m%d).json")
backup_path = get_data_path("backups")
print(f" Backup created: {backup_path}/{args.name}_$(date +%Y%m%d).json")
print(f" Status: completed")
elif args.command == "wallet-export":
print(f"Wallet export: {args.name}")
print(f" Export file: /var/lib/aitbc/exports/{args.name}_private.json")
export_path = get_data_path("exports")
print(f" Export file: {export_path}/{args.name}_private.json")
print(f" Status: completed")
elif args.command == "wallet-sync":

View File

@@ -12,9 +12,10 @@ from pathlib import Path
from typing import Optional, Dict, Any, List
import requests
import getpass
from aitbc.paths import get_keystore_path
# Default paths
DEFAULT_KEYSTORE_DIR = Path("/var/lib/aitbc/keystore")
DEFAULT_KEYSTORE_DIR = get_keystore_path()
DEFAULT_RPC_URL = "http://localhost:8006"
def get_password(password_arg: str = None, password_file: str = None) -> str:

View File

@@ -2,8 +2,9 @@ import json
import os
import time
import uuid
from aitbc.paths import get_data_path
STATE_FILE = "/var/lib/aitbc/data/cli_extended_state.json"
STATE_FILE = str(get_data_path("data/cli_extended_state.json"))
def load_state():
if os.path.exists(STATE_FILE):
@@ -82,10 +83,10 @@ def handle_extended_command(command, args, kwargs):
result["nodes_reached"] = 2
elif command == "wallet_backup":
result["path"] = f"/var/lib/aitbc/backups/{kwargs.get('name')}.backup"
result["path"] = f"{get_data_path('backups')}/{kwargs.get('name')}.backup"
elif command == "wallet_export":
result["path"] = f"/var/lib/aitbc/exports/{kwargs.get('name')}.key"
result["path"] = f"{get_data_path('exports')}/{kwargs.get('name')}.key"
elif command == "wallet_sync":
result["status"] = "Wallets synchronized"
@@ -275,7 +276,7 @@ def handle_extended_command(command, args, kwargs):
elif command == "compliance_report":
result["format"] = kwargs.get("format")
result["path"] = "/var/lib/aitbc/reports/compliance.pdf"
result["path"] = f"{get_data_path('reports')}/compliance.pdf"
elif command == "script_run":
result["file"] = kwargs.get("file")

View File

@@ -2,6 +2,7 @@
import json
import sys
from aitbc.paths import get_data_path
def handle_wallet_create(args, create_wallet, read_password, first):
@@ -140,7 +141,8 @@ def handle_wallet_backup(args, first):
print("Error: Wallet name is required")
sys.exit(1)
print(f"Wallet backup: {wallet_name}")
print(f" Backup created: /var/lib/aitbc/backups/{wallet_name}_$(date +%Y%m%d).json")
backup_path = get_data_path("backups")
print(f" Backup created: {backup_path}/{wallet_name}_$(date +%Y%m%d).json")
print(" Status: completed")

View File

@@ -42,8 +42,10 @@ def decrypt_private_key(keystore_data: Dict[str, Any], password: str) -> str:
return decrypted.decode()
def load_keystore(address: str, keystore_dir: Path | str = "/var/lib/aitbc/keystore") -> Dict[str, Any]:
def load_keystore(address: str, keystore_dir: Path | str = None) -> Dict[str, Any]:
"""Load keystore file for a given address."""
if keystore_dir is None:
keystore_dir = get_keystore_path()
keystore_dir = Path(keystore_dir)
keystore_file = keystore_dir / f"{address}.json"