feat: add marketplace metrics, privacy features, and service registry endpoints
- Add Prometheus metrics for marketplace API throughput and error rates with new dashboard panels - Implement confidential transaction models with encryption support and access control - Add key management system with registration, rotation, and audit logging - Create services and registry routers for service discovery and management - Integrate ZK proof generation for privacy-preserving receipts - Add metrics instru
This commit is contained in:
19
python-sdk/aitbc/apis/__init__.py
Normal file
19
python-sdk/aitbc/apis/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
API modules for AITBC Python SDK
|
||||
"""
|
||||
|
||||
from .jobs import JobsAPI, MultiNetworkJobsAPI
|
||||
from .marketplace import MarketplaceAPI
|
||||
from .wallet import WalletAPI
|
||||
from .receipts import ReceiptsAPI
|
||||
from .settlement import SettlementAPI, MultiNetworkSettlementAPI
|
||||
|
||||
__all__ = [
|
||||
"JobsAPI",
|
||||
"MultiNetworkJobsAPI",
|
||||
"MarketplaceAPI",
|
||||
"WalletAPI",
|
||||
"ReceiptsAPI",
|
||||
"SettlementAPI",
|
||||
"MultiNetworkSettlementAPI",
|
||||
]
|
||||
94
python-sdk/aitbc/apis/jobs.py
Normal file
94
python-sdk/aitbc/apis/jobs.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Jobs API for AITBC Python SDK
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import logging
|
||||
|
||||
from ..transport import Transport
|
||||
from ..transport.multinetwork import MultiNetworkClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JobsAPI:
|
||||
"""Jobs API client"""
|
||||
|
||||
def __init__(self, transport: Transport):
|
||||
self.transport = transport
|
||||
|
||||
async def create(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create a new job"""
|
||||
return await self.transport.request('POST', '/v1/jobs', data=data)
|
||||
|
||||
async def get(self, job_id: str) -> Dict[str, Any]:
|
||||
"""Get job details"""
|
||||
return await self.transport.request('GET', f'/v1/jobs/{job_id}')
|
||||
|
||||
async def list(self, **params) -> List[Dict[str, Any]]:
|
||||
"""List jobs"""
|
||||
response = await self.transport.request('GET', '/v1/jobs', params=params)
|
||||
return response.get('jobs', [])
|
||||
|
||||
async def update(self, job_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Update job"""
|
||||
return await self.transport.request('PUT', f'/v1/jobs/{job_id}', data=data)
|
||||
|
||||
async def delete(self, job_id: str) -> None:
|
||||
"""Delete job"""
|
||||
await self.transport.request('DELETE', f'/v1/jobs/{job_id}')
|
||||
|
||||
async def wait_for_completion(
|
||||
self,
|
||||
job_id: str,
|
||||
timeout: Optional[int] = None,
|
||||
poll_interval: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""Wait for job completion"""
|
||||
# Implementation would poll job status until complete
|
||||
pass
|
||||
|
||||
|
||||
class MultiNetworkJobsAPI(JobsAPI):
|
||||
"""Multi-network Jobs API client"""
|
||||
|
||||
def __init__(self, client: MultiNetworkClient):
|
||||
self.client = client
|
||||
|
||||
async def create(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
chain_id: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new job on specific network"""
|
||||
transport = self.client.get_transport(chain_id)
|
||||
return await transport.request('POST', '/v1/jobs', data=data)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
job_id: str,
|
||||
chain_id: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get job details from specific network"""
|
||||
transport = self.client.get_transport(chain_id)
|
||||
return await transport.request('GET', f'/v1/jobs/{job_id}')
|
||||
|
||||
async def list(
|
||||
self,
|
||||
chain_id: Optional[int] = None,
|
||||
**params
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List jobs from specific network"""
|
||||
transport = self.client.get_transport(chain_id)
|
||||
response = await transport.request('GET', '/v1/jobs', params=params)
|
||||
return response.get('jobs', [])
|
||||
|
||||
async def broadcast_create(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
chain_ids: Optional[List[int]] = None
|
||||
) -> Dict[int, Dict[str, Any]]:
|
||||
"""Create job on multiple networks"""
|
||||
return await self.client.broadcast_request(
|
||||
'POST', '/v1/jobs', data=data, chain_ids=chain_ids
|
||||
)
|
||||
46
python-sdk/aitbc/apis/marketplace.py
Normal file
46
python-sdk/aitbc/apis/marketplace.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Marketplace API for AITBC Python SDK
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import logging
|
||||
|
||||
from ..transport import Transport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarketplaceAPI:
|
||||
"""Marketplace API client"""
|
||||
|
||||
def __init__(self, transport: Transport):
|
||||
self.transport = transport
|
||||
|
||||
async def list_offers(self, **params) -> List[Dict[str, Any]]:
|
||||
"""List marketplace offers"""
|
||||
response = await self.transport.request('GET', '/v1/marketplace/offers', params=params)
|
||||
return response.get('offers', [])
|
||||
|
||||
async def create_offer(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create a new offer"""
|
||||
return await self.transport.request('POST', '/v1/marketplace/offers', data=data)
|
||||
|
||||
async def get_offer(self, offer_id: str) -> Dict[str, Any]:
|
||||
"""Get offer details"""
|
||||
return await self.transport.request('GET', f'/v1/marketplace/offers/{offer_id}')
|
||||
|
||||
async def update_offer(self, offer_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Update offer"""
|
||||
return await self.transport.request('PUT', f'/v1/marketplace/offers/{offer_id}', data=data)
|
||||
|
||||
async def delete_offer(self, offer_id: str) -> None:
|
||||
"""Delete offer"""
|
||||
await self.transport.request('DELETE', f'/v1/marketplace/offers/{offer_id}')
|
||||
|
||||
async def accept_offer(self, offer_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Accept an offer"""
|
||||
return await self.transport.request('POST', f'/v1/marketplace/offers/{offer_id}/accept', data=data)
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get marketplace statistics"""
|
||||
return await self.transport.request('GET', '/v1/marketplace/stats')
|
||||
34
python-sdk/aitbc/apis/receipts.py
Normal file
34
python-sdk/aitbc/apis/receipts.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Receipts API for AITBC Python SDK
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import logging
|
||||
|
||||
from ..transport import Transport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReceiptsAPI:
|
||||
"""Receipts API client"""
|
||||
|
||||
def __init__(self, transport: Transport):
|
||||
self.transport = transport
|
||||
|
||||
async def get(self, job_id: str) -> Dict[str, Any]:
|
||||
"""Get job receipt"""
|
||||
return await self.transport.request('GET', f'/v1/receipts/{job_id}')
|
||||
|
||||
async def verify(self, receipt: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Verify receipt"""
|
||||
return await self.transport.request('POST', '/v1/receipts/verify', data=receipt)
|
||||
|
||||
async def list(self, **params) -> List[Dict[str, Any]]:
|
||||
"""List receipts"""
|
||||
response = await self.transport.request('GET', '/v1/receipts', params=params)
|
||||
return response.get('receipts', [])
|
||||
|
||||
async def stream(self, **params):
|
||||
"""Stream new receipts"""
|
||||
return self.transport.stream('GET', '/v1/receipts/stream', params=params)
|
||||
100
python-sdk/aitbc/apis/settlement.py
Normal file
100
python-sdk/aitbc/apis/settlement.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Settlement API for AITBC Python SDK
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import logging
|
||||
|
||||
from ..transport import Transport
|
||||
from ..transport.multinetwork import MultiNetworkClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SettlementAPI:
|
||||
"""Settlement API client"""
|
||||
|
||||
def __init__(self, transport: Transport):
|
||||
self.transport = transport
|
||||
|
||||
async def settle_cross_chain(
|
||||
self,
|
||||
job_id: str,
|
||||
target_chain_id: int,
|
||||
bridge_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Initiate cross-chain settlement"""
|
||||
data = {
|
||||
'job_id': job_id,
|
||||
'target_chain_id': target_chain_id,
|
||||
'bridge_name': bridge_name
|
||||
}
|
||||
return await self.transport.request('POST', '/v1/settlement/cross-chain', data=data)
|
||||
|
||||
async def get_settlement_status(self, message_id: str) -> Dict[str, Any]:
|
||||
"""Get settlement status"""
|
||||
return await self.transport.request('GET', f'/v1/settlement/{message_id}/status')
|
||||
|
||||
async def estimate_cost(
|
||||
self,
|
||||
job_id: str,
|
||||
target_chain_id: int,
|
||||
bridge_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate settlement cost"""
|
||||
data = {
|
||||
'job_id': job_id,
|
||||
'target_chain_id': target_chain_id,
|
||||
'bridge_name': bridge_name
|
||||
}
|
||||
return await self.transport.request('POST', '/v1/settlement/estimate-cost', data=data)
|
||||
|
||||
async def list_bridges(self) -> Dict[str, Any]:
|
||||
"""List supported bridges"""
|
||||
return await self.transport.request('GET', '/v1/settlement/bridges')
|
||||
|
||||
async def list_chains(self) -> Dict[str, Any]:
|
||||
"""List supported chains"""
|
||||
return await self.transport.request('GET', '/v1/settlement/chains')
|
||||
|
||||
async def refund_settlement(self, message_id: str) -> Dict[str, Any]:
|
||||
"""Refund failed settlement"""
|
||||
return await self.transport.request('POST', f'/v1/settlement/{message_id}/refund')
|
||||
|
||||
|
||||
class MultiNetworkSettlementAPI(SettlementAPI):
|
||||
"""Multi-network Settlement API client"""
|
||||
|
||||
def __init__(self, client: MultiNetworkClient):
|
||||
self.client = client
|
||||
|
||||
async def settle_cross_chain(
|
||||
self,
|
||||
job_id: str,
|
||||
target_chain_id: int,
|
||||
source_chain_id: Optional[int] = None,
|
||||
bridge_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Initiate cross-chain settlement from specific network"""
|
||||
transport = self.client.get_transport(source_chain_id)
|
||||
data = {
|
||||
'job_id': job_id,
|
||||
'target_chain_id': target_chain_id,
|
||||
'bridge_name': bridge_name
|
||||
}
|
||||
return await transport.request('POST', '/v1/settlement/cross-chain', data=data)
|
||||
|
||||
async def batch_settle(
|
||||
self,
|
||||
job_ids: List[str],
|
||||
target_chain_id: int,
|
||||
bridge_name: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Batch settle multiple jobs"""
|
||||
data = {
|
||||
'job_ids': job_ids,
|
||||
'target_chain_id': target_chain_id,
|
||||
'bridge_name': bridge_name
|
||||
}
|
||||
transport = self.client.get_transport()
|
||||
return await transport.request('POST', '/v1/settlement/batch', data=data)
|
||||
50
python-sdk/aitbc/apis/wallet.py
Normal file
50
python-sdk/aitbc/apis/wallet.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Wallet API for AITBC Python SDK
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import logging
|
||||
|
||||
from ..transport import Transport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WalletAPI:
|
||||
"""Wallet API client"""
|
||||
|
||||
def __init__(self, transport: Transport):
|
||||
self.transport = transport
|
||||
|
||||
async def create(self) -> Dict[str, Any]:
|
||||
"""Create a new wallet"""
|
||||
return await self.transport.request('POST', '/v1/wallet')
|
||||
|
||||
async def get_balance(self, token: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get wallet balance"""
|
||||
params = {}
|
||||
if token:
|
||||
params['token'] = token
|
||||
return await self.transport.request('GET', '/v1/wallet/balance', params=params)
|
||||
|
||||
async def send(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Send tokens"""
|
||||
return await self.transport.request('POST', '/v1/wallet/send', data=data)
|
||||
|
||||
async def get_address(self) -> str:
|
||||
"""Get wallet address"""
|
||||
response = await self.transport.request('GET', '/v1/wallet/address')
|
||||
return response.get('address')
|
||||
|
||||
async def get_transactions(self, **params) -> List[Dict[str, Any]]:
|
||||
"""Get transaction history"""
|
||||
response = await self.transport.request('GET', '/v1/wallet/transactions', params=params)
|
||||
return response.get('transactions', [])
|
||||
|
||||
async def stake(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Stake tokens"""
|
||||
return await self.transport.request('POST', '/v1/wallet/stake', data=data)
|
||||
|
||||
async def unstake(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Unstake tokens"""
|
||||
return await self.transport.request('POST', '/v1/wallet/unstake', data=data)
|
||||
364
python-sdk/aitbc/client.py
Normal file
364
python-sdk/aitbc/client.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
Main AITBC client with pluggable transport abstraction
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
from datetime import datetime
|
||||
|
||||
from .transport import (
|
||||
Transport,
|
||||
HTTPTransport,
|
||||
WebSocketTransport,
|
||||
MultiNetworkClient,
|
||||
NetworkConfig,
|
||||
TransportError
|
||||
)
|
||||
from .transport.base import BatchTransport, CachedTransport, RateLimitedTransport
|
||||
from .apis.jobs import JobsAPI, MultiNetworkJobsAPI
|
||||
from .apis.marketplace import MarketplaceAPI
|
||||
from .apis.wallet import WalletAPI
|
||||
from .apis.receipts import ReceiptsAPI
|
||||
from .apis.settlement import SettlementAPI, MultiNetworkSettlementAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AITBCClient:
|
||||
"""AITBC client with pluggable transports and multi-network support"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: Optional[Union[Transport, Dict[str, Any]]] = None,
|
||||
multi_network: bool = False,
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Initialize AITBC client
|
||||
|
||||
Args:
|
||||
transport: Transport instance or configuration
|
||||
multi_network: Enable multi-network mode
|
||||
config: Additional configuration options
|
||||
"""
|
||||
self.config = config or {}
|
||||
self._connected = False
|
||||
self._apis = {}
|
||||
|
||||
# Initialize transport layer
|
||||
if multi_network:
|
||||
self._init_multi_network(transport or {})
|
||||
else:
|
||||
self._init_single_network(transport or self._get_default_config())
|
||||
|
||||
# Initialize API clients
|
||||
self._init_apis()
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""Get default configuration for backward compatibility"""
|
||||
return {
|
||||
'type': 'http',
|
||||
'base_url': self.config.get('base_url', 'https://api.aitbc.io'),
|
||||
'timeout': self.config.get('timeout', 30),
|
||||
'api_key': self.config.get('api_key'),
|
||||
'default_headers': {
|
||||
'User-Agent': f'AITBC-Python-SDK/{self._get_version()}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
}
|
||||
|
||||
def _init_single_network(self, transport_config: Union[Transport, Dict[str, Any]]) -> None:
|
||||
"""Initialize single network client"""
|
||||
if isinstance(transport_config, Transport):
|
||||
self.transport = transport_config
|
||||
else:
|
||||
# Create transport from config
|
||||
self.transport = self._create_transport(transport_config)
|
||||
|
||||
self.multi_network = False
|
||||
self.multi_network_client = None
|
||||
|
||||
def _init_multi_network(self, configs: Dict[str, Any]) -> None:
|
||||
"""Initialize multi-network client"""
|
||||
self.multi_network_client = MultiNetworkClient(configs)
|
||||
self.multi_network = True
|
||||
self.transport = None # Use multi_network_client instead
|
||||
|
||||
def _create_transport(self, config: Dict[str, Any]) -> Transport:
|
||||
"""Create transport from configuration"""
|
||||
transport_type = config.get('type', 'http')
|
||||
|
||||
# Add API key to headers if provided
|
||||
if 'api_key' in config and 'default_headers' not in config:
|
||||
config['default_headers'] = {
|
||||
'X-API-Key': config['api_key'],
|
||||
'User-Agent': f'AITBC-Python-SDK/{self._get_version()}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# Create base transport
|
||||
if transport_type == 'http':
|
||||
transport = HTTPTransport(config)
|
||||
elif transport_type == 'websocket':
|
||||
transport = WebSocketTransport(config)
|
||||
elif transport_type == 'crosschain':
|
||||
# Will be implemented later
|
||||
raise NotImplementedError("CrossChain transport not yet implemented")
|
||||
else:
|
||||
raise ValueError(f"Unknown transport type: {transport_type}")
|
||||
|
||||
# Apply mixins if enabled
|
||||
if config.get('cached', False):
|
||||
transport = CachedTransport(config)
|
||||
|
||||
if config.get('rate_limited', False):
|
||||
transport = RateLimitedTransport(config)
|
||||
|
||||
if config.get('batch', False):
|
||||
transport = BatchTransport(config)
|
||||
|
||||
return transport
|
||||
|
||||
def _init_apis(self) -> None:
|
||||
"""Initialize API clients"""
|
||||
if self.multi_network:
|
||||
# Multi-network APIs
|
||||
self.jobs = MultiNetworkJobsAPI(self.multi_network_client)
|
||||
self.settlement = MultiNetworkSettlementAPI(self.multi_network_client)
|
||||
|
||||
# Single-network APIs (use default network)
|
||||
default_transport = self.multi_network_client.get_transport()
|
||||
self.marketplace = MarketplaceAPI(default_transport)
|
||||
self.wallet = WalletAPI(default_transport)
|
||||
self.receipts = ReceiptsAPI(default_transport)
|
||||
else:
|
||||
# Single-network APIs
|
||||
self.jobs = JobsAPI(self.transport)
|
||||
self.marketplace = MarketplaceAPI(self.transport)
|
||||
self.wallet = WalletAPI(self.transport)
|
||||
self.receipts = ReceiptsAPI(self.transport)
|
||||
self.settlement = SettlementAPI(self.transport)
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to network(s)"""
|
||||
if self.multi_network:
|
||||
await self.multi_network_client.connect_all()
|
||||
else:
|
||||
await self.transport.connect()
|
||||
|
||||
self._connected = True
|
||||
logger.info("AITBC client connected")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from network(s)"""
|
||||
if self.multi_network:
|
||||
await self.multi_network_client.disconnect_all()
|
||||
elif self.transport:
|
||||
await self.transport.disconnect()
|
||||
|
||||
self._connected = False
|
||||
logger.info("AITBC client disconnected")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if client is connected"""
|
||||
if self.multi_network:
|
||||
return self.multi_network_client._connected
|
||||
elif self.transport:
|
||||
return self.transport.is_connected
|
||||
return False
|
||||
|
||||
# Multi-network methods
|
||||
def add_network(self, network_config: NetworkConfig) -> None:
|
||||
"""Add a network (multi-network mode only)"""
|
||||
if not self.multi_network:
|
||||
raise RuntimeError("Multi-network mode not enabled")
|
||||
|
||||
self.multi_network_client.add_network(network_config)
|
||||
|
||||
def remove_network(self, chain_id: int) -> None:
|
||||
"""Remove a network (multi-network mode only)"""
|
||||
if not self.multi_network:
|
||||
raise RuntimeError("Multi-network mode not enabled")
|
||||
|
||||
self.multi_network_client.remove_network(chain_id)
|
||||
|
||||
def get_networks(self) -> List[NetworkConfig]:
|
||||
"""Get all configured networks"""
|
||||
if not self.multi_network:
|
||||
raise RuntimeError("Multi-network mode not enabled")
|
||||
|
||||
return self.multi_network_client.list_networks()
|
||||
|
||||
def set_default_network(self, chain_id: int) -> None:
|
||||
"""Set default network (multi-network mode only)"""
|
||||
if not self.multi_network:
|
||||
raise RuntimeError("Multi-network mode not enabled")
|
||||
|
||||
self.multi_network_client.set_default_network(chain_id)
|
||||
|
||||
async def switch_network(self, chain_id: int) -> None:
|
||||
"""Switch to a different network (multi-network mode only)"""
|
||||
if not self.multi_network:
|
||||
raise RuntimeError("Multi-network mode not enabled")
|
||||
|
||||
await self.multi_network_client.switch_network(chain_id)
|
||||
|
||||
async def health_check(self) -> Union[bool, Dict[int, bool]]:
|
||||
"""Check health of connection(s)"""
|
||||
if self.multi_network:
|
||||
return await self.multi_network_client.health_check_all()
|
||||
elif self.transport:
|
||||
return await self.transport.health_check()
|
||||
return False
|
||||
|
||||
# Backward compatibility methods
|
||||
def get_api_key(self) -> Optional[str]:
|
||||
"""Get API key (backward compatibility)"""
|
||||
if self.multi_network:
|
||||
# Get from default network
|
||||
default_network = self.multi_network_client.get_default_network()
|
||||
if default_network:
|
||||
return default_network.transport.get_config('api_key')
|
||||
elif self.transport:
|
||||
return self.transport.get_config('api_key')
|
||||
return None
|
||||
|
||||
def set_api_key(self, api_key: str) -> None:
|
||||
"""Set API key (backward compatibility)"""
|
||||
if self.multi_network:
|
||||
# Update all networks
|
||||
for network in self.multi_network_client.networks.values():
|
||||
network.transport.update_config({'api_key': api_key})
|
||||
elif self.transport:
|
||||
self.transport.update_config({'api_key': api_key})
|
||||
|
||||
def get_base_url(self) -> Optional[str]:
|
||||
"""Get base URL (backward compatibility)"""
|
||||
if self.multi_network:
|
||||
default_network = self.multi_network_client.get_default_network()
|
||||
if default_network:
|
||||
return default_network.transport.get_config('base_url')
|
||||
elif self.transport:
|
||||
return self.transport.get_config('base_url')
|
||||
return None
|
||||
|
||||
# Utility methods
|
||||
def _get_version(self) -> str:
|
||||
"""Get SDK version"""
|
||||
try:
|
||||
from . import __version__
|
||||
return __version__
|
||||
except ImportError:
|
||||
return "1.0.0"
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get client statistics"""
|
||||
stats = {
|
||||
'multi_network': self.multi_network,
|
||||
'connected': self._connected,
|
||||
'version': self._get_version()
|
||||
}
|
||||
|
||||
if self.multi_network:
|
||||
stats['networks'] = self.multi_network_client.get_network_stats()
|
||||
elif self.transport:
|
||||
if hasattr(self.transport, 'get_stats'):
|
||||
stats['transport'] = self.transport.get_stats()
|
||||
else:
|
||||
stats['transport'] = {
|
||||
'connected': self.transport.is_connected,
|
||||
'chain_id': self.transport.chain_id
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
# Context managers
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit"""
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def create_client(
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
transport: Optional[Union[Transport, str]] = None,
|
||||
**kwargs
|
||||
) -> AITBCClient:
|
||||
"""
|
||||
Create AITBC client with backward-compatible interface
|
||||
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
base_url: Base URL for the API
|
||||
timeout: Request timeout in seconds
|
||||
transport: Transport type ('http', 'websocket') or Transport instance
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
AITBCClient instance
|
||||
"""
|
||||
config = {}
|
||||
|
||||
# Build configuration
|
||||
if api_key:
|
||||
config['api_key'] = api_key
|
||||
if base_url:
|
||||
config['base_url'] = base_url
|
||||
if timeout:
|
||||
config['timeout'] = timeout
|
||||
|
||||
# Add other config
|
||||
config.update(kwargs)
|
||||
|
||||
# Handle transport parameter
|
||||
if isinstance(transport, Transport):
|
||||
return AITBCClient(transport=transport, config=config)
|
||||
elif transport:
|
||||
config['type'] = transport
|
||||
|
||||
return AITBCClient(transport=config, config=config)
|
||||
|
||||
|
||||
def create_multi_network_client(
|
||||
networks: Dict[str, Dict[str, Any]],
|
||||
default_network: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> AITBCClient:
|
||||
"""
|
||||
Create multi-network AITBC client
|
||||
|
||||
Args:
|
||||
networks: Dictionary of network configurations
|
||||
default_network: Name of default network
|
||||
**kwargs: Additional configuration options
|
||||
|
||||
Returns:
|
||||
AITBCClient instance with multi-network support
|
||||
"""
|
||||
config = {
|
||||
'networks': networks,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
client = AITBCClient(multi_network=True, config=config)
|
||||
|
||||
# Set default network if specified
|
||||
if default_network:
|
||||
network = client.multi_network_client.find_network_by_name(default_network)
|
||||
if network:
|
||||
client.set_default_network(network.chain_id)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
# Legacy aliases for backward compatibility
|
||||
Client = AITBCClient
|
||||
17
python-sdk/aitbc/transport/__init__.py
Normal file
17
python-sdk/aitbc/transport/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Transport layer for AITBC Python SDK
|
||||
"""
|
||||
|
||||
from .base import Transport, TransportError
|
||||
from .http import HTTPTransport
|
||||
from .websocket import WebSocketTransport
|
||||
from .multinetwork import MultiNetworkClient, NetworkConfig
|
||||
|
||||
__all__ = [
|
||||
"Transport",
|
||||
"TransportError",
|
||||
"HTTPTransport",
|
||||
"WebSocketTransport",
|
||||
"MultiNetworkClient",
|
||||
"NetworkConfig",
|
||||
]
|
||||
264
python-sdk/aitbc/transport/base.py
Normal file
264
python-sdk/aitbc/transport/base.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Base transport interface for AITBC Python SDK
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional, AsyncIterator, Union, List
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransportError(Exception):
|
||||
"""Base exception for transport errors"""
|
||||
pass
|
||||
|
||||
|
||||
class TransportConnectionError(TransportError):
|
||||
"""Raised when transport fails to connect"""
|
||||
pass
|
||||
|
||||
|
||||
class TransportRequestError(TransportError):
|
||||
"""Raised when transport request fails"""
|
||||
def __init__(self, message: str, status_code: Optional[int] = None, response: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.response = response
|
||||
|
||||
|
||||
class Transport(ABC):
|
||||
"""Abstract base class for all transports"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self._connected = False
|
||||
self._lock = asyncio.Lock()
|
||||
self._connection_attempts = 0
|
||||
self._max_connection_attempts = config.get('max_connection_attempts', 3)
|
||||
self._retry_delay = config.get('retry_delay', 1)
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""Close connection"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[float] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Make a request"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stream(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Stream responses"""
|
||||
pass
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Check if transport is healthy"""
|
||||
try:
|
||||
if not self._connected:
|
||||
return False
|
||||
|
||||
# Default health check - make a ping request
|
||||
await self.request('GET', '/health')
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Transport health check failed: {e}")
|
||||
return False
|
||||
|
||||
async def ensure_connected(self) -> None:
|
||||
"""Ensure transport is connected, with retry logic"""
|
||||
async with self._lock:
|
||||
if self._connected:
|
||||
return
|
||||
|
||||
while self._connection_attempts < self._max_connection_attempts:
|
||||
try:
|
||||
await self.connect()
|
||||
self._connection_attempts = 0
|
||||
return
|
||||
except Exception as e:
|
||||
self._connection_attempts += 1
|
||||
logger.warning(f"Connection attempt {self._connection_attempts} failed: {e}")
|
||||
|
||||
if self._connection_attempts < self._max_connection_attempts:
|
||||
await asyncio.sleep(self._retry_delay * self._connection_attempts)
|
||||
else:
|
||||
raise TransportConnectionError(
|
||||
f"Failed to connect after {self._max_connection_attempts} attempts"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if transport is connected"""
|
||||
return self._connected
|
||||
|
||||
@property
|
||||
def chain_id(self) -> Optional[int]:
|
||||
"""Get the chain ID this transport is connected to"""
|
||||
return self.config.get('chain_id')
|
||||
|
||||
@property
|
||||
def network_name(self) -> Optional[str]:
|
||||
"""Get the network name"""
|
||||
return self.config.get('network_name')
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value"""
|
||||
return self.config.get(key, default)
|
||||
|
||||
def update_config(self, updates: Dict[str, Any]) -> None:
|
||||
"""Update configuration"""
|
||||
self.config.update(updates)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit"""
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
class BatchTransport(Transport):
|
||||
"""Transport mixin for batch operations"""
|
||||
|
||||
@abstractmethod
|
||||
async def batch_request(
|
||||
self,
|
||||
requests: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Make multiple requests in batch"""
|
||||
pass
|
||||
|
||||
|
||||
class CachedTransport(Transport):
|
||||
"""Transport mixin for caching responses"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self._cache: Dict[str, Any] = {}
|
||||
self._cache_ttl = config.get('cache_ttl', 300) # 5 minutes
|
||||
self._cache_timestamps: Dict[str, float] = {}
|
||||
|
||||
async def cached_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
cache_key: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Make request with caching"""
|
||||
# Only cache GET requests
|
||||
if method.upper() != 'GET':
|
||||
return await self.request(method, path, data, params, headers)
|
||||
|
||||
# Generate cache key
|
||||
if not cache_key:
|
||||
import hashlib
|
||||
import json
|
||||
cache_data = json.dumps({
|
||||
'method': method,
|
||||
'path': path,
|
||||
'params': params
|
||||
}, sort_keys=True)
|
||||
cache_key = hashlib.md5(cache_data.encode()).hexdigest()
|
||||
|
||||
# Check cache
|
||||
if cache_key in self._cache:
|
||||
timestamp = self._cache_timestamps.get(cache_key, 0)
|
||||
if asyncio.get_event_loop().time() - timestamp < self._cache_ttl:
|
||||
return self._cache[cache_key]
|
||||
|
||||
# Make request
|
||||
response = await self.request(method, path, data, params, headers)
|
||||
|
||||
# Cache response
|
||||
self._cache[cache_key] = response
|
||||
self._cache_timestamps[cache_key] = asyncio.get_event_loop().time()
|
||||
|
||||
return response
|
||||
|
||||
def clear_cache(self, pattern: Optional[str] = None) -> None:
|
||||
"""Clear cached responses"""
|
||||
if pattern:
|
||||
import re
|
||||
regex = re.compile(pattern)
|
||||
keys_to_remove = [k for k in self._cache.keys() if regex.match(k)]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
if key in self._cache_timestamps:
|
||||
del self._cache_timestamps[key]
|
||||
else:
|
||||
self._cache.clear()
|
||||
self._cache_timestamps.clear()
|
||||
|
||||
|
||||
class RateLimitedTransport(Transport):
|
||||
"""Transport mixin for rate limiting"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self._rate_limit = config.get('rate_limit', 60) # requests per minute
|
||||
self._rate_window = config.get('rate_window', 60) # seconds
|
||||
self._requests: List[float] = []
|
||||
self._rate_lock = asyncio.Lock()
|
||||
|
||||
async def _check_rate_limit(self) -> None:
|
||||
"""Check if request is within rate limit"""
|
||||
async with self._rate_lock:
|
||||
now = asyncio.get_event_loop().time()
|
||||
|
||||
# Remove old requests outside the window
|
||||
self._requests = [req_time for req_time in self._requests
|
||||
if now - req_time < self._rate_window]
|
||||
|
||||
# Check if we're at the limit
|
||||
if len(self._requests) >= self._rate_limit:
|
||||
# Calculate wait time
|
||||
oldest_request = min(self._requests)
|
||||
wait_time = self._rate_window - (now - oldest_request)
|
||||
|
||||
if wait_time > 0:
|
||||
logger.warning(f"Rate limit reached, waiting {wait_time:.2f} seconds")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
# Add current request
|
||||
self._requests.append(now)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[float] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Make request with rate limiting"""
|
||||
await self._check_rate_limit()
|
||||
return await super().request(method, path, data, params, headers, timeout)
|
||||
405
python-sdk/aitbc/transport/http.py
Normal file
405
python-sdk/aitbc/transport/http.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
HTTP transport implementation for AITBC Python SDK
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, AsyncIterator, Union
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import ClientTimeout, ClientError, ClientResponseError
|
||||
|
||||
from .base import Transport, TransportError, TransportConnectionError, TransportRequestError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HTTPTransport(Transport):
|
||||
"""HTTP transport for REST API calls"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self.base_url = config['base_url'].rstrip('/')
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
self.timeout = ClientTimeout(
|
||||
total=config.get('timeout', 30),
|
||||
connect=config.get('connect_timeout', 10),
|
||||
sock_read=config.get('read_timeout', 30)
|
||||
)
|
||||
self.default_headers = config.get('default_headers', {})
|
||||
self.max_redirects = config.get('max_redirects', 10)
|
||||
self.verify_ssl = config.get('verify_ssl', True)
|
||||
self._last_request_time: Optional[float] = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Create HTTP session"""
|
||||
try:
|
||||
# Configure SSL context
|
||||
ssl_context = None
|
||||
if not self.verify_ssl:
|
||||
import ssl
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Create connector
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=self.config.get('connection_limit', 100),
|
||||
limit_per_host=self.config.get('connection_limit_per_host', 30),
|
||||
ttl_dns_cache=self.config.get('dns_cache_ttl', 300),
|
||||
use_dns_cache=True,
|
||||
ssl=ssl_context,
|
||||
enable_cleanup_closed=True
|
||||
)
|
||||
|
||||
# Create session
|
||||
self.session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=self.timeout,
|
||||
headers=self.default_headers,
|
||||
max_redirects=self.max_redirects,
|
||||
raise_for_status=False # We'll handle status codes manually
|
||||
)
|
||||
|
||||
# Test connection with health check
|
||||
await self.health_check()
|
||||
self._connected = True
|
||||
logger.info(f"HTTP transport connected to {self.base_url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect HTTP transport: {e}")
|
||||
raise TransportConnectionError(f"Connection failed: {e}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close HTTP session"""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
self._connected = False
|
||||
logger.info("HTTP transport disconnected")
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[float] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Make HTTP request"""
|
||||
await self.ensure_connected()
|
||||
|
||||
if not self.session:
|
||||
raise TransportConnectionError("Transport not connected")
|
||||
|
||||
# Prepare URL
|
||||
url = f"{self.base_url}{path}"
|
||||
|
||||
# Prepare headers
|
||||
request_headers = {}
|
||||
if self.default_headers:
|
||||
request_headers.update(self.default_headers)
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
# Add content-type if data is provided
|
||||
if data and 'content-type' not in request_headers:
|
||||
request_headers['content-type'] = 'application/json'
|
||||
|
||||
# Prepare request timeout
|
||||
request_timeout = self.timeout
|
||||
if timeout:
|
||||
request_timeout = ClientTimeout(total=timeout)
|
||||
|
||||
# Log request
|
||||
logger.debug(f"HTTP {method} {url}")
|
||||
|
||||
try:
|
||||
# Make request
|
||||
async with self.session.request(
|
||||
method=method.upper(),
|
||||
url=url,
|
||||
json=data if data and request_headers.get('content-type') == 'application/json' else None,
|
||||
data=data if data and request_headers.get('content-type') != 'application/json' else None,
|
||||
params=params,
|
||||
headers=request_headers,
|
||||
timeout=request_timeout
|
||||
) as response:
|
||||
# Record request time
|
||||
self._last_request_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Handle response
|
||||
await self._handle_response(response)
|
||||
|
||||
# Parse response
|
||||
if response.content_type == 'application/json':
|
||||
result = await response.json()
|
||||
else:
|
||||
result = {'data': await response.text()}
|
||||
|
||||
# Add metadata
|
||||
result['_metadata'] = {
|
||||
'status_code': response.status,
|
||||
'headers': dict(response.headers),
|
||||
'url': str(response.url)
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except ClientResponseError as e:
|
||||
raise TransportRequestError(
|
||||
f"HTTP {e.status}: {e.message}",
|
||||
status_code=e.status,
|
||||
response={'error': e.message}
|
||||
)
|
||||
except ClientError as e:
|
||||
raise TransportError(f"HTTP request failed: {e}")
|
||||
except asyncio.TimeoutError:
|
||||
raise TransportError("Request timed out")
|
||||
except Exception as e:
|
||||
raise TransportError(f"Unexpected error: {e}")
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Stream responses (not supported for basic HTTP)"""
|
||||
raise NotImplementedError("HTTP transport does not support streaming")
|
||||
|
||||
async def download(
|
||||
self,
|
||||
path: str,
|
||||
file_path: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
chunk_size: int = 8192
|
||||
) -> None:
|
||||
"""Download file to disk"""
|
||||
await self.ensure_connected()
|
||||
|
||||
if not self.session:
|
||||
raise TransportConnectionError("Transport not connected")
|
||||
|
||||
url = f"{self.base_url}{path}"
|
||||
|
||||
try:
|
||||
async with self.session.get(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers
|
||||
) as response:
|
||||
await self._handle_response(response)
|
||||
|
||||
# Stream to file
|
||||
with open(file_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(chunk_size):
|
||||
f.write(chunk)
|
||||
|
||||
logger.info(f"Downloaded {url} to {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
raise TransportError(f"Download failed: {e}")
|
||||
|
||||
async def upload(
|
||||
self,
|
||||
path: str,
|
||||
file_path: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
chunk_size: int = 8192
|
||||
) -> Dict[str, Any]:
|
||||
"""Upload file from disk"""
|
||||
await self.ensure_connected()
|
||||
|
||||
if not self.session:
|
||||
raise TransportConnectionError("Transport not connected")
|
||||
|
||||
url = f"{self.base_url}{path}"
|
||||
|
||||
try:
|
||||
# Prepare multipart form data
|
||||
with open(file_path, 'rb') as f:
|
||||
data = aiohttp.FormData()
|
||||
data.add_field(
|
||||
'file',
|
||||
f,
|
||||
filename=file_path.split('/')[-1],
|
||||
content_type='application/octet-stream'
|
||||
)
|
||||
|
||||
# Add additional fields
|
||||
if params:
|
||||
for key, value in params.items():
|
||||
data.add_field(key, str(value))
|
||||
|
||||
async with self.session.post(
|
||||
url,
|
||||
data=data,
|
||||
headers=headers
|
||||
) as response:
|
||||
await self._handle_response(response)
|
||||
|
||||
if response.content_type == 'application/json':
|
||||
return await response.json()
|
||||
else:
|
||||
return {'status': 'uploaded'}
|
||||
|
||||
except Exception as e:
|
||||
raise TransportError(f"Upload failed: {e}")
|
||||
|
||||
async def _handle_response(self, response: aiohttp.ClientResponse) -> None:
|
||||
"""Handle HTTP response"""
|
||||
if response.status >= 400:
|
||||
error_data = {}
|
||||
|
||||
try:
|
||||
if response.content_type == 'application/json':
|
||||
error_data = await response.json()
|
||||
else:
|
||||
error_data = {'error': await response.text()}
|
||||
except:
|
||||
error_data = {'error': f'HTTP {response.status}'}
|
||||
|
||||
raise TransportRequestError(
|
||||
error_data.get('error', f'HTTP {response.status}'),
|
||||
status_code=response.status,
|
||||
response=error_data
|
||||
)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get transport statistics"""
|
||||
stats = {
|
||||
'connected': self._connected,
|
||||
'base_url': self.base_url,
|
||||
'last_request_time': self._last_request_time
|
||||
}
|
||||
|
||||
if self.session:
|
||||
# Get connector stats
|
||||
connector = self.session.connector
|
||||
stats.update({
|
||||
'total_connections': len(connector._conns),
|
||||
'available_connections': sum(len(conns) for conns in connector._conns.values())
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
class AuthenticatedHTTPTransport(HTTPTransport):
|
||||
"""HTTP transport with authentication"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self.auth_type = config.get('auth_type', 'api_key')
|
||||
self.auth_config = config.get('auth', {})
|
||||
|
||||
async def _add_auth_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Add authentication headers"""
|
||||
headers = headers.copy()
|
||||
|
||||
if self.auth_type == 'api_key':
|
||||
api_key = self.auth_config.get('api_key')
|
||||
if api_key:
|
||||
key_header = self.auth_config.get('key_header', 'X-API-Key')
|
||||
headers[key_header] = api_key
|
||||
|
||||
elif self.auth_type == 'bearer':
|
||||
token = self.auth_config.get('token')
|
||||
if token:
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
|
||||
elif self.auth_type == 'basic':
|
||||
username = self.auth_config.get('username')
|
||||
password = self.auth_config.get('password')
|
||||
if username and password:
|
||||
import base64
|
||||
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
headers['Authorization'] = f'Basic {credentials}'
|
||||
|
||||
return headers
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[float] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Make authenticated HTTP request"""
|
||||
# Add auth headers
|
||||
auth_headers = await self._add_auth_headers(headers or {})
|
||||
|
||||
return await super().request(
|
||||
method, path, data, params, auth_headers, timeout
|
||||
)
|
||||
|
||||
|
||||
class RetryableHTTPTransport(HTTPTransport):
|
||||
"""HTTP transport with automatic retry"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self.max_retries = config.get('max_retries', 3)
|
||||
self.retry_delay = config.get('retry_delay', 1)
|
||||
self.retry_backoff = config.get('retry_backoff', 2)
|
||||
self.retry_on = config.get('retry_on', [500, 502, 503, 504])
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[float] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Make HTTP request with retry logic"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await super().request(
|
||||
method, path, data, params, headers, timeout
|
||||
)
|
||||
|
||||
except TransportRequestError as e:
|
||||
last_error = e
|
||||
|
||||
# Check if we should retry
|
||||
if attempt < self.max_retries and e.status_code in self.retry_on:
|
||||
delay = self.retry_delay * (self.retry_backoff ** attempt)
|
||||
logger.warning(
|
||||
f"Request failed (attempt {attempt + 1}/{self.max_retries + 1}), "
|
||||
f"retrying in {delay}s: {e}"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
# Don't retry on client errors or final attempt
|
||||
break
|
||||
|
||||
except TransportError as e:
|
||||
last_error = e
|
||||
|
||||
# Retry on connection errors
|
||||
if attempt < self.max_retries:
|
||||
delay = self.retry_delay * (self.retry_backoff ** attempt)
|
||||
logger.warning(
|
||||
f"Request failed (attempt {attempt + 1}/{self.max_retries + 1}), "
|
||||
f"retrying in {delay}s: {e}"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
# All retries failed
|
||||
raise last_error
|
||||
377
python-sdk/aitbc/transport/multinetwork.py
Normal file
377
python-sdk/aitbc/transport/multinetwork.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Multi-network support for AITBC Python SDK
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Union
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from .base import Transport, TransportError, TransportConnectionError
|
||||
from .http import HTTPTransport
|
||||
from .websocket import WebSocketTransport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetworkConfig:
|
||||
"""Configuration for a network"""
|
||||
name: str
|
||||
chain_id: int
|
||||
transport: Transport
|
||||
is_default: bool = False
|
||||
bridges: List[str] = field(default_factory=list)
|
||||
explorer_url: Optional[str] = None
|
||||
rpc_url: Optional[str] = None
|
||||
native_token: str = "ETH"
|
||||
gas_token: Optional[str] = None
|
||||
|
||||
|
||||
class MultiNetworkClient:
|
||||
"""Client supporting multiple networks and cross-chain operations"""
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
self.networks: Dict[int, NetworkConfig] = {}
|
||||
self.default_network: Optional[int] = None
|
||||
self._connected = False
|
||||
self._connection_lock = asyncio.Lock()
|
||||
|
||||
if config:
|
||||
self._load_config(config)
|
||||
|
||||
def _load_config(self, config: Dict[str, Any]) -> None:
|
||||
"""Load network configurations"""
|
||||
networks_config = config.get('networks', {})
|
||||
|
||||
for name, net_config in networks_config.items():
|
||||
# Create transport
|
||||
transport = self._create_transport(net_config)
|
||||
|
||||
# Create network config
|
||||
network = NetworkConfig(
|
||||
name=name,
|
||||
chain_id=net_config['chain_id'],
|
||||
transport=transport,
|
||||
is_default=net_config.get('default', False),
|
||||
bridges=net_config.get('bridges', []),
|
||||
explorer_url=net_config.get('explorer_url'),
|
||||
rpc_url=net_config.get('rpc_url'),
|
||||
native_token=net_config.get('native_token', 'ETH'),
|
||||
gas_token=net_config.get('gas_token')
|
||||
)
|
||||
|
||||
self.add_network(network)
|
||||
|
||||
def _create_transport(self, config: Dict[str, Any]) -> Transport:
|
||||
"""Create transport from config"""
|
||||
transport_type = config.get('type', 'http')
|
||||
transport_config = config.copy()
|
||||
|
||||
if transport_type == 'http':
|
||||
return HTTPTransport(transport_config)
|
||||
elif transport_type == 'websocket':
|
||||
return WebSocketTransport(transport_config)
|
||||
else:
|
||||
raise ValueError(f"Unknown transport type: {transport_type}")
|
||||
|
||||
def add_network(self, network: NetworkConfig) -> None:
|
||||
"""Add a network configuration"""
|
||||
if network.chain_id in self.networks:
|
||||
logger.warning(f"Network {network.chain_id} already exists, overwriting")
|
||||
|
||||
self.networks[network.chain_id] = network
|
||||
|
||||
# Set as default if marked or if no default exists
|
||||
if network.is_default or self.default_network is None:
|
||||
self.default_network = network.chain_id
|
||||
|
||||
logger.info(f"Added network: {network.name} (chain_id: {network.chain_id})")
|
||||
|
||||
def remove_network(self, chain_id: int) -> None:
|
||||
"""Remove a network configuration"""
|
||||
if chain_id in self.networks:
|
||||
network = self.networks[chain_id]
|
||||
|
||||
# Disconnect if connected
|
||||
if network.transport.is_connected:
|
||||
asyncio.create_task(network.transport.disconnect())
|
||||
|
||||
del self.networks[chain_id]
|
||||
|
||||
# Update default if necessary
|
||||
if self.default_network == chain_id:
|
||||
self.default_network = None
|
||||
# Set new default if other networks exist
|
||||
if self.networks:
|
||||
self.default_network = next(iter(self.networks))
|
||||
|
||||
logger.info(f"Removed network: {network.name} (chain_id: {chain_id})")
|
||||
|
||||
def get_transport(self, chain_id: Optional[int] = None) -> Transport:
|
||||
"""Get transport for a network"""
|
||||
network_id = chain_id or self.default_network
|
||||
|
||||
if network_id is None:
|
||||
raise ValueError("No default network configured")
|
||||
|
||||
if network_id not in self.networks:
|
||||
raise ValueError(f"Network {network_id} not configured")
|
||||
|
||||
return self.networks[network_id].transport
|
||||
|
||||
def get_network(self, chain_id: int) -> Optional[NetworkConfig]:
|
||||
"""Get network configuration"""
|
||||
return self.networks.get(chain_id)
|
||||
|
||||
def list_networks(self) -> List[NetworkConfig]:
|
||||
"""List all configured networks"""
|
||||
return list(self.networks.values())
|
||||
|
||||
def get_default_network(self) -> Optional[NetworkConfig]:
|
||||
"""Get default network configuration"""
|
||||
if self.default_network:
|
||||
return self.networks.get(self.default_network)
|
||||
return None
|
||||
|
||||
def set_default_network(self, chain_id: int) -> None:
|
||||
"""Set default network"""
|
||||
if chain_id not in self.networks:
|
||||
raise ValueError(f"Network {chain_id} not configured")
|
||||
|
||||
self.default_network = chain_id
|
||||
|
||||
# Update all networks' default flag
|
||||
for net in self.networks.values():
|
||||
net.is_default = (net.chain_id == chain_id)
|
||||
|
||||
async def connect_all(self) -> None:
|
||||
"""Connect to all configured networks"""
|
||||
async with self._connection_lock:
|
||||
if self._connected:
|
||||
return
|
||||
|
||||
logger.info(f"Connecting to {len(self.networks)} networks...")
|
||||
|
||||
# Connect all transports
|
||||
tasks = []
|
||||
for chain_id, network in self.networks.items():
|
||||
task = asyncio.create_task(
|
||||
self._connect_network(network),
|
||||
name=f"connect_{network.name}"
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all connections
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check for errors
|
||||
errors = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
network_name = list(self.networks.values())[i].name
|
||||
errors.append(f"{network_name}: {result}")
|
||||
logger.error(f"Failed to connect to {network_name}: {result}")
|
||||
|
||||
if errors:
|
||||
raise TransportConnectionError(
|
||||
f"Failed to connect to some networks: {'; '.join(errors)}"
|
||||
)
|
||||
|
||||
self._connected = True
|
||||
logger.info("Connected to all networks")
|
||||
|
||||
async def disconnect_all(self) -> None:
|
||||
"""Disconnect from all networks"""
|
||||
async with self._connection_lock:
|
||||
if not self._connected:
|
||||
return
|
||||
|
||||
logger.info("Disconnecting from all networks...")
|
||||
|
||||
# Disconnect all transports
|
||||
tasks = []
|
||||
for network in self.networks.values():
|
||||
if network.transport.is_connected:
|
||||
task = asyncio.create_task(
|
||||
network.transport.disconnect(),
|
||||
name=f"disconnect_{network.name}"
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
self._connected = False
|
||||
logger.info("Disconnected from all networks")
|
||||
|
||||
async def connect_network(self, chain_id: int) -> None:
|
||||
"""Connect to a specific network"""
|
||||
network = self.networks.get(chain_id)
|
||||
if not network:
|
||||
raise ValueError(f"Network {chain_id} not configured")
|
||||
|
||||
await self._connect_network(network)
|
||||
|
||||
async def disconnect_network(self, chain_id: int) -> None:
|
||||
"""Disconnect from a specific network"""
|
||||
network = self.networks.get(chain_id)
|
||||
if not network:
|
||||
raise ValueError(f"Network {chain_id} not configured")
|
||||
|
||||
if network.transport.is_connected:
|
||||
await network.transport.disconnect()
|
||||
|
||||
async def _connect_network(self, network: NetworkConfig) -> None:
|
||||
"""Connect to a specific network"""
|
||||
try:
|
||||
if not network.transport.is_connected:
|
||||
await network.transport.connect()
|
||||
logger.info(f"Connected to {network.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to {network.name}: {e}")
|
||||
raise
|
||||
|
||||
async def switch_network(self, chain_id: int) -> None:
|
||||
"""Switch default network"""
|
||||
if chain_id not in self.networks:
|
||||
raise ValueError(f"Network {chain_id} not configured")
|
||||
|
||||
# Connect if not connected
|
||||
network = self.networks[chain_id]
|
||||
if not network.transport.is_connected:
|
||||
await self._connect_network(network)
|
||||
|
||||
# Set as default
|
||||
self.set_default_network(chain_id)
|
||||
logger.info(f"Switched to network: {network.name}")
|
||||
|
||||
async def health_check_all(self) -> Dict[int, bool]:
|
||||
"""Check health of all networks"""
|
||||
results = {}
|
||||
|
||||
for chain_id, network in self.networks.items():
|
||||
try:
|
||||
results[chain_id] = await network.transport.health_check()
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed for {network.name}: {e}")
|
||||
results[chain_id] = False
|
||||
|
||||
return results
|
||||
|
||||
async def broadcast_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
chain_ids: Optional[List[int]] = None
|
||||
) -> Dict[int, Dict[str, Any]]:
|
||||
"""Broadcast request to multiple networks"""
|
||||
if chain_ids is None:
|
||||
chain_ids = list(self.networks.keys())
|
||||
|
||||
results = {}
|
||||
|
||||
# Make requests in parallel
|
||||
tasks = {}
|
||||
for chain_id in chain_ids:
|
||||
if chain_id in self.networks:
|
||||
transport = self.networks[chain_id].transport
|
||||
task = asyncio.create_task(
|
||||
transport.request(method, path, data, params, headers),
|
||||
name=f"request_{chain_id}"
|
||||
)
|
||||
tasks[chain_id] = task
|
||||
|
||||
# Wait for all requests
|
||||
for chain_id, task in tasks.items():
|
||||
try:
|
||||
results[chain_id] = await task
|
||||
except Exception as e:
|
||||
network_name = self.networks[chain_id].name
|
||||
logger.error(f"Request failed for {network_name}: {e}")
|
||||
results[chain_id] = {'error': str(e)}
|
||||
|
||||
return results
|
||||
|
||||
def get_network_stats(self) -> Dict[int, Dict[str, Any]]:
|
||||
"""Get statistics for all networks"""
|
||||
stats = {}
|
||||
|
||||
for chain_id, network in self.networks.items():
|
||||
network_stats = {
|
||||
'name': network.name,
|
||||
'chain_id': network.chain_id,
|
||||
'is_default': network.is_default,
|
||||
'bridges': network.bridges,
|
||||
'explorer_url': network.explorer_url,
|
||||
'rpc_url': network.rpc_url,
|
||||
'native_token': network.native_token,
|
||||
'gas_token': network.gas_token
|
||||
}
|
||||
|
||||
# Add transport stats if available
|
||||
if hasattr(network.transport, 'get_stats'):
|
||||
network_stats['transport'] = network.transport.get_stats()
|
||||
|
||||
stats[chain_id] = network_stats
|
||||
|
||||
return stats
|
||||
|
||||
def find_network_by_name(self, name: str) -> Optional[NetworkConfig]:
|
||||
"""Find network by name"""
|
||||
for network in self.networks.values():
|
||||
if network.name == name:
|
||||
return network
|
||||
return None
|
||||
|
||||
def find_networks_by_bridge(self, bridge: str) -> List[NetworkConfig]:
|
||||
"""Find networks that support a specific bridge"""
|
||||
return [
|
||||
network for network in self.networks.values()
|
||||
if bridge in network.bridges
|
||||
]
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry"""
|
||||
await self.connect_all()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit"""
|
||||
await self.disconnect_all()
|
||||
|
||||
|
||||
class NetworkSwitcher:
|
||||
"""Utility for switching between networks"""
|
||||
|
||||
def __init__(self, client: MultiNetworkClient):
|
||||
self.client = client
|
||||
self._original_default: Optional[int] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Store original default network"""
|
||||
self._original_default = self.client.default_network
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Restore original default network"""
|
||||
if self._original_default:
|
||||
await self.client.switch_network(self._original_default)
|
||||
|
||||
async def switch_to(self, chain_id: int):
|
||||
"""Switch to specific network"""
|
||||
await self.client.switch_network(chain_id)
|
||||
return self
|
||||
|
||||
async def switch_to_name(self, name: str):
|
||||
"""Switch to network by name"""
|
||||
network = self.client.find_network_by_name(name)
|
||||
if not network:
|
||||
raise ValueError(f"Network {name} not found")
|
||||
|
||||
await self.switch_to(network.chain_id)
|
||||
return self
|
||||
449
python-sdk/aitbc/transport/websocket.py
Normal file
449
python-sdk/aitbc/transport/websocket.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
WebSocket transport implementation for AITBC Python SDK
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, AsyncIterator, Callable
|
||||
from datetime import datetime
|
||||
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed, ConnectionClosedError, ConnectionClosedOK
|
||||
|
||||
from .base import Transport, TransportError, TransportConnectionError, TransportRequestError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSocketTransport(Transport):
|
||||
"""WebSocket transport for real-time updates"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self.ws_url = config['ws_url']
|
||||
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
|
||||
self._subscriptions: Dict[str, Dict[str, Any]] = {}
|
||||
self._message_handlers: Dict[str, Callable] = {}
|
||||
self._message_queue = asyncio.Queue()
|
||||
self._consumer_task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_interval = config.get('heartbeat_interval', 30)
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._reconnect_enabled = config.get('reconnect', True)
|
||||
self._max_reconnect_attempts = config.get('max_reconnect_attempts', 5)
|
||||
self._reconnect_delay = config.get('reconnect_delay', 5)
|
||||
self._ping_timeout = config.get('ping_timeout', 20)
|
||||
self._close_code: Optional[int] = None
|
||||
self._close_reason: Optional[str] = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to WebSocket"""
|
||||
try:
|
||||
# Prepare connection parameters
|
||||
extra_headers = self.config.get('headers', {})
|
||||
ping_interval = self.config.get('ping_interval', self._heartbeat_interval)
|
||||
ping_timeout = self._ping_timeout
|
||||
|
||||
# Connect to WebSocket
|
||||
logger.info(f"Connecting to WebSocket: {self.ws_url}")
|
||||
self.websocket = await websockets.connect(
|
||||
self.ws_url,
|
||||
extra_headers=extra_headers,
|
||||
ping_interval=ping_interval,
|
||||
ping_timeout=ping_timeout,
|
||||
close_timeout=self.config.get('close_timeout', 10)
|
||||
)
|
||||
|
||||
# Start consumer task
|
||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||
|
||||
# Start heartbeat task
|
||||
self._heartbeat_task = asyncio.create_task(self._heartbeat())
|
||||
|
||||
self._connected = True
|
||||
logger.info("WebSocket transport connected")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect WebSocket: {e}")
|
||||
raise TransportConnectionError(f"WebSocket connection failed: {e}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect WebSocket"""
|
||||
self._connected = False
|
||||
|
||||
# Cancel tasks
|
||||
if self._consumer_task:
|
||||
self._consumer_task.cancel()
|
||||
try:
|
||||
await self._consumer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
try:
|
||||
await self._heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close WebSocket
|
||||
if self.websocket:
|
||||
try:
|
||||
await self.websocket.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing WebSocket: {e}")
|
||||
finally:
|
||||
self.websocket = None
|
||||
|
||||
logger.info("WebSocket transport disconnected")
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[float] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Send request via WebSocket"""
|
||||
await self.ensure_connected()
|
||||
|
||||
if not self.websocket:
|
||||
raise TransportConnectionError("WebSocket not connected")
|
||||
|
||||
# Generate request ID
|
||||
request_id = self._generate_id()
|
||||
|
||||
# Create message
|
||||
message = {
|
||||
'id': request_id,
|
||||
'type': 'request',
|
||||
'method': method,
|
||||
'path': path,
|
||||
'data': data,
|
||||
'params': params,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Send request
|
||||
await self._send_message(message)
|
||||
|
||||
# Wait for response
|
||||
timeout = timeout or self.config.get('request_timeout', 30)
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
self._wait_for_response(request_id),
|
||||
timeout=timeout
|
||||
)
|
||||
return response
|
||||
except asyncio.TimeoutError:
|
||||
raise TransportError(f"Request timed out after {timeout}s")
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Stream responses from WebSocket"""
|
||||
await self.ensure_connected()
|
||||
|
||||
# Create subscription
|
||||
subscription_id = self._generate_id()
|
||||
|
||||
# Subscribe
|
||||
message = {
|
||||
'id': subscription_id,
|
||||
'type': 'subscribe',
|
||||
'method': method,
|
||||
'path': path,
|
||||
'data': data,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self._send_message(message)
|
||||
|
||||
# Store subscription
|
||||
self._subscriptions[subscription_id] = {
|
||||
'method': method,
|
||||
'path': path,
|
||||
'created_at': datetime.utcnow()
|
||||
}
|
||||
|
||||
try:
|
||||
# Yield messages as they come
|
||||
async for message in self._stream_subscription(subscription_id):
|
||||
yield message
|
||||
finally:
|
||||
# Unsubscribe
|
||||
await self._unsubscribe(subscription_id)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
event: str,
|
||||
callback: Callable[[Dict[str, Any]], None],
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""Subscribe to events"""
|
||||
await self.ensure_connected()
|
||||
|
||||
subscription_id = self._generate_id()
|
||||
|
||||
# Store subscription with callback
|
||||
self._subscriptions[subscription_id] = {
|
||||
'event': event,
|
||||
'callback': callback,
|
||||
'data': data,
|
||||
'created_at': datetime.utcnow()
|
||||
}
|
||||
|
||||
# Send subscription message
|
||||
message = {
|
||||
'id': subscription_id,
|
||||
'type': 'subscribe',
|
||||
'event': event,
|
||||
'data': data,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self._send_message(message)
|
||||
|
||||
logger.info(f"Subscribed to event: {event}")
|
||||
return subscription_id
|
||||
|
||||
async def unsubscribe(self, subscription_id: str) -> None:
|
||||
"""Unsubscribe from events"""
|
||||
if subscription_id in self._subscriptions:
|
||||
# Send unsubscribe message
|
||||
message = {
|
||||
'id': subscription_id,
|
||||
'type': 'unsubscribe',
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self._send_message(message)
|
||||
|
||||
# Remove subscription
|
||||
del self._subscriptions[subscription_id]
|
||||
|
||||
logger.info(f"Unsubscribed: {subscription_id}")
|
||||
|
||||
async def emit(self, event: str, data: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Emit event to server"""
|
||||
await self.ensure_connected()
|
||||
|
||||
message = {
|
||||
'type': 'event',
|
||||
'event': event,
|
||||
'data': data,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
await self._send_message(message)
|
||||
|
||||
async def _send_message(self, message: Dict[str, Any]) -> None:
|
||||
"""Send message to WebSocket"""
|
||||
if not self.websocket:
|
||||
raise TransportConnectionError("WebSocket not connected")
|
||||
|
||||
try:
|
||||
await self.websocket.send(json.dumps(message))
|
||||
logger.debug(f"Sent WebSocket message: {message.get('type', 'unknown')}")
|
||||
except ConnectionClosed:
|
||||
await self._handle_disconnect()
|
||||
raise TransportConnectionError("WebSocket connection closed")
|
||||
except Exception as e:
|
||||
raise TransportError(f"Failed to send message: {e}")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
"""Consume messages from WebSocket"""
|
||||
while self._connected:
|
||||
try:
|
||||
# Wait for message
|
||||
message = await asyncio.wait_for(
|
||||
self.websocket.recv(),
|
||||
timeout=self._heartbeat_interval * 2
|
||||
)
|
||||
|
||||
# Parse message
|
||||
try:
|
||||
data = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Invalid JSON message: {message}")
|
||||
continue
|
||||
|
||||
# Handle message
|
||||
await self._handle_message(data)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No message received, check connection
|
||||
continue
|
||||
except ConnectionClosedOK:
|
||||
logger.info("WebSocket closed normally")
|
||||
break
|
||||
except ConnectionClosedError as e:
|
||||
logger.warning(f"WebSocket connection closed: {e}")
|
||||
await self._handle_disconnect()
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error consuming message: {e}")
|
||||
break
|
||||
|
||||
async def _handle_message(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle incoming message"""
|
||||
message_type = data.get('type')
|
||||
|
||||
if message_type == 'response':
|
||||
# Request response
|
||||
await self._message_queue.put(data)
|
||||
|
||||
elif message_type == 'event':
|
||||
# Event message
|
||||
await self._handle_event(data)
|
||||
|
||||
elif message_type == 'subscription':
|
||||
# Subscription update
|
||||
await self._handle_subscription_update(data)
|
||||
|
||||
elif message_type == 'error':
|
||||
# Error message
|
||||
logger.error(f"WebSocket error: {data.get('message')}")
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown message type: {message_type}")
|
||||
|
||||
async def _handle_event(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle event message"""
|
||||
event = data.get('event')
|
||||
event_data = data.get('data')
|
||||
|
||||
# Find matching subscriptions
|
||||
for sub_id, sub in self._subscriptions.items():
|
||||
if sub.get('event') == event:
|
||||
callback = sub.get('callback')
|
||||
if callback:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(event_data)
|
||||
else:
|
||||
callback(event_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in event callback: {e}")
|
||||
|
||||
async def _handle_subscription_update(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle subscription update"""
|
||||
subscription_id = data.get('subscription_id')
|
||||
status = data.get('status')
|
||||
|
||||
if subscription_id in self._subscriptions:
|
||||
sub = self._subscriptions[subscription_id]
|
||||
sub['status'] = status
|
||||
|
||||
if status == 'confirmed':
|
||||
logger.info(f"Subscription confirmed: {subscription_id}")
|
||||
elif status == 'error':
|
||||
logger.error(f"Subscription error: {subscription_id}")
|
||||
|
||||
async def _wait_for_response(self, request_id: str) -> Dict[str, Any]:
|
||||
"""Wait for specific response"""
|
||||
while True:
|
||||
message = await self._message_queue.get()
|
||||
|
||||
if message.get('id') == request_id:
|
||||
if message.get('type') == 'error':
|
||||
raise TransportRequestError(
|
||||
message.get('message', 'Request failed')
|
||||
)
|
||||
return message
|
||||
|
||||
async def _stream_subscription(self, subscription_id: str) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Stream messages for subscription"""
|
||||
queue = asyncio.Queue()
|
||||
|
||||
# Add queue to subscriptions
|
||||
if subscription_id in self._subscriptions:
|
||||
self._subscriptions[subscription_id]['queue'] = queue
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await queue.get()
|
||||
if message.get('type') == 'unsubscribe':
|
||||
break
|
||||
yield message
|
||||
finally:
|
||||
# Clean up queue
|
||||
if subscription_id in self._subscriptions:
|
||||
self._subscriptions[subscription_id].pop('queue', None)
|
||||
|
||||
async def _unsubscribe(self, subscription_id: str) -> None:
|
||||
"""Unsubscribe and clean up"""
|
||||
await self.unsubscribe(subscription_id)
|
||||
|
||||
async def _heartbeat(self) -> None:
|
||||
"""Send periodic heartbeat"""
|
||||
while self._connected:
|
||||
try:
|
||||
await asyncio.sleep(self._heartbeat_interval)
|
||||
|
||||
if self.websocket and self._connected:
|
||||
# Send ping
|
||||
await self.websocket.ping()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Heartbeat failed: {e}")
|
||||
break
|
||||
|
||||
async def _handle_disconnect(self) -> None:
|
||||
"""Handle unexpected disconnect"""
|
||||
self._connected = False
|
||||
|
||||
if self._reconnect_enabled:
|
||||
logger.info("Attempting to reconnect...")
|
||||
await self._reconnect()
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
"""Attempt to reconnect"""
|
||||
for attempt in range(self._max_reconnect_attempts):
|
||||
try:
|
||||
logger.info(f"Reconnect attempt {attempt + 1}/{self._max_reconnect_attempts}")
|
||||
|
||||
# Wait before reconnect
|
||||
await asyncio.sleep(self._reconnect_delay)
|
||||
|
||||
# Reconnect
|
||||
await self.connect()
|
||||
|
||||
# Resubscribe to all subscriptions
|
||||
for sub_id, sub in list(self._subscriptions.items()):
|
||||
if sub.get('event'):
|
||||
await self.subscribe(
|
||||
sub['event'],
|
||||
sub['callback'],
|
||||
sub.get('data')
|
||||
)
|
||||
|
||||
logger.info("Reconnected successfully")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Reconnect attempt {attempt + 1} failed: {e}")
|
||||
|
||||
logger.error("Failed to reconnect after all attempts")
|
||||
|
||||
def _generate_id(self) -> str:
|
||||
"""Generate unique ID"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get transport statistics"""
|
||||
return {
|
||||
'connected': self._connected,
|
||||
'ws_url': self.ws_url,
|
||||
'subscriptions': len(self._subscriptions),
|
||||
'close_code': self._close_code,
|
||||
'close_reason': self._close_reason
|
||||
}
|
||||
Reference in New Issue
Block a user