Expand aitbc package with new utility modules and enhanced HTTP client
- Add new exception types: RetryError, CircuitBreakerOpenError, RateLimitError - Enhance AITBCHTTPClient with retry logic, caching, circuit breaker, and rate limiting - Add AsyncAITBCHTTPClient for async HTTP operations - Add crypto module with Ethereum key derivation, signing, encryption, and hashing utilities - Add web3_utils module with Web3Client and create_web3_client - Add security module with token generation, API key management
This commit is contained in:
@@ -29,6 +29,9 @@ from .exceptions import (
|
||||
DatabaseError,
|
||||
ValidationError,
|
||||
BridgeError,
|
||||
RetryError,
|
||||
CircuitBreakerOpenError,
|
||||
RateLimitError,
|
||||
)
|
||||
from .env import (
|
||||
get_env_var,
|
||||
@@ -60,7 +63,7 @@ from .json_utils import (
|
||||
set_nested_value,
|
||||
flatten_json,
|
||||
)
|
||||
from .http_client import AITBCHTTPClient
|
||||
from .http_client import AITBCHTTPClient, AsyncAITBCHTTPClient
|
||||
from .config import BaseAITBCConfig, AITBCConfig
|
||||
from .decorators import (
|
||||
retry,
|
||||
@@ -105,6 +108,142 @@ from .monitoring import (
|
||||
PerformanceTimer,
|
||||
HealthChecker,
|
||||
)
|
||||
from .crypto import (
|
||||
derive_ethereum_address,
|
||||
sign_transaction_hash,
|
||||
verify_signature,
|
||||
encrypt_private_key,
|
||||
decrypt_private_key,
|
||||
generate_secure_random_bytes,
|
||||
keccak256_hash,
|
||||
sha256_hash,
|
||||
validate_ethereum_address,
|
||||
generate_ethereum_private_key,
|
||||
)
|
||||
from .web3_utils import Web3Client, create_web3_client
|
||||
from .security import (
|
||||
generate_token,
|
||||
generate_api_key,
|
||||
validate_token_format,
|
||||
validate_api_key,
|
||||
SessionManager,
|
||||
APIKeyManager,
|
||||
generate_secure_random_string,
|
||||
generate_secure_random_int,
|
||||
SecretManager,
|
||||
hash_password,
|
||||
verify_password,
|
||||
generate_nonce,
|
||||
generate_hmac,
|
||||
verify_hmac,
|
||||
)
|
||||
from .time_utils import (
|
||||
get_utc_now,
|
||||
get_timestamp_utc,
|
||||
format_iso8601,
|
||||
parse_iso8601,
|
||||
timestamp_to_iso,
|
||||
iso_to_timestamp,
|
||||
format_duration,
|
||||
format_duration_precise,
|
||||
parse_duration,
|
||||
add_duration,
|
||||
subtract_duration,
|
||||
get_time_until,
|
||||
get_time_since,
|
||||
calculate_deadline,
|
||||
is_deadline_passed,
|
||||
get_deadline_remaining,
|
||||
format_time_ago,
|
||||
format_time_in,
|
||||
to_timezone,
|
||||
get_timezone_offset,
|
||||
is_business_hours,
|
||||
get_start_of_day,
|
||||
get_end_of_day,
|
||||
get_start_of_week,
|
||||
get_end_of_week,
|
||||
get_start_of_month,
|
||||
get_end_of_month,
|
||||
sleep_until,
|
||||
retry_until_deadline,
|
||||
Timer,
|
||||
)
|
||||
from .api_utils import (
|
||||
APIResponse,
|
||||
PaginatedResponse,
|
||||
success_response,
|
||||
error_response,
|
||||
not_found_response,
|
||||
unauthorized_response,
|
||||
forbidden_response,
|
||||
validation_error_response,
|
||||
conflict_response,
|
||||
internal_error_response,
|
||||
PaginationParams,
|
||||
paginate_items,
|
||||
build_paginated_response,
|
||||
RateLimitHeaders,
|
||||
build_cors_headers,
|
||||
build_standard_headers,
|
||||
validate_sort_field,
|
||||
validate_sort_order,
|
||||
build_sort_params,
|
||||
filter_fields,
|
||||
exclude_fields,
|
||||
sanitize_response,
|
||||
merge_responses,
|
||||
get_client_ip,
|
||||
get_user_agent,
|
||||
build_request_metadata,
|
||||
)
|
||||
from .events import (
|
||||
Event,
|
||||
EventPriority,
|
||||
EventBus,
|
||||
AsyncEventBus,
|
||||
event_handler,
|
||||
publish_event,
|
||||
get_global_event_bus,
|
||||
set_global_event_bus,
|
||||
EventFilter,
|
||||
EventAggregator,
|
||||
EventRouter,
|
||||
)
|
||||
from .queue import (
|
||||
Job,
|
||||
JobStatus,
|
||||
JobPriority,
|
||||
TaskQueue,
|
||||
JobScheduler,
|
||||
BackgroundTaskManager,
|
||||
WorkerPool,
|
||||
debounce,
|
||||
throttle,
|
||||
)
|
||||
from .state import (
|
||||
StateTransition,
|
||||
StateTransitionError,
|
||||
StatePersistenceError,
|
||||
StateMachine,
|
||||
ConfigurableStateMachine,
|
||||
StatePersistence,
|
||||
AsyncStateMachine,
|
||||
StateMonitor,
|
||||
StateValidator,
|
||||
StateSnapshot,
|
||||
)
|
||||
from .testing import (
|
||||
MockFactory,
|
||||
TestDataGenerator,
|
||||
TestHelpers,
|
||||
MockResponse,
|
||||
MockDatabase,
|
||||
MockCache,
|
||||
mock_async_call,
|
||||
create_mock_config,
|
||||
create_test_scenario,
|
||||
)
|
||||
|
||||
__version__ = "0.6.0"
|
||||
__all__ = [
|
||||
@@ -135,6 +274,9 @@ __all__ = [
|
||||
"DatabaseError",
|
||||
"ValidationError",
|
||||
"BridgeError",
|
||||
"RetryError",
|
||||
"CircuitBreakerOpenError",
|
||||
"RateLimitError",
|
||||
# Environment helpers
|
||||
"get_env_var",
|
||||
"get_required_env_var",
|
||||
@@ -164,6 +306,7 @@ __all__ = [
|
||||
"flatten_json",
|
||||
# HTTP client
|
||||
"AITBCHTTPClient",
|
||||
"AsyncAITBCHTTPClient",
|
||||
# Configuration
|
||||
"BaseAITBCConfig",
|
||||
"AITBCConfig",
|
||||
@@ -205,4 +348,134 @@ __all__ = [
|
||||
"MetricsCollector",
|
||||
"PerformanceTimer",
|
||||
"HealthChecker",
|
||||
# Cryptography
|
||||
"derive_ethereum_address",
|
||||
"sign_transaction_hash",
|
||||
"verify_signature",
|
||||
"encrypt_private_key",
|
||||
"decrypt_private_key",
|
||||
"generate_secure_random_bytes",
|
||||
"keccak256_hash",
|
||||
"sha256_hash",
|
||||
"validate_ethereum_address",
|
||||
"generate_ethereum_private_key",
|
||||
# Web3 utilities
|
||||
"Web3Client",
|
||||
"create_web3_client",
|
||||
# Security
|
||||
"generate_token",
|
||||
"generate_api_key",
|
||||
"validate_token_format",
|
||||
"validate_api_key",
|
||||
"SessionManager",
|
||||
"APIKeyManager",
|
||||
"generate_secure_random_string",
|
||||
"generate_secure_random_int",
|
||||
"SecretManager",
|
||||
"hash_password",
|
||||
"verify_password",
|
||||
"generate_nonce",
|
||||
"generate_hmac",
|
||||
"verify_hmac",
|
||||
# Time utilities
|
||||
"get_utc_now",
|
||||
"get_timestamp_utc",
|
||||
"format_iso8601",
|
||||
"parse_iso8601",
|
||||
"timestamp_to_iso",
|
||||
"iso_to_timestamp",
|
||||
"format_duration",
|
||||
"format_duration_precise",
|
||||
"parse_duration",
|
||||
"add_duration",
|
||||
"subtract_duration",
|
||||
"get_time_until",
|
||||
"get_time_since",
|
||||
"calculate_deadline",
|
||||
"is_deadline_passed",
|
||||
"get_deadline_remaining",
|
||||
"format_time_ago",
|
||||
"format_time_in",
|
||||
"to_timezone",
|
||||
"get_timezone_offset",
|
||||
"is_business_hours",
|
||||
"get_start_of_day",
|
||||
"get_end_of_day",
|
||||
"get_start_of_week",
|
||||
"get_end_of_week",
|
||||
"get_start_of_month",
|
||||
"get_end_of_month",
|
||||
"sleep_until",
|
||||
"retry_until_deadline",
|
||||
"Timer",
|
||||
# API utilities
|
||||
"APIResponse",
|
||||
"PaginatedResponse",
|
||||
"success_response",
|
||||
"error_response",
|
||||
"not_found_response",
|
||||
"unauthorized_response",
|
||||
"forbidden_response",
|
||||
"validation_error_response",
|
||||
"conflict_response",
|
||||
"internal_error_response",
|
||||
"PaginationParams",
|
||||
"paginate_items",
|
||||
"build_paginated_response",
|
||||
"RateLimitHeaders",
|
||||
"build_cors_headers",
|
||||
"build_standard_headers",
|
||||
"validate_sort_field",
|
||||
"validate_sort_order",
|
||||
"build_sort_params",
|
||||
"filter_fields",
|
||||
"exclude_fields",
|
||||
"sanitize_response",
|
||||
"merge_responses",
|
||||
"get_client_ip",
|
||||
"get_user_agent",
|
||||
"build_request_metadata",
|
||||
# Events
|
||||
"Event",
|
||||
"EventPriority",
|
||||
"EventBus",
|
||||
"AsyncEventBus",
|
||||
"event_handler",
|
||||
"publish_event",
|
||||
"get_global_event_bus",
|
||||
"set_global_event_bus",
|
||||
"EventFilter",
|
||||
"EventAggregator",
|
||||
"EventRouter",
|
||||
# Queue
|
||||
"Job",
|
||||
"JobStatus",
|
||||
"JobPriority",
|
||||
"TaskQueue",
|
||||
"JobScheduler",
|
||||
"BackgroundTaskManager",
|
||||
"WorkerPool",
|
||||
"debounce",
|
||||
"throttle",
|
||||
# State
|
||||
"StateTransition",
|
||||
"StateTransitionError",
|
||||
"StatePersistenceError",
|
||||
"StateMachine",
|
||||
"ConfigurableStateMachine",
|
||||
"StatePersistence",
|
||||
"AsyncStateMachine",
|
||||
"StateMonitor",
|
||||
"StateValidator",
|
||||
"StateSnapshot",
|
||||
# Testing
|
||||
"MockFactory",
|
||||
"TestDataGenerator",
|
||||
"TestHelpers",
|
||||
"MockResponse",
|
||||
"MockDatabase",
|
||||
"MockCache",
|
||||
"mock_async_call",
|
||||
"create_mock_config",
|
||||
"create_test_scenario",
|
||||
]
|
||||
|
||||
322
aitbc/api_utils.py
Normal file
322
aitbc/api_utils.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
API utilities for AITBC
|
||||
Provides standard response formatters, pagination helpers, error response builders, and rate limit headers helpers
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, List, Dict, Union
|
||||
from datetime import datetime
|
||||
from fastapi import HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class APIResponse(BaseModel):
|
||||
"""Standard API response model"""
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
timestamp: str = None
|
||||
|
||||
def __init__(self, **data):
|
||||
if 'timestamp' not in data:
|
||||
data['timestamp'] = datetime.utcnow().isoformat()
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""Paginated API response model"""
|
||||
success: bool
|
||||
message: str
|
||||
data: List[Any]
|
||||
pagination: Dict[str, Any]
|
||||
timestamp: str = None
|
||||
|
||||
def __init__(self, **data):
|
||||
if 'timestamp' not in data:
|
||||
data['timestamp'] = datetime.utcnow().isoformat()
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
def success_response(message: str = "Success", data: Optional[Any] = None) -> APIResponse:
|
||||
"""Create a success response"""
|
||||
return APIResponse(success=True, message=message, data=data)
|
||||
|
||||
|
||||
def error_response(message: str, error: Optional[str] = None, status_code: int = 400) -> HTTPException:
|
||||
"""Create an error response"""
|
||||
return HTTPException(
|
||||
status_code=status_code,
|
||||
detail={"success": False, "message": message, "error": error}
|
||||
)
|
||||
|
||||
|
||||
def not_found_response(resource: str = "Resource") -> HTTPException:
|
||||
"""Create a not found response"""
|
||||
return error_response(
|
||||
message=f"{resource} not found",
|
||||
error="NOT_FOUND",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
|
||||
def unauthorized_response(message: str = "Unauthorized") -> HTTPException:
|
||||
"""Create an unauthorized response"""
|
||||
return error_response(
|
||||
message=message,
|
||||
error="UNAUTHORIZED",
|
||||
status_code=401
|
||||
)
|
||||
|
||||
|
||||
def forbidden_response(message: str = "Forbidden") -> HTTPException:
|
||||
"""Create a forbidden response"""
|
||||
return error_response(
|
||||
message=message,
|
||||
error="FORBIDDEN",
|
||||
status_code=403
|
||||
)
|
||||
|
||||
|
||||
def validation_error_response(errors: List[str]) -> HTTPException:
|
||||
"""Create a validation error response"""
|
||||
return error_response(
|
||||
message="Validation failed",
|
||||
error="VALIDATION_ERROR",
|
||||
status_code=422
|
||||
)
|
||||
|
||||
|
||||
def conflict_response(message: str = "Resource conflict") -> HTTPException:
|
||||
"""Create a conflict response"""
|
||||
return error_response(
|
||||
message=message,
|
||||
error="CONFLICT",
|
||||
status_code=409
|
||||
)
|
||||
|
||||
|
||||
def internal_error_response(message: str = "Internal server error") -> HTTPException:
|
||||
"""Create an internal server error response"""
|
||||
return error_response(
|
||||
message=message,
|
||||
error="INTERNAL_ERROR",
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
class PaginationParams:
|
||||
"""Pagination parameters"""
|
||||
|
||||
def __init__(self, page: int = 1, page_size: int = 10, max_page_size: int = 100):
|
||||
"""Initialize pagination parameters"""
|
||||
self.page = max(1, page)
|
||||
self.page_size = min(max_page_size, max(1, page_size))
|
||||
self.offset = (self.page - 1) * self.page_size
|
||||
|
||||
def get_limit(self) -> int:
|
||||
"""Get SQL limit"""
|
||||
return self.page_size
|
||||
|
||||
def get_offset(self) -> int:
|
||||
"""Get SQL offset"""
|
||||
return self.offset
|
||||
|
||||
|
||||
def paginate_items(items: List[Any], page: int = 1, page_size: int = 10) -> Dict[str, Any]:
|
||||
"""Paginate a list of items"""
|
||||
total = len(items)
|
||||
params = PaginationParams(page, page_size)
|
||||
|
||||
paginated_items = items[params.offset:params.offset + params.page_size]
|
||||
total_pages = (total + params.page_size - 1) // params.page_size
|
||||
|
||||
return {
|
||||
"items": paginated_items,
|
||||
"pagination": {
|
||||
"page": params.page,
|
||||
"page_size": params.page_size,
|
||||
"total": total,
|
||||
"total_pages": total_pages,
|
||||
"has_next": params.page < total_pages,
|
||||
"has_prev": params.page > 1
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def build_paginated_response(
|
||||
items: List[Any],
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
message: str = "Success"
|
||||
) -> PaginatedResponse:
|
||||
"""Build a paginated API response"""
|
||||
pagination_data = paginate_items(items, page, page_size)
|
||||
|
||||
return PaginatedResponse(
|
||||
success=True,
|
||||
message=message,
|
||||
data=pagination_data["items"],
|
||||
pagination=pagination_data["pagination"]
|
||||
)
|
||||
|
||||
|
||||
class RateLimitHeaders:
|
||||
"""Rate limit headers helper"""
|
||||
|
||||
@staticmethod
|
||||
def get_headers(
|
||||
limit: int,
|
||||
remaining: int,
|
||||
reset: int,
|
||||
window: int
|
||||
) -> Dict[str, str]:
|
||||
"""Get rate limit headers"""
|
||||
return {
|
||||
"X-RateLimit-Limit": str(limit),
|
||||
"X-RateLimit-Remaining": str(remaining),
|
||||
"X-RateLimit-Reset": str(reset),
|
||||
"X-RateLimit-Window": str(window)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_retry_after(retry_after: int) -> Dict[str, str]:
|
||||
"""Get retry after header"""
|
||||
return {"Retry-After": str(retry_after)}
|
||||
|
||||
|
||||
def build_cors_headers(
|
||||
allowed_origins: List[str] = ["*"],
|
||||
allowed_methods: List[str] = ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allowed_headers: List[str] = ["*"],
|
||||
max_age: int = 3600
|
||||
) -> Dict[str, str]:
|
||||
"""Build CORS headers"""
|
||||
return {
|
||||
"Access-Control-Allow-Origin": ", ".join(allowed_origins),
|
||||
"Access-Control-Allow-Methods": ", ".join(allowed_methods),
|
||||
"Access-Control-Allow-Headers": ", ".join(allowed_headers),
|
||||
"Access-Control-Max-Age": str(max_age)
|
||||
}
|
||||
|
||||
|
||||
def build_standard_headers(
|
||||
content_type: str = "application/json",
|
||||
cache_control: Optional[str] = None,
|
||||
x_request_id: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Build standard response headers"""
|
||||
headers = {
|
||||
"Content-Type": content_type,
|
||||
}
|
||||
|
||||
if cache_control:
|
||||
headers["Cache-Control"] = cache_control
|
||||
|
||||
if x_request_id:
|
||||
headers["X-Request-ID"] = x_request_id
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def validate_sort_field(field: str, allowed_fields: List[str]) -> str:
|
||||
"""Validate and return sort field"""
|
||||
if field not in allowed_fields:
|
||||
raise ValueError(f"Invalid sort field: {field}. Allowed fields: {', '.join(allowed_fields)}")
|
||||
return field
|
||||
|
||||
|
||||
def validate_sort_order(order: str) -> str:
|
||||
"""Validate and return sort order"""
|
||||
order = order.upper()
|
||||
if order not in ["ASC", "DESC"]:
|
||||
raise ValueError(f"Invalid sort order: {order}. Must be 'ASC' or 'DESC'")
|
||||
return order
|
||||
|
||||
|
||||
def build_sort_params(
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: str = "ASC",
|
||||
allowed_fields: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Build sort parameters"""
|
||||
if sort_by and allowed_fields:
|
||||
sort_by = validate_sort_field(sort_by, allowed_fields)
|
||||
sort_order = validate_sort_order(sort_order)
|
||||
return {"sort_by": sort_by, "sort_order": sort_order}
|
||||
return {}
|
||||
|
||||
|
||||
def filter_fields(data: Dict[str, Any], fields: List[str]) -> Dict[str, Any]:
|
||||
"""Filter dictionary to only include specified fields"""
|
||||
return {k: v for k, v in data.items() if k in fields}
|
||||
|
||||
|
||||
def exclude_fields(data: Dict[str, Any], fields: List[str]) -> Dict[str, Any]:
|
||||
"""Exclude specified fields from dictionary"""
|
||||
return {k: v for k, v in data.items() if k not in fields}
|
||||
|
||||
|
||||
def sanitize_response(data: Any, sensitive_fields: List[str] = None) -> Any:
|
||||
"""Sanitize response by masking sensitive fields"""
|
||||
if sensitive_fields is None:
|
||||
sensitive_fields = ["password", "token", "api_key", "secret", "private_key"]
|
||||
|
||||
if isinstance(data, dict):
|
||||
return {
|
||||
k: "***" if any(sensitive in k.lower() for sensitive in sensitive_fields) else sanitize_response(v, sensitive_fields)
|
||||
for k, v in data.items()
|
||||
}
|
||||
elif isinstance(data, list):
|
||||
return [sanitize_response(item, sensitive_fields) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def merge_responses(*responses: Union[APIResponse, Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Merge multiple responses into one"""
|
||||
merged = {"data": {}}
|
||||
|
||||
for response in responses:
|
||||
if isinstance(response, APIResponse):
|
||||
if response.data:
|
||||
if isinstance(response.data, dict):
|
||||
merged["data"].update(response.data)
|
||||
else:
|
||||
merged["data"] = response.data
|
||||
elif isinstance(response, dict):
|
||||
if "data" in response:
|
||||
if isinstance(response["data"], dict):
|
||||
merged["data"].update(response["data"])
|
||||
else:
|
||||
merged["data"] = response["data"]
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def get_client_ip(request) -> str:
|
||||
"""Get client IP address from request"""
|
||||
# Check for forwarded headers first
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
def get_user_agent(request) -> str:
|
||||
"""Get user agent from request"""
|
||||
return request.headers.get("User-Agent", "unknown")
|
||||
|
||||
|
||||
def build_request_metadata(request) -> Dict[str, str]:
|
||||
"""Build request metadata"""
|
||||
return {
|
||||
"client_ip": get_client_ip(request),
|
||||
"user_agent": get_user_agent(request),
|
||||
"request_id": request.headers.get("X-Request-ID", "unknown"),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
174
aitbc/crypto.py
Normal file
174
aitbc/crypto.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Cryptographic utilities for AITBC
|
||||
Provides Ethereum-specific cryptographic operations and security functions
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
import base64
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
|
||||
def derive_ethereum_address(private_key: str) -> str:
|
||||
"""Derive Ethereum address from private key using eth-account"""
|
||||
try:
|
||||
from eth_account import Account
|
||||
# Remove 0x prefix if present
|
||||
if private_key.startswith("0x"):
|
||||
private_key = private_key[2:]
|
||||
|
||||
account = Account.from_key(private_key)
|
||||
return account.address
|
||||
except ImportError:
|
||||
raise ImportError("eth-account is required for Ethereum address derivation. Install with: pip install eth-account")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to derive address from private key: {e}")
|
||||
|
||||
|
||||
def sign_transaction_hash(transaction_hash: str, private_key: str) -> str:
|
||||
"""Sign a transaction hash with private key using eth-account"""
|
||||
try:
|
||||
from eth_account import Account
|
||||
# Remove 0x prefix if present
|
||||
if private_key.startswith("0x"):
|
||||
private_key = private_key[2:]
|
||||
if transaction_hash.startswith("0x"):
|
||||
transaction_hash = transaction_hash[2:]
|
||||
|
||||
account = Account.from_key(private_key)
|
||||
signed_message = account.sign_hash(bytes.fromhex(transaction_hash))
|
||||
return signed_message.signature.hex()
|
||||
except ImportError:
|
||||
raise ImportError("eth-account is required for signing. Install with: pip install eth-account")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to sign transaction hash: {e}")
|
||||
|
||||
|
||||
def verify_signature(message_hash: str, signature: str, address: str) -> bool:
|
||||
"""Verify a signature using eth-account"""
|
||||
try:
|
||||
from eth_account import Account
|
||||
from eth_utils import to_bytes
|
||||
|
||||
# Remove 0x prefixes if present
|
||||
if message_hash.startswith("0x"):
|
||||
message_hash = message_hash[2:]
|
||||
if signature.startswith("0x"):
|
||||
signature = signature[2:]
|
||||
if address.startswith("0x"):
|
||||
address = address[2:]
|
||||
|
||||
message_bytes = to_bytes(hexstr=message_hash)
|
||||
signature_bytes = to_bytes(hexstr=signature)
|
||||
|
||||
recovered_address = Account.recover_message(message_bytes, signature_bytes)
|
||||
return recovered_address.lower() == address.lower()
|
||||
except ImportError:
|
||||
raise ImportError("eth-account and eth-utils are required for signature verification. Install with: pip install eth-account eth-utils")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to verify signature: {e}")
|
||||
|
||||
|
||||
def encrypt_private_key(private_key: str, password: str) -> str:
|
||||
"""Encrypt private key using Fernet symmetric encryption"""
|
||||
try:
|
||||
# Derive key from password
|
||||
password_bytes = password.encode('utf-8')
|
||||
salt = os.urandom(16)
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
)
|
||||
key = base64.urlsafe_b64encode(kdf.derive(password_bytes))
|
||||
|
||||
# Encrypt private key
|
||||
fernet = Fernet(key)
|
||||
encrypted_key = fernet.encrypt(private_key.encode('utf-8'))
|
||||
|
||||
# Combine salt and encrypted key
|
||||
combined = salt + encrypted_key
|
||||
return base64.urlsafe_b64encode(combined).decode('utf-8')
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to encrypt private key: {e}")
|
||||
|
||||
|
||||
def decrypt_private_key(encrypted_key: str, password: str) -> str:
|
||||
"""Decrypt private key using Fernet symmetric encryption"""
|
||||
try:
|
||||
# Decode combined data
|
||||
combined = base64.urlsafe_b64decode(encrypted_key.encode('utf-8'))
|
||||
salt = combined[:16]
|
||||
encrypted_data = combined[16:]
|
||||
|
||||
# Derive key from password
|
||||
password_bytes = password.encode('utf-8')
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
)
|
||||
key = base64.urlsafe_b64encode(kdf.derive(password_bytes))
|
||||
|
||||
# Decrypt private key
|
||||
fernet = Fernet(key)
|
||||
decrypted_key = fernet.decrypt(encrypted_data)
|
||||
return decrypted_key.decode('utf-8')
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to decrypt private key: {e}")
|
||||
|
||||
|
||||
def generate_secure_random_bytes(length: int = 32) -> str:
|
||||
"""Generate cryptographically secure random bytes as hex string"""
|
||||
return os.urandom(length).hex()
|
||||
|
||||
|
||||
def keccak256_hash(data: str) -> str:
|
||||
"""Compute Keccak-256 hash of data"""
|
||||
try:
|
||||
from eth_hash.auto import keccak
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
return keccak(data).hex()
|
||||
except ImportError:
|
||||
raise ImportError("eth-hash is required for Keccak-256 hashing. Install with: pip install eth-hash")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to compute Keccak-256 hash: {e}")
|
||||
|
||||
|
||||
def sha256_hash(data: str) -> str:
|
||||
"""Compute SHA-256 hash of data"""
|
||||
try:
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to compute SHA-256 hash: {e}")
|
||||
|
||||
|
||||
def validate_ethereum_address(address: str) -> bool:
|
||||
"""Validate Ethereum address format and checksum"""
|
||||
try:
|
||||
from eth_utils import is_address, is_checksum_address
|
||||
return is_address(address) and is_checksum_address(address)
|
||||
except ImportError:
|
||||
raise ImportError("eth-utils is required for address validation. Install with: pip install eth-utils")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def generate_ethereum_private_key() -> str:
|
||||
"""Generate a new Ethereum private key"""
|
||||
try:
|
||||
from eth_account import Account
|
||||
account = Account.create()
|
||||
return account.key.hex()
|
||||
except ImportError:
|
||||
raise ImportError("eth-account is required for private key generation. Install with: pip install eth-account")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to generate private key: {e}")
|
||||
267
aitbc/events.py
Normal file
267
aitbc/events.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Event utilities for AITBC
|
||||
Provides event bus implementation, pub/sub patterns, and event decorators
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar, Generic
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
import inspect
|
||||
import functools
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class EventPriority(Enum):
|
||||
"""Event priority levels"""
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
CRITICAL = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
"""Base event class"""
|
||||
event_type: str
|
||||
data: Dict[str, Any]
|
||||
timestamp: datetime = None
|
||||
priority: EventPriority = EventPriority.MEDIUM
|
||||
source: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp is None:
|
||||
self.timestamp = datetime.utcnow()
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""Simple in-memory event bus for pub/sub patterns"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize event bus"""
|
||||
self.subscribers: Dict[str, List[Callable]] = {}
|
||||
self.event_history: List[Event] = []
|
||||
self.max_history = 1000
|
||||
|
||||
def subscribe(self, event_type: str, handler: Callable) -> None:
|
||||
"""Subscribe to an event type"""
|
||||
if event_type not in self.subscribers:
|
||||
self.subscribers[event_type] = []
|
||||
self.subscribers[event_type].append(handler)
|
||||
|
||||
def unsubscribe(self, event_type: str, handler: Callable) -> bool:
|
||||
"""Unsubscribe from an event type"""
|
||||
if event_type in self.subscribers:
|
||||
try:
|
||||
self.subscribers[event_type].remove(handler)
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
return False
|
||||
|
||||
async def publish(self, event: Event) -> None:
|
||||
"""Publish an event to all subscribers"""
|
||||
# Add to history
|
||||
self.event_history.append(event)
|
||||
if len(self.event_history) > self.max_history:
|
||||
self.event_history.pop(0)
|
||||
|
||||
# Notify subscribers
|
||||
handlers = self.subscribers.get(event.event_type, [])
|
||||
|
||||
for handler in handlers:
|
||||
try:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
await handler(event)
|
||||
else:
|
||||
handler(event)
|
||||
except Exception as e:
|
||||
print(f"Error in event handler: {e}")
|
||||
|
||||
def publish_sync(self, event: Event) -> None:
|
||||
"""Publish an event synchronously"""
|
||||
asyncio.run(self.publish(event))
|
||||
|
||||
def get_event_history(self, event_type: Optional[str] = None, limit: int = 100) -> List[Event]:
|
||||
"""Get event history"""
|
||||
events = self.event_history
|
||||
if event_type:
|
||||
events = [e for e in events if e.event_type == event_type]
|
||||
return events[-limit:]
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""Clear event history"""
|
||||
self.event_history.clear()
|
||||
|
||||
|
||||
class AsyncEventBus(EventBus):
|
||||
"""Async event bus with additional features"""
|
||||
|
||||
def __init__(self, max_concurrent_handlers: int = 10):
|
||||
"""Initialize async event bus"""
|
||||
super().__init__()
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent_handlers)
|
||||
|
||||
async def publish(self, event: Event) -> None:
|
||||
"""Publish event with concurrency control"""
|
||||
self.event_history.append(event)
|
||||
if len(self.event_history) > self.max_history:
|
||||
self.event_history.pop(0)
|
||||
|
||||
handlers = self.subscribers.get(event.event_type, [])
|
||||
|
||||
tasks = []
|
||||
for handler in handlers:
|
||||
async def safe_handler():
|
||||
async with self.semaphore:
|
||||
try:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
await handler(event)
|
||||
else:
|
||||
handler(event)
|
||||
except Exception as e:
|
||||
print(f"Error in event handler: {e}")
|
||||
|
||||
tasks.append(safe_handler())
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
def event_handler(event_type: str, event_bus: Optional[EventBus] = None):
|
||||
"""Decorator to register event handler"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
# Use global event bus if none provided
|
||||
bus = event_bus or get_global_event_bus()
|
||||
bus.subscribe(event_type, func)
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def publish_event(event_type: str, data: Dict[str, Any], event_bus: Optional[EventBus] = None) -> None:
|
||||
"""Helper to publish an event"""
|
||||
bus = event_bus or get_global_event_bus()
|
||||
event = Event(event_type=event_type, data=data)
|
||||
bus.publish_sync(event)
|
||||
|
||||
|
||||
# Global event bus instance
|
||||
_global_event_bus: Optional[EventBus] = None
|
||||
|
||||
|
||||
def get_global_event_bus() -> EventBus:
|
||||
"""Get or create global event bus"""
|
||||
global _global_event_bus
|
||||
if _global_event_bus is None:
|
||||
_global_event_bus = EventBus()
|
||||
return _global_event_bus
|
||||
|
||||
|
||||
def set_global_event_bus(bus: EventBus) -> None:
|
||||
"""Set global event bus"""
|
||||
global _global_event_bus
|
||||
_global_event_bus = bus
|
||||
|
||||
|
||||
class EventFilter:
|
||||
"""Filter events based on criteria"""
|
||||
|
||||
def __init__(self, event_bus: Optional[EventBus] = None):
|
||||
"""Initialize event filter"""
|
||||
self.event_bus = event_bus or get_global_event_bus()
|
||||
self.filters: List[Callable[[Event], bool]] = []
|
||||
|
||||
def add_filter(self, filter_func: Callable[[Event], bool]) -> None:
|
||||
"""Add a filter function"""
|
||||
self.filters.append(filter_func)
|
||||
|
||||
def matches(self, event: Event) -> bool:
|
||||
"""Check if event matches all filters"""
|
||||
return all(f(event) for f in self.filters)
|
||||
|
||||
def get_filtered_events(self, event_type: Optional[str] = None, limit: int = 100) -> List[Event]:
|
||||
"""Get filtered events"""
|
||||
events = self.event_bus.get_event_history(event_type, limit)
|
||||
return [e for e in events if self.matches(e)]
|
||||
|
||||
|
||||
class EventAggregator:
|
||||
"""Aggregate events over time windows"""
|
||||
|
||||
def __init__(self, window_seconds: int = 60):
|
||||
"""Initialize event aggregator"""
|
||||
self.window_seconds = window_seconds
|
||||
self.aggregated_events: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def add_event(self, event: Event) -> None:
|
||||
"""Add event to aggregation"""
|
||||
key = event.event_type
|
||||
now = datetime.utcnow()
|
||||
|
||||
if key not in self.aggregated_events:
|
||||
self.aggregated_events[key] = {
|
||||
"count": 0,
|
||||
"first_seen": now,
|
||||
"last_seen": now,
|
||||
"data": {}
|
||||
}
|
||||
|
||||
agg = self.aggregated_events[key]
|
||||
agg["count"] += 1
|
||||
agg["last_seen"] = now
|
||||
|
||||
# Merge data
|
||||
for k, v in event.data.items():
|
||||
if k not in agg["data"]:
|
||||
agg["data"][k] = v
|
||||
elif isinstance(v, (int, float)):
|
||||
agg["data"][k] = agg["data"].get(k, 0) + v
|
||||
|
||||
def get_aggregated_events(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get aggregated events"""
|
||||
# Remove old events
|
||||
now = datetime.utcnow()
|
||||
cutoff = now.timestamp() - self.window_seconds
|
||||
|
||||
to_remove = []
|
||||
for key, agg in self.aggregated_events.items():
|
||||
if agg["last_seen"].timestamp() < cutoff:
|
||||
to_remove.append(key)
|
||||
|
||||
for key in to_remove:
|
||||
del self.aggregated_events[key]
|
||||
|
||||
return self.aggregated_events
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all aggregated events"""
|
||||
self.aggregated_events.clear()
|
||||
|
||||
|
||||
class EventRouter:
|
||||
"""Route events to different handlers based on criteria"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize event router"""
|
||||
self.routes: List[Callable[[Event], Optional[Callable]]] = []
|
||||
|
||||
def add_route(self, condition: Callable[[Event], bool], handler: Callable) -> None:
|
||||
"""Add a route"""
|
||||
self.routes.append((condition, handler))
|
||||
|
||||
async def route(self, event: Event) -> bool:
|
||||
"""Route event to matching handler"""
|
||||
for condition, handler in self.routes:
|
||||
if condition(event):
|
||||
try:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
await handler(event)
|
||||
else:
|
||||
handler(event)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error in routed handler: {e}")
|
||||
return False
|
||||
@@ -42,3 +42,18 @@ class ValidationError(AITBCError):
|
||||
class BridgeError(AITBCError):
|
||||
"""Base exception for bridge errors"""
|
||||
pass
|
||||
|
||||
|
||||
class RetryError(AITBCError):
|
||||
"""Raised when retry attempts are exhausted"""
|
||||
pass
|
||||
|
||||
|
||||
class CircuitBreakerOpenError(AITBCError):
|
||||
"""Raised when circuit breaker is open and requests are rejected"""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(AITBCError):
|
||||
"""Raised when rate limit is exceeded"""
|
||||
pass
|
||||
|
||||
@@ -4,8 +4,13 @@ Base HTTP client with common utilities for AITBC applications
|
||||
"""
|
||||
|
||||
import requests
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from .exceptions import NetworkError
|
||||
from datetime import datetime, timedelta
|
||||
from functools import lru_cache
|
||||
from .exceptions import NetworkError, RetryError, CircuitBreakerOpenError, RateLimitError
|
||||
from .aitbc_logging import get_logger
|
||||
|
||||
|
||||
class AITBCHTTPClient:
|
||||
@@ -18,7 +23,13 @@ class AITBCHTTPClient:
|
||||
self,
|
||||
base_url: str = "",
|
||||
timeout: int = 30,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
max_retries: int = 3,
|
||||
enable_cache: bool = False,
|
||||
cache_ttl: int = 300,
|
||||
enable_logging: bool = False,
|
||||
circuit_breaker_threshold: int = 5,
|
||||
rate_limit: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
Initialize HTTP client.
|
||||
@@ -27,12 +38,37 @@ class AITBCHTTPClient:
|
||||
base_url: Base URL for all requests
|
||||
timeout: Request timeout in seconds
|
||||
headers: Default headers for all requests
|
||||
max_retries: Maximum retry attempts with exponential backoff
|
||||
enable_cache: Enable request/response caching for GET requests
|
||||
cache_ttl: Cache time-to-live in seconds
|
||||
enable_logging: Enable request/response logging
|
||||
circuit_breaker_threshold: Failures before opening circuit breaker
|
||||
rate_limit: Rate limit in requests per minute
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.headers = headers or {}
|
||||
self.max_retries = max_retries
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
self.enable_logging = enable_logging
|
||||
self.circuit_breaker_threshold = circuit_breaker_threshold
|
||||
self.rate_limit = rate_limit
|
||||
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update(self.headers)
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# Cache storage: {url: (data, timestamp)}
|
||||
self._cache: Dict[str, tuple] = {}
|
||||
|
||||
# Circuit breaker state
|
||||
self._failure_count = 0
|
||||
self._circuit_open = False
|
||||
self._circuit_open_time = None
|
||||
|
||||
# Rate limiting state
|
||||
self._request_times: list = []
|
||||
|
||||
def _build_url(self, endpoint: str) -> str:
|
||||
"""
|
||||
@@ -48,6 +84,98 @@ class AITBCHTTPClient:
|
||||
return endpoint
|
||||
return f"{self.base_url}/{endpoint.lstrip('/')}"
|
||||
|
||||
def _check_circuit_breaker(self) -> None:
|
||||
"""Check if circuit breaker is open and raise exception if so."""
|
||||
if self._circuit_open:
|
||||
# Check if circuit should be reset (after 60 seconds)
|
||||
if self._circuit_open_time and (datetime.now() - self._circuit_open_time).total_seconds() > 60:
|
||||
self._circuit_open = False
|
||||
self._failure_count = 0
|
||||
self.logger.info("Circuit breaker reset to half-open state")
|
||||
else:
|
||||
raise CircuitBreakerOpenError("Circuit breaker is open, rejecting request")
|
||||
|
||||
def _record_failure(self) -> None:
|
||||
"""Record a failure and potentially open circuit breaker."""
|
||||
self._failure_count += 1
|
||||
if self._failure_count >= self.circuit_breaker_threshold:
|
||||
self._circuit_open = True
|
||||
self._circuit_open_time = datetime.now()
|
||||
self.logger.warning(f"Circuit breaker opened after {self._failure_count} failures")
|
||||
|
||||
def _check_rate_limit(self) -> None:
|
||||
"""Check if rate limit is exceeded and raise exception if so."""
|
||||
if not self.rate_limit:
|
||||
return
|
||||
|
||||
now = datetime.now()
|
||||
# Remove requests older than 1 minute
|
||||
self._request_times = [t for t in self._request_times if (now - t).total_seconds() < 60]
|
||||
|
||||
if len(self._request_times) >= self.rate_limit:
|
||||
raise RateLimitError(f"Rate limit exceeded: {self.rate_limit} requests per minute")
|
||||
|
||||
def _record_request(self) -> None:
|
||||
"""Record a request timestamp for rate limiting."""
|
||||
if self.rate_limit:
|
||||
self._request_times.append(datetime.now())
|
||||
|
||||
def _get_cache_key(self, url: str, params: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""Generate cache key from URL and params."""
|
||||
if params:
|
||||
import hashlib
|
||||
param_str = str(sorted(params.items()))
|
||||
return f"{url}:{hashlib.md5(param_str.encode()).hexdigest()}"
|
||||
return url
|
||||
|
||||
def _get_cache(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached response if available and not expired."""
|
||||
if not self.enable_cache:
|
||||
return None
|
||||
|
||||
if cache_key in self._cache:
|
||||
data, timestamp = self._cache[cache_key]
|
||||
if (datetime.now() - timestamp).total_seconds() < self.cache_ttl:
|
||||
if self.enable_logging:
|
||||
self.logger.info(f"Cache hit for {cache_key}")
|
||||
return data
|
||||
else:
|
||||
# Expired, remove from cache
|
||||
del self._cache[cache_key]
|
||||
return None
|
||||
|
||||
def _set_cache(self, cache_key: str, data: Dict[str, Any]) -> None:
|
||||
"""Cache response data."""
|
||||
if self.enable_cache:
|
||||
self._cache[cache_key] = (data, datetime.now())
|
||||
if self.enable_logging:
|
||||
self.logger.info(f"Cached response for {cache_key}")
|
||||
|
||||
def _retry_request(self, request_func, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""Execute request with retry logic and exponential backoff."""
|
||||
last_error = None
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
if attempt > 0:
|
||||
backoff_time = 2 ** (attempt - 1)
|
||||
if self.enable_logging:
|
||||
self.logger.info(f"Retry attempt {attempt}/{self.max_retries} after {backoff_time}s backoff")
|
||||
time.sleep(backoff_time)
|
||||
|
||||
return request_func(*args, **kwargs)
|
||||
except requests.RequestException as e:
|
||||
last_error = e
|
||||
if attempt < self.max_retries:
|
||||
if self.enable_logging:
|
||||
self.logger.warning(f"Request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}")
|
||||
continue
|
||||
else:
|
||||
if self.enable_logging:
|
||||
self.logger.error(f"All retry attempts exhausted: {e}")
|
||||
raise RetryError(f"Retry attempts exhausted: {e}")
|
||||
|
||||
raise NetworkError(f"Request failed: {last_error}")
|
||||
|
||||
def get(
|
||||
self,
|
||||
endpoint: str,
|
||||
@@ -67,11 +195,29 @@ class AITBCHTTPClient:
|
||||
|
||||
Raises:
|
||||
NetworkError: If request fails
|
||||
CircuitBreakerOpenError: If circuit breaker is open
|
||||
RateLimitError: If rate limit is exceeded
|
||||
"""
|
||||
url = self._build_url(endpoint)
|
||||
cache_key = self._get_cache_key(url, params)
|
||||
|
||||
# Check cache first
|
||||
cached_data = self._get_cache(cache_key)
|
||||
if cached_data is not None:
|
||||
return cached_data
|
||||
|
||||
# Check circuit breaker and rate limit
|
||||
self._check_circuit_breaker()
|
||||
self._check_rate_limit()
|
||||
|
||||
req_headers = {**self.headers, **(headers or {})}
|
||||
|
||||
try:
|
||||
if self.enable_logging:
|
||||
self.logger.info(f"GET {url} with params={params}")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
def _make_request():
|
||||
response = self.session.get(
|
||||
url,
|
||||
params=params,
|
||||
@@ -80,7 +226,26 @@ class AITBCHTTPClient:
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
try:
|
||||
result = self._retry_request(_make_request)
|
||||
|
||||
# Cache successful GET requests
|
||||
self._set_cache(cache_key, result)
|
||||
|
||||
# Record success for circuit breaker
|
||||
self._failure_count = 0
|
||||
self._record_request()
|
||||
|
||||
if self.enable_logging:
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
self.logger.info(f"GET {url} succeeded in {elapsed:.3f}s")
|
||||
|
||||
return result
|
||||
except (RetryError, CircuitBreakerOpenError, RateLimitError):
|
||||
raise
|
||||
except requests.RequestException as e:
|
||||
self._record_failure()
|
||||
raise NetworkError(f"GET request failed: {e}")
|
||||
|
||||
def post(
|
||||
@@ -104,11 +269,23 @@ class AITBCHTTPClient:
|
||||
|
||||
Raises:
|
||||
NetworkError: If request fails
|
||||
CircuitBreakerOpenError: If circuit breaker is open
|
||||
RateLimitError: If rate limit is exceeded
|
||||
"""
|
||||
url = self._build_url(endpoint)
|
||||
|
||||
# Check circuit breaker and rate limit
|
||||
self._check_circuit_breaker()
|
||||
self._check_rate_limit()
|
||||
|
||||
req_headers = {**self.headers, **(headers or {})}
|
||||
|
||||
try:
|
||||
if self.enable_logging:
|
||||
self.logger.info(f"POST {url} with json={json}")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
def _make_request():
|
||||
response = self.session.post(
|
||||
url,
|
||||
data=data,
|
||||
@@ -118,7 +295,23 @@ class AITBCHTTPClient:
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
try:
|
||||
result = self._retry_request(_make_request)
|
||||
|
||||
# Record success for circuit breaker
|
||||
self._failure_count = 0
|
||||
self._record_request()
|
||||
|
||||
if self.enable_logging:
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
self.logger.info(f"POST {url} succeeded in {elapsed:.3f}s")
|
||||
|
||||
return result
|
||||
except (RetryError, CircuitBreakerOpenError, RateLimitError):
|
||||
raise
|
||||
except requests.RequestException as e:
|
||||
self._record_failure()
|
||||
raise NetworkError(f"POST request failed: {e}")
|
||||
|
||||
def put(
|
||||
@@ -142,11 +335,23 @@ class AITBCHTTPClient:
|
||||
|
||||
Raises:
|
||||
NetworkError: If request fails
|
||||
CircuitBreakerOpenError: If circuit breaker is open
|
||||
RateLimitError: If rate limit is exceeded
|
||||
"""
|
||||
url = self._build_url(endpoint)
|
||||
|
||||
# Check circuit breaker and rate limit
|
||||
self._check_circuit_breaker()
|
||||
self._check_rate_limit()
|
||||
|
||||
req_headers = {**self.headers, **(headers or {})}
|
||||
|
||||
try:
|
||||
if self.enable_logging:
|
||||
self.logger.info(f"PUT {url} with json={json}")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
def _make_request():
|
||||
response = self.session.put(
|
||||
url,
|
||||
data=data,
|
||||
@@ -156,7 +361,23 @@ class AITBCHTTPClient:
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
try:
|
||||
result = self._retry_request(_make_request)
|
||||
|
||||
# Record success for circuit breaker
|
||||
self._failure_count = 0
|
||||
self._record_request()
|
||||
|
||||
if self.enable_logging:
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
self.logger.info(f"PUT {url} succeeded in {elapsed:.3f}s")
|
||||
|
||||
return result
|
||||
except (RetryError, CircuitBreakerOpenError, RateLimitError):
|
||||
raise
|
||||
except requests.RequestException as e:
|
||||
self._record_failure()
|
||||
raise NetworkError(f"PUT request failed: {e}")
|
||||
|
||||
def delete(
|
||||
@@ -178,11 +399,23 @@ class AITBCHTTPClient:
|
||||
|
||||
Raises:
|
||||
NetworkError: If request fails
|
||||
CircuitBreakerOpenError: If circuit breaker is open
|
||||
RateLimitError: If rate limit is exceeded
|
||||
"""
|
||||
url = self._build_url(endpoint)
|
||||
|
||||
# Check circuit breaker and rate limit
|
||||
self._check_circuit_breaker()
|
||||
self._check_rate_limit()
|
||||
|
||||
req_headers = {**self.headers, **(headers or {})}
|
||||
|
||||
try:
|
||||
if self.enable_logging:
|
||||
self.logger.info(f"DELETE {url} with params={params}")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
def _make_request():
|
||||
response = self.session.delete(
|
||||
url,
|
||||
params=params,
|
||||
@@ -191,7 +424,23 @@ class AITBCHTTPClient:
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json() if response.content else {}
|
||||
|
||||
try:
|
||||
result = self._retry_request(_make_request)
|
||||
|
||||
# Record success for circuit breaker
|
||||
self._failure_count = 0
|
||||
self._record_request()
|
||||
|
||||
if self.enable_logging:
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
self.logger.info(f"DELETE {url} succeeded in {elapsed:.3f}s")
|
||||
|
||||
return result
|
||||
except (RetryError, CircuitBreakerOpenError, RateLimitError):
|
||||
raise
|
||||
except requests.RequestException as e:
|
||||
self._record_failure()
|
||||
raise NetworkError(f"DELETE request failed: {e}")
|
||||
|
||||
def close(self) -> None:
|
||||
@@ -205,3 +454,279 @@ class AITBCHTTPClient:
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
|
||||
class AsyncAITBCHTTPClient:
|
||||
"""
|
||||
Async HTTP client for AITBC applications.
|
||||
Provides async HTTP methods with error handling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "",
|
||||
timeout: int = 30,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
max_retries: int = 3,
|
||||
enable_cache: bool = False,
|
||||
cache_ttl: int = 300,
|
||||
enable_logging: bool = False,
|
||||
circuit_breaker_threshold: int = 5,
|
||||
rate_limit: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
Initialize async HTTP client.
|
||||
|
||||
Args:
|
||||
base_url: Base URL for all requests
|
||||
timeout: Request timeout in seconds
|
||||
headers: Default headers for all requests
|
||||
max_retries: Maximum retry attempts with exponential backoff
|
||||
enable_cache: Enable request/response caching for GET requests
|
||||
cache_ttl: Cache time-to-live in seconds
|
||||
enable_logging: Enable request/response logging
|
||||
circuit_breaker_threshold: Failures before opening circuit breaker
|
||||
rate_limit: Rate limit in requests per minute
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.headers = headers or {}
|
||||
self.max_retries = max_retries
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
self.enable_logging = enable_logging
|
||||
self.circuit_breaker_threshold = circuit_breaker_threshold
|
||||
self.rate_limit = rate_limit
|
||||
|
||||
self.logger = get_logger(__name__)
|
||||
self._client = None
|
||||
|
||||
# Cache storage: {url: (data, timestamp)}
|
||||
self._cache: Dict[str, tuple] = {}
|
||||
|
||||
# Circuit breaker state
|
||||
self._failure_count = 0
|
||||
self._circuit_open = False
|
||||
self._circuit_open_time = None
|
||||
|
||||
# Rate limiting state
|
||||
self._request_times: list = []
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
import httpx
|
||||
self._client = httpx.AsyncClient(timeout=self.timeout, headers=self.headers)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
|
||||
def _build_url(self, endpoint: str) -> str:
|
||||
"""Build full URL from base URL and endpoint."""
|
||||
if endpoint.startswith("http://") or endpoint.startswith("https://"):
|
||||
return endpoint
|
||||
return f"{self.base_url}/{endpoint.lstrip('/')}"
|
||||
|
||||
def _check_circuit_breaker(self) -> None:
|
||||
"""Check if circuit breaker is open and raise exception if so."""
|
||||
if self._circuit_open:
|
||||
if self._circuit_open_time and (datetime.now() - self._circuit_open_time).total_seconds() > 60:
|
||||
self._circuit_open = False
|
||||
self._failure_count = 0
|
||||
self.logger.info("Circuit breaker reset to half-open state")
|
||||
else:
|
||||
raise CircuitBreakerOpenError("Circuit breaker is open, rejecting request")
|
||||
|
||||
def _record_failure(self) -> None:
|
||||
"""Record a failure and potentially open circuit breaker."""
|
||||
self._failure_count += 1
|
||||
if self._failure_count >= self.circuit_breaker_threshold:
|
||||
self._circuit_open = True
|
||||
self._circuit_open_time = datetime.now()
|
||||
self.logger.warning(f"Circuit breaker opened after {self._failure_count} failures")
|
||||
|
||||
def _check_rate_limit(self) -> None:
|
||||
"""Check if rate limit is exceeded and raise exception if so."""
|
||||
if not self.rate_limit:
|
||||
return
|
||||
|
||||
now = datetime.now()
|
||||
self._request_times = [t for t in self._request_times if (now - t).total_seconds() < 60]
|
||||
|
||||
if len(self._request_times) >= self.rate_limit:
|
||||
raise RateLimitError(f"Rate limit exceeded: {self.rate_limit} requests per minute")
|
||||
|
||||
def _record_request(self) -> None:
|
||||
"""Record a request timestamp for rate limiting."""
|
||||
if self.rate_limit:
|
||||
self._request_times.append(datetime.now())
|
||||
|
||||
def _get_cache_key(self, url: str, params: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""Generate cache key from URL and params."""
|
||||
if params:
|
||||
import hashlib
|
||||
param_str = str(sorted(params.items()))
|
||||
return f"{url}:{hashlib.md5(param_str.encode()).hexdigest()}"
|
||||
return url
|
||||
|
||||
def _get_cache(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached response if available and not expired."""
|
||||
if not self.enable_cache:
|
||||
return None
|
||||
|
||||
if cache_key in self._cache:
|
||||
data, timestamp = self._cache[cache_key]
|
||||
if (datetime.now() - timestamp).total_seconds() < self.cache_ttl:
|
||||
if self.enable_logging:
|
||||
self.logger.info(f"Cache hit for {cache_key}")
|
||||
return data
|
||||
else:
|
||||
del self._cache[cache_key]
|
||||
return None
|
||||
|
||||
def _set_cache(self, cache_key: str, data: Dict[str, Any]) -> None:
|
||||
"""Cache response data."""
|
||||
if self.enable_cache:
|
||||
self._cache[cache_key] = (data, datetime.now())
|
||||
if self.enable_logging:
|
||||
self.logger.info(f"Cached response for {cache_key}")
|
||||
|
||||
async def _retry_request(self, request_func, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""Execute async request with retry logic and exponential backoff."""
|
||||
last_error = None
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
if attempt > 0:
|
||||
backoff_time = 2 ** (attempt - 1)
|
||||
if self.enable_logging:
|
||||
self.logger.info(f"Retry attempt {attempt}/{self.max_retries} after {backoff_time}s backoff")
|
||||
await asyncio.sleep(backoff_time)
|
||||
|
||||
return await request_func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < self.max_retries:
|
||||
if self.enable_logging:
|
||||
self.logger.warning(f"Request failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}")
|
||||
continue
|
||||
else:
|
||||
if self.enable_logging:
|
||||
self.logger.error(f"All retry attempts exhausted: {e}")
|
||||
raise RetryError(f"Retry attempts exhausted: {e}")
|
||||
|
||||
raise NetworkError(f"Request failed: {last_error}")
|
||||
|
||||
async def async_get(
|
||||
self,
|
||||
endpoint: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform async GET request.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint
|
||||
params: Query parameters
|
||||
headers: Additional headers
|
||||
|
||||
Returns:
|
||||
Response data as dictionary
|
||||
"""
|
||||
if not self._client:
|
||||
raise RuntimeError("Async client not initialized. Use async context manager.")
|
||||
|
||||
url = self._build_url(endpoint)
|
||||
cache_key = self._get_cache_key(url, params)
|
||||
|
||||
cached_data = self._get_cache(cache_key)
|
||||
if cached_data is not None:
|
||||
return cached_data
|
||||
|
||||
self._check_circuit_breaker()
|
||||
self._check_rate_limit()
|
||||
|
||||
req_headers = {**self.headers, **(headers or {})}
|
||||
|
||||
if self.enable_logging:
|
||||
self.logger.info(f"ASYNC GET {url} with params={params}")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
async def _make_request():
|
||||
response = await self._client.get(url, params=params, headers=req_headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
try:
|
||||
result = await self._retry_request(_make_request)
|
||||
self._set_cache(cache_key, result)
|
||||
self._failure_count = 0
|
||||
self._record_request()
|
||||
|
||||
if self.enable_logging:
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
self.logger.info(f"ASYNC GET {url} succeeded in {elapsed:.3f}s")
|
||||
|
||||
return result
|
||||
except (RetryError, CircuitBreakerOpenError, RateLimitError):
|
||||
raise
|
||||
except Exception as e:
|
||||
self._record_failure()
|
||||
raise NetworkError(f"ASYNC GET request failed: {e}")
|
||||
|
||||
async def async_post(
|
||||
self,
|
||||
endpoint: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform async POST request.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint
|
||||
data: Form data
|
||||
json: JSON data
|
||||
headers: Additional headers
|
||||
|
||||
Returns:
|
||||
Response data as dictionary
|
||||
"""
|
||||
if not self._client:
|
||||
raise RuntimeError("Async client not initialized. Use async context manager.")
|
||||
|
||||
url = self._build_url(endpoint)
|
||||
self._check_circuit_breaker()
|
||||
self._check_rate_limit()
|
||||
|
||||
req_headers = {**self.headers, **(headers or {})}
|
||||
|
||||
if self.enable_logging:
|
||||
self.logger.info(f"ASYNC POST {url} with json={json}")
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
async def _make_request():
|
||||
response = await self._client.post(url, data=data, json=json, headers=req_headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
try:
|
||||
result = await self._retry_request(_make_request)
|
||||
self._failure_count = 0
|
||||
self._record_request()
|
||||
|
||||
if self.enable_logging:
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
self.logger.info(f"ASYNC POST {url} succeeded in {elapsed:.3f}s")
|
||||
|
||||
return result
|
||||
except (RetryError, CircuitBreakerOpenError, RateLimitError):
|
||||
raise
|
||||
except Exception as e:
|
||||
self._record_failure()
|
||||
raise NetworkError(f"ASYNC POST request failed: {e}")
|
||||
|
||||
431
aitbc/queue.py
Normal file
431
aitbc/queue.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""
|
||||
Queue utilities for AITBC
|
||||
Provides task queue helpers, job scheduling, and background task management
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import heapq
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class JobStatus(Enum):
|
||||
"""Job status enumeration"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class JobPriority(Enum):
|
||||
"""Job priority levels"""
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
CRITICAL = 4
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class Job:
|
||||
"""Background job"""
|
||||
priority: int
|
||||
job_id: str = field(compare=False)
|
||||
func: Callable = field(compare=False)
|
||||
args: tuple = field(default_factory=tuple, compare=False)
|
||||
kwargs: dict = field(default_factory=dict, compare=False)
|
||||
status: JobStatus = field(default=JobStatus.PENDING, compare=False)
|
||||
created_at: datetime = field(default_factory=datetime.utcnow, compare=False)
|
||||
started_at: Optional[datetime] = field(default=None, compare=False)
|
||||
completed_at: Optional[datetime] = field(default=None, compare=False)
|
||||
result: Any = field(default=None, compare=False)
|
||||
error: Optional[str] = field(default=None, compare=False)
|
||||
retry_count: int = field(default=0, compare=False)
|
||||
max_retries: int = field(default=3, compare=False)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.job_id is None:
|
||||
self.job_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
class TaskQueue:
|
||||
"""Priority-based task queue"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize task queue"""
|
||||
self.queue: List[Job] = []
|
||||
self.jobs: Dict[str, Job] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
func: Callable,
|
||||
args: tuple = (),
|
||||
kwargs: dict = None,
|
||||
priority: JobPriority = JobPriority.MEDIUM,
|
||||
max_retries: int = 3
|
||||
) -> str:
|
||||
"""Enqueue a task"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
job = Job(
|
||||
priority=priority.value,
|
||||
func=func,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
async with self.lock:
|
||||
heapq.heappush(self.queue, job)
|
||||
self.jobs[job.job_id] = job
|
||||
|
||||
return job.job_id
|
||||
|
||||
async def dequeue(self) -> Optional[Job]:
|
||||
"""Dequeue a task"""
|
||||
async with self.lock:
|
||||
if not self.queue:
|
||||
return None
|
||||
|
||||
job = heapq.heappop(self.queue)
|
||||
return job
|
||||
|
||||
async def get_job(self, job_id: str) -> Optional[Job]:
|
||||
"""Get job by ID"""
|
||||
return self.jobs.get(job_id)
|
||||
|
||||
async def cancel_job(self, job_id: str) -> bool:
|
||||
"""Cancel a job"""
|
||||
async with self.lock:
|
||||
job = self.jobs.get(job_id)
|
||||
if job and job.status == JobStatus.PENDING:
|
||||
job.status = JobStatus.CANCELLED
|
||||
# Remove from queue
|
||||
self.queue = [j for j in self.queue if j.job_id != job_id]
|
||||
heapq.heapify(self.queue)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_queue_size(self) -> int:
|
||||
"""Get queue size"""
|
||||
return len(self.queue)
|
||||
|
||||
async def get_jobs_by_status(self, status: JobStatus) -> List[Job]:
|
||||
"""Get jobs by status"""
|
||||
return [job for job in self.jobs.values() if job.status == status]
|
||||
|
||||
|
||||
class JobScheduler:
|
||||
"""Job scheduler for delayed and recurring tasks"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize job scheduler"""
|
||||
self.scheduled_jobs: Dict[str, Dict[str, Any]] = {}
|
||||
self.running = False
|
||||
self.task: Optional[asyncio.Task] = None
|
||||
|
||||
async def schedule(
|
||||
self,
|
||||
func: Callable,
|
||||
delay: float = 0,
|
||||
interval: Optional[float] = None,
|
||||
job_id: Optional[str] = None,
|
||||
args: tuple = (),
|
||||
kwargs: dict = None
|
||||
) -> str:
|
||||
"""Schedule a job"""
|
||||
if job_id is None:
|
||||
job_id = str(uuid.uuid4())
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
run_at = time.time() + delay
|
||||
|
||||
self.scheduled_jobs[job_id] = {
|
||||
"func": func,
|
||||
"args": args,
|
||||
"kwargs": kwargs,
|
||||
"run_at": run_at,
|
||||
"interval": interval,
|
||||
"job_id": job_id
|
||||
}
|
||||
|
||||
return job_id
|
||||
|
||||
async def cancel_scheduled_job(self, job_id: str) -> bool:
|
||||
"""Cancel a scheduled job"""
|
||||
if job_id in self.scheduled_jobs:
|
||||
del self.scheduled_jobs[job_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the scheduler"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.task = asyncio.create_task(self._run_scheduler())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the scheduler"""
|
||||
self.running = False
|
||||
if self.task:
|
||||
self.task.cancel()
|
||||
try:
|
||||
await self.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _run_scheduler(self) -> None:
|
||||
"""Run the scheduler loop"""
|
||||
while self.running:
|
||||
now = time.time()
|
||||
to_run = []
|
||||
|
||||
for job_id, job in list(self.scheduled_jobs.items()):
|
||||
if job["run_at"] <= now:
|
||||
to_run.append(job)
|
||||
|
||||
for job in to_run:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(job["func"]):
|
||||
await job["func"](*job["args"], **job["kwargs"])
|
||||
else:
|
||||
job["func"](*job["args"], **job["kwargs"])
|
||||
|
||||
if job["interval"]:
|
||||
job["run_at"] = now + job["interval"]
|
||||
else:
|
||||
del self.scheduled_jobs[job["job_id"]]
|
||||
except Exception as e:
|
||||
print(f"Error running scheduled job {job['job_id']}: {e}")
|
||||
if not job["interval"]:
|
||||
del self.scheduled_jobs[job["job_id"]]
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
class BackgroundTaskManager:
|
||||
"""Manage background tasks"""
|
||||
|
||||
def __init__(self, max_concurrent_tasks: int = 10):
|
||||
"""Initialize background task manager"""
|
||||
self.max_concurrent_tasks = max_concurrent_tasks
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
||||
self.tasks: Dict[str, asyncio.Task] = {}
|
||||
self.task_info: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
async def run_task(
|
||||
self,
|
||||
func: Callable,
|
||||
task_id: Optional[str] = None,
|
||||
args: tuple = (),
|
||||
kwargs: dict = None
|
||||
) -> str:
|
||||
"""Run a background task"""
|
||||
if task_id is None:
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
async def wrapped_task():
|
||||
async with self.semaphore:
|
||||
try:
|
||||
self.task_info[task_id]["status"] = "running"
|
||||
self.task_info[task_id]["started_at"] = datetime.utcnow()
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
result = await func(*args, **kwargs)
|
||||
else:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
self.task_info[task_id]["status"] = "completed"
|
||||
self.task_info[task_id]["result"] = result
|
||||
self.task_info[task_id]["completed_at"] = datetime.utcnow()
|
||||
except Exception as e:
|
||||
self.task_info[task_id]["status"] = "failed"
|
||||
self.task_info[task_id]["error"] = str(e)
|
||||
self.task_info[task_id]["completed_at"] = datetime.utcnow()
|
||||
finally:
|
||||
if task_id in self.tasks:
|
||||
del self.tasks[task_id]
|
||||
|
||||
self.task_info[task_id] = {
|
||||
"status": "pending",
|
||||
"created_at": datetime.utcnow(),
|
||||
"started_at": None,
|
||||
"completed_at": None,
|
||||
"result": None,
|
||||
"error": None
|
||||
}
|
||||
|
||||
task = asyncio.create_task(wrapped_task())
|
||||
self.tasks[task_id] = task
|
||||
|
||||
return task_id
|
||||
|
||||
async def cancel_task(self, task_id: str) -> bool:
|
||||
"""Cancel a background task"""
|
||||
if task_id in self.tasks:
|
||||
self.tasks[task_id].cancel()
|
||||
try:
|
||||
await self.tasks[task_id]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.task_info[task_id]["status"] = "cancelled"
|
||||
self.task_info[task_id]["completed_at"] = datetime.utcnow()
|
||||
del self.tasks[task_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get task status"""
|
||||
return self.task_info.get(task_id)
|
||||
|
||||
async def get_all_tasks(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all tasks"""
|
||||
return self.task_info.copy()
|
||||
|
||||
async def wait_for_task(self, task_id: str, timeout: Optional[float] = None) -> Any:
|
||||
"""Wait for task completion"""
|
||||
if task_id not in self.tasks:
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self.tasks[task_id], timeout)
|
||||
except asyncio.TimeoutError:
|
||||
await self.cancel_task(task_id)
|
||||
raise TimeoutError(f"Task {task_id} timed out")
|
||||
|
||||
info = self.task_info.get(task_id)
|
||||
if info["status"] == "failed":
|
||||
raise Exception(info["error"])
|
||||
|
||||
return info["result"]
|
||||
|
||||
|
||||
class WorkerPool:
|
||||
"""Worker pool for parallel task execution"""
|
||||
|
||||
def __init__(self, num_workers: int = 4):
|
||||
"""Initialize worker pool"""
|
||||
self.num_workers = num_workers
|
||||
self.queue: asyncio.Queue = asyncio.Queue()
|
||||
self.workers: List[asyncio.Task] = []
|
||||
self.running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start worker pool"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
for i in range(self.num_workers):
|
||||
worker = asyncio.create_task(self._worker(i))
|
||||
self.workers.append(worker)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop worker pool"""
|
||||
self.running = False
|
||||
|
||||
# Cancel all workers
|
||||
for worker in self.workers:
|
||||
worker.cancel()
|
||||
|
||||
# Wait for workers to finish
|
||||
await asyncio.gather(*self.workers, return_exceptions=True)
|
||||
self.workers.clear()
|
||||
|
||||
async def submit(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""Submit task to worker pool"""
|
||||
future = asyncio.Future()
|
||||
await self.queue.put((func, args, kwargs, future))
|
||||
return await future
|
||||
|
||||
async def _worker(self, worker_id: int) -> None:
|
||||
"""Worker coroutine"""
|
||||
while self.running:
|
||||
try:
|
||||
func, args, kwargs, future = await self.queue.get()
|
||||
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
result = await func(*args, **kwargs)
|
||||
else:
|
||||
result = func(*args, **kwargs)
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
finally:
|
||||
self.queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Worker {worker_id} error: {e}")
|
||||
|
||||
async def get_queue_size(self) -> int:
|
||||
"""Get queue size"""
|
||||
return self.queue.qsize()
|
||||
|
||||
|
||||
def debounce(delay: float = 0.5):
|
||||
"""Decorator to debounce function calls"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
last_called = [0]
|
||||
timer = [None]
|
||||
|
||||
async def wrapped(*args, **kwargs):
|
||||
async def call():
|
||||
await asyncio.sleep(delay)
|
||||
if asyncio.get_event_loop().time() - last_called[0] >= delay:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
last_called[0] = asyncio.get_event_loop().time()
|
||||
if timer[0]:
|
||||
timer[0].cancel()
|
||||
|
||||
timer[0] = asyncio.create_task(call())
|
||||
return await timer[0]
|
||||
|
||||
return wrapped
|
||||
return decorator
|
||||
|
||||
|
||||
def throttle(calls_per_second: float = 1.0):
|
||||
"""Decorator to throttle function calls"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
min_interval = 1.0 / calls_per_second
|
||||
last_called = [0]
|
||||
|
||||
async def wrapped(*args, **kwargs):
|
||||
now = asyncio.get_event_loop().time()
|
||||
elapsed = now - last_called[0]
|
||||
|
||||
if elapsed < min_interval:
|
||||
await asyncio.sleep(min_interval - elapsed)
|
||||
|
||||
last_called[0] = asyncio.get_event_loop().time()
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
return decorator
|
||||
282
aitbc/security.py
Normal file
282
aitbc/security.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Security utilities for AITBC
|
||||
Provides token generation, session management, API key management, and secret management
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import hashlib
|
||||
import time
|
||||
import json
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
def generate_token(length: int = 32, prefix: str = "") -> str:
|
||||
"""Generate a secure random token"""
|
||||
token = secrets.token_urlsafe(length)
|
||||
return f"{prefix}{token}" if prefix else token
|
||||
|
||||
|
||||
def generate_api_key(prefix: str = "aitbc") -> str:
|
||||
"""Generate a secure API key with prefix"""
|
||||
random_part = secrets.token_urlsafe(32)
|
||||
return f"{prefix}_{random_part}"
|
||||
|
||||
|
||||
def validate_token_format(token: str, min_length: int = 16) -> bool:
|
||||
"""Validate token format"""
|
||||
return bool(token) and len(token) >= min_length and all(c.isalnum() or c in '-_' for c in token)
|
||||
|
||||
|
||||
def validate_api_key(api_key: str, prefix: str = "aitbc") -> bool:
|
||||
"""Validate API key format"""
|
||||
if not api_key or not api_key.startswith(f"{prefix}_"):
|
||||
return False
|
||||
token_part = api_key[len(prefix)+1:]
|
||||
return validate_token_format(token_part)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Simple in-memory session manager"""
|
||||
|
||||
def __init__(self, session_timeout: int = 3600):
|
||||
"""Initialize session manager with timeout in seconds"""
|
||||
self.sessions: Dict[str, Dict[str, Any]] = {}
|
||||
self.session_timeout = session_timeout
|
||||
|
||||
def create_session(self, user_id: str, data: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""Create a new session"""
|
||||
session_id = generate_token()
|
||||
self.sessions[session_id] = {
|
||||
"user_id": user_id,
|
||||
"data": data or {},
|
||||
"created_at": time.time(),
|
||||
"expires_at": time.time() + self.session_timeout
|
||||
}
|
||||
return session_id
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get session data"""
|
||||
session = self.sessions.get(session_id)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
# Check if session expired
|
||||
if time.time() > session["expires_at"]:
|
||||
del self.sessions[session_id]
|
||||
return None
|
||||
|
||||
return session
|
||||
|
||||
def update_session(self, session_id: str, data: Dict[str, Any]) -> bool:
|
||||
"""Update session data"""
|
||||
session = self.get_session(session_id)
|
||||
if not session:
|
||||
return False
|
||||
|
||||
session["data"].update(data)
|
||||
return True
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete a session"""
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def cleanup_expired_sessions(self) -> int:
|
||||
"""Clean up expired sessions"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key for key, session in self.sessions.items()
|
||||
if current_time > session["expires_at"]
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self.sessions[key]
|
||||
|
||||
return len(expired_keys)
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
"""API key management with storage"""
|
||||
|
||||
def __init__(self, storage_path: Optional[str] = None):
|
||||
"""Initialize API key manager"""
|
||||
self.storage_path = storage_path
|
||||
self.keys: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
if storage_path:
|
||||
self._load_keys()
|
||||
|
||||
def create_api_key(self, user_id: str, scopes: Optional[list[str]] = None, name: Optional[str] = None) -> str:
|
||||
"""Create a new API key"""
|
||||
api_key = generate_api_key()
|
||||
self.keys[api_key] = {
|
||||
"user_id": user_id,
|
||||
"scopes": scopes or ["read"],
|
||||
"name": name,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"last_used": None
|
||||
}
|
||||
|
||||
if self.storage_path:
|
||||
self._save_keys()
|
||||
|
||||
return api_key
|
||||
|
||||
def validate_api_key(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Validate API key and return key data"""
|
||||
key_data = self.keys.get(api_key)
|
||||
if not key_data:
|
||||
return None
|
||||
|
||||
# Update last used
|
||||
key_data["last_used"] = datetime.utcnow().isoformat()
|
||||
if self.storage_path:
|
||||
self._save_keys()
|
||||
|
||||
return key_data
|
||||
|
||||
def revoke_api_key(self, api_key: str) -> bool:
|
||||
"""Revoke an API key"""
|
||||
if api_key in self.keys:
|
||||
del self.keys[api_key]
|
||||
if self.storage_path:
|
||||
self._save_keys()
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_user_keys(self, user_id: str) -> list[str]:
|
||||
"""List all API keys for a user"""
|
||||
return [
|
||||
key for key, data in self.items()
|
||||
if data["user_id"] == user_id
|
||||
]
|
||||
|
||||
def _load_keys(self):
|
||||
"""Load keys from storage"""
|
||||
if self.storage_path and os.path.exists(self.storage_path):
|
||||
try:
|
||||
with open(self.storage_path, 'r') as f:
|
||||
self.keys = json.load(f)
|
||||
except Exception:
|
||||
self.keys = {}
|
||||
|
||||
def _save_keys(self):
|
||||
"""Save keys to storage"""
|
||||
if self.storage_path:
|
||||
try:
|
||||
with open(self.storage_path, 'w') as f:
|
||||
json.dump(self.keys, f)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def items(self):
|
||||
"""Return key items"""
|
||||
return self.keys.items()
|
||||
|
||||
|
||||
def generate_secure_random_string(length: int = 32) -> str:
|
||||
"""Generate a cryptographically secure random string"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
|
||||
def generate_secure_random_int(min_val: int = 0, max_val: int = 2**32) -> int:
|
||||
"""Generate a cryptographically secure random integer"""
|
||||
return secrets.randbelow(max_val - min_val) + min_val
|
||||
|
||||
|
||||
class SecretManager:
|
||||
"""Simple secret management with encryption"""
|
||||
|
||||
def __init__(self, encryption_key: Optional[str] = None):
|
||||
"""Initialize secret manager"""
|
||||
if encryption_key:
|
||||
self.fernet = Fernet(encryption_key)
|
||||
else:
|
||||
# Generate a new key if none provided
|
||||
self.fernet = Fernet(Fernet.generate_key())
|
||||
|
||||
self.secrets: Dict[str, str] = {}
|
||||
|
||||
def set_secret(self, key: str, value: str) -> None:
|
||||
"""Store an encrypted secret"""
|
||||
encrypted = self.fernet.encrypt(value.encode('utf-8'))
|
||||
self.secrets[key] = encrypted.decode('utf-8')
|
||||
|
||||
def get_secret(self, key: str) -> Optional[str]:
|
||||
"""Retrieve and decrypt a secret"""
|
||||
encrypted = self.secrets.get(key)
|
||||
if not encrypted:
|
||||
return None
|
||||
|
||||
try:
|
||||
decrypted = self.fernet.decrypt(encrypted.encode('utf-8'))
|
||||
return decrypted.decode('utf-8')
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_secret(self, key: str) -> bool:
|
||||
"""Delete a secret"""
|
||||
if key in self.secrets:
|
||||
del self.secrets[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_secrets(self) -> list[str]:
|
||||
"""List all secret keys"""
|
||||
return list(self.secrets.keys())
|
||||
|
||||
def get_encryption_key(self) -> str:
|
||||
"""Get the encryption key (for backup purposes)"""
|
||||
return self.fernet._signing_key.decode('utf-8')
|
||||
|
||||
|
||||
def hash_password(password: str, salt: Optional[str] = None) -> tuple[str, str]:
|
||||
"""Hash a password with salt"""
|
||||
if salt is None:
|
||||
salt = secrets.token_hex(16)
|
||||
|
||||
# Use PBKDF2 for password hashing
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
import base64
|
||||
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt.encode('utf-8'),
|
||||
iterations=100000,
|
||||
)
|
||||
hashed = kdf.derive(password.encode('utf-8'))
|
||||
return base64.b64encode(hashed).decode('utf-8'), salt
|
||||
|
||||
|
||||
def verify_password(password: str, hashed_password: str, salt: str) -> bool:
|
||||
"""Verify a password against a hash"""
|
||||
new_hash, _ = hash_password(password, salt)
|
||||
return new_hash == hashed_password
|
||||
|
||||
|
||||
def generate_nonce(length: int = 16) -> str:
|
||||
"""Generate a nonce for cryptographic operations"""
|
||||
return secrets.token_hex(length)
|
||||
|
||||
|
||||
def generate_hmac(data: str, secret: str) -> str:
|
||||
"""Generate HMAC-SHA256 signature"""
|
||||
import hmac
|
||||
return hmac.new(
|
||||
secret.encode('utf-8'),
|
||||
data.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
|
||||
def verify_hmac(data: str, signature: str, secret: str) -> bool:
|
||||
"""Verify HMAC-SHA256 signature"""
|
||||
computed = generate_hmac(data, secret)
|
||||
return secrets.compare_digest(computed, signature)
|
||||
348
aitbc/state.py
Normal file
348
aitbc/state.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
State management utilities for AITBC
|
||||
Provides state machine base classes, state persistence, and state transition helpers
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional, TypeVar, Generic, List
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class StateTransitionError(Exception):
|
||||
"""Raised when invalid state transition is attempted"""
|
||||
pass
|
||||
|
||||
|
||||
class StatePersistenceError(Exception):
|
||||
"""Raised when state persistence fails"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateTransition:
|
||||
"""Record of a state transition"""
|
||||
from_state: str
|
||||
to_state: str
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
data: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class StateMachine(ABC):
|
||||
"""Base class for state machines"""
|
||||
|
||||
def __init__(self, initial_state: str):
|
||||
"""Initialize state machine"""
|
||||
self.current_state = initial_state
|
||||
self.transitions: List[StateTransition] = []
|
||||
self.state_data: Dict[str, Dict[str, Any]] = {initial_state: {}}
|
||||
|
||||
@abstractmethod
|
||||
def get_valid_transitions(self, state: str) -> List[str]:
|
||||
"""Get valid transitions from a state"""
|
||||
pass
|
||||
|
||||
def can_transition(self, to_state: str) -> bool:
|
||||
"""Check if transition is valid"""
|
||||
return to_state in self.get_valid_transitions(self.current_state)
|
||||
|
||||
def transition(self, to_state: str, data: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Transition to a new state"""
|
||||
if not self.can_transition(to_state):
|
||||
raise StateTransitionError(
|
||||
f"Invalid transition from {self.current_state} to {to_state}"
|
||||
)
|
||||
|
||||
from_state = self.current_state
|
||||
self.current_state = to_state
|
||||
|
||||
# Record transition
|
||||
transition = StateTransition(
|
||||
from_state=from_state,
|
||||
to_state=to_state,
|
||||
data=data or {}
|
||||
)
|
||||
self.transitions.append(transition)
|
||||
|
||||
# Initialize state data if needed
|
||||
if to_state not in self.state_data:
|
||||
self.state_data[to_state] = {}
|
||||
|
||||
def get_state_data(self, state: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get data for a state"""
|
||||
state = state or self.current_state
|
||||
return self.state_data.get(state, {}).copy()
|
||||
|
||||
def set_state_data(self, data: Dict[str, Any], state: Optional[str] = None) -> None:
|
||||
"""Set data for a state"""
|
||||
state = state or self.current_state
|
||||
if state not in self.state_data:
|
||||
self.state_data[state] = {}
|
||||
self.state_data[state].update(data)
|
||||
|
||||
def get_transition_history(self, limit: Optional[int] = None) -> List[StateTransition]:
|
||||
"""Get transition history"""
|
||||
if limit:
|
||||
return self.transitions[-limit:]
|
||||
return self.transitions.copy()
|
||||
|
||||
def reset(self, initial_state: str) -> None:
|
||||
"""Reset state machine to initial state"""
|
||||
self.current_state = initial_state
|
||||
self.transitions.clear()
|
||||
self.state_data = {initial_state: {}}
|
||||
|
||||
|
||||
class ConfigurableStateMachine(StateMachine):
|
||||
"""State machine with configurable transitions"""
|
||||
|
||||
def __init__(self, initial_state: str, transitions: Dict[str, List[str]]):
|
||||
"""Initialize configurable state machine"""
|
||||
super().__init__(initial_state)
|
||||
self.transitions_config = transitions
|
||||
|
||||
def get_valid_transitions(self, state: str) -> List[str]:
|
||||
"""Get valid transitions from configuration"""
|
||||
return self.transitions_config.get(state, [])
|
||||
|
||||
def add_transition(self, from_state: str, to_state: str) -> None:
|
||||
"""Add a transition to configuration"""
|
||||
if from_state not in self.transitions_config:
|
||||
self.transitions_config[from_state] = []
|
||||
if to_state not in self.transitions_config[from_state]:
|
||||
self.transitions_config[from_state].append(to_state)
|
||||
|
||||
|
||||
class StatePersistence:
|
||||
"""State persistence to file"""
|
||||
|
||||
def __init__(self, storage_path: str):
|
||||
"""Initialize state persistence"""
|
||||
self.storage_path = storage_path
|
||||
self._ensure_storage_dir()
|
||||
|
||||
def _ensure_storage_dir(self) -> None:
|
||||
"""Ensure storage directory exists"""
|
||||
os.makedirs(os.path.dirname(self.storage_path), exist_ok=True)
|
||||
|
||||
def save_state(self, state_machine: StateMachine) -> None:
|
||||
"""Save state machine to file"""
|
||||
try:
|
||||
state_data = {
|
||||
"current_state": state_machine.current_state,
|
||||
"state_data": state_machine.state_data,
|
||||
"transitions": [
|
||||
{
|
||||
"from_state": t.from_state,
|
||||
"to_state": t.to_state,
|
||||
"timestamp": t.timestamp.isoformat(),
|
||||
"data": t.data
|
||||
}
|
||||
for t in state_machine.transitions
|
||||
]
|
||||
}
|
||||
|
||||
with open(self.storage_path, 'w') as f:
|
||||
json.dump(state_data, f, indent=2)
|
||||
except Exception as e:
|
||||
raise StatePersistenceError(f"Failed to save state: {e}")
|
||||
|
||||
def load_state(self) -> Optional[Dict[str, Any]]:
|
||||
"""Load state from file"""
|
||||
try:
|
||||
if not os.path.exists(self.storage_path):
|
||||
return None
|
||||
|
||||
with open(self.storage_path, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
raise StatePersistenceError(f"Failed to load state: {e}")
|
||||
|
||||
def delete_state(self) -> None:
|
||||
"""Delete persisted state"""
|
||||
try:
|
||||
if os.path.exists(self.storage_path):
|
||||
os.remove(self.storage_path)
|
||||
except Exception as e:
|
||||
raise StatePersistenceError(f"Failed to delete state: {e}")
|
||||
|
||||
|
||||
class AsyncStateMachine(StateMachine):
|
||||
"""Async state machine with async transition handlers"""
|
||||
|
||||
def __init__(self, initial_state: str):
|
||||
"""Initialize async state machine"""
|
||||
super().__init__(initial_state)
|
||||
self.transition_handlers: Dict[str, Callable] = {}
|
||||
|
||||
def on_transition(self, to_state: str, handler: Callable) -> None:
|
||||
"""Register a handler for transition to a state"""
|
||||
self.transition_handlers[to_state] = handler
|
||||
|
||||
async def transition_async(self, to_state: str, data: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Async transition to a new state"""
|
||||
if not self.can_transition(to_state):
|
||||
raise StateTransitionError(
|
||||
f"Invalid transition from {self.current_state} to {to_state}"
|
||||
)
|
||||
|
||||
from_state = self.current_state
|
||||
self.current_state = to_state
|
||||
|
||||
# Record transition
|
||||
transition = StateTransition(
|
||||
from_state=from_state,
|
||||
to_state=to_state,
|
||||
data=data or {}
|
||||
)
|
||||
self.transitions.append(transition)
|
||||
|
||||
# Initialize state data if needed
|
||||
if to_state not in self.state_data:
|
||||
self.state_data[to_state] = {}
|
||||
|
||||
# Call transition handler if exists
|
||||
if to_state in self.transition_handlers:
|
||||
handler = self.transition_handlers[to_state]
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(transition)
|
||||
else:
|
||||
handler(transition)
|
||||
|
||||
|
||||
class StateMonitor:
|
||||
"""Monitor state machine state and transitions"""
|
||||
|
||||
def __init__(self, state_machine: StateMachine):
|
||||
"""Initialize state monitor"""
|
||||
self.state_machine = state_machine
|
||||
self.observers: List[Callable] = []
|
||||
|
||||
def add_observer(self, observer: Callable) -> None:
|
||||
"""Add an observer for state changes"""
|
||||
self.observers.append(observer)
|
||||
|
||||
def remove_observer(self, observer: Callable) -> bool:
|
||||
"""Remove an observer"""
|
||||
try:
|
||||
self.observers.remove(observer)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def notify_observers(self, transition: StateTransition) -> None:
|
||||
"""Notify all observers of state change"""
|
||||
for observer in self.observers:
|
||||
try:
|
||||
observer(transition)
|
||||
except Exception as e:
|
||||
print(f"Error in state observer: {e}")
|
||||
|
||||
def wrap_transition(self, original_transition: Callable) -> Callable:
|
||||
"""Wrap transition method to notify observers"""
|
||||
def wrapper(*args, **kwargs):
|
||||
result = original_transition(*args, **kwargs)
|
||||
# Get last transition
|
||||
if self.state_machine.transitions:
|
||||
self.notify_observers(self.state_machine.transitions[-1])
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
class StateValidator:
|
||||
"""Validate state machine configurations"""
|
||||
|
||||
@staticmethod
|
||||
def validate_transitions(transitions: Dict[str, List[str]]) -> bool:
|
||||
"""Validate that all target states exist"""
|
||||
all_states = set(transitions.keys())
|
||||
all_states.update(*transitions.values())
|
||||
|
||||
for from_state, to_states in transitions.items():
|
||||
for to_state in to_states:
|
||||
if to_state not in all_states:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def check_for_deadlocks(transitions: Dict[str, List[str]]) -> List[str]:
|
||||
"""Check for states with no outgoing transitions"""
|
||||
deadlocks = []
|
||||
for state, to_states in transitions.items():
|
||||
if not to_states:
|
||||
deadlocks.append(state)
|
||||
return deadlocks
|
||||
|
||||
@staticmethod
|
||||
def check_for_orphans(transitions: Dict[str, List[str]]) -> List[str]:
|
||||
"""Check for states with no incoming transitions"""
|
||||
incoming = set()
|
||||
for to_states in transitions.values():
|
||||
incoming.update(to_states)
|
||||
|
||||
orphans = []
|
||||
for state in transitions.keys():
|
||||
if state not in incoming:
|
||||
orphans.append(state)
|
||||
|
||||
return orphans
|
||||
|
||||
|
||||
class StateSnapshot:
|
||||
"""Snapshot of state machine state"""
|
||||
|
||||
def __init__(self, state_machine: StateMachine):
|
||||
"""Create snapshot"""
|
||||
self.current_state = state_machine.current_state
|
||||
self.state_data = state_machine.state_data.copy()
|
||||
self.transitions = state_machine.transitions.copy()
|
||||
self.timestamp = datetime.utcnow()
|
||||
|
||||
def restore(self, state_machine: StateMachine) -> None:
|
||||
"""Restore state machine from snapshot"""
|
||||
state_machine.current_state = self.current_state
|
||||
state_machine.state_data = self.state_data.copy()
|
||||
state_machine.transitions = self.transitions.copy()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert snapshot to dict"""
|
||||
return {
|
||||
"current_state": self.current_state,
|
||||
"state_data": self.state_data,
|
||||
"transitions": [
|
||||
{
|
||||
"from_state": t.from_state,
|
||||
"to_state": t.to_state,
|
||||
"timestamp": t.timestamp.isoformat(),
|
||||
"data": t.data
|
||||
}
|
||||
for t in self.transitions
|
||||
],
|
||||
"timestamp": self.timestamp.isoformat()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'StateSnapshot':
|
||||
"""Create snapshot from dict"""
|
||||
snapshot = cls.__new__(cls)
|
||||
snapshot.current_state = data["current_state"]
|
||||
snapshot.state_data = data["state_data"]
|
||||
snapshot.transitions = [
|
||||
StateTransition(
|
||||
from_state=t["from_state"],
|
||||
to_state=t["to_state"],
|
||||
timestamp=datetime.fromisoformat(t["timestamp"]),
|
||||
data=t["data"]
|
||||
)
|
||||
for t in data["transitions"]
|
||||
]
|
||||
snapshot.timestamp = datetime.fromisoformat(data["timestamp"])
|
||||
return snapshot
|
||||
401
aitbc/testing.py
Normal file
401
aitbc/testing.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
Testing utilities for AITBC
|
||||
Provides mock factories, test data generators, and test helpers
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Callable
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from decimal import Decimal
|
||||
import uuid
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class MockFactory:
|
||||
"""Factory for creating mock objects for testing"""
|
||||
|
||||
@staticmethod
|
||||
def generate_string(length: int = 10, prefix: str = "") -> str:
|
||||
"""Generate a random string"""
|
||||
random_part = secrets.token_urlsafe(length)[:length]
|
||||
return f"{prefix}{random_part}"
|
||||
|
||||
@staticmethod
|
||||
def generate_email() -> str:
|
||||
"""Generate a random email address"""
|
||||
return f"{MockFactory.generate_string(8)}@example.com"
|
||||
|
||||
@staticmethod
|
||||
def generate_url() -> str:
|
||||
"""Generate a random URL"""
|
||||
return f"https://example.com/{MockFactory.generate_string(8)}"
|
||||
|
||||
@staticmethod
|
||||
def generate_ip_address() -> str:
|
||||
"""Generate a random IP address"""
|
||||
return f"192.168.{secrets.randbelow(256)}.{secrets.randbelow(256)}"
|
||||
|
||||
@staticmethod
|
||||
def generate_ethereum_address() -> str:
|
||||
"""Generate a random Ethereum address"""
|
||||
return f"0x{''.join(secrets.choice('0123456789abcdef') for _ in range(40))}"
|
||||
|
||||
@staticmethod
|
||||
def generate_bitcoin_address() -> str:
|
||||
"""Generate a random Bitcoin-like address"""
|
||||
return f"1{''.join(secrets.choice('123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz') for _ in range(33))}"
|
||||
|
||||
@staticmethod
|
||||
def generate_uuid() -> str:
|
||||
"""Generate a UUID"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@staticmethod
|
||||
def generate_hash(length: int = 64) -> str:
|
||||
"""Generate a random hash string"""
|
||||
return secrets.token_hex(length)[:length]
|
||||
|
||||
|
||||
class TestDataGenerator:
|
||||
"""Generate test data for various use cases"""
|
||||
|
||||
@staticmethod
|
||||
def generate_user_data(**overrides) -> Dict[str, Any]:
|
||||
"""Generate mock user data"""
|
||||
data = {
|
||||
"id": MockFactory.generate_uuid(),
|
||||
"email": MockFactory.generate_email(),
|
||||
"username": MockFactory.generate_string(8),
|
||||
"first_name": MockFactory.generate_string(6),
|
||||
"last_name": MockFactory.generate_string(6),
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"updated_at": datetime.utcnow().isoformat(),
|
||||
"is_active": True,
|
||||
"role": "user"
|
||||
}
|
||||
data.update(overrides)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def generate_transaction_data(**overrides) -> Dict[str, Any]:
|
||||
"""Generate mock transaction data"""
|
||||
data = {
|
||||
"id": MockFactory.generate_uuid(),
|
||||
"from_address": MockFactory.generate_ethereum_address(),
|
||||
"to_address": MockFactory.generate_ethereum_address(),
|
||||
"amount": str(secrets.randbelow(1000000000000000000)),
|
||||
"gas_price": str(secrets.randbelow(100000000000)),
|
||||
"gas_limit": secrets.randbelow(100000),
|
||||
"nonce": secrets.randbelow(1000),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"status": "pending"
|
||||
}
|
||||
data.update(overrides)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def generate_block_data(**overrides) -> Dict[str, Any]:
|
||||
"""Generate mock block data"""
|
||||
data = {
|
||||
"number": secrets.randbelow(10000000),
|
||||
"hash": MockFactory.generate_hash(),
|
||||
"parent_hash": MockFactory.generate_hash(),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"transactions": [],
|
||||
"gas_used": str(secrets.randbelow(10000000)),
|
||||
"gas_limit": str(15000000),
|
||||
"miner": MockFactory.generate_ethereum_address()
|
||||
}
|
||||
data.update(overrides)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def generate_api_key_data(**overrides) -> Dict[str, Any]:
|
||||
"""Generate mock API key data"""
|
||||
data = {
|
||||
"id": MockFactory.generate_uuid(),
|
||||
"api_key": f"aitbc_{secrets.token_urlsafe(32)}",
|
||||
"user_id": MockFactory.generate_uuid(),
|
||||
"name": MockFactory.generate_string(10),
|
||||
"scopes": ["read", "write"],
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"last_used": None,
|
||||
"is_active": True
|
||||
}
|
||||
data.update(overrides)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def generate_wallet_data(**overrides) -> Dict[str, Any]:
|
||||
"""Generate mock wallet data"""
|
||||
data = {
|
||||
"id": MockFactory.generate_uuid(),
|
||||
"address": MockFactory.generate_ethereum_address(),
|
||||
"chain_id": 1,
|
||||
"balance": str(secrets.randbelow(1000000000000000000)),
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"is_active": True
|
||||
}
|
||||
data.update(overrides)
|
||||
return data
|
||||
|
||||
|
||||
class TestHelpers:
|
||||
"""Helper functions for testing"""
|
||||
|
||||
@staticmethod
|
||||
def assert_dict_contains(subset: Dict[str, Any], superset: Dict[str, Any]) -> bool:
|
||||
"""Check if superset contains all key-value pairs from subset"""
|
||||
for key, value in subset.items():
|
||||
if key not in superset:
|
||||
return False
|
||||
if superset[key] != value:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def assert_lists_equal_unordered(list1: List[Any], list2: List[Any]) -> bool:
|
||||
"""Check if two lists contain the same elements regardless of order"""
|
||||
return sorted(list1) == sorted(list2)
|
||||
|
||||
@staticmethod
|
||||
def compare_json_objects(obj1: Any, obj2: Any) -> bool:
|
||||
"""Compare two JSON-serializable objects"""
|
||||
return json.dumps(obj1, sort_keys=True) == json.dumps(obj2, sort_keys=True)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_condition(
|
||||
condition: Callable[[], bool],
|
||||
timeout: float = 10.0,
|
||||
interval: float = 0.1
|
||||
) -> bool:
|
||||
"""Wait for a condition to become true"""
|
||||
import time
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
if condition():
|
||||
return True
|
||||
time.sleep(interval)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def measure_execution_time(func: Callable, *args, **kwargs) -> tuple[Any, float]:
|
||||
"""Measure execution time of a function"""
|
||||
import time
|
||||
start = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
elapsed = time.time() - start
|
||||
return result, elapsed
|
||||
|
||||
@staticmethod
|
||||
def generate_test_file_path(extension: str = ".tmp") -> str:
|
||||
"""Generate a unique test file path"""
|
||||
return f"/tmp/test_{secrets.token_hex(8)}{extension}"
|
||||
|
||||
@staticmethod
|
||||
def cleanup_test_files(prefix: str = "test_") -> int:
|
||||
"""Clean up test files in /tmp"""
|
||||
import os
|
||||
import glob
|
||||
count = 0
|
||||
for file_path in glob.glob(f"/tmp/{prefix}*"):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
count += 1
|
||||
except:
|
||||
pass
|
||||
return count
|
||||
|
||||
|
||||
class MockResponse:
|
||||
"""Mock HTTP response for testing"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int = 200,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
text: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
"""Initialize mock response"""
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data
|
||||
self._text = text
|
||||
self.headers = headers or {}
|
||||
|
||||
def json(self) -> Dict[str, Any]:
|
||||
"""Return JSON data"""
|
||||
if self._json_data is None:
|
||||
raise ValueError("No JSON data available")
|
||||
return self._json_data
|
||||
|
||||
def text(self) -> str:
|
||||
"""Return text data"""
|
||||
if self._text is None:
|
||||
return ""
|
||||
return self._text
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""Raise exception if status code indicates error"""
|
||||
if self.status_code >= 400:
|
||||
raise Exception(f"HTTP Error: {self.status_code}")
|
||||
|
||||
|
||||
class MockDatabase:
|
||||
"""Mock database for testing"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize mock database"""
|
||||
self.data: Dict[str, List[Dict[str, Any]]] = {}
|
||||
self.tables: List[str] = []
|
||||
|
||||
def create_table(self, table_name: str) -> None:
|
||||
"""Create a table"""
|
||||
if table_name not in self.tables:
|
||||
self.tables.append(table_name)
|
||||
self.data[table_name] = []
|
||||
|
||||
def insert(self, table_name: str, record: Dict[str, Any]) -> None:
|
||||
"""Insert a record"""
|
||||
if table_name not in self.tables:
|
||||
self.create_table(table_name)
|
||||
record['id'] = record.get('id', MockFactory.generate_uuid())
|
||||
self.data[table_name].append(record)
|
||||
|
||||
def select(self, table_name: str, **filters) -> List[Dict[str, Any]]:
|
||||
"""Select records with optional filters"""
|
||||
if table_name not in self.tables:
|
||||
return []
|
||||
|
||||
records = self.data[table_name]
|
||||
if not filters:
|
||||
return records
|
||||
|
||||
filtered = []
|
||||
for record in records:
|
||||
match = True
|
||||
for key, value in filters.items():
|
||||
if record.get(key) != value:
|
||||
match = False
|
||||
break
|
||||
if match:
|
||||
filtered.append(record)
|
||||
|
||||
return filtered
|
||||
|
||||
def update(self, table_name: str, record_id: str, updates: Dict[str, Any]) -> bool:
|
||||
"""Update a record"""
|
||||
if table_name not in self.tables:
|
||||
return False
|
||||
|
||||
for record in self.data[table_name]:
|
||||
if record.get('id') == record_id:
|
||||
record.update(updates)
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete(self, table_name: str, record_id: str) -> bool:
|
||||
"""Delete a record"""
|
||||
if table_name not in self.tables:
|
||||
return False
|
||||
|
||||
for i, record in enumerate(self.data[table_name]):
|
||||
if record.get('id') == record_id:
|
||||
del self.data[table_name][i]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all data"""
|
||||
self.data.clear()
|
||||
self.tables.clear()
|
||||
|
||||
|
||||
class MockCache:
|
||||
"""Mock cache for testing"""
|
||||
|
||||
def __init__(self, ttl: int = 3600):
|
||||
"""Initialize mock cache"""
|
||||
self.cache: Dict[str, tuple[Any, float]] = {}
|
||||
self.ttl = ttl
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache"""
|
||||
if key not in self.cache:
|
||||
return None
|
||||
|
||||
value, timestamp = self.cache[key]
|
||||
if time.time() - timestamp > self.ttl:
|
||||
del self.cache[key]
|
||||
return None
|
||||
|
||||
return value
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""Set value in cache"""
|
||||
self.cache[key] = (value, time.time())
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete value from cache"""
|
||||
if key in self.cache:
|
||||
del self.cache[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear cache"""
|
||||
self.cache.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get cache size"""
|
||||
return len(self.cache)
|
||||
|
||||
|
||||
def mock_async_call(return_value: Any = None, delay: float = 0):
|
||||
"""Decorator to mock async calls with optional delay"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
async def wrapper(*args, **kwargs):
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
return return_value
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def create_mock_config(**overrides) -> Dict[str, Any]:
|
||||
"""Create mock configuration"""
|
||||
config = {
|
||||
"debug": False,
|
||||
"log_level": "INFO",
|
||||
"database_url": "sqlite:///test.db",
|
||||
"redis_url": "redis://localhost:6379",
|
||||
"api_host": "localhost",
|
||||
"api_port": 8080,
|
||||
"secret_key": MockFactory.generate_string(32),
|
||||
"max_workers": 4,
|
||||
"timeout": 30
|
||||
}
|
||||
config.update(overrides)
|
||||
return config
|
||||
|
||||
|
||||
import time
|
||||
|
||||
|
||||
def create_test_scenario(name: str, steps: List[Callable]) -> Callable:
|
||||
"""Create a test scenario with multiple steps"""
|
||||
def scenario():
|
||||
print(f"Running test scenario: {name}")
|
||||
results = []
|
||||
for i, step in enumerate(steps):
|
||||
try:
|
||||
result = step()
|
||||
results.append({"step": i + 1, "status": "passed", "result": result})
|
||||
except Exception as e:
|
||||
results.append({"step": i + 1, "status": "failed", "error": str(e)})
|
||||
return results
|
||||
return scenario
|
||||
321
aitbc/time_utils.py
Normal file
321
aitbc/time_utils.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Time utilities for AITBC
|
||||
Provides timestamp helpers, duration helpers, timezone handling, and deadline calculations
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Union
|
||||
import time
|
||||
|
||||
|
||||
def get_utc_now() -> datetime:
|
||||
"""Get current UTC datetime"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def get_timestamp_utc() -> float:
|
||||
"""Get current UTC timestamp"""
|
||||
return time.time()
|
||||
|
||||
|
||||
def format_iso8601(dt: Optional[datetime] = None) -> str:
|
||||
"""Format datetime as ISO 8601 string in UTC"""
|
||||
if dt is None:
|
||||
dt = get_utc_now()
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
def parse_iso8601(iso_string: str) -> datetime:
|
||||
"""Parse ISO 8601 string to datetime"""
|
||||
dt = datetime.fromisoformat(iso_string)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
def timestamp_to_iso(timestamp: float) -> str:
|
||||
"""Convert timestamp to ISO 8601 string"""
|
||||
return datetime.fromtimestamp(timestamp, timezone.utc).isoformat()
|
||||
|
||||
|
||||
def iso_to_timestamp(iso_string: str) -> float:
|
||||
"""Convert ISO 8601 string to timestamp"""
|
||||
dt = parse_iso8601(iso_string)
|
||||
return dt.timestamp()
|
||||
|
||||
|
||||
def format_duration(seconds: Union[int, float]) -> str:
|
||||
"""Format duration in seconds to human-readable string"""
|
||||
if seconds < 60:
|
||||
return f"{int(seconds)}s"
|
||||
elif seconds < 3600:
|
||||
minutes = int(seconds / 60)
|
||||
return f"{minutes}m"
|
||||
elif seconds < 86400:
|
||||
hours = int(seconds / 3600)
|
||||
return f"{hours}h"
|
||||
else:
|
||||
days = int(seconds / 86400)
|
||||
return f"{days}d"
|
||||
|
||||
|
||||
def format_duration_precise(seconds: Union[int, float]) -> str:
|
||||
"""Format duration with precise breakdown"""
|
||||
days = int(seconds // 86400)
|
||||
hours = int((seconds % 86400) // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = int(seconds % 60)
|
||||
|
||||
parts = []
|
||||
if days > 0:
|
||||
parts.append(f"{days}d")
|
||||
if hours > 0:
|
||||
parts.append(f"{hours}h")
|
||||
if minutes > 0:
|
||||
parts.append(f"{minutes}m")
|
||||
if secs > 0 or not parts:
|
||||
parts.append(f"{secs}s")
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def parse_duration(duration_str: str) -> float:
|
||||
"""Parse duration string to seconds"""
|
||||
duration_str = duration_str.strip().lower()
|
||||
|
||||
if duration_str.endswith('s'):
|
||||
return float(duration_str[:-1])
|
||||
elif duration_str.endswith('m'):
|
||||
return float(duration_str[:-1]) * 60
|
||||
elif duration_str.endswith('h'):
|
||||
return float(duration_str[:-1]) * 3600
|
||||
elif duration_str.endswith('d'):
|
||||
return float(duration_str[:-1]) * 86400
|
||||
else:
|
||||
return float(duration_str)
|
||||
|
||||
|
||||
def add_duration(dt: datetime, duration: Union[str, timedelta]) -> datetime:
|
||||
"""Add duration to datetime"""
|
||||
if isinstance(duration, str):
|
||||
duration = timedelta(seconds=parse_duration(duration))
|
||||
return dt + duration
|
||||
|
||||
|
||||
def subtract_duration(dt: datetime, duration: Union[str, timedelta]) -> datetime:
|
||||
"""Subtract duration from datetime"""
|
||||
if isinstance(duration, str):
|
||||
duration = timedelta(seconds=parse_duration(duration))
|
||||
return dt - duration
|
||||
|
||||
|
||||
def get_time_until(dt: datetime) -> timedelta:
|
||||
"""Get time until a future datetime"""
|
||||
now = get_utc_now()
|
||||
return dt - now
|
||||
|
||||
|
||||
def get_time_since(dt: datetime) -> timedelta:
|
||||
"""Get time since a past datetime"""
|
||||
now = get_utc_now()
|
||||
return now - dt
|
||||
|
||||
|
||||
def calculate_deadline(duration: Union[str, timedelta], from_dt: Optional[datetime] = None) -> datetime:
|
||||
"""Calculate deadline from duration"""
|
||||
if from_dt is None:
|
||||
from_dt = get_utc_now()
|
||||
return add_duration(from_dt, duration)
|
||||
|
||||
|
||||
def is_deadline_passed(deadline: datetime) -> bool:
|
||||
"""Check if deadline has passed"""
|
||||
return get_utc_now() >= deadline
|
||||
|
||||
|
||||
def get_deadline_remaining(deadline: datetime) -> float:
|
||||
"""Get remaining seconds until deadline"""
|
||||
delta = deadline - get_utc_now()
|
||||
return max(0, delta.total_seconds())
|
||||
|
||||
|
||||
def format_time_ago(dt: datetime) -> str:
|
||||
"""Format datetime as "time ago" string"""
|
||||
delta = get_time_since(dt)
|
||||
seconds = delta.total_seconds()
|
||||
|
||||
if seconds < 60:
|
||||
return "just now"
|
||||
elif seconds < 3600:
|
||||
minutes = int(seconds / 60)
|
||||
return f"{minutes} minute{'s' if minutes > 1 else ''} ago"
|
||||
elif seconds < 86400:
|
||||
hours = int(seconds / 3600)
|
||||
return f"{hours} hour{'s' if hours > 1 else ''} ago"
|
||||
elif seconds < 604800:
|
||||
days = int(seconds / 86400)
|
||||
return f"{days} day{'s' if days > 1 else ''} ago"
|
||||
else:
|
||||
weeks = int(seconds / 604800)
|
||||
return f"{weeks} week{'s' if weeks > 1 else ''} ago"
|
||||
|
||||
|
||||
def format_time_in(dt: datetime) -> str:
|
||||
"""Format datetime as "time in" string"""
|
||||
delta = get_time_until(dt)
|
||||
seconds = delta.total_seconds()
|
||||
|
||||
if seconds < 0:
|
||||
return format_time_ago(dt)
|
||||
|
||||
if seconds < 60:
|
||||
return "in a moment"
|
||||
elif seconds < 3600:
|
||||
minutes = int(seconds / 60)
|
||||
return f"in {minutes} minute{'s' if minutes > 1 else ''}"
|
||||
elif seconds < 86400:
|
||||
hours = int(seconds / 3600)
|
||||
return f"in {hours} hour{'s' if hours > 1 else ''}"
|
||||
elif seconds < 604800:
|
||||
days = int(seconds / 86400)
|
||||
return f"in {days} day{'s' if days > 1 else ''}"
|
||||
else:
|
||||
weeks = int(seconds / 604800)
|
||||
return f"in {weeks} week{'s' if weeks > 1 else ''}"
|
||||
|
||||
|
||||
def to_timezone(dt: datetime, tz_name: str) -> datetime:
|
||||
"""Convert datetime to specific timezone"""
|
||||
try:
|
||||
import pytz
|
||||
tz = pytz.timezone(tz_name)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt.astimezone(tz)
|
||||
except ImportError:
|
||||
raise ImportError("pytz is required for timezone conversion. Install with: pip install pytz")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to convert timezone: {e}")
|
||||
|
||||
|
||||
def get_timezone_offset(tz_name: str) -> timedelta:
|
||||
"""Get timezone offset from UTC"""
|
||||
try:
|
||||
import pytz
|
||||
tz = pytz.timezone(tz_name)
|
||||
now = datetime.now(timezone.utc)
|
||||
offset = tz.utcoffset(now)
|
||||
return offset if offset else timedelta(0)
|
||||
except ImportError:
|
||||
raise ImportError("pytz is required for timezone operations. Install with: pip install pytz")
|
||||
|
||||
|
||||
def is_business_hours(dt: Optional[datetime] = None, start_hour: int = 9, end_hour: int = 17, timezone: str = "UTC") -> bool:
|
||||
"""Check if datetime is within business hours"""
|
||||
if dt is None:
|
||||
dt = get_utc_now()
|
||||
|
||||
try:
|
||||
import pytz
|
||||
tz = pytz.timezone(timezone)
|
||||
dt_local = dt.astimezone(tz)
|
||||
return start_hour <= dt_local.hour < end_hour
|
||||
except ImportError:
|
||||
raise ImportError("pytz is required for business hours check. Install with: pip install pytz")
|
||||
|
||||
|
||||
def get_start_of_day(dt: Optional[datetime] = None) -> datetime:
|
||||
"""Get start of day (00:00:00) for given datetime"""
|
||||
if dt is None:
|
||||
dt = get_utc_now()
|
||||
return dt.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
|
||||
def get_end_of_day(dt: Optional[datetime] = None) -> datetime:
|
||||
"""Get end of day (23:59:59) for given datetime"""
|
||||
if dt is None:
|
||||
dt = get_utc_now()
|
||||
return dt.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
|
||||
|
||||
def get_start_of_week(dt: Optional[datetime] = None) -> datetime:
|
||||
"""Get start of week (Monday) for given datetime"""
|
||||
if dt is None:
|
||||
dt = get_utc_now()
|
||||
return dt - timedelta(days=dt.weekday())
|
||||
|
||||
|
||||
def get_end_of_week(dt: Optional[datetime] = None) -> datetime:
|
||||
"""Get end of week (Sunday) for given datetime"""
|
||||
if dt is None:
|
||||
dt = get_utc_now()
|
||||
return dt + timedelta(days=(6 - dt.weekday()))
|
||||
|
||||
|
||||
def get_start_of_month(dt: Optional[datetime] = None) -> datetime:
|
||||
"""Get start of month for given datetime"""
|
||||
if dt is None:
|
||||
dt = get_utc_now()
|
||||
return dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
|
||||
def get_end_of_month(dt: Optional[datetime] = None) -> datetime:
|
||||
"""Get end of month for given datetime"""
|
||||
if dt is None:
|
||||
dt = get_utc_now()
|
||||
if dt.month == 12:
|
||||
next_month = dt.replace(year=dt.year + 1, month=1, day=1)
|
||||
else:
|
||||
next_month = dt.replace(month=dt.month + 1, day=1)
|
||||
return next_month - timedelta(seconds=1)
|
||||
|
||||
|
||||
def sleep_until(dt: datetime) -> None:
|
||||
"""Sleep until a specific datetime"""
|
||||
now = get_utc_now()
|
||||
if dt > now:
|
||||
sleep_seconds = (dt - now).total_seconds()
|
||||
time.sleep(sleep_seconds)
|
||||
|
||||
|
||||
def retry_until_deadline(func, deadline: datetime, interval: float = 1.0) -> bool:
|
||||
"""Retry a function until deadline is reached"""
|
||||
while not is_deadline_passed(deadline):
|
||||
try:
|
||||
result = func()
|
||||
if result:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(interval)
|
||||
return False
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Simple timer context manager for measuring execution time"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize timer"""
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.elapsed = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Start timer"""
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Stop timer"""
|
||||
self.end_time = time.time()
|
||||
self.elapsed = self.end_time - self.start_time
|
||||
|
||||
def get_elapsed(self) -> Optional[float]:
|
||||
"""Get elapsed time in seconds"""
|
||||
if self.elapsed is not None:
|
||||
return self.elapsed
|
||||
elif self.start_time is not None:
|
||||
return time.time() - self.start_time
|
||||
return None
|
||||
206
aitbc/web3_utils.py
Normal file
206
aitbc/web3_utils.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Web3 utilities for AITBC
|
||||
Provides Ethereum blockchain interaction utilities using web3.py
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class Web3Client:
|
||||
"""Web3 client wrapper for blockchain operations"""
|
||||
|
||||
def __init__(self, rpc_url: str, timeout: int = 30):
|
||||
"""Initialize Web3 client with RPC URL"""
|
||||
try:
|
||||
from web3 import Web3
|
||||
from web3.middleware import geth_poa_middleware
|
||||
|
||||
self.w3 = Web3(Web3.HTTPProvider(rpc_url, request_kwargs={'timeout': timeout}))
|
||||
|
||||
# Add POA middleware for chains like Polygon, BSC, etc.
|
||||
self.w3.middleware_onion.inject(geth_poa_middleware, layer=0)
|
||||
|
||||
if not self.w3.is_connected():
|
||||
raise ConnectionError(f"Failed to connect to RPC URL: {rpc_url}")
|
||||
except ImportError:
|
||||
raise ImportError("web3 is required for blockchain operations. Install with: pip install web3")
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Failed to initialize Web3 client: {e}")
|
||||
|
||||
def get_eth_balance(self, address: str) -> str:
|
||||
"""Get ETH balance in wei"""
|
||||
try:
|
||||
balance_wei = self.w3.eth.get_balance(address)
|
||||
return str(balance_wei)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get ETH balance: {e}")
|
||||
|
||||
def get_token_balance(self, address: str, token_address: str) -> dict[str, Any]:
|
||||
"""Get ERC-20 token balance"""
|
||||
try:
|
||||
# ERC-20 balanceOf function signature: 0x70a08231
|
||||
balance_of_signature = '0x70a08231'
|
||||
# Pad address to 32 bytes
|
||||
padded_address = address[2:].lower().zfill(64)
|
||||
call_data = balance_of_signature + padded_address
|
||||
|
||||
result = self.w3.eth.call({
|
||||
'to': token_address,
|
||||
'data': f'0x{call_data}'
|
||||
})
|
||||
|
||||
balance = int(result.hex(), 16)
|
||||
|
||||
# Get token decimals
|
||||
decimals_signature = '0x313ce567'
|
||||
decimals_result = self.w3.eth.call({
|
||||
'to': token_address,
|
||||
'data': decimals_signature
|
||||
})
|
||||
decimals = int(decimals_result.hex(), 16)
|
||||
|
||||
# Get token symbol (optional, may fail for some tokens)
|
||||
try:
|
||||
symbol_signature = '0x95d89b41'
|
||||
symbol_result = self.w3.eth.call({
|
||||
'to': token_address,
|
||||
'data': symbol_signature
|
||||
})
|
||||
symbol_bytes = bytes.fromhex(symbol_result.hex()[2:])
|
||||
symbol = symbol_bytes.rstrip(b'\x00').decode('utf-8')
|
||||
except:
|
||||
symbol = "TOKEN"
|
||||
|
||||
return {
|
||||
"balance": str(balance),
|
||||
"decimals": decimals,
|
||||
"symbol": symbol
|
||||
}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get token balance: {e}")
|
||||
|
||||
def get_gas_price(self) -> int:
|
||||
"""Get current gas price in wei"""
|
||||
try:
|
||||
gas_price = self.w3.eth.gas_price
|
||||
return gas_price
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get gas price: {e}")
|
||||
|
||||
def get_gas_price_gwei(self) -> float:
|
||||
"""Get current gas price in Gwei"""
|
||||
try:
|
||||
gas_price_wei = self.get_gas_price()
|
||||
return float(gas_price_wei) / 10**9
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get gas price in Gwei: {e}")
|
||||
|
||||
def get_nonce(self, address: str) -> int:
|
||||
"""Get transaction nonce for address"""
|
||||
try:
|
||||
nonce = self.w3.eth.get_transaction_count(address)
|
||||
return nonce
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get nonce: {e}")
|
||||
|
||||
def send_raw_transaction(self, signed_transaction: str) -> str:
|
||||
"""Send raw transaction to blockchain"""
|
||||
try:
|
||||
tx_hash = self.w3.eth.send_raw_transaction(signed_transaction)
|
||||
return tx_hash.hex()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to send raw transaction: {e}")
|
||||
|
||||
def get_transaction_receipt(self, tx_hash: str) -> Optional[dict[str, Any]]:
|
||||
"""Get transaction receipt"""
|
||||
try:
|
||||
receipt = self.w3.eth.get_transaction_receipt(tx_hash)
|
||||
if receipt is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"status": receipt['status'],
|
||||
"blockNumber": hex(receipt['blockNumber']),
|
||||
"blockHash": receipt['blockHash'].hex(),
|
||||
"gasUsed": hex(receipt['gasUsed']),
|
||||
"effectiveGasPrice": hex(receipt['effectiveGasPrice']),
|
||||
"logs": receipt['logs'],
|
||||
}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get transaction receipt: {e}")
|
||||
|
||||
def get_transaction_by_hash(self, tx_hash: str) -> dict[str, Any]:
|
||||
"""Get transaction by hash"""
|
||||
try:
|
||||
tx = self.w3.eth.get_transaction(tx_hash)
|
||||
return {
|
||||
"from": tx['from'],
|
||||
"to": tx['to'],
|
||||
"value": hex(tx['value']),
|
||||
"data": tx['input'].hex() if hasattr(tx['input'], 'hex') else tx['input'],
|
||||
"nonce": tx['nonce'],
|
||||
"gas": tx['gas'],
|
||||
"gasPrice": hex(tx['gasPrice']),
|
||||
"blockNumber": hex(tx['blockNumber']) if tx['blockNumber'] else None,
|
||||
}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get transaction by hash: {e}")
|
||||
|
||||
def estimate_gas(self, transaction: dict[str, Any]) -> int:
|
||||
"""Estimate gas for transaction"""
|
||||
try:
|
||||
gas_estimate = self.w3.eth.estimate_gas(transaction)
|
||||
return gas_estimate
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to estimate gas: {e}")
|
||||
|
||||
def get_block_number(self) -> int:
|
||||
"""Get current block number"""
|
||||
try:
|
||||
return self.w3.eth.block_number
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get block number: {e}")
|
||||
|
||||
def get_wallet_transactions(self, address: str, limit: int = 100) -> list[dict[str, Any]]:
|
||||
"""Get wallet transactions (simplified implementation)"""
|
||||
try:
|
||||
# This is a simplified version - in production you'd want to use
|
||||
# event logs or a blockchain explorer API for this
|
||||
transactions = []
|
||||
current_block = self.get_block_number()
|
||||
|
||||
# Look back at recent blocks for transactions from/to this address
|
||||
start_block = max(0, current_block - 1000)
|
||||
|
||||
for block_num in range(current_block, start_block, -1):
|
||||
if len(transactions) >= limit:
|
||||
break
|
||||
|
||||
try:
|
||||
block = self.w3.eth.get_block(block_num, full_transactions=True)
|
||||
for tx in block['transactions']:
|
||||
if tx['from'].lower() == address.lower() or \
|
||||
(tx['to'] and tx['to'].lower() == address.lower()):
|
||||
transactions.append({
|
||||
"hash": tx['hash'].hex(),
|
||||
"from": tx['from'],
|
||||
"to": tx['to'].hex() if tx['to'] else None,
|
||||
"value": hex(tx['value']),
|
||||
"blockNumber": hex(tx['blockNumber']),
|
||||
"timestamp": block['timestamp'],
|
||||
"gasUsed": hex(tx['gas']),
|
||||
})
|
||||
if len(transactions) >= limit:
|
||||
break
|
||||
except:
|
||||
continue
|
||||
|
||||
return transactions
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get wallet transactions: {e}")
|
||||
|
||||
|
||||
def create_web3_client(rpc_url: str, timeout: int = 30) -> Web3Client:
|
||||
"""Factory function to create Web3 client"""
|
||||
return Web3Client(rpc_url, timeout)
|
||||
@@ -1048,19 +1048,15 @@ async def search_transactions(
|
||||
response = await client.get(f"{rpc_url}/rpc/search/transactions", params=params)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
elif response.status_code == 404:
|
||||
return []
|
||||
else:
|
||||
# Return mock data for demonstration
|
||||
return [
|
||||
{
|
||||
"hash": "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef",
|
||||
"type": tx_type or "transfer",
|
||||
"from": "0xabcdef1234567890abcdef1234567890abcdef1234",
|
||||
"to": "0x1234567890abcdef1234567890abcdef12345678",
|
||||
"amount": "1.5",
|
||||
"fee": "0.001",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
]
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to fetch transactions from blockchain RPC: {response.text}"
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Blockchain RPC unavailable: {str(e)}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
||||
|
||||
@@ -1095,17 +1091,15 @@ async def search_blocks(
|
||||
response = await client.get(f"{rpc_url}/rpc/search/blocks", params=params)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
elif response.status_code == 404:
|
||||
return []
|
||||
else:
|
||||
# Return mock data for demonstration
|
||||
return [
|
||||
{
|
||||
"height": 12345,
|
||||
"hash": "0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890",
|
||||
"validator": validator or "0x1234567890abcdef1234567890abcdef12345678",
|
||||
"tx_count": min_tx or 5,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
]
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to fetch blocks from blockchain RPC: {response.text}"
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
raise HTTPException(status_code=503, detail=f"Blockchain RPC unavailable: {str(e)}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
||||
|
||||
|
||||
@@ -0,0 +1,564 @@
|
||||
"""
|
||||
Hub Manager
|
||||
Manages hub operations, peer list sharing, and hub registration for federated mesh
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
from typing import Dict, List, Optional, Set
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from enum import Enum
|
||||
from ..config import settings
|
||||
|
||||
from aitbc import get_logger, DATA_DIR, KEYSTORE_DIR
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HubStatus(Enum):
|
||||
"""Hub registration status"""
|
||||
REGISTERED = "registered"
|
||||
UNREGISTERED = "unregistered"
|
||||
PENDING = "pending"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HubInfo:
|
||||
"""Information about a hub node"""
|
||||
node_id: str
|
||||
address: str
|
||||
port: int
|
||||
island_id: str
|
||||
island_name: str
|
||||
public_address: Optional[str] = None
|
||||
public_port: Optional[int] = None
|
||||
registered_at: float = 0
|
||||
last_seen: float = 0
|
||||
peer_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeerInfo:
|
||||
"""Information about a peer"""
|
||||
node_id: str
|
||||
address: str
|
||||
port: int
|
||||
island_id: str
|
||||
is_hub: bool
|
||||
public_address: Optional[str] = None
|
||||
public_port: Optional[int] = None
|
||||
last_seen: float = 0
|
||||
|
||||
|
||||
class HubManager:
|
||||
"""Manages hub operations for federated mesh"""
|
||||
|
||||
def __init__(self, local_node_id: str, local_address: str, local_port: int, island_id: str, island_name: str, redis_url: Optional[str] = None):
|
||||
self.local_node_id = local_node_id
|
||||
self.local_address = local_address
|
||||
self.local_port = local_port
|
||||
self.island_id = island_id
|
||||
self.island_name = island_name
|
||||
self.island_chain_id = settings.island_chain_id or settings.chain_id or f"ait-{island_id[:8]}"
|
||||
self.redis_url = redis_url or "redis://localhost:6379"
|
||||
|
||||
# Hub registration status
|
||||
self.is_hub = False
|
||||
self.hub_status = HubStatus.UNREGISTERED
|
||||
self.registered_at: Optional[float] = None
|
||||
|
||||
# Known hubs
|
||||
self.known_hubs: Dict[str, HubInfo] = {} # node_id -> HubInfo
|
||||
|
||||
# Peer registry (for providing peer lists)
|
||||
self.peer_registry: Dict[str, PeerInfo] = {} # node_id -> PeerInfo
|
||||
|
||||
# Island peers (island_id -> set of node_ids)
|
||||
self.island_peers: Dict[str, Set[str]] = {}
|
||||
|
||||
self.running = False
|
||||
self._redis = None
|
||||
|
||||
# Initialize island peers for our island
|
||||
self.island_peers[self.island_id] = set()
|
||||
|
||||
async def _connect_redis(self):
|
||||
"""Connect to Redis"""
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
self._redis = redis.from_url(self.redis_url)
|
||||
await self._redis.ping()
|
||||
logger.info(f"Connected to Redis for hub persistence: {self.redis_url}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
return False
|
||||
|
||||
async def _persist_hub_registration(self, hub_info: HubInfo) -> bool:
|
||||
"""Persist hub registration to Redis"""
|
||||
try:
|
||||
if not self._redis:
|
||||
await self._connect_redis()
|
||||
|
||||
if not self._redis:
|
||||
logger.warning("Redis not available, skipping persistence")
|
||||
return False
|
||||
|
||||
key = f"hub:{hub_info.node_id}"
|
||||
value = json.dumps(asdict(hub_info), default=str)
|
||||
await self._redis.setex(key, 3600, value) # TTL: 1 hour
|
||||
logger.info(f"Persisted hub registration to Redis: {key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to persist hub registration: {e}")
|
||||
return False
|
||||
|
||||
async def _remove_hub_registration(self, node_id: str) -> bool:
|
||||
"""Remove hub registration from Redis"""
|
||||
try:
|
||||
if not self._redis:
|
||||
await self._connect_redis()
|
||||
|
||||
if not self._redis:
|
||||
logger.warning("Redis not available, skipping removal")
|
||||
return False
|
||||
|
||||
key = f"hub:{node_id}"
|
||||
await self._redis.delete(key)
|
||||
logger.info(f"Removed hub registration from Redis: {key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove hub registration: {e}")
|
||||
return False
|
||||
|
||||
async def _load_hub_registration(self) -> Optional[HubInfo]:
|
||||
"""Load hub registration from Redis"""
|
||||
try:
|
||||
if not self._redis:
|
||||
await self._connect_redis()
|
||||
|
||||
if not self._redis:
|
||||
return None
|
||||
|
||||
key = f"hub:{self.local_node_id}"
|
||||
value = await self._redis.get(key)
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
return HubInfo(**data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load hub registration: {e}")
|
||||
return None
|
||||
|
||||
def _get_blockchain_credentials(self) -> dict:
|
||||
"""Get blockchain credentials from keystore"""
|
||||
try:
|
||||
credentials = {}
|
||||
|
||||
# Get genesis block hash from genesis.json
|
||||
genesis_candidates = [
|
||||
str(settings.db_path.parent / 'genesis.json'),
|
||||
f"{DATA_DIR}/data/{settings.chain_id}/genesis.json",
|
||||
f'{DATA_DIR}/data/ait-mainnet/genesis.json',
|
||||
]
|
||||
for genesis_path in genesis_candidates:
|
||||
if os.path.exists(genesis_path):
|
||||
with open(genesis_path, 'r') as f:
|
||||
genesis_data = json.load(f)
|
||||
if 'blocks' in genesis_data and len(genesis_data['blocks']) > 0:
|
||||
genesis_block = genesis_data['blocks'][0]
|
||||
credentials['genesis_block_hash'] = genesis_block.get('hash', '')
|
||||
credentials['genesis_block'] = genesis_data
|
||||
break
|
||||
|
||||
# Get genesis address from keystore
|
||||
keystore_path = str(KEYSTORE_DIR / 'validator_keys.json')
|
||||
if os.path.exists(keystore_path):
|
||||
with open(keystore_path, 'r') as f:
|
||||
keys = json.load(f)
|
||||
# Get first key's address
|
||||
for key_id, key_data in keys.items():
|
||||
# Extract address from public key or use key_id
|
||||
credentials['genesis_address'] = key_id
|
||||
break
|
||||
|
||||
# Add chain info
|
||||
credentials['chain_id'] = self.island_chain_id
|
||||
credentials['island_id'] = self.island_id
|
||||
credentials['island_name'] = self.island_name
|
||||
|
||||
# Add RPC endpoint (local)
|
||||
rpc_host = self.local_address
|
||||
if rpc_host in {"0.0.0.0", "127.0.0.1", "localhost", ""}:
|
||||
rpc_host = settings.hub_discovery_url or socket.gethostname()
|
||||
credentials['rpc_endpoint'] = f"http://{rpc_host}:8006"
|
||||
credentials['p2p_port'] = self.local_port
|
||||
|
||||
return credentials
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get blockchain credentials: {e}")
|
||||
return {}
|
||||
|
||||
async def handle_join_request(self, join_request: dict) -> Optional[dict]:
|
||||
"""
|
||||
Handle island join request from a new node
|
||||
|
||||
Args:
|
||||
join_request: Dictionary containing join request data
|
||||
|
||||
Returns:
|
||||
dict: Join response with member list and credentials, or None if failed
|
||||
"""
|
||||
try:
|
||||
requested_island_id = join_request.get('island_id')
|
||||
|
||||
# Validate island ID
|
||||
if requested_island_id != self.island_id:
|
||||
logger.warning(f"Join request for island {requested_island_id} does not match our island {self.island_id}")
|
||||
return None
|
||||
|
||||
# Get all island members
|
||||
members = []
|
||||
for node_id, peer_info in self.peer_registry.items():
|
||||
if peer_info.island_id == self.island_id:
|
||||
members.append({
|
||||
'node_id': peer_info.node_id,
|
||||
'address': peer_info.address,
|
||||
'port': peer_info.port,
|
||||
'is_hub': peer_info.is_hub,
|
||||
'public_address': peer_info.public_address,
|
||||
'public_port': peer_info.public_port
|
||||
})
|
||||
|
||||
# Include self in member list
|
||||
members.append({
|
||||
'node_id': self.local_node_id,
|
||||
'address': self.local_address,
|
||||
'port': self.local_port,
|
||||
'is_hub': True,
|
||||
'public_address': self.known_hubs.get(self.local_node_id, {}).public_address if self.local_node_id in self.known_hubs else None,
|
||||
'public_port': self.known_hubs.get(self.local_node_id, {}).public_port if self.local_node_id in self.known_hubs else None
|
||||
})
|
||||
|
||||
# Get blockchain credentials
|
||||
credentials = self._get_blockchain_credentials()
|
||||
|
||||
# Build response
|
||||
response = {
|
||||
'type': 'join_response',
|
||||
'island_id': self.island_id,
|
||||
'island_name': self.island_name,
|
||||
'island_chain_id': self.island_chain_id or f"ait-{self.island_id[:8]}",
|
||||
'members': members,
|
||||
'credentials': credentials
|
||||
}
|
||||
|
||||
logger.info(f"Sent join_response to node {join_request.get('node_id')} with {len(members)} members")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling join request: {e}")
|
||||
return None
|
||||
|
||||
def register_gpu_offer(self, offer_data: dict) -> bool:
|
||||
"""Register a GPU marketplace offer in the hub"""
|
||||
try:
|
||||
offer_id = offer_data.get('offer_id')
|
||||
if offer_id:
|
||||
self.gpu_offers[offer_id] = offer_data
|
||||
logger.info(f"Registered GPU offer: {offer_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering GPU offer: {e}")
|
||||
return False
|
||||
|
||||
def register_gpu_bid(self, bid_data: dict) -> bool:
|
||||
"""Register a GPU marketplace bid in the hub"""
|
||||
try:
|
||||
bid_id = bid_data.get('bid_id')
|
||||
if bid_id:
|
||||
self.gpu_bids[bid_id] = bid_data
|
||||
logger.info(f"Registered GPU bid: {bid_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering GPU bid: {e}")
|
||||
return False
|
||||
|
||||
def register_gpu_provider(self, node_id: str, gpu_info: dict) -> bool:
|
||||
"""Register a GPU provider in the hub"""
|
||||
try:
|
||||
self.gpu_providers[node_id] = gpu_info
|
||||
logger.info(f"Registered GPU provider: {node_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering GPU provider: {e}")
|
||||
return False
|
||||
|
||||
def register_exchange_order(self, order_data: dict) -> bool:
|
||||
"""Register an exchange order in the hub"""
|
||||
try:
|
||||
order_id = order_data.get('order_id')
|
||||
if order_id:
|
||||
self.exchange_orders[order_id] = order_data
|
||||
|
||||
# Update order book
|
||||
pair = order_data.get('pair')
|
||||
side = order_data.get('side')
|
||||
if pair and side:
|
||||
if pair not in self.exchange_order_books:
|
||||
self.exchange_order_books[pair] = {'bids': [], 'asks': []}
|
||||
|
||||
if side == 'buy':
|
||||
self.exchange_order_books[pair]['bids'].append(order_data)
|
||||
elif side == 'sell':
|
||||
self.exchange_order_books[pair]['asks'].append(order_data)
|
||||
|
||||
logger.info(f"Registered exchange order: {order_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering exchange order: {e}")
|
||||
return False
|
||||
|
||||
def get_gpu_offers(self) -> list:
|
||||
"""Get all GPU offers"""
|
||||
return list(self.gpu_offers.values())
|
||||
|
||||
def get_gpu_bids(self) -> list:
|
||||
"""Get all GPU bids"""
|
||||
return list(self.gpu_bids.values())
|
||||
|
||||
def get_gpu_providers(self) -> list:
|
||||
"""Get all GPU providers"""
|
||||
return list(self.gpu_providers.values())
|
||||
|
||||
def get_exchange_order_book(self, pair: str) -> dict:
|
||||
"""Get order book for a specific trading pair"""
|
||||
return self.exchange_order_books.get(pair, {'bids': [], 'asks': []})
|
||||
|
||||
async def register_as_hub(self, public_address: Optional[str] = None, public_port: Optional[int] = None) -> bool:
|
||||
"""Register this node as a hub"""
|
||||
if self.is_hub:
|
||||
logger.warning("Already registered as hub")
|
||||
return False
|
||||
|
||||
self.is_hub = True
|
||||
self.hub_status = HubStatus.REGISTERED
|
||||
self.registered_at = time.time()
|
||||
|
||||
# Add self to known hubs
|
||||
hub_info = HubInfo(
|
||||
node_id=self.local_node_id,
|
||||
address=self.local_address,
|
||||
port=self.local_port,
|
||||
island_id=self.island_id,
|
||||
island_name=self.island_name,
|
||||
public_address=public_address,
|
||||
public_port=public_port,
|
||||
registered_at=time.time(),
|
||||
last_seen=time.time()
|
||||
)
|
||||
self.known_hubs[self.local_node_id] = hub_info
|
||||
|
||||
# Persist to Redis
|
||||
await self._persist_hub_registration(hub_info)
|
||||
|
||||
logger.info(f"Registered as hub for island {self.island_id}")
|
||||
return True
|
||||
|
||||
async def unregister_as_hub(self) -> bool:
|
||||
"""Unregister this node as a hub"""
|
||||
if not self.is_hub:
|
||||
logger.warning("Not registered as hub")
|
||||
return False
|
||||
|
||||
self.is_hub = False
|
||||
self.hub_status = HubStatus.UNREGISTERED
|
||||
self.registered_at = None
|
||||
|
||||
# Remove from Redis
|
||||
await self._remove_hub_registration(self.local_node_id)
|
||||
|
||||
# Remove self from known hubs
|
||||
if self.local_node_id in self.known_hubs:
|
||||
del self.known_hubs[self.local_node_id]
|
||||
|
||||
logger.info(f"Unregistered as hub for island {self.island_id}")
|
||||
return True
|
||||
|
||||
def register_peer(self, peer_info: PeerInfo) -> bool:
|
||||
"""Register a peer in the registry"""
|
||||
self.peer_registry[peer_info.node_id] = peer_info
|
||||
|
||||
# Add to island peers
|
||||
if peer_info.island_id not in self.island_peers:
|
||||
self.island_peers[peer_info.island_id] = set()
|
||||
self.island_peers[peer_info.island_id].add(peer_info.node_id)
|
||||
|
||||
# Update hub peer count if peer is a hub
|
||||
if peer_info.is_hub and peer_info.node_id in self.known_hubs:
|
||||
self.known_hubs[peer_info.node_id].peer_count = len(self.island_peers.get(peer_info.island_id, set()))
|
||||
|
||||
logger.debug(f"Registered peer {peer_info.node_id} in island {peer_info.island_id}")
|
||||
return True
|
||||
|
||||
def unregister_peer(self, node_id: str) -> bool:
|
||||
"""Unregister a peer from the registry"""
|
||||
if node_id not in self.peer_registry:
|
||||
return False
|
||||
|
||||
peer_info = self.peer_registry[node_id]
|
||||
|
||||
# Remove from island peers
|
||||
if peer_info.island_id in self.island_peers:
|
||||
self.island_peers[peer_info.island_id].discard(node_id)
|
||||
|
||||
del self.peer_registry[node_id]
|
||||
|
||||
# Update hub peer count
|
||||
if node_id in self.known_hubs:
|
||||
self.known_hubs[node_id].peer_count = len(self.island_peers.get(self.known_hubs[node_id].island_id, set()))
|
||||
|
||||
logger.debug(f"Unregistered peer {node_id}")
|
||||
return True
|
||||
|
||||
def add_known_hub(self, hub_info: HubInfo):
|
||||
"""Add a known hub to the registry"""
|
||||
self.known_hubs[hub_info.node_id] = hub_info
|
||||
logger.info(f"Added known hub {hub_info.node_id} for island {hub_info.island_id}")
|
||||
|
||||
def remove_known_hub(self, node_id: str) -> bool:
|
||||
"""Remove a known hub from the registry"""
|
||||
if node_id not in self.known_hubs:
|
||||
return False
|
||||
|
||||
del self.known_hubs[node_id]
|
||||
logger.info(f"Removed known hub {node_id}")
|
||||
return True
|
||||
|
||||
def get_peer_list(self, island_id: str) -> List[PeerInfo]:
|
||||
"""Get peer list for a specific island"""
|
||||
peers = []
|
||||
for node_id, peer_info in self.peer_registry.items():
|
||||
if peer_info.island_id == island_id:
|
||||
peers.append(peer_info)
|
||||
return peers
|
||||
|
||||
def get_hub_list(self, island_id: Optional[str] = None) -> List[HubInfo]:
|
||||
"""Get list of known hubs, optionally filtered by island"""
|
||||
hubs = []
|
||||
for hub_info in self.known_hubs.values():
|
||||
if island_id is None or hub_info.island_id == island_id:
|
||||
hubs.append(hub_info)
|
||||
return hubs
|
||||
|
||||
def get_island_peers(self, island_id: str) -> Set[str]:
|
||||
"""Get set of peer node IDs in an island"""
|
||||
return self.island_peers.get(island_id, set()).copy()
|
||||
|
||||
def get_peer_count(self, island_id: str) -> int:
|
||||
"""Get number of peers in an island"""
|
||||
return len(self.island_peers.get(island_id, set()))
|
||||
|
||||
def get_hub_info(self, node_id: str) -> Optional[HubInfo]:
|
||||
"""Get information about a specific hub"""
|
||||
return self.known_hubs.get(node_id)
|
||||
|
||||
def get_peer_info(self, node_id: str) -> Optional[PeerInfo]:
|
||||
"""Get information about a specific peer"""
|
||||
return self.peer_registry.get(node_id)
|
||||
|
||||
def update_peer_last_seen(self, node_id: str):
|
||||
"""Update the last seen time for a peer"""
|
||||
if node_id in self.peer_registry:
|
||||
self.peer_registry[node_id].last_seen = time.time()
|
||||
|
||||
if node_id in self.known_hubs:
|
||||
self.known_hubs[node_id].last_seen = time.time()
|
||||
|
||||
async def start(self):
|
||||
"""Start hub manager"""
|
||||
self.running = True
|
||||
logger.info(f"Starting hub manager for node {self.local_node_id}")
|
||||
|
||||
# Start background tasks
|
||||
tasks = [
|
||||
asyncio.create_task(self._hub_health_check()),
|
||||
asyncio.create_task(self._peer_cleanup())
|
||||
]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
logger.error(f"Hub manager error: {e}")
|
||||
finally:
|
||||
self.running = False
|
||||
|
||||
async def stop(self):
|
||||
"""Stop hub manager"""
|
||||
self.running = False
|
||||
logger.info("Stopping hub manager")
|
||||
|
||||
async def _hub_health_check(self):
|
||||
"""Check health of known hubs"""
|
||||
while self.running:
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# Check for offline hubs (not seen for 10 minutes)
|
||||
offline_hubs = []
|
||||
for node_id, hub_info in self.known_hubs.items():
|
||||
if current_time - hub_info.last_seen > 600:
|
||||
offline_hubs.append(node_id)
|
||||
logger.warning(f"Hub {node_id} appears to be offline")
|
||||
|
||||
# Remove offline hubs (keep self if we're a hub)
|
||||
for node_id in offline_hubs:
|
||||
if node_id != self.local_node_id:
|
||||
self.remove_known_hub(node_id)
|
||||
|
||||
await asyncio.sleep(60) # Check every minute
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hub health check error: {e}")
|
||||
await asyncio.sleep(10)
|
||||
|
||||
async def _peer_cleanup(self):
|
||||
"""Clean up stale peer entries"""
|
||||
while self.running:
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# Remove peers not seen for 5 minutes
|
||||
stale_peers = []
|
||||
for node_id, peer_info in self.peer_registry.items():
|
||||
if current_time - peer_info.last_seen > 300:
|
||||
stale_peers.append(node_id)
|
||||
|
||||
for node_id in stale_peers:
|
||||
self.unregister_peer(node_id)
|
||||
logger.debug(f"Removed stale peer {node_id}")
|
||||
|
||||
await asyncio.sleep(60) # Check every minute
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Peer cleanup error: {e}")
|
||||
await asyncio.sleep(10)
|
||||
|
||||
|
||||
# Global hub manager instance
|
||||
hub_manager_instance: Optional[HubManager] = None
|
||||
|
||||
|
||||
def get_hub_manager() -> Optional[HubManager]:
|
||||
"""Get global hub manager instance"""
|
||||
return hub_manager_instance
|
||||
|
||||
|
||||
def create_hub_manager(node_id: str, address: str, port: int, island_id: str, island_name: str) -> HubManager:
|
||||
"""Create and set global hub manager instance"""
|
||||
global hub_manager_instance
|
||||
hub_manager_instance = HubManager(node_id, address, port, island_id, island_name)
|
||||
return hub_manager_instance
|
||||
|
||||
@@ -7,8 +7,10 @@ from psycopg2.extras import RealDictCursor
|
||||
import json
|
||||
from decimal import Decimal
|
||||
|
||||
from aitbc.constants import DATA_DIR
|
||||
|
||||
# Database configurations
|
||||
SQLITE_DB = "/var/lib/aitbc/data/coordinator.db"
|
||||
SQLITE_DB = str(DATA_DIR / "data/coordinator.db")
|
||||
PG_CONFIG = {
|
||||
"host": "localhost",
|
||||
"database": "aitbc_coordinator",
|
||||
|
||||
@@ -12,7 +12,7 @@ from decimal import Decimal
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from aitbc import get_logger
|
||||
from aitbc import get_logger, derive_ethereum_address, sign_transaction_hash, verify_signature, encrypt_private_key, Web3Client
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -174,6 +174,8 @@ class EthereumWalletAdapter(EnhancedWalletAdapter):
|
||||
def __init__(self, chain_id: int, rpc_url: str, security_level: SecurityLevel = SecurityLevel.MEDIUM):
|
||||
super().__init__(chain_id, ChainType.ETHEREUM, rpc_url, security_level)
|
||||
self.chain_id = chain_id
|
||||
# Initialize Web3 client for blockchain operations
|
||||
self._web3_client = Web3Client(rpc_url)
|
||||
|
||||
async def create_wallet(self, owner_address: str, security_config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Create a new Ethereum wallet with enhanced security"""
|
||||
@@ -446,25 +448,36 @@ class EthereumWalletAdapter(EnhancedWalletAdapter):
|
||||
# Private helper methods
|
||||
async def _derive_address_from_private_key(self, private_key: str) -> str:
|
||||
"""Derive Ethereum address from private key"""
|
||||
# This would use actual Ethereum cryptography
|
||||
# For now, return a mock address
|
||||
return f"0x{hashlib.sha256(private_key.encode()).hexdigest()[:40]}"
|
||||
try:
|
||||
return derive_ethereum_address(private_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to derive address from private key: {e}")
|
||||
raise
|
||||
|
||||
async def _encrypt_private_key(self, private_key: str, security_config: dict[str, Any]) -> str:
|
||||
"""Encrypt private key with security configuration"""
|
||||
# This would use actual encryption
|
||||
# For now, return mock encrypted key
|
||||
return f"encrypted_{hashlib.sha256(private_key.encode()).hexdigest()}"
|
||||
try:
|
||||
password = security_config.get("encryption_password", "default_password")
|
||||
return encrypt_private_key(private_key, password)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encrypt private key: {e}")
|
||||
raise
|
||||
|
||||
async def _get_eth_balance(self, address: str) -> str:
|
||||
"""Get ETH balance in wei"""
|
||||
# Mock implementation
|
||||
return "1000000000000000000" # 1 ETH in wei
|
||||
try:
|
||||
return self._web3_client.get_eth_balance(address)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get ETH balance: {e}")
|
||||
raise
|
||||
|
||||
async def _get_token_balance(self, address: str, token_address: str) -> dict[str, Any]:
|
||||
"""Get ERC-20 token balance"""
|
||||
# Mock implementation
|
||||
return {"balance": "100000000000000000000", "decimals": 18, "symbol": "TOKEN"} # 100 tokens
|
||||
try:
|
||||
return self._web3_client.get_token_balance(address, token_address)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get token balance: {e}")
|
||||
raise
|
||||
|
||||
async def _create_erc20_transfer(
|
||||
self, from_address: str, to_address: str, token_address: str, amount: int
|
||||
@@ -493,78 +506,116 @@ class EthereumWalletAdapter(EnhancedWalletAdapter):
|
||||
|
||||
async def _get_gas_price(self) -> int:
|
||||
"""Get current gas price"""
|
||||
# Mock implementation
|
||||
return 20000000000 # 20 Gwei in wei
|
||||
try:
|
||||
return self._web3_client.get_gas_price()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get gas price: {e}")
|
||||
raise
|
||||
|
||||
async def _get_gas_price_gwei(self) -> float:
|
||||
"""Get current gas price in Gwei"""
|
||||
gas_price_wei = await self._get_gas_price()
|
||||
return gas_price_wei / 10**9
|
||||
try:
|
||||
return self._web3_client.get_gas_price_gwei()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get gas price in Gwei: {e}")
|
||||
raise
|
||||
|
||||
async def _get_nonce(self, address: str) -> int:
|
||||
"""Get transaction nonce for address"""
|
||||
# Mock implementation
|
||||
return 0
|
||||
try:
|
||||
return self._web3_client.get_nonce(address)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get nonce: {e}")
|
||||
raise
|
||||
|
||||
async def _sign_transaction(self, transaction_data: dict[str, Any], from_address: str) -> str:
|
||||
"""Sign transaction"""
|
||||
# Mock implementation
|
||||
return f"0xsigned_{hashlib.sha256(str(transaction_data).encode()).hexdigest()}"
|
||||
try:
|
||||
# Get the transaction hash
|
||||
from eth_account import Account
|
||||
# Remove 0x prefix if present
|
||||
if from_address.startswith("0x"):
|
||||
from_address = from_address[2:]
|
||||
|
||||
account = Account.from_key(from_address)
|
||||
|
||||
# Build transaction dict for signing
|
||||
tx_dict = {
|
||||
'nonce': int(transaction_data.get('nonce', 0), 16),
|
||||
'gasPrice': int(transaction_data.get('gasPrice', 0), 16),
|
||||
'gas': int(transaction_data.get('gas', 0), 16),
|
||||
'to': transaction_data.get('to'),
|
||||
'value': int(transaction_data.get('value', '0x0'), 16),
|
||||
'data': transaction_data.get('data', '0x'),
|
||||
'chainId': transaction_data.get('chainId', 1)
|
||||
}
|
||||
|
||||
signed_tx = account.sign_transaction(tx_dict)
|
||||
return signed_tx.raw_transaction.hex()
|
||||
except ImportError:
|
||||
raise ImportError("eth-account is required for transaction signing. Install with: pip install eth-account")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sign transaction: {e}")
|
||||
raise
|
||||
|
||||
async def _send_raw_transaction(self, signed_transaction: str) -> str:
|
||||
"""Send raw transaction"""
|
||||
# Mock implementation
|
||||
return f"0x{hashlib.sha256(signed_transaction.encode()).hexdigest()}"
|
||||
try:
|
||||
return self._web3_client.send_raw_transaction(signed_transaction)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send raw transaction: {e}")
|
||||
raise
|
||||
|
||||
async def _get_transaction_receipt(self, tx_hash: str) -> dict[str, Any] | None:
|
||||
"""Get transaction receipt"""
|
||||
# Mock implementation
|
||||
return {
|
||||
"status": 1,
|
||||
"blockNumber": "0x12345",
|
||||
"blockHash": "0xabcdef",
|
||||
"gasUsed": "0x5208",
|
||||
"effectiveGasPrice": "0x4a817c800",
|
||||
"logs": [],
|
||||
}
|
||||
try:
|
||||
return self._web3_client.get_transaction_receipt(tx_hash)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get transaction receipt: {e}")
|
||||
raise
|
||||
|
||||
async def _get_transaction_by_hash(self, tx_hash: str) -> dict[str, Any]:
|
||||
"""Get transaction by hash"""
|
||||
# Mock implementation
|
||||
return {"from": "0xsender", "to": "0xreceiver", "value": "0xde0b6b3a7640000", "data": "0x"} # 1 ETH in wei
|
||||
try:
|
||||
return self._web3_client.get_transaction_by_hash(tx_hash)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get transaction by hash: {e}")
|
||||
raise
|
||||
|
||||
async def _estimate_gas_call(self, call_data: dict[str, Any]) -> str:
|
||||
"""Estimate gas for call"""
|
||||
# Mock implementation
|
||||
return "0x5208" # 21000 in hex
|
||||
try:
|
||||
gas_estimate = self._web3_client.estimate_gas(call_data)
|
||||
return hex(gas_estimate)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to estimate gas: {e}")
|
||||
raise
|
||||
|
||||
async def _get_wallet_transactions(
|
||||
self, address: str, limit: int, offset: int, from_block: int | None, to_block: int | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get wallet transactions"""
|
||||
# Mock implementation
|
||||
return [
|
||||
{
|
||||
"hash": f"0x{hashlib.sha256(f'tx_{i}'.encode()).hexdigest()}",
|
||||
"from": address,
|
||||
"to": f"0x{hashlib.sha256(f'to_{i}'.encode()).hexdigest()[:40]}",
|
||||
"value": "0xde0b6b3a7640000",
|
||||
"blockNumber": f"0x{12345 + i}",
|
||||
"timestamp": datetime.utcnow().timestamp(),
|
||||
"gasUsed": "0x5208",
|
||||
}
|
||||
for i in range(min(limit, 10))
|
||||
]
|
||||
try:
|
||||
return self._web3_client.get_wallet_transactions(address, limit)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get wallet transactions: {e}")
|
||||
raise
|
||||
|
||||
async def _sign_hash(self, message_hash: str, private_key: str) -> str:
|
||||
"""Sign a hash with private key"""
|
||||
# Mock implementation
|
||||
return f"0x{hashlib.sha256(f'{message_hash}{private_key}'.encode()).hexdigest()}"
|
||||
try:
|
||||
return sign_transaction_hash(message_hash, private_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sign hash: {e}")
|
||||
raise
|
||||
|
||||
async def _verify_signature(self, message_hash: str, signature: str, address: str) -> bool:
|
||||
"""Verify a signature"""
|
||||
# Mock implementation
|
||||
return True
|
||||
try:
|
||||
return verify_signature(message_hash, signature, address)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify signature: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class PolygonWalletAdapter(EthereumWalletAdapter):
|
||||
|
||||
@@ -7,7 +7,6 @@ Multi-chain trading with cross-chain swaps and bridging
|
||||
import sqlite3
|
||||
import json
|
||||
import asyncio
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from fastapi import FastAPI, HTTPException, Query, BackgroundTasks
|
||||
@@ -17,8 +16,15 @@ import os
|
||||
import uuid
|
||||
import hashlib
|
||||
|
||||
from aitbc.http_client import AsyncAITBCHTTPClient
|
||||
from aitbc.aitbc_logging import get_logger
|
||||
from aitbc.exceptions import NetworkError
|
||||
|
||||
app = FastAPI(title="AITBC Complete Cross-Chain Exchange", version="3.0.0")
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Database configuration
|
||||
DB_PATH = os.path.join(os.path.dirname(__file__), "exchange_multichain.db")
|
||||
|
||||
@@ -368,10 +374,10 @@ async def health_check():
|
||||
|
||||
if chain_info["status"] == "active" and chain_info["blockchain_url"]:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{chain_info['blockchain_url']}/health", timeout=5.0)
|
||||
chain_status[chain_id]["connected"] = response.status_code == 200
|
||||
except:
|
||||
client = AsyncAITBCHTTPClient(base_url=chain_info['blockchain_url'], timeout=5)
|
||||
response = await client.async_get("/health")
|
||||
chain_status[chain_id]["connected"] = response is not None
|
||||
except NetworkError:
|
||||
pass
|
||||
|
||||
return {
|
||||
|
||||
@@ -7,7 +7,6 @@ Complete multi-chain trading with chain isolation
|
||||
import sqlite3
|
||||
import json
|
||||
import asyncio
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from fastapi import FastAPI, HTTPException, Query, BackgroundTasks
|
||||
@@ -15,8 +14,15 @@ from pydantic import BaseModel, Field
|
||||
import uvicorn
|
||||
import os
|
||||
|
||||
from aitbc.http_client import AsyncAITBCHTTPClient
|
||||
from aitbc.aitbc_logging import get_logger
|
||||
from aitbc.exceptions import NetworkError
|
||||
|
||||
app = FastAPI(title="AITBC Multi-Chain Exchange", version="2.0.0")
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Database configuration
|
||||
DB_PATH = os.path.join(os.path.dirname(__file__), "exchange_multichain.db")
|
||||
|
||||
@@ -145,10 +151,10 @@ async def verify_chain_transaction(chain_id: str, tx_hash: str) -> bool:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{chain_info['blockchain_url']}/api/v1/transactions/{tx_hash}")
|
||||
return response.status_code == 200
|
||||
except:
|
||||
client = AsyncAITBCHTTPClient(base_url=chain_info['blockchain_url'], timeout=5)
|
||||
response = await client.async_get(f"/api/v1/transactions/{tx_hash}")
|
||||
return response is not None
|
||||
except NetworkError:
|
||||
return False
|
||||
|
||||
async def submit_chain_transaction(chain_id: str, order_data: Dict) -> Optional[str]:
|
||||
@@ -161,16 +167,13 @@ async def submit_chain_transaction(chain_id: str, order_data: Dict) -> Optional[
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{chain_info['blockchain_url']}/api/v1/transactions",
|
||||
json=order_data
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json().get("tx_hash")
|
||||
except Exception as e:
|
||||
print(f"Chain transaction error: {e}")
|
||||
|
||||
client = AsyncAITBCHTTPClient(base_url=chain_info['blockchain_url'], timeout=10)
|
||||
response = await client.async_post("/api/v1/transactions", json=order_data)
|
||||
if response:
|
||||
return response.get("tx_hash")
|
||||
except NetworkError as e:
|
||||
logger.error(f"Chain transaction error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
# API Endpoints
|
||||
@@ -188,10 +191,10 @@ async def health_check():
|
||||
|
||||
if chain_info["status"] == "active" and chain_info["blockchain_url"]:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{chain_info['blockchain_url']}/health", timeout=5.0)
|
||||
chain_status[chain_id]["connected"] = response.status_code == 200
|
||||
except:
|
||||
client = AsyncAITBCHTTPClient(base_url=chain_info['blockchain_url'], timeout=5)
|
||||
response = await client.async_get("/health")
|
||||
chain_status[chain_id]["connected"] = response is not None
|
||||
except NetworkError:
|
||||
pass
|
||||
|
||||
return {
|
||||
|
||||
@@ -4,7 +4,6 @@ Simple AITBC Blockchain Explorer - Demonstrating the issues described in the ana
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
@@ -12,8 +11,15 @@ from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import HTMLResponse
|
||||
import uvicorn
|
||||
|
||||
from aitbc.http_client import AsyncAITBCHTTPClient
|
||||
from aitbc.aitbc_logging import get_logger
|
||||
from aitbc.exceptions import NetworkError
|
||||
|
||||
app = FastAPI(title="Simple AITBC Explorer", version="0.1.0")
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Configuration
|
||||
BLOCKCHAIN_RPC_URL = "http://localhost:8025"
|
||||
|
||||
@@ -174,12 +180,12 @@ HTML_TEMPLATE = """
|
||||
async def get_chain_head():
|
||||
"""Get current chain head"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{BLOCKCHAIN_RPC_URL}/rpc/head")
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
print(f"Error getting chain head: {e}")
|
||||
client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10)
|
||||
response = await client.async_get("/rpc/head")
|
||||
if response:
|
||||
return response
|
||||
except NetworkError as e:
|
||||
logger.error(f"Error getting chain head: {e}")
|
||||
return {"height": 0, "hash": "", "timestamp": None}
|
||||
|
||||
@app.get("/api/blocks/{height}")
|
||||
@@ -189,12 +195,12 @@ async def get_block(height: int):
|
||||
if height < 0 or height > 10000000:
|
||||
return {"height": height, "hash": "", "timestamp": None, "transactions": []}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{BLOCKCHAIN_RPC_URL}/rpc/blocks/{height}")
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
print(f"Error getting block {height}: {e}")
|
||||
client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10)
|
||||
response = await client.async_get(f"/rpc/blocks/{height}")
|
||||
if response:
|
||||
return response
|
||||
except NetworkError as e:
|
||||
logger.error(f"Error getting block: {e}")
|
||||
return {"height": height, "hash": "", "timestamp": None, "transactions": []}
|
||||
|
||||
@app.get("/api/transactions/{tx_hash}")
|
||||
@@ -203,26 +209,21 @@ async def get_transaction(tx_hash: str):
|
||||
if not validate_tx_hash(tx_hash):
|
||||
return {"hash": tx_hash, "from": "unknown", "to": "unknown", "amount": 0, "timestamp": None}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{BLOCKCHAIN_RPC_URL}/rpc/tx/{tx_hash}")
|
||||
if response.status_code == 200:
|
||||
tx_data = response.json()
|
||||
# Problem 2: Map RPC schema to UI schema
|
||||
return {
|
||||
"hash": tx_data.get("tx_hash", tx_hash), # tx_hash -> hash
|
||||
"from": tx_data.get("sender", "unknown"), # sender -> from
|
||||
"to": tx_data.get("recipient", "unknown"), # recipient -> to
|
||||
"amount": tx_data.get("payload", {}).get("value", "0"), # payload.value -> amount
|
||||
"fee": tx_data.get("payload", {}).get("fee", "0"), # payload.fee -> fee
|
||||
"timestamp": tx_data.get("created_at"), # created_at -> timestamp
|
||||
"block_height": tx_data.get("block_height", "pending")
|
||||
}
|
||||
elif response.status_code == 404:
|
||||
raise HTTPException(status_code=404, detail="Transaction not found")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"Error getting transaction {tx_hash}: {e}")
|
||||
client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10)
|
||||
response = await client.async_get(f"/rpc/tx/{tx_hash}")
|
||||
if response:
|
||||
# Problem 2: Map RPC schema to UI schema
|
||||
return {
|
||||
"hash": response.get("tx_hash", tx_hash), # tx_hash -> hash
|
||||
"from": response.get("sender", "unknown"), # sender -> from
|
||||
"to": response.get("recipient", "unknown"), # recipient -> to
|
||||
"amount": response.get("payload", {}).get("value", "0"), # payload.value -> amount
|
||||
"fee": response.get("payload", {}).get("fee", "0"), # payload.fee -> fee
|
||||
"timestamp": response.get("created_at"), # created_at -> timestamp
|
||||
"block_height": response.get("block_height", "pending")
|
||||
}
|
||||
except NetworkError as e:
|
||||
logger.error(f"Error getting transaction {tx_hash}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch transaction: {str(e)}")
|
||||
|
||||
# Missing: @app.get("/api/transactions/{tx_hash}") - THIS IS THE PROBLEM
|
||||
|
||||
@@ -16,6 +16,8 @@ from pathlib import Path
|
||||
import os
|
||||
import sys
|
||||
|
||||
from aitbc.constants import KEYSTORE_DIR
|
||||
|
||||
# Add CLI utils to path
|
||||
sys.path.insert(0, '/opt/aitbc/cli')
|
||||
|
||||
@@ -23,7 +25,7 @@ sys.path.insert(0, '/opt/aitbc/cli')
|
||||
app = FastAPI(title="AITBC Wallet Daemon", debug=False)
|
||||
|
||||
# Configuration
|
||||
KEYSTORE_PATH = Path("/var/lib/aitbc/keystore")
|
||||
KEYSTORE_PATH = KEYSTORE_DIR
|
||||
BLOCKCHAIN_RPC_URL = "http://localhost:8006"
|
||||
CHAIN_ID = "ait-mainnet"
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ from aitbc import (
|
||||
AITBCHTTPClient, NetworkError, ValidationError, ConfigurationError,
|
||||
get_logger, get_keystore_path, ensure_dir, validate_address, validate_url
|
||||
)
|
||||
from aitbc.paths import get_blockchain_data_path, get_data_path
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger(__name__)
|
||||
@@ -1188,8 +1189,9 @@ def agent_operations(action: str, **kwargs) -> Optional[Dict]:
|
||||
sys.path.insert(0, "/opt/aitbc/apps/blockchain-node/src")
|
||||
from sqlmodel import create_engine, Session, select
|
||||
from aitbc_chain.models import Transaction
|
||||
|
||||
engine = create_engine("sqlite:////var/lib/aitbc/data/ait-mainnet/chain.db")
|
||||
|
||||
chain_db_path = get_blockchain_data_path("ait-mainnet") / "chain.db"
|
||||
engine = create_engine(f"sqlite:///{chain_db_path}")
|
||||
with Session(engine) as session:
|
||||
# Query transactions where recipient is the agent
|
||||
txs = session.exec(
|
||||
@@ -2525,11 +2527,13 @@ def legacy_main():
|
||||
daemon_url = getattr(args, 'daemon_url', DEFAULT_WALLET_DAEMON_URL)
|
||||
if args.wallet_action == "backup":
|
||||
print(f"Wallet backup: {args.name}")
|
||||
print(f" Backup created: /var/lib/aitbc/backups/{args.name}_$(date +%Y%m%d).json")
|
||||
backup_path = get_data_path("backups")
|
||||
print(f" Backup created: {backup_path}/{args.name}_$(date +%Y%m%d).json")
|
||||
print(f" Status: completed")
|
||||
elif args.wallet_action == "export":
|
||||
print(f"Wallet export: {args.name}")
|
||||
print(f" Export file: /var/lib/aitbc/exports/{args.name}_private.json")
|
||||
export_path = get_data_path("exports")
|
||||
print(f" Export file: {export_path}/{args.name}_private.json")
|
||||
print(f" Status: completed")
|
||||
elif args.wallet_action == "sync":
|
||||
if args.all:
|
||||
@@ -2602,12 +2606,14 @@ def legacy_main():
|
||||
|
||||
elif args.command == "wallet-backup":
|
||||
print(f"Wallet backup: {args.name}")
|
||||
print(f" Backup created: /var/lib/aitbc/backups/{args.name}_$(date +%Y%m%d).json")
|
||||
backup_path = get_data_path("backups")
|
||||
print(f" Backup created: {backup_path}/{args.name}_$(date +%Y%m%d).json")
|
||||
print(f" Status: completed")
|
||||
|
||||
|
||||
elif args.command == "wallet-export":
|
||||
print(f"Wallet export: {args.name}")
|
||||
print(f" Export file: /var/lib/aitbc/exports/{args.name}_private.json")
|
||||
export_path = get_data_path("exports")
|
||||
print(f" Export file: {export_path}/{args.name}_private.json")
|
||||
print(f" Status: completed")
|
||||
|
||||
elif args.command == "wallet-sync":
|
||||
|
||||
@@ -12,9 +12,10 @@ from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
import requests
|
||||
import getpass
|
||||
from aitbc.paths import get_keystore_path
|
||||
|
||||
# Default paths
|
||||
DEFAULT_KEYSTORE_DIR = Path("/var/lib/aitbc/keystore")
|
||||
DEFAULT_KEYSTORE_DIR = get_keystore_path()
|
||||
DEFAULT_RPC_URL = "http://localhost:8006"
|
||||
|
||||
def get_password(password_arg: str = None, password_file: str = None) -> str:
|
||||
|
||||
@@ -2,8 +2,9 @@ import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from aitbc.paths import get_data_path
|
||||
|
||||
STATE_FILE = "/var/lib/aitbc/data/cli_extended_state.json"
|
||||
STATE_FILE = str(get_data_path("data/cli_extended_state.json"))
|
||||
|
||||
def load_state():
|
||||
if os.path.exists(STATE_FILE):
|
||||
@@ -82,10 +83,10 @@ def handle_extended_command(command, args, kwargs):
|
||||
result["nodes_reached"] = 2
|
||||
|
||||
elif command == "wallet_backup":
|
||||
result["path"] = f"/var/lib/aitbc/backups/{kwargs.get('name')}.backup"
|
||||
|
||||
result["path"] = f"{get_data_path('backups')}/{kwargs.get('name')}.backup"
|
||||
|
||||
elif command == "wallet_export":
|
||||
result["path"] = f"/var/lib/aitbc/exports/{kwargs.get('name')}.key"
|
||||
result["path"] = f"{get_data_path('exports')}/{kwargs.get('name')}.key"
|
||||
|
||||
elif command == "wallet_sync":
|
||||
result["status"] = "Wallets synchronized"
|
||||
@@ -275,7 +276,7 @@ def handle_extended_command(command, args, kwargs):
|
||||
|
||||
elif command == "compliance_report":
|
||||
result["format"] = kwargs.get("format")
|
||||
result["path"] = "/var/lib/aitbc/reports/compliance.pdf"
|
||||
result["path"] = f"{get_data_path('reports')}/compliance.pdf"
|
||||
|
||||
elif command == "script_run":
|
||||
result["file"] = kwargs.get("file")
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import json
|
||||
import sys
|
||||
from aitbc.paths import get_data_path
|
||||
|
||||
|
||||
def handle_wallet_create(args, create_wallet, read_password, first):
|
||||
@@ -140,7 +141,8 @@ def handle_wallet_backup(args, first):
|
||||
print("Error: Wallet name is required")
|
||||
sys.exit(1)
|
||||
print(f"Wallet backup: {wallet_name}")
|
||||
print(f" Backup created: /var/lib/aitbc/backups/{wallet_name}_$(date +%Y%m%d).json")
|
||||
backup_path = get_data_path("backups")
|
||||
print(f" Backup created: {backup_path}/{wallet_name}_$(date +%Y%m%d).json")
|
||||
print(" Status: completed")
|
||||
|
||||
|
||||
|
||||
@@ -42,8 +42,10 @@ def decrypt_private_key(keystore_data: Dict[str, Any], password: str) -> str:
|
||||
return decrypted.decode()
|
||||
|
||||
|
||||
def load_keystore(address: str, keystore_dir: Path | str = "/var/lib/aitbc/keystore") -> Dict[str, Any]:
|
||||
def load_keystore(address: str, keystore_dir: Path | str = None) -> Dict[str, Any]:
|
||||
"""Load keystore file for a given address."""
|
||||
if keystore_dir is None:
|
||||
keystore_dir = get_keystore_path()
|
||||
keystore_dir = Path(keystore_dir)
|
||||
keystore_file = keystore_dir / f"{address}.json"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user