feat: add marketplace metrics, privacy features, and service registry endpoints

- Add Prometheus metrics for marketplace API throughput and error rates with new dashboard panels
- Implement confidential transaction models with encryption support and access control
- Add key management system with registration, rotation, and audit logging
- Create services and registry routers for service discovery and management
- Integrate ZK proof generation for privacy-preserving receipts
- Add metrics instru
This commit is contained in:
oib
2025-12-22 10:33:23 +01:00
parent fa5a6fddf3
commit a4d4be4a1e
260 changed files with 59033 additions and 351 deletions

View File

@@ -0,0 +1,19 @@
"""
API modules for AITBC Python SDK
"""
from .jobs import JobsAPI, MultiNetworkJobsAPI
from .marketplace import MarketplaceAPI
from .wallet import WalletAPI
from .receipts import ReceiptsAPI
from .settlement import SettlementAPI, MultiNetworkSettlementAPI
__all__ = [
"JobsAPI",
"MultiNetworkJobsAPI",
"MarketplaceAPI",
"WalletAPI",
"ReceiptsAPI",
"SettlementAPI",
"MultiNetworkSettlementAPI",
]

View File

@@ -0,0 +1,94 @@
"""
Jobs API for AITBC Python SDK
"""
from typing import Dict, Any, Optional, List
import logging
from ..transport import Transport
from ..transport.multinetwork import MultiNetworkClient
logger = logging.getLogger(__name__)
class JobsAPI:
"""Jobs API client"""
def __init__(self, transport: Transport):
self.transport = transport
async def create(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Create a new job"""
return await self.transport.request('POST', '/v1/jobs', data=data)
async def get(self, job_id: str) -> Dict[str, Any]:
"""Get job details"""
return await self.transport.request('GET', f'/v1/jobs/{job_id}')
async def list(self, **params) -> List[Dict[str, Any]]:
"""List jobs"""
response = await self.transport.request('GET', '/v1/jobs', params=params)
return response.get('jobs', [])
async def update(self, job_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""Update job"""
return await self.transport.request('PUT', f'/v1/jobs/{job_id}', data=data)
async def delete(self, job_id: str) -> None:
"""Delete job"""
await self.transport.request('DELETE', f'/v1/jobs/{job_id}')
async def wait_for_completion(
self,
job_id: str,
timeout: Optional[int] = None,
poll_interval: int = 5
) -> Dict[str, Any]:
"""Wait for job completion"""
# Implementation would poll job status until complete
pass
class MultiNetworkJobsAPI(JobsAPI):
"""Multi-network Jobs API client"""
def __init__(self, client: MultiNetworkClient):
self.client = client
async def create(
self,
data: Dict[str, Any],
chain_id: Optional[int] = None
) -> Dict[str, Any]:
"""Create a new job on specific network"""
transport = self.client.get_transport(chain_id)
return await transport.request('POST', '/v1/jobs', data=data)
async def get(
self,
job_id: str,
chain_id: Optional[int] = None
) -> Dict[str, Any]:
"""Get job details from specific network"""
transport = self.client.get_transport(chain_id)
return await transport.request('GET', f'/v1/jobs/{job_id}')
async def list(
self,
chain_id: Optional[int] = None,
**params
) -> List[Dict[str, Any]]:
"""List jobs from specific network"""
transport = self.client.get_transport(chain_id)
response = await transport.request('GET', '/v1/jobs', params=params)
return response.get('jobs', [])
async def broadcast_create(
self,
data: Dict[str, Any],
chain_ids: Optional[List[int]] = None
) -> Dict[int, Dict[str, Any]]:
"""Create job on multiple networks"""
return await self.client.broadcast_request(
'POST', '/v1/jobs', data=data, chain_ids=chain_ids
)

View File

@@ -0,0 +1,46 @@
"""
Marketplace API for AITBC Python SDK
"""
from typing import Dict, Any, Optional, List
import logging
from ..transport import Transport
logger = logging.getLogger(__name__)
class MarketplaceAPI:
"""Marketplace API client"""
def __init__(self, transport: Transport):
self.transport = transport
async def list_offers(self, **params) -> List[Dict[str, Any]]:
"""List marketplace offers"""
response = await self.transport.request('GET', '/v1/marketplace/offers', params=params)
return response.get('offers', [])
async def create_offer(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Create a new offer"""
return await self.transport.request('POST', '/v1/marketplace/offers', data=data)
async def get_offer(self, offer_id: str) -> Dict[str, Any]:
"""Get offer details"""
return await self.transport.request('GET', f'/v1/marketplace/offers/{offer_id}')
async def update_offer(self, offer_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""Update offer"""
return await self.transport.request('PUT', f'/v1/marketplace/offers/{offer_id}', data=data)
async def delete_offer(self, offer_id: str) -> None:
"""Delete offer"""
await self.transport.request('DELETE', f'/v1/marketplace/offers/{offer_id}')
async def accept_offer(self, offer_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""Accept an offer"""
return await self.transport.request('POST', f'/v1/marketplace/offers/{offer_id}/accept', data=data)
async def get_stats(self) -> Dict[str, Any]:
"""Get marketplace statistics"""
return await self.transport.request('GET', '/v1/marketplace/stats')

View File

@@ -0,0 +1,34 @@
"""
Receipts API for AITBC Python SDK
"""
from typing import Dict, Any, Optional, List
import logging
from ..transport import Transport
logger = logging.getLogger(__name__)
class ReceiptsAPI:
"""Receipts API client"""
def __init__(self, transport: Transport):
self.transport = transport
async def get(self, job_id: str) -> Dict[str, Any]:
"""Get job receipt"""
return await self.transport.request('GET', f'/v1/receipts/{job_id}')
async def verify(self, receipt: Dict[str, Any]) -> Dict[str, Any]:
"""Verify receipt"""
return await self.transport.request('POST', '/v1/receipts/verify', data=receipt)
async def list(self, **params) -> List[Dict[str, Any]]:
"""List receipts"""
response = await self.transport.request('GET', '/v1/receipts', params=params)
return response.get('receipts', [])
async def stream(self, **params):
"""Stream new receipts"""
return self.transport.stream('GET', '/v1/receipts/stream', params=params)

View File

@@ -0,0 +1,100 @@
"""
Settlement API for AITBC Python SDK
"""
from typing import Dict, Any, Optional, List
import logging
from ..transport import Transport
from ..transport.multinetwork import MultiNetworkClient
logger = logging.getLogger(__name__)
class SettlementAPI:
"""Settlement API client"""
def __init__(self, transport: Transport):
self.transport = transport
async def settle_cross_chain(
self,
job_id: str,
target_chain_id: int,
bridge_name: Optional[str] = None
) -> Dict[str, Any]:
"""Initiate cross-chain settlement"""
data = {
'job_id': job_id,
'target_chain_id': target_chain_id,
'bridge_name': bridge_name
}
return await self.transport.request('POST', '/v1/settlement/cross-chain', data=data)
async def get_settlement_status(self, message_id: str) -> Dict[str, Any]:
"""Get settlement status"""
return await self.transport.request('GET', f'/v1/settlement/{message_id}/status')
async def estimate_cost(
self,
job_id: str,
target_chain_id: int,
bridge_name: Optional[str] = None
) -> Dict[str, Any]:
"""Estimate settlement cost"""
data = {
'job_id': job_id,
'target_chain_id': target_chain_id,
'bridge_name': bridge_name
}
return await self.transport.request('POST', '/v1/settlement/estimate-cost', data=data)
async def list_bridges(self) -> Dict[str, Any]:
"""List supported bridges"""
return await self.transport.request('GET', '/v1/settlement/bridges')
async def list_chains(self) -> Dict[str, Any]:
"""List supported chains"""
return await self.transport.request('GET', '/v1/settlement/chains')
async def refund_settlement(self, message_id: str) -> Dict[str, Any]:
"""Refund failed settlement"""
return await self.transport.request('POST', f'/v1/settlement/{message_id}/refund')
class MultiNetworkSettlementAPI(SettlementAPI):
"""Multi-network Settlement API client"""
def __init__(self, client: MultiNetworkClient):
self.client = client
async def settle_cross_chain(
self,
job_id: str,
target_chain_id: int,
source_chain_id: Optional[int] = None,
bridge_name: Optional[str] = None
) -> Dict[str, Any]:
"""Initiate cross-chain settlement from specific network"""
transport = self.client.get_transport(source_chain_id)
data = {
'job_id': job_id,
'target_chain_id': target_chain_id,
'bridge_name': bridge_name
}
return await transport.request('POST', '/v1/settlement/cross-chain', data=data)
async def batch_settle(
self,
job_ids: List[str],
target_chain_id: int,
bridge_name: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Batch settle multiple jobs"""
data = {
'job_ids': job_ids,
'target_chain_id': target_chain_id,
'bridge_name': bridge_name
}
transport = self.client.get_transport()
return await transport.request('POST', '/v1/settlement/batch', data=data)

View File

@@ -0,0 +1,50 @@
"""
Wallet API for AITBC Python SDK
"""
from typing import Dict, Any, Optional, List
import logging
from ..transport import Transport
logger = logging.getLogger(__name__)
class WalletAPI:
"""Wallet API client"""
def __init__(self, transport: Transport):
self.transport = transport
async def create(self) -> Dict[str, Any]:
"""Create a new wallet"""
return await self.transport.request('POST', '/v1/wallet')
async def get_balance(self, token: Optional[str] = None) -> Dict[str, Any]:
"""Get wallet balance"""
params = {}
if token:
params['token'] = token
return await self.transport.request('GET', '/v1/wallet/balance', params=params)
async def send(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Send tokens"""
return await self.transport.request('POST', '/v1/wallet/send', data=data)
async def get_address(self) -> str:
"""Get wallet address"""
response = await self.transport.request('GET', '/v1/wallet/address')
return response.get('address')
async def get_transactions(self, **params) -> List[Dict[str, Any]]:
"""Get transaction history"""
response = await self.transport.request('GET', '/v1/wallet/transactions', params=params)
return response.get('transactions', [])
async def stake(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Stake tokens"""
return await self.transport.request('POST', '/v1/wallet/stake', data=data)
async def unstake(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Unstake tokens"""
return await self.transport.request('POST', '/v1/wallet/unstake', data=data)

364
python-sdk/aitbc/client.py Normal file
View File

@@ -0,0 +1,364 @@
"""
Main AITBC client with pluggable transport abstraction
"""
import asyncio
import logging
from typing import Dict, Any, Optional, Union, List
from datetime import datetime
from .transport import (
Transport,
HTTPTransport,
WebSocketTransport,
MultiNetworkClient,
NetworkConfig,
TransportError
)
from .transport.base import BatchTransport, CachedTransport, RateLimitedTransport
from .apis.jobs import JobsAPI, MultiNetworkJobsAPI
from .apis.marketplace import MarketplaceAPI
from .apis.wallet import WalletAPI
from .apis.receipts import ReceiptsAPI
from .apis.settlement import SettlementAPI, MultiNetworkSettlementAPI
logger = logging.getLogger(__name__)
class AITBCClient:
"""AITBC client with pluggable transports and multi-network support"""
def __init__(
self,
transport: Optional[Union[Transport, Dict[str, Any]]] = None,
multi_network: bool = False,
config: Optional[Dict[str, Any]] = None
):
"""
Initialize AITBC client
Args:
transport: Transport instance or configuration
multi_network: Enable multi-network mode
config: Additional configuration options
"""
self.config = config or {}
self._connected = False
self._apis = {}
# Initialize transport layer
if multi_network:
self._init_multi_network(transport or {})
else:
self._init_single_network(transport or self._get_default_config())
# Initialize API clients
self._init_apis()
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration for backward compatibility"""
return {
'type': 'http',
'base_url': self.config.get('base_url', 'https://api.aitbc.io'),
'timeout': self.config.get('timeout', 30),
'api_key': self.config.get('api_key'),
'default_headers': {
'User-Agent': f'AITBC-Python-SDK/{self._get_version()}',
'Content-Type': 'application/json'
}
}
def _init_single_network(self, transport_config: Union[Transport, Dict[str, Any]]) -> None:
"""Initialize single network client"""
if isinstance(transport_config, Transport):
self.transport = transport_config
else:
# Create transport from config
self.transport = self._create_transport(transport_config)
self.multi_network = False
self.multi_network_client = None
def _init_multi_network(self, configs: Dict[str, Any]) -> None:
"""Initialize multi-network client"""
self.multi_network_client = MultiNetworkClient(configs)
self.multi_network = True
self.transport = None # Use multi_network_client instead
def _create_transport(self, config: Dict[str, Any]) -> Transport:
"""Create transport from configuration"""
transport_type = config.get('type', 'http')
# Add API key to headers if provided
if 'api_key' in config and 'default_headers' not in config:
config['default_headers'] = {
'X-API-Key': config['api_key'],
'User-Agent': f'AITBC-Python-SDK/{self._get_version()}',
'Content-Type': 'application/json'
}
# Create base transport
if transport_type == 'http':
transport = HTTPTransport(config)
elif transport_type == 'websocket':
transport = WebSocketTransport(config)
elif transport_type == 'crosschain':
# Will be implemented later
raise NotImplementedError("CrossChain transport not yet implemented")
else:
raise ValueError(f"Unknown transport type: {transport_type}")
# Apply mixins if enabled
if config.get('cached', False):
transport = CachedTransport(config)
if config.get('rate_limited', False):
transport = RateLimitedTransport(config)
if config.get('batch', False):
transport = BatchTransport(config)
return transport
def _init_apis(self) -> None:
"""Initialize API clients"""
if self.multi_network:
# Multi-network APIs
self.jobs = MultiNetworkJobsAPI(self.multi_network_client)
self.settlement = MultiNetworkSettlementAPI(self.multi_network_client)
# Single-network APIs (use default network)
default_transport = self.multi_network_client.get_transport()
self.marketplace = MarketplaceAPI(default_transport)
self.wallet = WalletAPI(default_transport)
self.receipts = ReceiptsAPI(default_transport)
else:
# Single-network APIs
self.jobs = JobsAPI(self.transport)
self.marketplace = MarketplaceAPI(self.transport)
self.wallet = WalletAPI(self.transport)
self.receipts = ReceiptsAPI(self.transport)
self.settlement = SettlementAPI(self.transport)
async def connect(self) -> None:
"""Connect to network(s)"""
if self.multi_network:
await self.multi_network_client.connect_all()
else:
await self.transport.connect()
self._connected = True
logger.info("AITBC client connected")
async def disconnect(self) -> None:
"""Disconnect from network(s)"""
if self.multi_network:
await self.multi_network_client.disconnect_all()
elif self.transport:
await self.transport.disconnect()
self._connected = False
logger.info("AITBC client disconnected")
@property
def is_connected(self) -> bool:
"""Check if client is connected"""
if self.multi_network:
return self.multi_network_client._connected
elif self.transport:
return self.transport.is_connected
return False
# Multi-network methods
def add_network(self, network_config: NetworkConfig) -> None:
"""Add a network (multi-network mode only)"""
if not self.multi_network:
raise RuntimeError("Multi-network mode not enabled")
self.multi_network_client.add_network(network_config)
def remove_network(self, chain_id: int) -> None:
"""Remove a network (multi-network mode only)"""
if not self.multi_network:
raise RuntimeError("Multi-network mode not enabled")
self.multi_network_client.remove_network(chain_id)
def get_networks(self) -> List[NetworkConfig]:
"""Get all configured networks"""
if not self.multi_network:
raise RuntimeError("Multi-network mode not enabled")
return self.multi_network_client.list_networks()
def set_default_network(self, chain_id: int) -> None:
"""Set default network (multi-network mode only)"""
if not self.multi_network:
raise RuntimeError("Multi-network mode not enabled")
self.multi_network_client.set_default_network(chain_id)
async def switch_network(self, chain_id: int) -> None:
"""Switch to a different network (multi-network mode only)"""
if not self.multi_network:
raise RuntimeError("Multi-network mode not enabled")
await self.multi_network_client.switch_network(chain_id)
async def health_check(self) -> Union[bool, Dict[int, bool]]:
"""Check health of connection(s)"""
if self.multi_network:
return await self.multi_network_client.health_check_all()
elif self.transport:
return await self.transport.health_check()
return False
# Backward compatibility methods
def get_api_key(self) -> Optional[str]:
"""Get API key (backward compatibility)"""
if self.multi_network:
# Get from default network
default_network = self.multi_network_client.get_default_network()
if default_network:
return default_network.transport.get_config('api_key')
elif self.transport:
return self.transport.get_config('api_key')
return None
def set_api_key(self, api_key: str) -> None:
"""Set API key (backward compatibility)"""
if self.multi_network:
# Update all networks
for network in self.multi_network_client.networks.values():
network.transport.update_config({'api_key': api_key})
elif self.transport:
self.transport.update_config({'api_key': api_key})
def get_base_url(self) -> Optional[str]:
"""Get base URL (backward compatibility)"""
if self.multi_network:
default_network = self.multi_network_client.get_default_network()
if default_network:
return default_network.transport.get_config('base_url')
elif self.transport:
return self.transport.get_config('base_url')
return None
# Utility methods
def _get_version(self) -> str:
"""Get SDK version"""
try:
from . import __version__
return __version__
except ImportError:
return "1.0.0"
def get_stats(self) -> Dict[str, Any]:
"""Get client statistics"""
stats = {
'multi_network': self.multi_network,
'connected': self._connected,
'version': self._get_version()
}
if self.multi_network:
stats['networks'] = self.multi_network_client.get_network_stats()
elif self.transport:
if hasattr(self.transport, 'get_stats'):
stats['transport'] = self.transport.get_stats()
else:
stats['transport'] = {
'connected': self.transport.is_connected,
'chain_id': self.transport.chain_id
}
return stats
# Context managers
async def __aenter__(self):
"""Async context manager entry"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
await self.disconnect()
# Convenience functions for backward compatibility
def create_client(
api_key: Optional[str] = None,
base_url: Optional[str] = None,
timeout: Optional[int] = None,
transport: Optional[Union[Transport, str]] = None,
**kwargs
) -> AITBCClient:
"""
Create AITBC client with backward-compatible interface
Args:
api_key: API key for authentication
base_url: Base URL for the API
timeout: Request timeout in seconds
transport: Transport type ('http', 'websocket') or Transport instance
**kwargs: Additional configuration options
Returns:
AITBCClient instance
"""
config = {}
# Build configuration
if api_key:
config['api_key'] = api_key
if base_url:
config['base_url'] = base_url
if timeout:
config['timeout'] = timeout
# Add other config
config.update(kwargs)
# Handle transport parameter
if isinstance(transport, Transport):
return AITBCClient(transport=transport, config=config)
elif transport:
config['type'] = transport
return AITBCClient(transport=config, config=config)
def create_multi_network_client(
networks: Dict[str, Dict[str, Any]],
default_network: Optional[str] = None,
**kwargs
) -> AITBCClient:
"""
Create multi-network AITBC client
Args:
networks: Dictionary of network configurations
default_network: Name of default network
**kwargs: Additional configuration options
Returns:
AITBCClient instance with multi-network support
"""
config = {
'networks': networks,
**kwargs
}
client = AITBCClient(multi_network=True, config=config)
# Set default network if specified
if default_network:
network = client.multi_network_client.find_network_by_name(default_network)
if network:
client.set_default_network(network.chain_id)
return client
# Legacy aliases for backward compatibility
Client = AITBCClient

View File

@@ -0,0 +1,17 @@
"""
Transport layer for AITBC Python SDK
"""
from .base import Transport, TransportError
from .http import HTTPTransport
from .websocket import WebSocketTransport
from .multinetwork import MultiNetworkClient, NetworkConfig
__all__ = [
"Transport",
"TransportError",
"HTTPTransport",
"WebSocketTransport",
"MultiNetworkClient",
"NetworkConfig",
]

View File

@@ -0,0 +1,264 @@
"""
Base transport interface for AITBC Python SDK
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, AsyncIterator, Union, List
import asyncio
import logging
from datetime import timedelta
logger = logging.getLogger(__name__)
class TransportError(Exception):
"""Base exception for transport errors"""
pass
class TransportConnectionError(TransportError):
"""Raised when transport fails to connect"""
pass
class TransportRequestError(TransportError):
"""Raised when transport request fails"""
def __init__(self, message: str, status_code: Optional[int] = None, response: Optional[Dict[str, Any]] = None):
super().__init__(message)
self.status_code = status_code
self.response = response
class Transport(ABC):
"""Abstract base class for all transports"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self._connected = False
self._lock = asyncio.Lock()
self._connection_attempts = 0
self._max_connection_attempts = config.get('max_connection_attempts', 3)
self._retry_delay = config.get('retry_delay', 1)
@abstractmethod
async def connect(self) -> None:
"""Establish connection"""
pass
@abstractmethod
async def disconnect(self) -> None:
"""Close connection"""
pass
@abstractmethod
async def request(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None
) -> Dict[str, Any]:
"""Make a request"""
pass
@abstractmethod
async def stream(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None
) -> AsyncIterator[Dict[str, Any]]:
"""Stream responses"""
pass
async def health_check(self) -> bool:
"""Check if transport is healthy"""
try:
if not self._connected:
return False
# Default health check - make a ping request
await self.request('GET', '/health')
return True
except Exception as e:
logger.warning(f"Transport health check failed: {e}")
return False
async def ensure_connected(self) -> None:
"""Ensure transport is connected, with retry logic"""
async with self._lock:
if self._connected:
return
while self._connection_attempts < self._max_connection_attempts:
try:
await self.connect()
self._connection_attempts = 0
return
except Exception as e:
self._connection_attempts += 1
logger.warning(f"Connection attempt {self._connection_attempts} failed: {e}")
if self._connection_attempts < self._max_connection_attempts:
await asyncio.sleep(self._retry_delay * self._connection_attempts)
else:
raise TransportConnectionError(
f"Failed to connect after {self._max_connection_attempts} attempts"
)
@property
def is_connected(self) -> bool:
"""Check if transport is connected"""
return self._connected
@property
def chain_id(self) -> Optional[int]:
"""Get the chain ID this transport is connected to"""
return self.config.get('chain_id')
@property
def network_name(self) -> Optional[str]:
"""Get the network name"""
return self.config.get('network_name')
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value"""
return self.config.get(key, default)
def update_config(self, updates: Dict[str, Any]) -> None:
"""Update configuration"""
self.config.update(updates)
async def __aenter__(self):
"""Async context manager entry"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
await self.disconnect()
class BatchTransport(Transport):
"""Transport mixin for batch operations"""
@abstractmethod
async def batch_request(
self,
requests: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Make multiple requests in batch"""
pass
class CachedTransport(Transport):
"""Transport mixin for caching responses"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self._cache: Dict[str, Any] = {}
self._cache_ttl = config.get('cache_ttl', 300) # 5 minutes
self._cache_timestamps: Dict[str, float] = {}
async def cached_request(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
cache_key: Optional[str] = None
) -> Dict[str, Any]:
"""Make request with caching"""
# Only cache GET requests
if method.upper() != 'GET':
return await self.request(method, path, data, params, headers)
# Generate cache key
if not cache_key:
import hashlib
import json
cache_data = json.dumps({
'method': method,
'path': path,
'params': params
}, sort_keys=True)
cache_key = hashlib.md5(cache_data.encode()).hexdigest()
# Check cache
if cache_key in self._cache:
timestamp = self._cache_timestamps.get(cache_key, 0)
if asyncio.get_event_loop().time() - timestamp < self._cache_ttl:
return self._cache[cache_key]
# Make request
response = await self.request(method, path, data, params, headers)
# Cache response
self._cache[cache_key] = response
self._cache_timestamps[cache_key] = asyncio.get_event_loop().time()
return response
def clear_cache(self, pattern: Optional[str] = None) -> None:
"""Clear cached responses"""
if pattern:
import re
regex = re.compile(pattern)
keys_to_remove = [k for k in self._cache.keys() if regex.match(k)]
for key in keys_to_remove:
del self._cache[key]
if key in self._cache_timestamps:
del self._cache_timestamps[key]
else:
self._cache.clear()
self._cache_timestamps.clear()
class RateLimitedTransport(Transport):
"""Transport mixin for rate limiting"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self._rate_limit = config.get('rate_limit', 60) # requests per minute
self._rate_window = config.get('rate_window', 60) # seconds
self._requests: List[float] = []
self._rate_lock = asyncio.Lock()
async def _check_rate_limit(self) -> None:
"""Check if request is within rate limit"""
async with self._rate_lock:
now = asyncio.get_event_loop().time()
# Remove old requests outside the window
self._requests = [req_time for req_time in self._requests
if now - req_time < self._rate_window]
# Check if we're at the limit
if len(self._requests) >= self._rate_limit:
# Calculate wait time
oldest_request = min(self._requests)
wait_time = self._rate_window - (now - oldest_request)
if wait_time > 0:
logger.warning(f"Rate limit reached, waiting {wait_time:.2f} seconds")
await asyncio.sleep(wait_time)
# Add current request
self._requests.append(now)
async def request(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None
) -> Dict[str, Any]:
"""Make request with rate limiting"""
await self._check_rate_limit()
return await super().request(method, path, data, params, headers, timeout)

View File

@@ -0,0 +1,405 @@
"""
HTTP transport implementation for AITBC Python SDK
"""
import asyncio
import json
import logging
from typing import Dict, Any, Optional, AsyncIterator, Union
from datetime import datetime, timedelta
import aiohttp
from aiohttp import ClientTimeout, ClientError, ClientResponseError
from .base import Transport, TransportError, TransportConnectionError, TransportRequestError
logger = logging.getLogger(__name__)
class HTTPTransport(Transport):
"""HTTP transport for REST API calls"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.base_url = config['base_url'].rstrip('/')
self.session: Optional[aiohttp.ClientSession] = None
self.timeout = ClientTimeout(
total=config.get('timeout', 30),
connect=config.get('connect_timeout', 10),
sock_read=config.get('read_timeout', 30)
)
self.default_headers = config.get('default_headers', {})
self.max_redirects = config.get('max_redirects', 10)
self.verify_ssl = config.get('verify_ssl', True)
self._last_request_time: Optional[float] = None
async def connect(self) -> None:
"""Create HTTP session"""
try:
# Configure SSL context
ssl_context = None
if not self.verify_ssl:
import ssl
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
# Create connector
connector = aiohttp.TCPConnector(
limit=self.config.get('connection_limit', 100),
limit_per_host=self.config.get('connection_limit_per_host', 30),
ttl_dns_cache=self.config.get('dns_cache_ttl', 300),
use_dns_cache=True,
ssl=ssl_context,
enable_cleanup_closed=True
)
# Create session
self.session = aiohttp.ClientSession(
connector=connector,
timeout=self.timeout,
headers=self.default_headers,
max_redirects=self.max_redirects,
raise_for_status=False # We'll handle status codes manually
)
# Test connection with health check
await self.health_check()
self._connected = True
logger.info(f"HTTP transport connected to {self.base_url}")
except Exception as e:
logger.error(f"Failed to connect HTTP transport: {e}")
raise TransportConnectionError(f"Connection failed: {e}")
async def disconnect(self) -> None:
"""Close HTTP session"""
if self.session:
await self.session.close()
self.session = None
self._connected = False
logger.info("HTTP transport disconnected")
async def request(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None
) -> Dict[str, Any]:
"""Make HTTP request"""
await self.ensure_connected()
if not self.session:
raise TransportConnectionError("Transport not connected")
# Prepare URL
url = f"{self.base_url}{path}"
# Prepare headers
request_headers = {}
if self.default_headers:
request_headers.update(self.default_headers)
if headers:
request_headers.update(headers)
# Add content-type if data is provided
if data and 'content-type' not in request_headers:
request_headers['content-type'] = 'application/json'
# Prepare request timeout
request_timeout = self.timeout
if timeout:
request_timeout = ClientTimeout(total=timeout)
# Log request
logger.debug(f"HTTP {method} {url}")
try:
# Make request
async with self.session.request(
method=method.upper(),
url=url,
json=data if data and request_headers.get('content-type') == 'application/json' else None,
data=data if data and request_headers.get('content-type') != 'application/json' else None,
params=params,
headers=request_headers,
timeout=request_timeout
) as response:
# Record request time
self._last_request_time = asyncio.get_event_loop().time()
# Handle response
await self._handle_response(response)
# Parse response
if response.content_type == 'application/json':
result = await response.json()
else:
result = {'data': await response.text()}
# Add metadata
result['_metadata'] = {
'status_code': response.status,
'headers': dict(response.headers),
'url': str(response.url)
}
return result
except ClientResponseError as e:
raise TransportRequestError(
f"HTTP {e.status}: {e.message}",
status_code=e.status,
response={'error': e.message}
)
except ClientError as e:
raise TransportError(f"HTTP request failed: {e}")
except asyncio.TimeoutError:
raise TransportError("Request timed out")
except Exception as e:
raise TransportError(f"Unexpected error: {e}")
async def stream(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None
) -> AsyncIterator[Dict[str, Any]]:
"""Stream responses (not supported for basic HTTP)"""
raise NotImplementedError("HTTP transport does not support streaming")
async def download(
self,
path: str,
file_path: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
chunk_size: int = 8192
) -> None:
"""Download file to disk"""
await self.ensure_connected()
if not self.session:
raise TransportConnectionError("Transport not connected")
url = f"{self.base_url}{path}"
try:
async with self.session.get(
url,
params=params,
headers=headers
) as response:
await self._handle_response(response)
# Stream to file
with open(file_path, 'wb') as f:
async for chunk in response.content.iter_chunked(chunk_size):
f.write(chunk)
logger.info(f"Downloaded {url} to {file_path}")
except Exception as e:
raise TransportError(f"Download failed: {e}")
async def upload(
self,
path: str,
file_path: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
chunk_size: int = 8192
) -> Dict[str, Any]:
"""Upload file from disk"""
await self.ensure_connected()
if not self.session:
raise TransportConnectionError("Transport not connected")
url = f"{self.base_url}{path}"
try:
# Prepare multipart form data
with open(file_path, 'rb') as f:
data = aiohttp.FormData()
data.add_field(
'file',
f,
filename=file_path.split('/')[-1],
content_type='application/octet-stream'
)
# Add additional fields
if params:
for key, value in params.items():
data.add_field(key, str(value))
async with self.session.post(
url,
data=data,
headers=headers
) as response:
await self._handle_response(response)
if response.content_type == 'application/json':
return await response.json()
else:
return {'status': 'uploaded'}
except Exception as e:
raise TransportError(f"Upload failed: {e}")
async def _handle_response(self, response: aiohttp.ClientResponse) -> None:
"""Handle HTTP response"""
if response.status >= 400:
error_data = {}
try:
if response.content_type == 'application/json':
error_data = await response.json()
else:
error_data = {'error': await response.text()}
except:
error_data = {'error': f'HTTP {response.status}'}
raise TransportRequestError(
error_data.get('error', f'HTTP {response.status}'),
status_code=response.status,
response=error_data
)
def get_stats(self) -> Dict[str, Any]:
"""Get transport statistics"""
stats = {
'connected': self._connected,
'base_url': self.base_url,
'last_request_time': self._last_request_time
}
if self.session:
# Get connector stats
connector = self.session.connector
stats.update({
'total_connections': len(connector._conns),
'available_connections': sum(len(conns) for conns in connector._conns.values())
})
return stats
class AuthenticatedHTTPTransport(HTTPTransport):
"""HTTP transport with authentication"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.auth_type = config.get('auth_type', 'api_key')
self.auth_config = config.get('auth', {})
async def _add_auth_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
"""Add authentication headers"""
headers = headers.copy()
if self.auth_type == 'api_key':
api_key = self.auth_config.get('api_key')
if api_key:
key_header = self.auth_config.get('key_header', 'X-API-Key')
headers[key_header] = api_key
elif self.auth_type == 'bearer':
token = self.auth_config.get('token')
if token:
headers['Authorization'] = f'Bearer {token}'
elif self.auth_type == 'basic':
username = self.auth_config.get('username')
password = self.auth_config.get('password')
if username and password:
import base64
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
headers['Authorization'] = f'Basic {credentials}'
return headers
async def request(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None
) -> Dict[str, Any]:
"""Make authenticated HTTP request"""
# Add auth headers
auth_headers = await self._add_auth_headers(headers or {})
return await super().request(
method, path, data, params, auth_headers, timeout
)
class RetryableHTTPTransport(HTTPTransport):
"""HTTP transport with automatic retry"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.max_retries = config.get('max_retries', 3)
self.retry_delay = config.get('retry_delay', 1)
self.retry_backoff = config.get('retry_backoff', 2)
self.retry_on = config.get('retry_on', [500, 502, 503, 504])
async def request(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None
) -> Dict[str, Any]:
"""Make HTTP request with retry logic"""
last_error = None
for attempt in range(self.max_retries + 1):
try:
return await super().request(
method, path, data, params, headers, timeout
)
except TransportRequestError as e:
last_error = e
# Check if we should retry
if attempt < self.max_retries and e.status_code in self.retry_on:
delay = self.retry_delay * (self.retry_backoff ** attempt)
logger.warning(
f"Request failed (attempt {attempt + 1}/{self.max_retries + 1}), "
f"retrying in {delay}s: {e}"
)
await asyncio.sleep(delay)
continue
# Don't retry on client errors or final attempt
break
except TransportError as e:
last_error = e
# Retry on connection errors
if attempt < self.max_retries:
delay = self.retry_delay * (self.retry_backoff ** attempt)
logger.warning(
f"Request failed (attempt {attempt + 1}/{self.max_retries + 1}), "
f"retrying in {delay}s: {e}"
)
await asyncio.sleep(delay)
continue
break
# All retries failed
raise last_error

View File

@@ -0,0 +1,377 @@
"""
Multi-network support for AITBC Python SDK
"""
import asyncio
import logging
from typing import Dict, Any, Optional, List, Union
from dataclasses import dataclass, field
from datetime import datetime
from .base import Transport, TransportError, TransportConnectionError
from .http import HTTPTransport
from .websocket import WebSocketTransport
logger = logging.getLogger(__name__)
@dataclass
class NetworkConfig:
"""Configuration for a network"""
name: str
chain_id: int
transport: Transport
is_default: bool = False
bridges: List[str] = field(default_factory=list)
explorer_url: Optional[str] = None
rpc_url: Optional[str] = None
native_token: str = "ETH"
gas_token: Optional[str] = None
class MultiNetworkClient:
"""Client supporting multiple networks and cross-chain operations"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
self.networks: Dict[int, NetworkConfig] = {}
self.default_network: Optional[int] = None
self._connected = False
self._connection_lock = asyncio.Lock()
if config:
self._load_config(config)
def _load_config(self, config: Dict[str, Any]) -> None:
"""Load network configurations"""
networks_config = config.get('networks', {})
for name, net_config in networks_config.items():
# Create transport
transport = self._create_transport(net_config)
# Create network config
network = NetworkConfig(
name=name,
chain_id=net_config['chain_id'],
transport=transport,
is_default=net_config.get('default', False),
bridges=net_config.get('bridges', []),
explorer_url=net_config.get('explorer_url'),
rpc_url=net_config.get('rpc_url'),
native_token=net_config.get('native_token', 'ETH'),
gas_token=net_config.get('gas_token')
)
self.add_network(network)
def _create_transport(self, config: Dict[str, Any]) -> Transport:
"""Create transport from config"""
transport_type = config.get('type', 'http')
transport_config = config.copy()
if transport_type == 'http':
return HTTPTransport(transport_config)
elif transport_type == 'websocket':
return WebSocketTransport(transport_config)
else:
raise ValueError(f"Unknown transport type: {transport_type}")
def add_network(self, network: NetworkConfig) -> None:
"""Add a network configuration"""
if network.chain_id in self.networks:
logger.warning(f"Network {network.chain_id} already exists, overwriting")
self.networks[network.chain_id] = network
# Set as default if marked or if no default exists
if network.is_default or self.default_network is None:
self.default_network = network.chain_id
logger.info(f"Added network: {network.name} (chain_id: {network.chain_id})")
def remove_network(self, chain_id: int) -> None:
"""Remove a network configuration"""
if chain_id in self.networks:
network = self.networks[chain_id]
# Disconnect if connected
if network.transport.is_connected:
asyncio.create_task(network.transport.disconnect())
del self.networks[chain_id]
# Update default if necessary
if self.default_network == chain_id:
self.default_network = None
# Set new default if other networks exist
if self.networks:
self.default_network = next(iter(self.networks))
logger.info(f"Removed network: {network.name} (chain_id: {chain_id})")
def get_transport(self, chain_id: Optional[int] = None) -> Transport:
"""Get transport for a network"""
network_id = chain_id or self.default_network
if network_id is None:
raise ValueError("No default network configured")
if network_id not in self.networks:
raise ValueError(f"Network {network_id} not configured")
return self.networks[network_id].transport
def get_network(self, chain_id: int) -> Optional[NetworkConfig]:
"""Get network configuration"""
return self.networks.get(chain_id)
def list_networks(self) -> List[NetworkConfig]:
"""List all configured networks"""
return list(self.networks.values())
def get_default_network(self) -> Optional[NetworkConfig]:
"""Get default network configuration"""
if self.default_network:
return self.networks.get(self.default_network)
return None
def set_default_network(self, chain_id: int) -> None:
"""Set default network"""
if chain_id not in self.networks:
raise ValueError(f"Network {chain_id} not configured")
self.default_network = chain_id
# Update all networks' default flag
for net in self.networks.values():
net.is_default = (net.chain_id == chain_id)
async def connect_all(self) -> None:
"""Connect to all configured networks"""
async with self._connection_lock:
if self._connected:
return
logger.info(f"Connecting to {len(self.networks)} networks...")
# Connect all transports
tasks = []
for chain_id, network in self.networks.items():
task = asyncio.create_task(
self._connect_network(network),
name=f"connect_{network.name}"
)
tasks.append(task)
# Wait for all connections
results = await asyncio.gather(*tasks, return_exceptions=True)
# Check for errors
errors = []
for i, result in enumerate(results):
if isinstance(result, Exception):
network_name = list(self.networks.values())[i].name
errors.append(f"{network_name}: {result}")
logger.error(f"Failed to connect to {network_name}: {result}")
if errors:
raise TransportConnectionError(
f"Failed to connect to some networks: {'; '.join(errors)}"
)
self._connected = True
logger.info("Connected to all networks")
async def disconnect_all(self) -> None:
"""Disconnect from all networks"""
async with self._connection_lock:
if not self._connected:
return
logger.info("Disconnecting from all networks...")
# Disconnect all transports
tasks = []
for network in self.networks.values():
if network.transport.is_connected:
task = asyncio.create_task(
network.transport.disconnect(),
name=f"disconnect_{network.name}"
)
tasks.append(task)
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
self._connected = False
logger.info("Disconnected from all networks")
async def connect_network(self, chain_id: int) -> None:
"""Connect to a specific network"""
network = self.networks.get(chain_id)
if not network:
raise ValueError(f"Network {chain_id} not configured")
await self._connect_network(network)
async def disconnect_network(self, chain_id: int) -> None:
"""Disconnect from a specific network"""
network = self.networks.get(chain_id)
if not network:
raise ValueError(f"Network {chain_id} not configured")
if network.transport.is_connected:
await network.transport.disconnect()
async def _connect_network(self, network: NetworkConfig) -> None:
"""Connect to a specific network"""
try:
if not network.transport.is_connected:
await network.transport.connect()
logger.info(f"Connected to {network.name}")
except Exception as e:
logger.error(f"Failed to connect to {network.name}: {e}")
raise
async def switch_network(self, chain_id: int) -> None:
"""Switch default network"""
if chain_id not in self.networks:
raise ValueError(f"Network {chain_id} not configured")
# Connect if not connected
network = self.networks[chain_id]
if not network.transport.is_connected:
await self._connect_network(network)
# Set as default
self.set_default_network(chain_id)
logger.info(f"Switched to network: {network.name}")
async def health_check_all(self) -> Dict[int, bool]:
"""Check health of all networks"""
results = {}
for chain_id, network in self.networks.items():
try:
results[chain_id] = await network.transport.health_check()
except Exception as e:
logger.warning(f"Health check failed for {network.name}: {e}")
results[chain_id] = False
return results
async def broadcast_request(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
chain_ids: Optional[List[int]] = None
) -> Dict[int, Dict[str, Any]]:
"""Broadcast request to multiple networks"""
if chain_ids is None:
chain_ids = list(self.networks.keys())
results = {}
# Make requests in parallel
tasks = {}
for chain_id in chain_ids:
if chain_id in self.networks:
transport = self.networks[chain_id].transport
task = asyncio.create_task(
transport.request(method, path, data, params, headers),
name=f"request_{chain_id}"
)
tasks[chain_id] = task
# Wait for all requests
for chain_id, task in tasks.items():
try:
results[chain_id] = await task
except Exception as e:
network_name = self.networks[chain_id].name
logger.error(f"Request failed for {network_name}: {e}")
results[chain_id] = {'error': str(e)}
return results
def get_network_stats(self) -> Dict[int, Dict[str, Any]]:
"""Get statistics for all networks"""
stats = {}
for chain_id, network in self.networks.items():
network_stats = {
'name': network.name,
'chain_id': network.chain_id,
'is_default': network.is_default,
'bridges': network.bridges,
'explorer_url': network.explorer_url,
'rpc_url': network.rpc_url,
'native_token': network.native_token,
'gas_token': network.gas_token
}
# Add transport stats if available
if hasattr(network.transport, 'get_stats'):
network_stats['transport'] = network.transport.get_stats()
stats[chain_id] = network_stats
return stats
def find_network_by_name(self, name: str) -> Optional[NetworkConfig]:
"""Find network by name"""
for network in self.networks.values():
if network.name == name:
return network
return None
def find_networks_by_bridge(self, bridge: str) -> List[NetworkConfig]:
"""Find networks that support a specific bridge"""
return [
network for network in self.networks.values()
if bridge in network.bridges
]
async def __aenter__(self):
"""Async context manager entry"""
await self.connect_all()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
await self.disconnect_all()
class NetworkSwitcher:
"""Utility for switching between networks"""
def __init__(self, client: MultiNetworkClient):
self.client = client
self._original_default: Optional[int] = None
async def __aenter__(self):
"""Store original default network"""
self._original_default = self.client.default_network
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Restore original default network"""
if self._original_default:
await self.client.switch_network(self._original_default)
async def switch_to(self, chain_id: int):
"""Switch to specific network"""
await self.client.switch_network(chain_id)
return self
async def switch_to_name(self, name: str):
"""Switch to network by name"""
network = self.client.find_network_by_name(name)
if not network:
raise ValueError(f"Network {name} not found")
await self.switch_to(network.chain_id)
return self

View File

@@ -0,0 +1,449 @@
"""
WebSocket transport implementation for AITBC Python SDK
"""
import asyncio
import json
import logging
from typing import Dict, Any, Optional, AsyncIterator, Callable
from datetime import datetime
import websockets
from websockets.exceptions import ConnectionClosed, ConnectionClosedError, ConnectionClosedOK
from .base import Transport, TransportError, TransportConnectionError, TransportRequestError
logger = logging.getLogger(__name__)
class WebSocketTransport(Transport):
"""WebSocket transport for real-time updates"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.ws_url = config['ws_url']
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
self._subscriptions: Dict[str, Dict[str, Any]] = {}
self._message_handlers: Dict[str, Callable] = {}
self._message_queue = asyncio.Queue()
self._consumer_task: Optional[asyncio.Task] = None
self._heartbeat_interval = config.get('heartbeat_interval', 30)
self._heartbeat_task: Optional[asyncio.Task] = None
self._reconnect_enabled = config.get('reconnect', True)
self._max_reconnect_attempts = config.get('max_reconnect_attempts', 5)
self._reconnect_delay = config.get('reconnect_delay', 5)
self._ping_timeout = config.get('ping_timeout', 20)
self._close_code: Optional[int] = None
self._close_reason: Optional[str] = None
async def connect(self) -> None:
"""Connect to WebSocket"""
try:
# Prepare connection parameters
extra_headers = self.config.get('headers', {})
ping_interval = self.config.get('ping_interval', self._heartbeat_interval)
ping_timeout = self._ping_timeout
# Connect to WebSocket
logger.info(f"Connecting to WebSocket: {self.ws_url}")
self.websocket = await websockets.connect(
self.ws_url,
extra_headers=extra_headers,
ping_interval=ping_interval,
ping_timeout=ping_timeout,
close_timeout=self.config.get('close_timeout', 10)
)
# Start consumer task
self._consumer_task = asyncio.create_task(self._consume_messages())
# Start heartbeat task
self._heartbeat_task = asyncio.create_task(self._heartbeat())
self._connected = True
logger.info("WebSocket transport connected")
except Exception as e:
logger.error(f"Failed to connect WebSocket: {e}")
raise TransportConnectionError(f"WebSocket connection failed: {e}")
async def disconnect(self) -> None:
"""Disconnect WebSocket"""
self._connected = False
# Cancel tasks
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
if self._heartbeat_task:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
# Close WebSocket
if self.websocket:
try:
await self.websocket.close()
except Exception as e:
logger.warning(f"Error closing WebSocket: {e}")
finally:
self.websocket = None
logger.info("WebSocket transport disconnected")
async def request(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None
) -> Dict[str, Any]:
"""Send request via WebSocket"""
await self.ensure_connected()
if not self.websocket:
raise TransportConnectionError("WebSocket not connected")
# Generate request ID
request_id = self._generate_id()
# Create message
message = {
'id': request_id,
'type': 'request',
'method': method,
'path': path,
'data': data,
'params': params,
'timestamp': datetime.utcnow().isoformat()
}
# Send request
await self._send_message(message)
# Wait for response
timeout = timeout or self.config.get('request_timeout', 30)
try:
response = await asyncio.wait_for(
self._wait_for_response(request_id),
timeout=timeout
)
return response
except asyncio.TimeoutError:
raise TransportError(f"Request timed out after {timeout}s")
async def stream(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None
) -> AsyncIterator[Dict[str, Any]]:
"""Stream responses from WebSocket"""
await self.ensure_connected()
# Create subscription
subscription_id = self._generate_id()
# Subscribe
message = {
'id': subscription_id,
'type': 'subscribe',
'method': method,
'path': path,
'data': data,
'timestamp': datetime.utcnow().isoformat()
}
await self._send_message(message)
# Store subscription
self._subscriptions[subscription_id] = {
'method': method,
'path': path,
'created_at': datetime.utcnow()
}
try:
# Yield messages as they come
async for message in self._stream_subscription(subscription_id):
yield message
finally:
# Unsubscribe
await self._unsubscribe(subscription_id)
async def subscribe(
self,
event: str,
callback: Callable[[Dict[str, Any]], None],
data: Optional[Dict[str, Any]] = None
) -> str:
"""Subscribe to events"""
await self.ensure_connected()
subscription_id = self._generate_id()
# Store subscription with callback
self._subscriptions[subscription_id] = {
'event': event,
'callback': callback,
'data': data,
'created_at': datetime.utcnow()
}
# Send subscription message
message = {
'id': subscription_id,
'type': 'subscribe',
'event': event,
'data': data,
'timestamp': datetime.utcnow().isoformat()
}
await self._send_message(message)
logger.info(f"Subscribed to event: {event}")
return subscription_id
async def unsubscribe(self, subscription_id: str) -> None:
"""Unsubscribe from events"""
if subscription_id in self._subscriptions:
# Send unsubscribe message
message = {
'id': subscription_id,
'type': 'unsubscribe',
'timestamp': datetime.utcnow().isoformat()
}
await self._send_message(message)
# Remove subscription
del self._subscriptions[subscription_id]
logger.info(f"Unsubscribed: {subscription_id}")
async def emit(self, event: str, data: Optional[Dict[str, Any]] = None) -> None:
"""Emit event to server"""
await self.ensure_connected()
message = {
'type': 'event',
'event': event,
'data': data,
'timestamp': datetime.utcnow().isoformat()
}
await self._send_message(message)
async def _send_message(self, message: Dict[str, Any]) -> None:
"""Send message to WebSocket"""
if not self.websocket:
raise TransportConnectionError("WebSocket not connected")
try:
await self.websocket.send(json.dumps(message))
logger.debug(f"Sent WebSocket message: {message.get('type', 'unknown')}")
except ConnectionClosed:
await self._handle_disconnect()
raise TransportConnectionError("WebSocket connection closed")
except Exception as e:
raise TransportError(f"Failed to send message: {e}")
async def _consume_messages(self) -> None:
"""Consume messages from WebSocket"""
while self._connected:
try:
# Wait for message
message = await asyncio.wait_for(
self.websocket.recv(),
timeout=self._heartbeat_interval * 2
)
# Parse message
try:
data = json.loads(message)
except json.JSONDecodeError:
logger.error(f"Invalid JSON message: {message}")
continue
# Handle message
await self._handle_message(data)
except asyncio.TimeoutError:
# No message received, check connection
continue
except ConnectionClosedOK:
logger.info("WebSocket closed normally")
break
except ConnectionClosedError as e:
logger.warning(f"WebSocket connection closed: {e}")
await self._handle_disconnect()
break
except Exception as e:
logger.error(f"Error consuming message: {e}")
break
async def _handle_message(self, data: Dict[str, Any]) -> None:
"""Handle incoming message"""
message_type = data.get('type')
if message_type == 'response':
# Request response
await self._message_queue.put(data)
elif message_type == 'event':
# Event message
await self._handle_event(data)
elif message_type == 'subscription':
# Subscription update
await self._handle_subscription_update(data)
elif message_type == 'error':
# Error message
logger.error(f"WebSocket error: {data.get('message')}")
else:
logger.warning(f"Unknown message type: {message_type}")
async def _handle_event(self, data: Dict[str, Any]) -> None:
"""Handle event message"""
event = data.get('event')
event_data = data.get('data')
# Find matching subscriptions
for sub_id, sub in self._subscriptions.items():
if sub.get('event') == event:
callback = sub.get('callback')
if callback:
try:
if asyncio.iscoroutinefunction(callback):
await callback(event_data)
else:
callback(event_data)
except Exception as e:
logger.error(f"Error in event callback: {e}")
async def _handle_subscription_update(self, data: Dict[str, Any]) -> None:
"""Handle subscription update"""
subscription_id = data.get('subscription_id')
status = data.get('status')
if subscription_id in self._subscriptions:
sub = self._subscriptions[subscription_id]
sub['status'] = status
if status == 'confirmed':
logger.info(f"Subscription confirmed: {subscription_id}")
elif status == 'error':
logger.error(f"Subscription error: {subscription_id}")
async def _wait_for_response(self, request_id: str) -> Dict[str, Any]:
"""Wait for specific response"""
while True:
message = await self._message_queue.get()
if message.get('id') == request_id:
if message.get('type') == 'error':
raise TransportRequestError(
message.get('message', 'Request failed')
)
return message
async def _stream_subscription(self, subscription_id: str) -> AsyncIterator[Dict[str, Any]]:
"""Stream messages for subscription"""
queue = asyncio.Queue()
# Add queue to subscriptions
if subscription_id in self._subscriptions:
self._subscriptions[subscription_id]['queue'] = queue
try:
while True:
message = await queue.get()
if message.get('type') == 'unsubscribe':
break
yield message
finally:
# Clean up queue
if subscription_id in self._subscriptions:
self._subscriptions[subscription_id].pop('queue', None)
async def _unsubscribe(self, subscription_id: str) -> None:
"""Unsubscribe and clean up"""
await self.unsubscribe(subscription_id)
async def _heartbeat(self) -> None:
"""Send periodic heartbeat"""
while self._connected:
try:
await asyncio.sleep(self._heartbeat_interval)
if self.websocket and self._connected:
# Send ping
await self.websocket.ping()
except Exception as e:
logger.warning(f"Heartbeat failed: {e}")
break
async def _handle_disconnect(self) -> None:
"""Handle unexpected disconnect"""
self._connected = False
if self._reconnect_enabled:
logger.info("Attempting to reconnect...")
await self._reconnect()
async def _reconnect(self) -> None:
"""Attempt to reconnect"""
for attempt in range(self._max_reconnect_attempts):
try:
logger.info(f"Reconnect attempt {attempt + 1}/{self._max_reconnect_attempts}")
# Wait before reconnect
await asyncio.sleep(self._reconnect_delay)
# Reconnect
await self.connect()
# Resubscribe to all subscriptions
for sub_id, sub in list(self._subscriptions.items()):
if sub.get('event'):
await self.subscribe(
sub['event'],
sub['callback'],
sub.get('data')
)
logger.info("Reconnected successfully")
return
except Exception as e:
logger.error(f"Reconnect attempt {attempt + 1} failed: {e}")
logger.error("Failed to reconnect after all attempts")
def _generate_id(self) -> str:
"""Generate unique ID"""
import uuid
return str(uuid.uuid4())
def get_stats(self) -> Dict[str, Any]:
"""Get transport statistics"""
return {
'connected': self._connected,
'ws_url': self.ws_url,
'subscriptions': len(self._subscriptions),
'close_code': self._close_code,
'close_reason': self._close_reason
}