feat: add marketplace metrics, privacy features, and service registry endpoints
- Add Prometheus metrics for marketplace API throughput and error rates with new dashboard panels - Implement confidential transaction models with encryption support and access control - Add key management system with registration, rotation, and audit logging - Create services and registry routers for service discovery and management - Integrate ZK proof generation for privacy-preserving receipts - Add metrics instru
This commit is contained in:
@ -0,0 +1,30 @@
|
||||
"""
|
||||
AITBC Enterprise Connectors SDK
|
||||
|
||||
Python SDK for integrating AITBC with enterprise systems including
|
||||
payment processors, ERP systems, and other business applications.
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "AITBC Team"
|
||||
|
||||
from .core import AITBCClient, ConnectorConfig
|
||||
from .base import BaseConnector
|
||||
from .exceptions import (
|
||||
AITBCError,
|
||||
AuthenticationError,
|
||||
RateLimitError,
|
||||
APIError,
|
||||
ConfigurationError
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AITBCClient",
|
||||
"ConnectorConfig",
|
||||
"BaseConnector",
|
||||
"AITBCError",
|
||||
"AuthenticationError",
|
||||
"RateLimitError",
|
||||
"APIError",
|
||||
"ConfigurationError",
|
||||
]
|
||||
207
enterprise-connectors/python-sdk/aitbc_enterprise/auth.py
Normal file
207
enterprise-connectors/python-sdk/aitbc_enterprise/auth.py
Normal file
@ -0,0 +1,207 @@
|
||||
"""
|
||||
Authentication handlers for AITBC Enterprise Connectors
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from .core import ConnectorConfig
|
||||
from .exceptions import AuthenticationError
|
||||
|
||||
|
||||
class AuthHandler(ABC):
|
||||
"""Abstract base class for authentication handlers"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_headers(self) -> Dict[str, str]:
|
||||
"""Get authentication headers"""
|
||||
pass
|
||||
|
||||
|
||||
class BearerAuthHandler(AuthHandler):
|
||||
"""Bearer token authentication"""
|
||||
|
||||
def __init__(self, config: ConnectorConfig):
|
||||
self.api_key = config.api_key
|
||||
|
||||
async def get_headers(self) -> Dict[str, str]:
|
||||
"""Get Bearer token headers"""
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
|
||||
|
||||
class BasicAuthHandler(AuthHandler):
|
||||
"""Basic authentication"""
|
||||
|
||||
def __init__(self, config: ConnectorConfig):
|
||||
self.username = config.auth_config.get("username")
|
||||
self.password = config.auth_config.get("password")
|
||||
|
||||
async def get_headers(self) -> Dict[str, str]:
|
||||
"""Get Basic auth headers"""
|
||||
if not self.username or not self.password:
|
||||
raise AuthenticationError("Username and password required for Basic auth")
|
||||
|
||||
credentials = f"{self.username}:{self.password}"
|
||||
encoded = base64.b64encode(credentials.encode()).decode()
|
||||
|
||||
return {
|
||||
"Authorization": f"Basic {encoded}"
|
||||
}
|
||||
|
||||
|
||||
class APIKeyAuthHandler(AuthHandler):
|
||||
"""API key authentication (custom header)"""
|
||||
|
||||
def __init__(self, config: ConnectorConfig):
|
||||
self.api_key = config.api_key
|
||||
self.header_name = config.auth_config.get("header_name", "X-API-Key")
|
||||
|
||||
async def get_headers(self) -> Dict[str, str]:
|
||||
"""Get API key headers"""
|
||||
return {
|
||||
self.header_name: self.api_key
|
||||
}
|
||||
|
||||
|
||||
class HMACAuthHandler(AuthHandler):
|
||||
"""HMAC signature authentication"""
|
||||
|
||||
def __init__(self, config: ConnectorConfig):
|
||||
self.api_key = config.api_key
|
||||
self.secret = config.auth_config.get("secret")
|
||||
self.algorithm = config.auth_config.get("algorithm", "sha256")
|
||||
|
||||
async def get_headers(self) -> Dict[str, str]:
|
||||
"""Get HMAC signature headers"""
|
||||
if not self.secret:
|
||||
raise AuthenticationError("Secret required for HMAC auth")
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
message = f"{timestamp}:{self.api_key}"
|
||||
|
||||
signature = hmac.new(
|
||||
self.secret.encode(),
|
||||
message.encode(),
|
||||
getattr(hashlib, self.algorithm)
|
||||
).hexdigest()
|
||||
|
||||
return {
|
||||
"X-API-Key": self.api_key,
|
||||
"X-Timestamp": timestamp,
|
||||
"X-Signature": signature
|
||||
}
|
||||
|
||||
|
||||
class OAuth2Handler(AuthHandler):
|
||||
"""OAuth 2.0 authentication"""
|
||||
|
||||
def __init__(self, config: ConnectorConfig):
|
||||
self.client_id = config.auth_config.get("client_id")
|
||||
self.client_secret = config.auth_config.get("client_secret")
|
||||
self.token_url = config.auth_config.get("token_url")
|
||||
self.scope = config.auth_config.get("scope", "")
|
||||
|
||||
self._access_token = None
|
||||
self._refresh_token = None
|
||||
self._expires_at = None
|
||||
|
||||
async def get_headers(self) -> Dict[str, str]:
|
||||
"""Get OAuth 2.0 headers"""
|
||||
if not self._is_token_valid():
|
||||
await self._refresh_access_token()
|
||||
|
||||
return {
|
||||
"Authorization": f"Bearer {self._access_token}"
|
||||
}
|
||||
|
||||
def _is_token_valid(self) -> bool:
|
||||
"""Check if access token is valid"""
|
||||
if not self._access_token or not self._expires_at:
|
||||
return False
|
||||
|
||||
# Refresh 5 minutes before expiry
|
||||
return datetime.utcnow() < (self._expires_at - timedelta(minutes=5))
|
||||
|
||||
async def _refresh_access_token(self):
|
||||
"""Refresh OAuth 2.0 access token"""
|
||||
import aiohttp
|
||||
|
||||
data = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"scope": self.scope
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self.token_url, data=data) as response:
|
||||
if response.status != 200:
|
||||
raise AuthenticationError(f"OAuth token request failed: {response.status}")
|
||||
|
||||
token_data = await response.json()
|
||||
|
||||
self._access_token = token_data["access_token"]
|
||||
self._refresh_token = token_data.get("refresh_token")
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
|
||||
class CertificateAuthHandler(AuthHandler):
|
||||
"""Certificate-based authentication"""
|
||||
|
||||
def __init__(self, config: ConnectorConfig):
|
||||
self.cert_path = config.auth_config.get("cert_path")
|
||||
self.key_path = config.auth_config.get("key_path")
|
||||
self.passphrase = config.auth_config.get("passphrase")
|
||||
|
||||
async def get_headers(self) -> Dict[str, str]:
|
||||
"""Certificate auth uses client cert, not headers"""
|
||||
return {}
|
||||
|
||||
def get_ssl_context(self):
|
||||
"""Get SSL context for certificate authentication"""
|
||||
import ssl
|
||||
|
||||
context = ssl.create_default_context()
|
||||
|
||||
if self.cert_path and self.key_path:
|
||||
context.load_cert_chain(
|
||||
self.cert_path,
|
||||
self.key_path,
|
||||
password=self.passphrase
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class AuthHandlerFactory:
|
||||
"""Factory for creating authentication handlers"""
|
||||
|
||||
@staticmethod
|
||||
def create(config: ConnectorConfig) -> AuthHandler:
|
||||
"""Create appropriate auth handler based on config"""
|
||||
auth_type = config.auth_type.lower()
|
||||
|
||||
if auth_type == "bearer":
|
||||
return BearerAuthHandler(config)
|
||||
elif auth_type == "basic":
|
||||
return BasicAuthHandler(config)
|
||||
elif auth_type == "api_key":
|
||||
return APIKeyAuthHandler(config)
|
||||
elif auth_type == "hmac":
|
||||
return HMACAuthHandler(config)
|
||||
elif auth_type == "oauth2":
|
||||
return OAuth2Handler(config)
|
||||
elif auth_type == "certificate":
|
||||
return CertificateAuthHandler(config)
|
||||
else:
|
||||
raise AuthenticationError(f"Unsupported auth type: {auth_type}")
|
||||
369
enterprise-connectors/python-sdk/aitbc_enterprise/base.py
Normal file
369
enterprise-connectors/python-sdk/aitbc_enterprise/base.py
Normal file
@ -0,0 +1,369 @@
|
||||
"""
|
||||
Base connector class for AITBC Enterprise Connectors
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional, List, Union, Callable, Awaitable
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
|
||||
from .core import AITBCClient, ConnectorConfig
|
||||
from .exceptions import AITBCError, ConnectorError, ValidationError
|
||||
from .webhooks import WebhookHandler
|
||||
from .validators import BaseValidator
|
||||
|
||||
|
||||
@dataclass
|
||||
class OperationResult:
|
||||
"""Result of a connector operation"""
|
||||
success: bool
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
timestamp: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp is None:
|
||||
self.timestamp = datetime.utcnow()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transaction:
|
||||
"""Standard transaction representation"""
|
||||
id: str
|
||||
amount: float
|
||||
currency: str
|
||||
status: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"amount": self.amount,
|
||||
"currency": self.currency,
|
||||
"status": self.status,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"metadata": self.metadata or {}
|
||||
}
|
||||
|
||||
|
||||
class BaseConnector(ABC):
|
||||
"""Base class for all enterprise connectors"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: AITBCClient,
|
||||
config: ConnectorConfig,
|
||||
validator: Optional[BaseValidator] = None,
|
||||
webhook_handler: Optional[WebhookHandler] = None
|
||||
):
|
||||
self.client = client
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
# Injected dependencies
|
||||
self.validator = validator
|
||||
self.webhook_handler = webhook_handler
|
||||
|
||||
# Connector state
|
||||
self._initialized = False
|
||||
self._last_sync = None
|
||||
|
||||
# Event handlers
|
||||
self._operation_handlers: Dict[str, List[Callable]] = {}
|
||||
|
||||
# Metrics
|
||||
self._operation_count = 0
|
||||
self._error_count = 0
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the connector"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# Perform connector-specific initialization
|
||||
await self._initialize()
|
||||
|
||||
# Set up webhooks if configured
|
||||
if self.config.webhook_endpoint and self.webhook_handler:
|
||||
await self._setup_webhooks()
|
||||
|
||||
# Register event handlers
|
||||
self._register_handlers()
|
||||
|
||||
self._initialized = True
|
||||
self.logger.info(f"{self.__class__.__name__} initialized")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize {self.__class__.__name__}: {e}")
|
||||
raise ConnectorError(f"Initialization failed: {e}")
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup connector resources"""
|
||||
try:
|
||||
# Perform connector-specific cleanup
|
||||
await self._cleanup()
|
||||
|
||||
# Cleanup webhooks
|
||||
if self.webhook_handler:
|
||||
await self.webhook_handler.cleanup()
|
||||
|
||||
self._initialized = False
|
||||
self.logger.info(f"{self.__class__.__name__} cleaned up")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during cleanup: {e}")
|
||||
|
||||
async def execute_operation(
|
||||
self,
|
||||
operation: str,
|
||||
data: Dict[str, Any],
|
||||
**kwargs
|
||||
) -> OperationResult:
|
||||
"""Execute an operation with validation and error handling"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
# Validate input if validator is configured
|
||||
if self.validator:
|
||||
await self.validator.validate(operation, data)
|
||||
|
||||
# Pre-operation hook
|
||||
await self._before_operation(operation, data)
|
||||
|
||||
# Execute the operation
|
||||
result = await self._execute_operation(operation, data, **kwargs)
|
||||
|
||||
# Post-operation hook
|
||||
await self._after_operation(operation, data, result)
|
||||
|
||||
# Update metrics
|
||||
self._operation_count += 1
|
||||
|
||||
# Emit operation event
|
||||
await self._emit_operation_event(operation, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self._error_count += 1
|
||||
self.logger.error(f"Operation {operation} failed: {e}")
|
||||
|
||||
error_result = OperationResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
|
||||
# Emit error event
|
||||
await self._emit_operation_event(f"{operation}.error", error_result)
|
||||
|
||||
return error_result
|
||||
|
||||
finally:
|
||||
# Log operation duration
|
||||
duration = (datetime.utcnow() - start_time).total_seconds()
|
||||
self.logger.debug(f"Operation {operation} completed in {duration:.3f}s")
|
||||
|
||||
async def batch_execute(
|
||||
self,
|
||||
operations: List[Dict[str, Any]],
|
||||
max_concurrent: int = 10
|
||||
) -> List[OperationResult]:
|
||||
"""Execute multiple operations concurrently"""
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def _execute_with_semaphore(op_data):
|
||||
async with semaphore:
|
||||
return await self.execute_operation(**op_data)
|
||||
|
||||
tasks = [_execute_with_semaphore(op) for op in operations]
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def sync(
|
||||
self,
|
||||
since: Optional[datetime] = None,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Synchronize data with external system"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
# Perform sync
|
||||
result = await self._sync(since, filters)
|
||||
|
||||
# Update last sync timestamp
|
||||
self._last_sync = datetime.utcnow()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Sync failed: {e}")
|
||||
raise ConnectorError(f"Sync failed: {e}")
|
||||
|
||||
async def validate_webhook(self, payload: Dict[str, Any], signature: str) -> bool:
|
||||
"""Validate incoming webhook payload"""
|
||||
if not self.webhook_handler:
|
||||
return False
|
||||
|
||||
return await self.webhook_handler.validate(payload, signature)
|
||||
|
||||
async def handle_webhook(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Handle incoming webhook"""
|
||||
if not self.webhook_handler:
|
||||
raise ConnectorError("Webhook handler not configured")
|
||||
|
||||
return await self.webhook_handler.handle(payload)
|
||||
|
||||
def add_operation_handler(
|
||||
self,
|
||||
operation: str,
|
||||
handler: Callable[[Dict[str, Any]], Awaitable[None]]
|
||||
):
|
||||
"""Add handler for specific operation"""
|
||||
if operation not in self._operation_handlers:
|
||||
self._operation_handlers[operation] = []
|
||||
self._operation_handlers[operation].append(handler)
|
||||
|
||||
def remove_operation_handler(
|
||||
self,
|
||||
operation: str,
|
||||
handler: Callable
|
||||
):
|
||||
"""Remove handler for specific operation"""
|
||||
if operation in self._operation_handlers:
|
||||
try:
|
||||
self._operation_handlers[operation].remove(handler)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Abstract methods to be implemented by subclasses
|
||||
|
||||
@abstractmethod
|
||||
async def _initialize(self) -> None:
|
||||
"""Connector-specific initialization"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _cleanup(self) -> None:
|
||||
"""Connector-specific cleanup"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _execute_operation(
|
||||
self,
|
||||
operation: str,
|
||||
data: Dict[str, Any],
|
||||
**kwargs
|
||||
) -> OperationResult:
|
||||
"""Execute connector-specific operation"""
|
||||
pass
|
||||
|
||||
async def _sync(
|
||||
self,
|
||||
since: Optional[datetime],
|
||||
filters: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Default sync implementation"""
|
||||
return {
|
||||
"synced_at": datetime.utcnow().isoformat(),
|
||||
"records": 0,
|
||||
"message": "Sync not implemented"
|
||||
}
|
||||
|
||||
# Hook methods
|
||||
|
||||
async def _before_operation(
|
||||
self,
|
||||
operation: str,
|
||||
data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Called before operation execution"""
|
||||
pass
|
||||
|
||||
async def _after_operation(
|
||||
self,
|
||||
operation: str,
|
||||
data: Dict[str, Any],
|
||||
result: OperationResult
|
||||
) -> None:
|
||||
"""Called after operation execution"""
|
||||
pass
|
||||
|
||||
# Private methods
|
||||
|
||||
async def _setup_webhooks(self) -> None:
|
||||
"""Setup webhook endpoints"""
|
||||
if not self.webhook_handler:
|
||||
return
|
||||
|
||||
await self.webhook_handler.setup(
|
||||
endpoint=self.config.webhook_endpoint,
|
||||
secret=self.config.webhook_secret
|
||||
)
|
||||
|
||||
def _register_handlers(self) -> None:
|
||||
"""Register default event handlers"""
|
||||
# Register with client if needed
|
||||
pass
|
||||
|
||||
async def _emit_operation_event(
|
||||
self,
|
||||
event: str,
|
||||
result: OperationResult
|
||||
) -> None:
|
||||
"""Emit operation event to handlers"""
|
||||
if event in self._operation_handlers:
|
||||
tasks = []
|
||||
for handler in self._operation_handlers[event]:
|
||||
try:
|
||||
tasks.append(handler(result.to_dict() if result.data else {}))
|
||||
except Exception as e:
|
||||
self.logger.error(f"Handler error: {e}")
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Properties
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if connector is initialized"""
|
||||
return self._initialized
|
||||
|
||||
@property
|
||||
def last_sync(self) -> Optional[datetime]:
|
||||
"""Get last sync timestamp"""
|
||||
return self._last_sync
|
||||
|
||||
@property
|
||||
def metrics(self) -> Dict[str, Any]:
|
||||
"""Get connector metrics"""
|
||||
return {
|
||||
"operation_count": self._operation_count,
|
||||
"error_count": self._error_count,
|
||||
"error_rate": self._error_count / max(self._operation_count, 1),
|
||||
"last_sync": self._last_sync.isoformat() if self._last_sync else None
|
||||
}
|
||||
|
||||
# Context manager
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry"""
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit"""
|
||||
await self.cleanup()
|
||||
296
enterprise-connectors/python-sdk/aitbc_enterprise/core.py
Normal file
296
enterprise-connectors/python-sdk/aitbc_enterprise/core.py
Normal file
@ -0,0 +1,296 @@
|
||||
"""
|
||||
Core components for AITBC Enterprise Connectors SDK
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, Callable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
import aiohttp
|
||||
from aiohttp import ClientTimeout, ClientSession
|
||||
|
||||
from .auth import AuthHandler
|
||||
from .rate_limiter import RateLimiter
|
||||
from .metrics import MetricsCollector
|
||||
from .exceptions import ConfigurationError
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectorConfig:
|
||||
"""Configuration for AITBC connectors"""
|
||||
|
||||
# API Configuration
|
||||
base_url: str
|
||||
api_key: str
|
||||
api_version: str = "v1"
|
||||
|
||||
# Connection Settings
|
||||
timeout: float = 30.0
|
||||
max_connections: int = 100
|
||||
max_retries: int = 3
|
||||
retry_backoff: float = 1.0
|
||||
|
||||
# Rate Limiting
|
||||
rate_limit: Optional[int] = None # Requests per second
|
||||
burst_limit: Optional[int] = None
|
||||
|
||||
# Authentication
|
||||
auth_type: str = "bearer" # bearer, basic, custom
|
||||
auth_config: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Webhooks
|
||||
webhook_secret: Optional[str] = None
|
||||
webhook_endpoint: Optional[str] = None
|
||||
|
||||
# Monitoring
|
||||
enable_metrics: bool = True
|
||||
metrics_endpoint: Optional[str] = None
|
||||
|
||||
# Logging
|
||||
log_level: str = "INFO"
|
||||
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
# Enterprise Features
|
||||
enterprise_id: Optional[str] = None
|
||||
tenant_id: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration"""
|
||||
if not self.base_url:
|
||||
raise ConfigurationError("base_url is required")
|
||||
if not self.api_key:
|
||||
raise ConfigurationError("api_key is required")
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, self.log_level.upper()),
|
||||
format=self.log_format
|
||||
)
|
||||
|
||||
|
||||
class AITBCClient:
|
||||
"""Main client for AITBC Enterprise Connectors"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ConnectorConfig,
|
||||
session: Optional[ClientSession] = None,
|
||||
auth_handler: Optional[AuthHandler] = None,
|
||||
rate_limiter: Optional[RateLimiter] = None,
|
||||
metrics: Optional[MetricsCollector] = None
|
||||
):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
# Initialize components with dependency injection
|
||||
self._session = session or self._create_session()
|
||||
self._auth = auth_handler or AuthHandler(config)
|
||||
self._rate_limiter = rate_limiter or RateLimiter(config)
|
||||
self._metrics = metrics or MetricsCollector(config) if config.enable_metrics else None
|
||||
|
||||
# Event handlers
|
||||
self._event_handlers: Dict[str, list] = {}
|
||||
|
||||
# Connection state
|
||||
self._connected = False
|
||||
self._last_activity = None
|
||||
|
||||
def _create_session(self) -> ClientSession:
|
||||
"""Create HTTP session with configuration"""
|
||||
timeout = ClientTimeout(total=self.config.timeout)
|
||||
|
||||
# Set up headers
|
||||
headers = {
|
||||
"User-Agent": f"AITBC-SDK/{__version__}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
return ClientSession(
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
connector=aiohttp.TCPConnector(
|
||||
limit=self.config.max_connections,
|
||||
limit_per_host=self.config.max_connections // 4
|
||||
)
|
||||
)
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to AITBC"""
|
||||
if self._connected:
|
||||
return
|
||||
|
||||
try:
|
||||
# Test connection
|
||||
await self._test_connection()
|
||||
|
||||
# Start metrics collection
|
||||
if self._metrics:
|
||||
await self._metrics.start()
|
||||
|
||||
self._connected = True
|
||||
self._last_activity = datetime.utcnow()
|
||||
|
||||
self.logger.info("Connected to AITBC")
|
||||
await self._emit_event("connected", {"timestamp": self._last_activity})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to connect: {e}")
|
||||
raise
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close connection to AITBC"""
|
||||
if not self._connected:
|
||||
return
|
||||
|
||||
try:
|
||||
# Stop metrics collection
|
||||
if self._metrics:
|
||||
await self._metrics.stop()
|
||||
|
||||
# Close session
|
||||
await self._session.close()
|
||||
|
||||
self._connected = False
|
||||
self.logger.info("Disconnected from AITBC")
|
||||
await self._emit_event("disconnected", {"timestamp": datetime.utcnow()})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during disconnect: {e}")
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Make authenticated request to AITBC API"""
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
# Apply rate limiting
|
||||
if self.config.rate_limit:
|
||||
await self._rate_limiter.acquire()
|
||||
|
||||
# Prepare request
|
||||
url = f"{self.config.base_url}/{self.config.api_version}/{path.lstrip('/')}"
|
||||
|
||||
# Add authentication
|
||||
headers = kwargs.pop("headers", {})
|
||||
auth_headers = await self._auth.get_headers()
|
||||
headers.update(auth_headers)
|
||||
|
||||
# Retry logic
|
||||
last_exception = None
|
||||
for attempt in range(self.config.max_retries + 1):
|
||||
try:
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
async with self._session.request(
|
||||
method,
|
||||
url,
|
||||
headers=headers,
|
||||
**kwargs
|
||||
) as response:
|
||||
# Record metrics
|
||||
if self._metrics:
|
||||
duration = (datetime.utcnow() - start_time).total_seconds()
|
||||
await self._metrics.record_request(
|
||||
method=method,
|
||||
path=path,
|
||||
status=response.status,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
# Handle response
|
||||
if response.status == 429:
|
||||
retry_after = int(response.headers.get("Retry-After", self.config.retry_backoff))
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
data = await response.json()
|
||||
self._last_activity = datetime.utcnow()
|
||||
|
||||
return data
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
last_exception = e
|
||||
if attempt < self.config.max_retries:
|
||||
backoff = self.config.retry_backoff * (2 ** attempt)
|
||||
self.logger.warning(f"Request failed, retrying in {backoff}s: {e}")
|
||||
await asyncio.sleep(backoff)
|
||||
else:
|
||||
self.logger.error(f"Request failed after {self.config.max_retries} retries: {e}")
|
||||
raise
|
||||
|
||||
raise last_exception
|
||||
|
||||
async def get(self, path: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Make GET request"""
|
||||
return await self.request("GET", path, **kwargs)
|
||||
|
||||
async def post(self, path: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Make POST request"""
|
||||
return await self.request("POST", path, **kwargs)
|
||||
|
||||
async def put(self, path: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Make PUT request"""
|
||||
return await self.request("PUT", path, **kwargs)
|
||||
|
||||
async def delete(self, path: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Make DELETE request"""
|
||||
return await self.request("DELETE", path, **kwargs)
|
||||
|
||||
def on(self, event: str, handler: Callable[[Dict[str, Any]], Awaitable[None]]):
|
||||
"""Register event handler"""
|
||||
if event not in self._event_handlers:
|
||||
self._event_handlers[event] = []
|
||||
self._event_handlers[event].append(handler)
|
||||
|
||||
def off(self, event: str, handler: Callable):
|
||||
"""Unregister event handler"""
|
||||
if event in self._event_handlers:
|
||||
try:
|
||||
self._event_handlers[event].remove(handler)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async def _emit_event(self, event: str, data: Dict[str, Any]):
|
||||
"""Emit event to registered handlers"""
|
||||
if event in self._event_handlers:
|
||||
tasks = []
|
||||
for handler in self._event_handlers[event]:
|
||||
tasks.append(handler(data))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def _test_connection(self):
|
||||
"""Test connection to AITBC"""
|
||||
try:
|
||||
await self.get("/health")
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Failed to connect to AITBC: {e}")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if client is connected"""
|
||||
return self._connected
|
||||
|
||||
@property
|
||||
def last_activity(self) -> Optional[datetime]:
|
||||
"""Get last activity timestamp"""
|
||||
return self._last_activity
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit"""
|
||||
await self.disconnect()
|
||||
@ -0,0 +1,18 @@
|
||||
"""
|
||||
ERP system connectors for AITBC Enterprise
|
||||
"""
|
||||
|
||||
from .base import ERPConnector, ERPDataModel, ProtocolHandler, DataMapper
|
||||
from .sap import SAPConnector
|
||||
from .oracle import OracleConnector
|
||||
from .netsuite import NetSuiteConnector
|
||||
|
||||
__all__ = [
|
||||
"ERPConnector",
|
||||
"ERPDataModel",
|
||||
"ProtocolHandler",
|
||||
"DataMapper",
|
||||
"SAPConnector",
|
||||
"OracleConnector",
|
||||
"NetSuiteConnector",
|
||||
]
|
||||
501
enterprise-connectors/python-sdk/aitbc_enterprise/erp/base.py
Normal file
501
enterprise-connectors/python-sdk/aitbc_enterprise/erp/base.py
Normal file
@ -0,0 +1,501 @@
|
||||
"""
|
||||
Base classes for ERP connectors with plugin architecture
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List, Optional, Type, Union, Callable
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import importlib
|
||||
|
||||
from ..base import BaseConnector, OperationResult
|
||||
from ..core import ConnectorConfig
|
||||
from ..exceptions import ERPError, ValidationError
|
||||
|
||||
|
||||
class ERPSystem(Enum):
|
||||
"""Supported ERP systems"""
|
||||
SAP = "sap"
|
||||
ORACLE = "oracle"
|
||||
NETSUITE = "netsuite"
|
||||
MICROSOFT_DYNAMICS = "dynamics"
|
||||
SALESFORCE = "salesforce"
|
||||
|
||||
|
||||
class Protocol(Enum):
|
||||
"""Supported protocols"""
|
||||
REST = "rest"
|
||||
SOAP = "soap"
|
||||
ODATA = "odata"
|
||||
IDOC = "idoc"
|
||||
BAPI = "bapi"
|
||||
SUITE_TALK = "suite_talk"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ERPDataModel:
|
||||
"""ERP data model definition"""
|
||||
entity_type: str
|
||||
fields: Dict[str, Any]
|
||||
relationships: Dict[str, str] = field(default_factory=dict)
|
||||
validations: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"entity_type": self.entity_type,
|
||||
"fields": self.fields,
|
||||
"relationships": self.relationships,
|
||||
"validations": self.validations
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncResult:
|
||||
"""Synchronization result"""
|
||||
entity_type: str
|
||||
synced_count: int
|
||||
failed_count: int
|
||||
errors: List[str] = field(default_factory=list)
|
||||
last_sync: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"entity_type": self.entity_type,
|
||||
"synced_count": self.synced_count,
|
||||
"failed_count": self.failed_count,
|
||||
"errors": self.errors,
|
||||
"last_sync": self.last_sync.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class ProtocolHandler(ABC):
|
||||
"""Abstract base class for protocol handlers"""
|
||||
|
||||
def __init__(self, config: ConnectorConfig):
|
||||
self.config = config
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> bool:
|
||||
"""Establish protocol connection"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self):
|
||||
"""Close protocol connection"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_request(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Send request via protocol"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def batch_request(self, requests: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Send batch requests"""
|
||||
pass
|
||||
|
||||
|
||||
class DataMapper:
|
||||
"""Maps data between AITBC and ERP formats"""
|
||||
|
||||
def __init__(self, mappings: Dict[str, Dict[str, str]]):
|
||||
self.mappings = mappings
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
def to_erp(self, entity_type: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Map AITBC format to ERP format"""
|
||||
if entity_type not in self.mappings:
|
||||
raise ValidationError(f"No mapping for entity type: {entity_type}")
|
||||
|
||||
mapping = self.mappings[entity_type]
|
||||
erp_data = {}
|
||||
|
||||
for aitbc_field, erp_field in mapping.items():
|
||||
if aitbc_field in data:
|
||||
erp_data[erp_field] = data[aitbc_field]
|
||||
|
||||
return erp_data
|
||||
|
||||
def from_erp(self, entity_type: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Map ERP format to AITBC format"""
|
||||
if entity_type not in self.mappings:
|
||||
raise ValidationError(f"No mapping for entity type: {entity_type}")
|
||||
|
||||
mapping = self.mappings[entity_type]
|
||||
aitbc_data = {}
|
||||
|
||||
# Reverse mapping
|
||||
reverse_mapping = {v: k for k, v in mapping.items()}
|
||||
|
||||
for erp_field, value in data.items():
|
||||
if erp_field in reverse_mapping:
|
||||
aitbc_data[reverse_mapping[erp_field]] = value
|
||||
|
||||
return aitbc_data
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""Handles batch operations for ERP connectors"""
|
||||
|
||||
def __init__(self, batch_size: int = 100, max_concurrent: int = 5):
|
||||
self.batch_size = batch_size
|
||||
self.max_concurrent = max_concurrent
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
async def process_batches(
|
||||
self,
|
||||
items: List[Dict[str, Any]],
|
||||
processor: Callable[[List[Dict[str, Any]]], List[Dict[str, Any]]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Process items in batches"""
|
||||
results = []
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
async def process_batch(batch):
|
||||
async with semaphore:
|
||||
try:
|
||||
return await processor(batch)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Batch processing failed: {e}")
|
||||
return [{"error": str(e)} for _ in batch]
|
||||
|
||||
# Create batches
|
||||
batches = [
|
||||
items[i:i + self.batch_size]
|
||||
for i in range(0, len(items), self.batch_size)
|
||||
]
|
||||
|
||||
# Process batches concurrently
|
||||
tasks = [process_batch(batch) for batch in batches]
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Flatten results
|
||||
for result in batch_results:
|
||||
if isinstance(result, list):
|
||||
results.extend(result)
|
||||
else:
|
||||
results.append({"error": str(result)})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class ChangeTracker:
|
||||
"""Tracks changes for delta synchronization"""
|
||||
|
||||
def __init__(self):
|
||||
self.last_syncs: Dict[str, datetime] = {}
|
||||
self.change_logs: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
def update_last_sync(self, entity_type: str, timestamp: datetime):
|
||||
"""Update last sync timestamp"""
|
||||
self.last_syncs[entity_type] = timestamp
|
||||
|
||||
def get_last_sync(self, entity_type: str) -> Optional[datetime]:
|
||||
"""Get last sync timestamp"""
|
||||
return self.last_syncs.get(entity_type)
|
||||
|
||||
def log_change(self, entity_type: str, change: Dict[str, Any]):
|
||||
"""Log a change"""
|
||||
if entity_type not in self.change_logs:
|
||||
self.change_logs[entity_type] = []
|
||||
|
||||
self.change_logs[entity_type].append({
|
||||
**change,
|
||||
"timestamp": datetime.utcnow()
|
||||
})
|
||||
|
||||
def get_changes_since(
|
||||
self,
|
||||
entity_type: str,
|
||||
since: datetime
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get changes since timestamp"""
|
||||
changes = self.change_logs.get(entity_type, [])
|
||||
return [
|
||||
c for c in changes
|
||||
if c["timestamp"] > since
|
||||
]
|
||||
|
||||
|
||||
class ERPConnector(BaseConnector):
|
||||
"""Base class for ERP connectors with plugin architecture"""
|
||||
|
||||
# Registry for protocol handlers
|
||||
_protocol_registry: Dict[Protocol, Type[ProtocolHandler]] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: 'AITBCClient',
|
||||
config: ConnectorConfig,
|
||||
erp_system: ERPSystem,
|
||||
protocol: Protocol,
|
||||
data_mapper: Optional[DataMapper] = None
|
||||
):
|
||||
super().__init__(client, config)
|
||||
|
||||
self.erp_system = erp_system
|
||||
self.protocol = protocol
|
||||
|
||||
# Initialize components
|
||||
self.protocol_handler = self._create_protocol_handler()
|
||||
self.data_mapper = data_mapper or DataMapper({})
|
||||
self.batch_processor = BatchProcessor()
|
||||
self.change_tracker = ChangeTracker()
|
||||
|
||||
# ERP-specific configuration
|
||||
self.erp_config = config.auth_config.get("erp", {})
|
||||
|
||||
# Data models
|
||||
self.data_models: Dict[str, ERPDataModel] = {}
|
||||
|
||||
@classmethod
|
||||
def register_protocol(
|
||||
cls,
|
||||
protocol: Protocol,
|
||||
handler_class: Type[ProtocolHandler]
|
||||
):
|
||||
"""Register a protocol handler"""
|
||||
cls._protocol_registry[protocol] = handler_class
|
||||
|
||||
def _create_protocol_handler(self) -> ProtocolHandler:
|
||||
"""Create protocol handler from registry"""
|
||||
if self.protocol not in self._protocol_registry:
|
||||
raise ERPError(f"No handler registered for protocol: {self.protocol}")
|
||||
|
||||
handler_class = self._protocol_registry[self.protocol]
|
||||
return handler_class(self.config)
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
"""Initialize ERP connector"""
|
||||
# Connect via protocol
|
||||
if not await self.protocol_handler.connect():
|
||||
raise ERPError(f"Failed to connect via {self.protocol}")
|
||||
|
||||
# Load data models
|
||||
await self._load_data_models()
|
||||
|
||||
self.logger.info(f"{self.erp_system.value} connector initialized")
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""Cleanup ERP connector"""
|
||||
await self.protocol_handler.disconnect()
|
||||
|
||||
async def _execute_operation(
|
||||
self,
|
||||
operation: str,
|
||||
data: Dict[str, Any],
|
||||
**kwargs
|
||||
) -> OperationResult:
|
||||
"""Execute ERP-specific operations"""
|
||||
try:
|
||||
if operation.startswith("create_"):
|
||||
entity_type = operation[7:] # Remove "create_" prefix
|
||||
return await self._create_entity(entity_type, data)
|
||||
elif operation.startswith("update_"):
|
||||
entity_type = operation[7:] # Remove "update_" prefix
|
||||
return await self._update_entity(entity_type, data)
|
||||
elif operation.startswith("delete_"):
|
||||
entity_type = operation[7:] # Remove "delete_" prefix
|
||||
return await self._delete_entity(entity_type, data)
|
||||
elif operation == "sync":
|
||||
return await self._sync_data(data)
|
||||
elif operation == "batch_sync":
|
||||
return await self._batch_sync(data)
|
||||
else:
|
||||
raise ValidationError(f"Unknown operation: {operation}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"ERP operation failed: {e}")
|
||||
raise ERPError(f"Operation failed: {e}")
|
||||
|
||||
async def _create_entity(self, entity_type: str, data: Dict[str, Any]) -> OperationResult:
|
||||
"""Create entity in ERP"""
|
||||
# Map data to ERP format
|
||||
erp_data = self.data_mapper.to_erp(entity_type, data)
|
||||
|
||||
# Send to ERP
|
||||
endpoint = f"/{entity_type}"
|
||||
result = await self.protocol_handler.send_request(endpoint, erp_data)
|
||||
|
||||
# Track change
|
||||
self.change_tracker.log_change(entity_type, {
|
||||
"action": "create",
|
||||
"data": result
|
||||
})
|
||||
|
||||
return OperationResult(
|
||||
success=True,
|
||||
data=result,
|
||||
metadata={"entity_type": entity_type, "action": "create"}
|
||||
)
|
||||
|
||||
async def _update_entity(self, entity_type: str, data: Dict[str, Any]) -> OperationResult:
|
||||
"""Update entity in ERP"""
|
||||
entity_id = data.get("id")
|
||||
if not entity_id:
|
||||
raise ValidationError("Entity ID required for update")
|
||||
|
||||
# Map data to ERP format
|
||||
erp_data = self.data_mapper.to_erp(entity_type, data)
|
||||
|
||||
# Send to ERP
|
||||
endpoint = f"/{entity_type}/{entity_id}"
|
||||
result = await self.protocol_handler.send_request(endpoint, erp_data, method="PUT")
|
||||
|
||||
# Track change
|
||||
self.change_tracker.log_change(entity_type, {
|
||||
"action": "update",
|
||||
"entity_id": entity_id,
|
||||
"data": result
|
||||
})
|
||||
|
||||
return OperationResult(
|
||||
success=True,
|
||||
data=result,
|
||||
metadata={"entity_type": entity_type, "action": "update"}
|
||||
)
|
||||
|
||||
async def _delete_entity(self, entity_type: str, data: Dict[str, Any]) -> OperationResult:
|
||||
"""Delete entity from ERP"""
|
||||
entity_id = data.get("id")
|
||||
if not entity_id:
|
||||
raise ValidationError("Entity ID required for delete")
|
||||
|
||||
# Send to ERP
|
||||
endpoint = f"/{entity_type}/{entity_id}"
|
||||
await self.protocol_handler.send_request(endpoint, {}, method="DELETE")
|
||||
|
||||
# Track change
|
||||
self.change_tracker.log_change(entity_type, {
|
||||
"action": "delete",
|
||||
"entity_id": entity_id
|
||||
})
|
||||
|
||||
return OperationResult(
|
||||
success=True,
|
||||
metadata={"entity_type": entity_type, "action": "delete"}
|
||||
)
|
||||
|
||||
async def _sync_data(self, data: Dict[str, Any]) -> OperationResult:
|
||||
"""Synchronize data from ERP"""
|
||||
entity_type = data.get("entity_type")
|
||||
since = data.get("since")
|
||||
|
||||
if not entity_type:
|
||||
raise ValidationError("entity_type required")
|
||||
|
||||
# Get last sync if not provided
|
||||
if not since:
|
||||
since = self.change_tracker.get_last_sync(entity_type)
|
||||
|
||||
# Query ERP for changes
|
||||
endpoint = f"/{entity_type}"
|
||||
params = {"since": since.isoformat()} if since else {}
|
||||
|
||||
result = await self.protocol_handler.send_request(endpoint, params)
|
||||
|
||||
# Map data to AITBC format
|
||||
items = result.get("items", [])
|
||||
mapped_items = [
|
||||
self.data_mapper.from_erp(entity_type, item)
|
||||
for item in items
|
||||
]
|
||||
|
||||
# Update last sync
|
||||
self.change_tracker.update_last_sync(entity_type, datetime.utcnow())
|
||||
|
||||
return OperationResult(
|
||||
success=True,
|
||||
data={"items": mapped_items, "count": len(mapped_items)},
|
||||
metadata={"entity_type": entity_type, "since": since}
|
||||
)
|
||||
|
||||
async def _batch_sync(self, data: Dict[str, Any]) -> OperationResult:
|
||||
"""Batch synchronize data"""
|
||||
entity_type = data.get("entity_type")
|
||||
items = data.get("items", [])
|
||||
|
||||
if not entity_type or not items:
|
||||
raise ValidationError("entity_type and items required")
|
||||
|
||||
# Process in batches
|
||||
batch_data = [{"entity_type": entity_type, "item": item} for item in items]
|
||||
|
||||
results = await self.batch_processor.process_batches(
|
||||
batch_data,
|
||||
self._process_sync_batch
|
||||
)
|
||||
|
||||
# Count successes and failures
|
||||
successful = sum(1 for r in results if "error" not in r)
|
||||
failed = len(results) - successful
|
||||
|
||||
return OperationResult(
|
||||
success=failed == 0,
|
||||
data={"results": results},
|
||||
metadata={
|
||||
"entity_type": entity_type,
|
||||
"total": len(items),
|
||||
"successful": successful,
|
||||
"failed": failed
|
||||
}
|
||||
)
|
||||
|
||||
async def _process_sync_batch(self, batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Process a sync batch"""
|
||||
entity_type = batch[0]["entity_type"]
|
||||
items = [b["item"] for b in batch]
|
||||
|
||||
# Map to ERP format
|
||||
erp_items = [
|
||||
self.data_mapper.to_erp(entity_type, item)
|
||||
for item in items
|
||||
]
|
||||
|
||||
# Send batch request
|
||||
endpoint = f"/{entity_type}/batch"
|
||||
results = await self.protocol_handler.batch_request([
|
||||
{"method": "POST", "endpoint": endpoint, "data": item}
|
||||
for item in erp_items
|
||||
])
|
||||
|
||||
return results
|
||||
|
||||
async def _load_data_models(self):
|
||||
"""Load ERP data models"""
|
||||
# Default models - override in subclasses
|
||||
self.data_models = {
|
||||
"customer": ERPDataModel(
|
||||
entity_type="customer",
|
||||
fields={"id": str, "name": str, "email": str, "phone": str}
|
||||
),
|
||||
"order": ERPDataModel(
|
||||
entity_type="order",
|
||||
fields={"id": str, "customer_id": str, "items": list, "total": float}
|
||||
),
|
||||
"invoice": ERPDataModel(
|
||||
entity_type="invoice",
|
||||
fields={"id": str, "order_id": str, "amount": float, "status": str}
|
||||
)
|
||||
}
|
||||
|
||||
def register_data_model(self, model: ERPDataModel):
|
||||
"""Register a data model"""
|
||||
self.data_models[model.entity_type] = model
|
||||
|
||||
def get_data_model(self, entity_type: str) -> Optional[ERPDataModel]:
|
||||
"""Get data model by type"""
|
||||
return self.data_models.get(entity_type)
|
||||
|
||||
|
||||
# Protocol handler registry decorator
|
||||
def register_protocol(protocol: Protocol):
|
||||
"""Decorator to register protocol handlers"""
|
||||
def decorator(handler_class: Type[ProtocolHandler]):
|
||||
ERPConnector.register_protocol(protocol, handler_class)
|
||||
return handler_class
|
||||
return decorator
|
||||
@ -0,0 +1,19 @@
|
||||
"""
|
||||
NetSuite ERP connector for AITBC Enterprise (Placeholder)
|
||||
"""
|
||||
|
||||
from .base import ERPConnector, ERPSystem, Protocol
|
||||
|
||||
|
||||
class NetSuiteConnector(ERPConnector):
|
||||
"""NetSuite ERP connector with SuiteTalk support"""
|
||||
|
||||
def __init__(self, client, config, netsuite_account, netsuite_consumer_key, netsuite_consumer_secret):
|
||||
# TODO: Implement NetSuite connector
|
||||
raise NotImplementedError("NetSuite connector not yet implemented")
|
||||
|
||||
# TODO: Implement NetSuite-specific methods
|
||||
# - SuiteTalk REST API
|
||||
# - SuiteTalk SOAP web services
|
||||
# - OAuth authentication
|
||||
# - Data mapping for NetSuite records
|
||||
@ -0,0 +1,19 @@
|
||||
"""
|
||||
Oracle ERP connector for AITBC Enterprise (Placeholder)
|
||||
"""
|
||||
|
||||
from .base import ERPConnector, ERPSystem, Protocol
|
||||
|
||||
|
||||
class OracleConnector(ERPConnector):
|
||||
"""Oracle ERP connector with REST and SOAP support"""
|
||||
|
||||
def __init__(self, client, config, oracle_client_id, oracle_secret):
|
||||
# TODO: Implement Oracle connector
|
||||
raise NotImplementedError("Oracle connector not yet implemented")
|
||||
|
||||
# TODO: Implement Oracle-specific methods
|
||||
# - REST API calls
|
||||
# - SOAP web services
|
||||
# - Oracle authentication
|
||||
# - Data mapping for Oracle modules
|
||||
19
enterprise-connectors/python-sdk/aitbc_enterprise/erp/sap.py
Normal file
19
enterprise-connectors/python-sdk/aitbc_enterprise/erp/sap.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""
|
||||
SAP ERP connector for AITBC Enterprise (Placeholder)
|
||||
"""
|
||||
|
||||
from .base import ERPConnector, ERPSystem, Protocol
|
||||
|
||||
|
||||
class SAPConnector(ERPConnector):
|
||||
"""SAP ERP connector with IDOC and BAPI support"""
|
||||
|
||||
def __init__(self, client, config, sap_client):
|
||||
# TODO: Implement SAP connector
|
||||
raise NotImplementedError("SAP connector not yet implemented")
|
||||
|
||||
# TODO: Implement SAP-specific methods
|
||||
# - IDOC processing
|
||||
# - BAPI calls
|
||||
# - SAP authentication
|
||||
# - Data mapping for SAP structures
|
||||
@ -0,0 +1,68 @@
|
||||
"""
|
||||
Exception classes for AITBC Enterprise Connectors
|
||||
"""
|
||||
|
||||
|
||||
class AITBCError(Exception):
|
||||
"""Base exception for all AITBC errors"""
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(AITBCError):
|
||||
"""Raised when authentication fails"""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(AITBCError):
|
||||
"""Raised when rate limit is exceeded"""
|
||||
def __init__(self, message: str, retry_after: int = None):
|
||||
super().__init__(message)
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
class APIError(AITBCError):
|
||||
"""Raised when API request fails"""
|
||||
def __init__(self, message: str, status_code: int = None, response: dict = None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.response = response
|
||||
|
||||
|
||||
class ConfigurationError(AITBCError):
|
||||
"""Raised when configuration is invalid"""
|
||||
pass
|
||||
|
||||
|
||||
class ConnectorError(AITBCError):
|
||||
"""Raised when connector operation fails"""
|
||||
pass
|
||||
|
||||
|
||||
class PaymentError(ConnectorError):
|
||||
"""Raised when payment operation fails"""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(AITBCError):
|
||||
"""Raised when data validation fails"""
|
||||
pass
|
||||
|
||||
|
||||
class WebhookError(AITBCError):
|
||||
"""Raised when webhook processing fails"""
|
||||
pass
|
||||
|
||||
|
||||
class ERPError(ConnectorError):
|
||||
"""Raised when ERP operation fails"""
|
||||
pass
|
||||
|
||||
|
||||
class SyncError(ConnectorError):
|
||||
"""Raised when synchronization fails"""
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(AITBCError):
|
||||
"""Raised when operation times out"""
|
||||
pass
|
||||
293
enterprise-connectors/python-sdk/aitbc_enterprise/metrics.py
Normal file
293
enterprise-connectors/python-sdk/aitbc_enterprise/metrics.py
Normal file
@ -0,0 +1,293 @@
|
||||
"""
|
||||
Metrics collection for AITBC Enterprise Connectors
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Any, Optional, List
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
|
||||
from .core import ConnectorConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricPoint:
|
||||
"""Single metric data point"""
|
||||
name: str
|
||||
value: float
|
||||
timestamp: datetime
|
||||
tags: Dict[str, str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"value": self.value,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"tags": self.tags or {}
|
||||
}
|
||||
|
||||
|
||||
class MetricsCollector:
|
||||
"""Collects and manages metrics for connectors"""
|
||||
|
||||
def __init__(self, config: ConnectorConfig):
|
||||
self.config = config
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
# Metric storage
|
||||
self._counters: Dict[str, float] = defaultdict(float)
|
||||
self._gauges: Dict[str, float] = {}
|
||||
self._histograms: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
|
||||
self._timers: Dict[str, List[float]] = defaultdict(list)
|
||||
|
||||
# Runtime state
|
||||
self._running = False
|
||||
self._flush_task = None
|
||||
self._buffer: List[MetricPoint] = []
|
||||
self._buffer_size = 1000
|
||||
|
||||
# Aggregated metrics
|
||||
self._request_count = 0
|
||||
self._error_count = 0
|
||||
self._total_duration = 0.0
|
||||
self._last_flush = None
|
||||
|
||||
async def start(self):
|
||||
"""Start metrics collection"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._last_flush = datetime.utcnow()
|
||||
|
||||
# Start periodic flush task
|
||||
if self.config.metrics_endpoint:
|
||||
self._flush_task = asyncio.create_task(self._flush_loop())
|
||||
|
||||
self.logger.info("Metrics collection started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop metrics collection"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
# Cancel flush task
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Final flush
|
||||
await self._flush_metrics()
|
||||
|
||||
self.logger.info("Metrics collection stopped")
|
||||
|
||||
def increment(self, name: str, value: float = 1.0, tags: Dict[str, str] = None):
|
||||
"""Increment counter metric"""
|
||||
key = self._make_key(name, tags)
|
||||
self._counters[key] += value
|
||||
|
||||
# Add to buffer
|
||||
self._add_to_buffer(name, value, tags)
|
||||
|
||||
def gauge(self, name: str, value: float, tags: Dict[str, str] = None):
|
||||
"""Set gauge metric"""
|
||||
key = self._make_key(name, tags)
|
||||
self._gauges[key] = value
|
||||
|
||||
# Add to buffer
|
||||
self._add_to_buffer(name, value, tags)
|
||||
|
||||
def histogram(self, name: str, value: float, tags: Dict[str, str] = None):
|
||||
"""Add value to histogram"""
|
||||
key = self._make_key(name, tags)
|
||||
self._histograms[key].append(value)
|
||||
|
||||
# Add to buffer
|
||||
self._add_to_buffer(name, value, tags)
|
||||
|
||||
def timer(self, name: str, duration: float, tags: Dict[str, str] = None):
|
||||
"""Record timing metric"""
|
||||
key = self._make_key(name, tags)
|
||||
self._timers[key].append(duration)
|
||||
|
||||
# Keep only last 1000 timings
|
||||
if len(self._timers[key]) > 1000:
|
||||
self._timers[key] = self._timers[key][-1000:]
|
||||
|
||||
# Add to buffer
|
||||
self._add_to_buffer(f"{name}_duration", duration, tags)
|
||||
|
||||
async def record_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
status: int,
|
||||
duration: float
|
||||
):
|
||||
"""Record request metrics"""
|
||||
# Update aggregated metrics
|
||||
self._request_count += 1
|
||||
self._total_duration += duration
|
||||
|
||||
if status >= 400:
|
||||
self._error_count += 1
|
||||
|
||||
# Record detailed metrics
|
||||
tags = {
|
||||
"method": method,
|
||||
"path": path,
|
||||
"status": str(status)
|
||||
}
|
||||
|
||||
self.increment("requests_total", 1.0, tags)
|
||||
self.timer("request_duration", duration, tags)
|
||||
|
||||
if status >= 400:
|
||||
self.increment("errors_total", 1.0, tags)
|
||||
|
||||
def get_metric(self, name: str, tags: Dict[str, str] = None) -> Optional[float]:
|
||||
"""Get current metric value"""
|
||||
key = self._make_key(name, tags)
|
||||
|
||||
if key in self._counters:
|
||||
return self._counters[key]
|
||||
elif key in self._gauges:
|
||||
return self._gauges[key]
|
||||
elif key in self._histograms:
|
||||
values = list(self._histograms[key])
|
||||
return sum(values) / len(values) if values else 0
|
||||
elif key in self._timers:
|
||||
values = self._timers[key]
|
||||
return sum(values) / len(values) if values else 0
|
||||
|
||||
return None
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""Get metrics summary"""
|
||||
return {
|
||||
"requests_total": self._request_count,
|
||||
"errors_total": self._error_count,
|
||||
"error_rate": self._error_count / max(self._request_count, 1),
|
||||
"avg_duration": self._total_duration / max(self._request_count, 1),
|
||||
"last_flush": self._last_flush.isoformat() if self._last_flush else None,
|
||||
"metrics_count": len(self._counters) + len(self._gauges) + len(self._histograms) + len(self._timers)
|
||||
}
|
||||
|
||||
def _make_key(self, name: str, tags: Dict[str, str] = None) -> str:
|
||||
"""Create metric key with tags"""
|
||||
if not tags:
|
||||
return name
|
||||
|
||||
tag_str = ",".join(f"{k}={v}" for k, v in sorted(tags.items()))
|
||||
return f"{name}[{tag_str}]"
|
||||
|
||||
def _add_to_buffer(self, name: str, value: float, tags: Dict[str, str] = None):
|
||||
"""Add metric point to buffer"""
|
||||
point = MetricPoint(
|
||||
name=name,
|
||||
value=value,
|
||||
timestamp=datetime.utcnow(),
|
||||
tags=tags
|
||||
)
|
||||
|
||||
self._buffer.append(point)
|
||||
|
||||
# Flush if buffer is full
|
||||
if len(self._buffer) >= self._buffer_size:
|
||||
asyncio.create_task(self._flush_metrics())
|
||||
|
||||
async def _flush_loop(self):
|
||||
"""Periodic flush loop"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(60) # Flush every minute
|
||||
await self._flush_metrics()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Flush loop error: {e}")
|
||||
|
||||
async def _flush_metrics(self):
|
||||
"""Flush metrics to endpoint"""
|
||||
if not self.config.metrics_endpoint or not self._buffer:
|
||||
return
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
# Prepare metrics payload
|
||||
payload = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"source": "aitbc-enterprise-sdk",
|
||||
"metrics": [asdict(point) for point in self._buffer]
|
||||
}
|
||||
|
||||
# Send to endpoint
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.config.metrics_endpoint,
|
||||
json=payload,
|
||||
timeout=10
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
self._buffer.clear()
|
||||
self._last_flush = datetime.utcnow()
|
||||
self.logger.debug(f"Flushed {len(payload['metrics'])} metrics")
|
||||
else:
|
||||
self.logger.error(f"Failed to flush metrics: {response.status}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error flushing metrics: {e}")
|
||||
|
||||
|
||||
class PerformanceTracker:
|
||||
"""Track performance metrics for operations"""
|
||||
|
||||
def __init__(self, metrics: MetricsCollector):
|
||||
self.metrics = metrics
|
||||
self._operations: Dict[str, float] = {}
|
||||
|
||||
def start_operation(self, operation: str):
|
||||
"""Start timing an operation"""
|
||||
self._operations[operation] = time.time()
|
||||
|
||||
def end_operation(self, operation: str, tags: Dict[str, str] = None):
|
||||
"""End timing an operation"""
|
||||
if operation in self._operations:
|
||||
duration = time.time() - self._operations[operation]
|
||||
del self._operations[operation]
|
||||
|
||||
self.metrics.timer(f"operation_{operation}", duration, tags)
|
||||
|
||||
return duration
|
||||
return None
|
||||
|
||||
async def track_operation(self, operation: str, coro, tags: Dict[str, str] = None):
|
||||
"""Context manager for tracking operations"""
|
||||
start = time.time()
|
||||
try:
|
||||
result = await coro
|
||||
success = True
|
||||
return result
|
||||
except Exception as e:
|
||||
success = False
|
||||
raise
|
||||
finally:
|
||||
duration = time.time() - start
|
||||
|
||||
metric_tags = {
|
||||
"operation": operation,
|
||||
"success": str(success),
|
||||
**(tags or {})
|
||||
}
|
||||
|
||||
self.metrics.timer(f"operation_{operation}", duration, metric_tags)
|
||||
self.metrics.increment(f"operations_total", 1.0, metric_tags)
|
||||
@ -0,0 +1,19 @@
|
||||
"""
|
||||
Payment processor connectors for AITBC Enterprise
|
||||
"""
|
||||
|
||||
from .base import PaymentConnector, PaymentMethod, Charge, Refund, Subscription
|
||||
from .stripe import StripeConnector
|
||||
from .paypal import PayPalConnector
|
||||
from .square import SquareConnector
|
||||
|
||||
__all__ = [
|
||||
"PaymentConnector",
|
||||
"PaymentMethod",
|
||||
"Charge",
|
||||
"Refund",
|
||||
"Subscription",
|
||||
"StripeConnector",
|
||||
"PayPalConnector",
|
||||
"SquareConnector",
|
||||
]
|
||||
@ -0,0 +1,256 @@
|
||||
"""
|
||||
Base classes for payment processor connectors
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PaymentStatus(Enum):
|
||||
"""Payment status enumeration"""
|
||||
PENDING = "pending"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
REFUNDED = "refunded"
|
||||
PARTIALLY_REFUNDED = "partially_refunded"
|
||||
CANCELED = "canceled"
|
||||
|
||||
|
||||
class RefundStatus(Enum):
|
||||
"""Refund status enumeration"""
|
||||
PENDING = "pending"
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
CANCELED = "canceled"
|
||||
|
||||
|
||||
class SubscriptionStatus(Enum):
|
||||
"""Subscription status enumeration"""
|
||||
TRIALING = "trialing"
|
||||
ACTIVE = "active"
|
||||
PAST_DUE = "past_due"
|
||||
CANCELED = "canceled"
|
||||
UNPAID = "unpaid"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PaymentMethod:
|
||||
"""Payment method representation"""
|
||||
id: str
|
||||
type: str
|
||||
created_at: datetime
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
# Card-specific fields
|
||||
brand: Optional[str] = None
|
||||
last4: Optional[str] = None
|
||||
exp_month: Optional[int] = None
|
||||
exp_year: Optional[int] = None
|
||||
|
||||
# Bank account fields
|
||||
bank_name: Optional[str] = None
|
||||
last4_ach: Optional[str] = None
|
||||
routing_number: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_stripe_payment_method(cls, pm_data: Dict[str, Any]) -> 'PaymentMethod':
|
||||
"""Create from Stripe payment method data"""
|
||||
card = pm_data.get("card", {})
|
||||
|
||||
return cls(
|
||||
id=pm_data["id"],
|
||||
type=pm_data["type"],
|
||||
created_at=datetime.fromtimestamp(pm_data["created"]),
|
||||
metadata=pm_data.get("metadata", {}),
|
||||
brand=card.get("brand"),
|
||||
last4=card.get("last4"),
|
||||
exp_month=card.get("exp_month"),
|
||||
exp_year=card.get("exp_year")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Charge:
|
||||
"""Charge representation"""
|
||||
id: str
|
||||
amount: int
|
||||
currency: str
|
||||
status: PaymentStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
description: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
# Refund information
|
||||
amount_refunded: int = 0
|
||||
refunds: List[Dict[str, Any]] = None
|
||||
|
||||
# Payment method
|
||||
payment_method_id: Optional[str] = None
|
||||
payment_method_details: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.refunds is None:
|
||||
self.refunds = []
|
||||
|
||||
@classmethod
|
||||
def from_stripe_charge(cls, charge_data: Dict[str, Any]) -> 'Charge':
|
||||
"""Create from Stripe charge data"""
|
||||
return cls(
|
||||
id=charge_data["id"],
|
||||
amount=charge_data["amount"],
|
||||
currency=charge_data["currency"],
|
||||
status=PaymentStatus(charge_data["status"]),
|
||||
created_at=datetime.fromtimestamp(charge_data["created"]),
|
||||
updated_at=datetime.fromtimestamp(charge_data.get("updated", charge_data["created"])),
|
||||
description=charge_data.get("description"),
|
||||
metadata=charge_data.get("metadata", {}),
|
||||
amount_refunded=charge_data.get("amount_refunded", 0),
|
||||
refunds=[r.to_dict() for r in charge_data.get("refunds", {}).get("data", [])],
|
||||
payment_method_id=charge_data.get("payment_method"),
|
||||
payment_method_details=charge_data.get("payment_method_details")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Refund:
|
||||
"""Refund representation"""
|
||||
id: str
|
||||
amount: int
|
||||
currency: str
|
||||
status: RefundStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
charge_id: str
|
||||
reason: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_stripe_refund(cls, refund_data: Dict[str, Any]) -> 'Refund':
|
||||
"""Create from Stripe refund data"""
|
||||
return cls(
|
||||
id=refund_data["id"],
|
||||
amount=refund_data["amount"],
|
||||
currency=refund_data["currency"],
|
||||
status=RefundStatus(refund_data["status"]),
|
||||
created_at=datetime.fromtimestamp(refund_data["created"]),
|
||||
updated_at=datetime.fromtimestamp(refund_data.get("updated", refund_data["created"])),
|
||||
charge_id=refund_data["charge"],
|
||||
reason=refund_data.get("reason"),
|
||||
metadata=refund_data.get("metadata", {})
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Subscription:
|
||||
"""Subscription representation"""
|
||||
id: str
|
||||
status: SubscriptionStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
current_period_start: datetime
|
||||
current_period_end: datetime
|
||||
customer_id: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
# Pricing
|
||||
amount: Optional[int] = None
|
||||
currency: Optional[str] = None
|
||||
interval: Optional[str] = None
|
||||
interval_count: Optional[int] = None
|
||||
|
||||
# Trial
|
||||
trial_start: Optional[datetime] = None
|
||||
trial_end: Optional[datetime] = None
|
||||
|
||||
# Cancellation
|
||||
canceled_at: Optional[datetime] = None
|
||||
ended_at: Optional[datetime] = None
|
||||
|
||||
@classmethod
|
||||
def from_stripe_subscription(cls, sub_data: Dict[str, Any]) -> 'Subscription':
|
||||
"""Create from Stripe subscription data"""
|
||||
items = sub_data.get("items", {}).get("data", [])
|
||||
first_item = items[0] if items else {}
|
||||
price = first_item.get("price", {})
|
||||
|
||||
return cls(
|
||||
id=sub_data["id"],
|
||||
status=SubscriptionStatus(sub_data["status"]),
|
||||
created_at=datetime.fromtimestamp(sub_data["created"]),
|
||||
updated_at=datetime.fromtimestamp(sub_data.get("updated", sub_data["created"])),
|
||||
current_period_start=datetime.fromtimestamp(sub_data["current_period_start"]),
|
||||
current_period_end=datetime.fromtimestamp(sub_data["current_period_end"]),
|
||||
customer_id=sub_data["customer"],
|
||||
metadata=sub_data.get("metadata", {}),
|
||||
amount=price.get("unit_amount"),
|
||||
currency=price.get("currency"),
|
||||
interval=price.get("recurring", {}).get("interval"),
|
||||
interval_count=price.get("recurring", {}).get("interval_count"),
|
||||
trial_start=datetime.fromtimestamp(sub_data["trial_start"]) if sub_data.get("trial_start") else None,
|
||||
trial_end=datetime.fromtimestamp(sub_data["trial_end"]) if sub_data.get("trial_end") else None,
|
||||
canceled_at=datetime.fromtimestamp(sub_data["canceled_at"]) if sub_data.get("canceled_at") else None,
|
||||
ended_at=datetime.fromtimestamp(sub_data["ended_at"]) if sub_data.get("ended_at") else None
|
||||
)
|
||||
|
||||
|
||||
class PaymentConnector(ABC):
|
||||
"""Abstract base class for payment connectors"""
|
||||
|
||||
def __init__(self, client, config):
|
||||
self.client = client
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
async def create_charge(
|
||||
self,
|
||||
amount: int,
|
||||
currency: str,
|
||||
source: str,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Charge:
|
||||
"""Create a charge"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_refund(
|
||||
self,
|
||||
charge_id: str,
|
||||
amount: Optional[int] = None,
|
||||
reason: Optional[str] = None
|
||||
) -> Refund:
|
||||
"""Create a refund"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_payment_method(
|
||||
self,
|
||||
type: str,
|
||||
card: Dict[str, Any],
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> PaymentMethod:
|
||||
"""Create a payment method"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_subscription(
|
||||
self,
|
||||
customer: str,
|
||||
items: List[Dict[str, Any]],
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Subscription:
|
||||
"""Create a subscription"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cancel_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
at_period_end: bool = True
|
||||
) -> Subscription:
|
||||
"""Cancel a subscription"""
|
||||
pass
|
||||
@ -0,0 +1,33 @@
|
||||
"""
|
||||
PayPal payment connector for AITBC Enterprise (Placeholder)
|
||||
"""
|
||||
|
||||
from .base import PaymentConnector, PaymentMethod, Charge, Refund, Subscription
|
||||
|
||||
|
||||
class PayPalConnector(PaymentConnector):
|
||||
"""PayPal payment processor connector"""
|
||||
|
||||
def __init__(self, client, config, paypal_client_id, paypal_secret):
|
||||
# TODO: Implement PayPal connector
|
||||
raise NotImplementedError("PayPal connector not yet implemented")
|
||||
|
||||
async def create_charge(self, amount, currency, source, description=None, metadata=None):
|
||||
# TODO: Implement PayPal charge creation
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_refund(self, charge_id, amount=None, reason=None):
|
||||
# TODO: Implement PayPal refund
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_payment_method(self, type, card, metadata=None):
|
||||
# TODO: Implement PayPal payment method
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_subscription(self, customer, items, metadata=None):
|
||||
# TODO: Implement PayPal subscription
|
||||
raise NotImplementedError
|
||||
|
||||
async def cancel_subscription(self, subscription_id, at_period_end=True):
|
||||
# TODO: Implement PayPal subscription cancellation
|
||||
raise NotImplementedError
|
||||
@ -0,0 +1,33 @@
|
||||
"""
|
||||
Square payment connector for AITBC Enterprise (Placeholder)
|
||||
"""
|
||||
|
||||
from .base import PaymentConnector, PaymentMethod, Charge, Refund, Subscription
|
||||
|
||||
|
||||
class SquareConnector(PaymentConnector):
|
||||
"""Square payment processor connector"""
|
||||
|
||||
def __init__(self, client, config, square_access_token):
|
||||
# TODO: Implement Square connector
|
||||
raise NotImplementedError("Square connector not yet implemented")
|
||||
|
||||
async def create_charge(self, amount, currency, source, description=None, metadata=None):
|
||||
# TODO: Implement Square charge creation
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_refund(self, charge_id, amount=None, reason=None):
|
||||
# TODO: Implement Square refund
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_payment_method(self, type, card, metadata=None):
|
||||
# TODO: Implement Square payment method
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_subscription(self, customer, items, metadata=None):
|
||||
# TODO: Implement Square subscription
|
||||
raise NotImplementedError
|
||||
|
||||
async def cancel_subscription(self, subscription_id, at_period_end=True):
|
||||
# TODO: Implement Square subscription cancellation
|
||||
raise NotImplementedError
|
||||
@ -0,0 +1,489 @@
|
||||
"""
|
||||
Stripe payment connector for AITBC Enterprise
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
import stripe
|
||||
|
||||
from ..base import BaseConnector, OperationResult, Transaction
|
||||
from ..core import ConnectorConfig
|
||||
from .base import PaymentConnector, PaymentMethod, Charge, Refund, Subscription
|
||||
from ..exceptions import PaymentError, ValidationError
|
||||
|
||||
|
||||
class StripeConnector(PaymentConnector):
|
||||
"""Stripe payment processor connector"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: 'AITBCClient',
|
||||
config: ConnectorConfig,
|
||||
stripe_api_key: str,
|
||||
webhook_secret: Optional[str] = None
|
||||
):
|
||||
super().__init__(client, config)
|
||||
|
||||
# Stripe configuration
|
||||
self.stripe_api_key = stripe_api_key
|
||||
self.webhook_secret = webhook_secret
|
||||
|
||||
# Initialize Stripe client
|
||||
stripe.api_key = stripe_api_key
|
||||
stripe.api_version = "2023-10-16"
|
||||
|
||||
# Stripe-specific configuration
|
||||
self._stripe_config = {
|
||||
"api_key": stripe_api_key,
|
||||
"api_version": stripe.api_version,
|
||||
"connect_timeout": config.timeout,
|
||||
"read_timeout": config.timeout
|
||||
}
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
"""Initialize Stripe connector"""
|
||||
try:
|
||||
# Test Stripe connection
|
||||
await self._test_stripe_connection()
|
||||
|
||||
# Set up webhook handler
|
||||
if self.webhook_secret:
|
||||
await self._setup_webhook_handler()
|
||||
|
||||
self.logger.info("Stripe connector initialized")
|
||||
|
||||
except Exception as e:
|
||||
raise PaymentError(f"Failed to initialize Stripe: {e}")
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""Cleanup Stripe connector"""
|
||||
# No specific cleanup needed for Stripe
|
||||
pass
|
||||
|
||||
async def _execute_operation(
|
||||
self,
|
||||
operation: str,
|
||||
data: Dict[str, Any],
|
||||
**kwargs
|
||||
) -> OperationResult:
|
||||
"""Execute Stripe-specific operations"""
|
||||
try:
|
||||
if operation == "create_charge":
|
||||
return await self._create_charge(data)
|
||||
elif operation == "create_refund":
|
||||
return await self._create_refund(data)
|
||||
elif operation == "create_payment_method":
|
||||
return await self._create_payment_method(data)
|
||||
elif operation == "create_customer":
|
||||
return await self._create_customer(data)
|
||||
elif operation == "create_subscription":
|
||||
return await self._create_subscription(data)
|
||||
elif operation == "cancel_subscription":
|
||||
return await self._cancel_subscription(data)
|
||||
elif operation == "retrieve_balance":
|
||||
return await self._retrieve_balance()
|
||||
else:
|
||||
raise ValidationError(f"Unknown operation: {operation}")
|
||||
|
||||
except stripe.error.StripeError as e:
|
||||
self.logger.error(f"Stripe error: {e}")
|
||||
return OperationResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
metadata={"stripe_error_code": getattr(e, 'code', None)}
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Operation failed: {e}")
|
||||
return OperationResult(
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def create_charge(
|
||||
self,
|
||||
amount: int,
|
||||
currency: str,
|
||||
source: str,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Charge:
|
||||
"""Create a charge"""
|
||||
result = await self.execute_operation(
|
||||
"create_charge",
|
||||
{
|
||||
"amount": amount,
|
||||
"currency": currency,
|
||||
"source": source,
|
||||
"description": description,
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
raise PaymentError(result.error)
|
||||
|
||||
return Charge.from_stripe_charge(result.data)
|
||||
|
||||
async def create_refund(
|
||||
self,
|
||||
charge_id: str,
|
||||
amount: Optional[int] = None,
|
||||
reason: Optional[str] = None
|
||||
) -> Refund:
|
||||
"""Create a refund"""
|
||||
result = await self.execute_operation(
|
||||
"create_refund",
|
||||
{
|
||||
"charge": charge_id,
|
||||
"amount": amount,
|
||||
"reason": reason
|
||||
}
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
raise PaymentError(result.error)
|
||||
|
||||
return Refund.from_stripe_refund(result.data)
|
||||
|
||||
async def create_payment_method(
|
||||
self,
|
||||
type: str,
|
||||
card: Dict[str, Any],
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> PaymentMethod:
|
||||
"""Create a payment method"""
|
||||
result = await self.execute_operation(
|
||||
"create_payment_method",
|
||||
{
|
||||
"type": type,
|
||||
"card": card,
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
raise PaymentError(result.error)
|
||||
|
||||
return PaymentMethod.from_stripe_payment_method(result.data)
|
||||
|
||||
async def create_subscription(
|
||||
self,
|
||||
customer: str,
|
||||
items: List[Dict[str, Any]],
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Subscription:
|
||||
"""Create a subscription"""
|
||||
result = await self.execute_operation(
|
||||
"create_subscription",
|
||||
{
|
||||
"customer": customer,
|
||||
"items": items,
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
raise PaymentError(result.error)
|
||||
|
||||
return Subscription.from_stripe_subscription(result.data)
|
||||
|
||||
async def cancel_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
at_period_end: bool = True
|
||||
) -> Subscription:
|
||||
"""Cancel a subscription"""
|
||||
result = await self.execute_operation(
|
||||
"cancel_subscription",
|
||||
{
|
||||
"subscription": subscription_id,
|
||||
"at_period_end": at_period_end
|
||||
}
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
raise PaymentError(result.error)
|
||||
|
||||
return Subscription.from_stripe_subscription(result.data)
|
||||
|
||||
async def retrieve_balance(self) -> Dict[str, Any]:
|
||||
"""Retrieve account balance"""
|
||||
result = await self.execute_operation("retrieve_balance", {})
|
||||
|
||||
if not result.success:
|
||||
raise PaymentError(result.error)
|
||||
|
||||
return result.data
|
||||
|
||||
async def verify_webhook(self, payload: bytes, signature: str) -> bool:
|
||||
"""Verify Stripe webhook signature"""
|
||||
try:
|
||||
stripe.WebhookSignature.verify_header(
|
||||
payload,
|
||||
signature,
|
||||
self.webhook_secret,
|
||||
300
|
||||
)
|
||||
return True
|
||||
except stripe.error.SignatureVerificationError:
|
||||
return False
|
||||
|
||||
async def handle_webhook(self, payload: bytes) -> Dict[str, Any]:
|
||||
"""Handle Stripe webhook"""
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload,
|
||||
None, # Already verified
|
||||
self.webhook_secret,
|
||||
300
|
||||
)
|
||||
|
||||
# Process event based on type
|
||||
result = await self._process_webhook_event(event)
|
||||
|
||||
return {
|
||||
"processed": True,
|
||||
"event_type": event.type,
|
||||
"event_id": event.id,
|
||||
"result": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Webhook processing failed: {e}")
|
||||
return {
|
||||
"processed": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Private methods
|
||||
|
||||
async def _test_stripe_connection(self):
|
||||
"""Test Stripe API connection"""
|
||||
try:
|
||||
# Use asyncio to run in thread
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, stripe.Balance.retrieve)
|
||||
except Exception as e:
|
||||
raise PaymentError(f"Stripe connection test failed: {e}")
|
||||
|
||||
async def _setup_webhook_handler(self):
|
||||
"""Setup webhook handler"""
|
||||
# Register webhook verification with base connector
|
||||
self.add_operation_handler("webhook.verified", self._handle_verified_webhook)
|
||||
|
||||
async def _create_charge(self, data: Dict[str, Any]) -> OperationResult:
|
||||
"""Create Stripe charge"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
charge = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: stripe.Charge.create(**data)
|
||||
)
|
||||
|
||||
return OperationResult(
|
||||
success=True,
|
||||
data=charge.to_dict(),
|
||||
metadata={"charge_id": charge.id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise PaymentError(f"Failed to create charge: {e}")
|
||||
|
||||
async def _create_refund(self, data: Dict[str, Any]) -> OperationResult:
|
||||
"""Create Stripe refund"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
refund = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: stripe.Refund.create(**data)
|
||||
)
|
||||
|
||||
return OperationResult(
|
||||
success=True,
|
||||
data=refund.to_dict(),
|
||||
metadata={"refund_id": refund.id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise PaymentError(f"Failed to create refund: {e}")
|
||||
|
||||
async def _create_payment_method(self, data: Dict[str, Any]) -> OperationResult:
|
||||
"""Create Stripe payment method"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
pm = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: stripe.PaymentMethod.create(**data)
|
||||
)
|
||||
|
||||
return OperationResult(
|
||||
success=True,
|
||||
data=pm.to_dict(),
|
||||
metadata={"payment_method_id": pm.id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise PaymentError(f"Failed to create payment method: {e}")
|
||||
|
||||
async def _create_customer(self, data: Dict[str, Any]) -> OperationResult:
|
||||
"""Create Stripe customer"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
customer = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: stripe.Customer.create(**data)
|
||||
)
|
||||
|
||||
return OperationResult(
|
||||
success=True,
|
||||
data=customer.to_dict(),
|
||||
metadata={"customer_id": customer.id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise PaymentError(f"Failed to create customer: {e}")
|
||||
|
||||
async def _create_subscription(self, data: Dict[str, Any]) -> OperationResult:
|
||||
"""Create Stripe subscription"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
subscription = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: stripe.Subscription.create(**data)
|
||||
)
|
||||
|
||||
return OperationResult(
|
||||
success=True,
|
||||
data=subscription.to_dict(),
|
||||
metadata={"subscription_id": subscription.id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise PaymentError(f"Failed to create subscription: {e}")
|
||||
|
||||
async def _cancel_subscription(self, data: Dict[str, Any]) -> OperationResult:
|
||||
"""Cancel Stripe subscription"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
subscription = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: stripe.Subscription.retrieve(data["subscription"])
|
||||
)
|
||||
|
||||
subscription = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: subscription.cancel(at_period_end=data.get("at_period_end", True))
|
||||
)
|
||||
|
||||
return OperationResult(
|
||||
success=True,
|
||||
data=subscription.to_dict(),
|
||||
metadata={"subscription_id": subscription.id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise PaymentError(f"Failed to cancel subscription: {e}")
|
||||
|
||||
async def _retrieve_balance(self) -> OperationResult:
|
||||
"""Retrieve Stripe balance"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
balance = await loop.run_in_executor(None, stripe.Balance.retrieve)
|
||||
|
||||
return OperationResult(
|
||||
success=True,
|
||||
data=balance.to_dict()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise PaymentError(f"Failed to retrieve balance: {e}")
|
||||
|
||||
async def _process_webhook_event(self, event) -> Dict[str, Any]:
|
||||
"""Process webhook event"""
|
||||
event_type = event.type
|
||||
|
||||
if event_type.startswith("charge."):
|
||||
return await self._handle_charge_event(event)
|
||||
elif event_type.startswith("payment_method."):
|
||||
return await self._handle_payment_method_event(event)
|
||||
elif event_type.startswith("customer."):
|
||||
return await self._handle_customer_event(event)
|
||||
elif event_type.startswith("invoice."):
|
||||
return await self._handle_invoice_event(event)
|
||||
else:
|
||||
self.logger.info(f"Unhandled webhook event type: {event_type}")
|
||||
return {"status": "ignored"}
|
||||
|
||||
async def _handle_charge_event(self, event) -> Dict[str, Any]:
|
||||
"""Handle charge-related webhook events"""
|
||||
charge = event.data.object
|
||||
|
||||
# Emit to AITBC
|
||||
await self.client.post(
|
||||
"/webhooks/stripe/charge",
|
||||
json={
|
||||
"event_id": event.id,
|
||||
"event_type": event.type,
|
||||
"charge": charge.to_dict()
|
||||
}
|
||||
)
|
||||
|
||||
return {"status": "processed", "charge_id": charge.id}
|
||||
|
||||
async def _handle_payment_method_event(self, event) -> Dict[str, Any]:
|
||||
"""Handle payment method webhook events"""
|
||||
pm = event.data.object
|
||||
|
||||
await self.client.post(
|
||||
"/webhooks/stripe/payment_method",
|
||||
json={
|
||||
"event_id": event.id,
|
||||
"event_type": event.type,
|
||||
"payment_method": pm.to_dict()
|
||||
}
|
||||
)
|
||||
|
||||
return {"status": "processed", "payment_method_id": pm.id}
|
||||
|
||||
async def _handle_customer_event(self, event) -> Dict[str, Any]:
|
||||
"""Handle customer webhook events"""
|
||||
customer = event.data.object
|
||||
|
||||
await self.client.post(
|
||||
"/webhooks/stripe/customer",
|
||||
json={
|
||||
"event_id": event.id,
|
||||
"event_type": event.type,
|
||||
"customer": customer.to_dict()
|
||||
}
|
||||
)
|
||||
|
||||
return {"status": "processed", "customer_id": customer.id}
|
||||
|
||||
async def _handle_invoice_event(self, event) -> Dict[str, Any]:
|
||||
"""Handle invoice webhook events"""
|
||||
invoice = event.data.object
|
||||
|
||||
await self.client.post(
|
||||
"/webhooks/stripe/invoice",
|
||||
json={
|
||||
"event_id": event.id,
|
||||
"event_type": event.type,
|
||||
"invoice": invoice.to_dict()
|
||||
}
|
||||
)
|
||||
|
||||
return {"status": "processed", "invoice_id": invoice.id}
|
||||
|
||||
async def _handle_verified_webhook(self, data: Dict[str, Any]):
|
||||
"""Handle verified webhook"""
|
||||
self.logger.info(f"Webhook verified: {data}")
|
||||
@ -0,0 +1,189 @@
|
||||
"""
|
||||
Rate limiting for AITBC Enterprise Connectors
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .core import ConnectorConfig
|
||||
from .exceptions import RateLimitError
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitInfo:
|
||||
"""Rate limit information"""
|
||||
limit: int
|
||||
remaining: int
|
||||
reset_time: float
|
||||
retry_after: Optional[int] = None
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
"""Token bucket rate limiter"""
|
||||
|
||||
def __init__(self, rate: float, capacity: int):
|
||||
self.rate = rate # Tokens per second
|
||||
self.capacity = capacity
|
||||
self.tokens = capacity
|
||||
self.last_refill = time.time()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self, tokens: int = 1) -> bool:
|
||||
"""Acquire tokens from bucket"""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
|
||||
# Refill tokens
|
||||
elapsed = now - self.last_refill
|
||||
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
|
||||
self.last_refill = now
|
||||
|
||||
# Check if enough tokens
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def wait_for_token(self, tokens: int = 1):
|
||||
"""Wait until token is available"""
|
||||
while not await self.acquire(tokens):
|
||||
# Calculate wait time
|
||||
wait_time = (tokens - self.tokens) / self.rate
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
|
||||
class SlidingWindowCounter:
|
||||
"""Sliding window rate limiter"""
|
||||
|
||||
def __init__(self, limit: int, window: int):
|
||||
self.limit = limit
|
||||
self.window = window # Window size in seconds
|
||||
self.requests = deque()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def is_allowed(self) -> bool:
|
||||
"""Check if request is allowed"""
|
||||
async with self._lock:
|
||||
now = time.time()
|
||||
|
||||
# Remove old requests
|
||||
while self.requests and self.requests[0] <= now - self.window:
|
||||
self.requests.popleft()
|
||||
|
||||
# Check if under limit
|
||||
if len(self.requests) < self.limit:
|
||||
self.requests.append(now)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def wait_for_slot(self):
|
||||
"""Wait until request slot is available"""
|
||||
while not await self.is_allowed():
|
||||
# Calculate wait time until oldest request expires
|
||||
if self.requests:
|
||||
wait_time = self.requests[0] + self.window - time.time()
|
||||
if wait_time > 0:
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter with multiple strategies"""
|
||||
|
||||
def __init__(self, config: ConnectorConfig):
|
||||
self.config = config
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
# Initialize rate limiters
|
||||
self._token_bucket = None
|
||||
self._sliding_window = None
|
||||
self._strategy = "token_bucket"
|
||||
|
||||
if config.rate_limit:
|
||||
# Default to token bucket with burst capacity
|
||||
burst = config.burst_limit or config.rate_limit * 2
|
||||
self._token_bucket = TokenBucket(
|
||||
rate=config.rate_limit,
|
||||
capacity=burst
|
||||
)
|
||||
|
||||
# Track rate limit info from server
|
||||
self._server_limits: Dict[str, RateLimitInfo] = {}
|
||||
|
||||
async def acquire(self, endpoint: str = None) -> None:
|
||||
"""Acquire rate limit permit"""
|
||||
if self._strategy == "token_bucket" and self._token_bucket:
|
||||
await self._token_bucket.wait_for_token()
|
||||
elif self._strategy == "sliding_window" and self._sliding_window:
|
||||
await self._sliding_window.wait_for_slot()
|
||||
|
||||
# Check server-side limits
|
||||
if endpoint and endpoint in self._server_limits:
|
||||
limit_info = self._server_limits[endpoint]
|
||||
|
||||
if limit_info.remaining <= 0:
|
||||
wait_time = limit_info.reset_time - time.time()
|
||||
if wait_time > 0:
|
||||
raise RateLimitError(
|
||||
f"Rate limit exceeded for {endpoint}",
|
||||
retry_after=int(wait_time) + 1
|
||||
)
|
||||
|
||||
def update_server_limit(self, endpoint: str, headers: Dict[str, str]):
|
||||
"""Update rate limit info from server response"""
|
||||
# Parse common rate limit headers
|
||||
limit = headers.get("X-RateLimit-Limit")
|
||||
remaining = headers.get("X-RateLimit-Remaining")
|
||||
reset = headers.get("X-RateLimit-Reset")
|
||||
retry_after = headers.get("Retry-After")
|
||||
|
||||
if limit or remaining or reset:
|
||||
self._server_limits[endpoint] = RateLimitInfo(
|
||||
limit=int(limit) if limit else 0,
|
||||
remaining=int(remaining) if remaining else 0,
|
||||
reset_time=float(reset) if reset else time.time() + 3600,
|
||||
retry_after=int(retry_after) if retry_after else None
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Updated rate limit for {endpoint}: "
|
||||
f"{remaining}/{limit} remaining"
|
||||
)
|
||||
|
||||
def get_limit_info(self, endpoint: str = None) -> Optional[RateLimitInfo]:
|
||||
"""Get current rate limit info"""
|
||||
if endpoint and endpoint in self._server_limits:
|
||||
return self._server_limits[endpoint]
|
||||
|
||||
# Return configured limit if no server limit
|
||||
if self.config.rate_limit:
|
||||
return RateLimitInfo(
|
||||
limit=self.config.rate_limit,
|
||||
remaining=self.config.rate_limit, # Approximate
|
||||
reset_time=time.time() + 3600
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def set_strategy(self, strategy: str):
|
||||
"""Set rate limiting strategy"""
|
||||
if strategy not in ["token_bucket", "sliding_window", "none"]:
|
||||
raise ValueError(f"Unknown strategy: {strategy}")
|
||||
|
||||
self._strategy = strategy
|
||||
|
||||
def reset(self):
|
||||
"""Reset rate limiter state"""
|
||||
if self._token_bucket:
|
||||
self._token_bucket.tokens = self._token_bucket.capacity
|
||||
self._token_bucket.last_refill = time.time()
|
||||
|
||||
if self._sliding_window:
|
||||
self._sliding_window.requests.clear()
|
||||
|
||||
self._server_limits.clear()
|
||||
self.logger.info("Rate limiter reset")
|
||||
318
enterprise-connectors/python-sdk/aitbc_enterprise/validators.py
Normal file
318
enterprise-connectors/python-sdk/aitbc_enterprise/validators.py
Normal file
@ -0,0 +1,318 @@
|
||||
"""
|
||||
Validation utilities for AITBC Enterprise Connectors
|
||||
"""
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from .exceptions import ValidationError
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationRule:
|
||||
"""Validation rule definition"""
|
||||
name: str
|
||||
required: bool = True
|
||||
type: type = str
|
||||
min_length: Optional[int] = None
|
||||
max_length: Optional[int] = None
|
||||
pattern: Optional[str] = None
|
||||
min_value: Optional[Union[int, float]] = None
|
||||
max_value: Optional[Union[int, float]] = None
|
||||
allowed_values: Optional[List[Any]] = None
|
||||
custom_validator: Optional[callable] = None
|
||||
|
||||
|
||||
class BaseValidator(ABC):
|
||||
"""Abstract base class for validators"""
|
||||
|
||||
@abstractmethod
|
||||
async def validate(self, operation: str, data: Dict[str, Any]) -> bool:
|
||||
"""Validate operation data"""
|
||||
pass
|
||||
|
||||
|
||||
class SchemaValidator(BaseValidator):
|
||||
"""Schema-based validator"""
|
||||
|
||||
def __init__(self, schemas: Dict[str, Dict[str, ValidationRule]]):
|
||||
self.schemas = schemas
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
async def validate(self, operation: str, data: Dict[str, Any]) -> bool:
|
||||
"""Validate data against schema"""
|
||||
if operation not in self.schemas:
|
||||
self.logger.warning(f"No schema for operation: {operation}")
|
||||
return True
|
||||
|
||||
schema = self.schemas[operation]
|
||||
errors = []
|
||||
|
||||
# Validate each field
|
||||
for field_name, rule in schema.items():
|
||||
try:
|
||||
self._validate_field(field_name, data.get(field_name), rule)
|
||||
except ValidationError as e:
|
||||
errors.append(f"{field_name}: {str(e)}")
|
||||
|
||||
# Check for unexpected fields
|
||||
allowed_fields = set(schema.keys())
|
||||
provided_fields = set(data.keys())
|
||||
unexpected = provided_fields - allowed_fields
|
||||
|
||||
if unexpected:
|
||||
self.logger.warning(f"Unexpected fields: {unexpected}")
|
||||
|
||||
if errors:
|
||||
raise ValidationError(f"Validation failed: {'; '.join(errors)}")
|
||||
|
||||
return True
|
||||
|
||||
def _validate_field(self, name: str, value: Any, rule: ValidationRule):
|
||||
"""Validate a single field"""
|
||||
# Check required
|
||||
if rule.required and value is None:
|
||||
raise ValidationError(f"{name} is required")
|
||||
|
||||
# Skip validation if not required and value is None
|
||||
if not rule.required and value is None:
|
||||
return
|
||||
|
||||
# Type validation
|
||||
if not isinstance(value, rule.type):
|
||||
try:
|
||||
value = rule.type(value)
|
||||
except (ValueError, TypeError):
|
||||
raise ValidationError(f"{name} must be of type {rule.type.__name__}")
|
||||
|
||||
# String validations
|
||||
if isinstance(value, str):
|
||||
if rule.min_length and len(value) < rule.min_length:
|
||||
raise ValidationError(f"{name} must be at least {rule.min_length} characters")
|
||||
|
||||
if rule.max_length and len(value) > rule.max_length:
|
||||
raise ValidationError(f"{name} must be at most {rule.max_length} characters")
|
||||
|
||||
if rule.pattern and not re.match(rule.pattern, value):
|
||||
raise ValidationError(f"{name} does not match required pattern")
|
||||
|
||||
# Numeric validations
|
||||
if isinstance(value, (int, float)):
|
||||
if rule.min_value is not None and value < rule.min_value:
|
||||
raise ValidationError(f"{name} must be at least {rule.min_value}")
|
||||
|
||||
if rule.max_value is not None and value > rule.max_value:
|
||||
raise ValidationError(f"{name} must be at most {rule.max_value}")
|
||||
|
||||
# Allowed values
|
||||
if rule.allowed_values and value not in rule.allowed_values:
|
||||
raise ValidationError(f"{name} must be one of: {rule.allowed_values}")
|
||||
|
||||
# Custom validator
|
||||
if rule.custom_validator:
|
||||
try:
|
||||
if not rule.custom_validator(value):
|
||||
raise ValidationError(f"{name} failed custom validation")
|
||||
except Exception as e:
|
||||
raise ValidationError(f"{name} validation error: {str(e)}")
|
||||
|
||||
|
||||
class PaymentValidator(SchemaValidator):
|
||||
"""Validator for payment operations"""
|
||||
|
||||
def __init__(self):
|
||||
schemas = {
|
||||
"create_charge": {
|
||||
"amount": ValidationRule(
|
||||
name="amount",
|
||||
type=int,
|
||||
min_value=50, # Minimum $0.50
|
||||
max_value=99999999, # Maximum $999,999.99
|
||||
custom_validator=lambda x: x % 1 == 0 # Must be whole cents
|
||||
),
|
||||
"currency": ValidationRule(
|
||||
name="currency",
|
||||
type=str,
|
||||
min_length=3,
|
||||
max_length=3,
|
||||
pattern=r"^[A-Z]{3}$",
|
||||
allowed_values=["USD", "EUR", "GBP", "JPY", "CAD", "AUD"]
|
||||
),
|
||||
"source": ValidationRule(
|
||||
name="source",
|
||||
type=str,
|
||||
min_length=1,
|
||||
max_length=255
|
||||
),
|
||||
"description": ValidationRule(
|
||||
name="description",
|
||||
type=str,
|
||||
required=False,
|
||||
max_length=1000
|
||||
)
|
||||
},
|
||||
"create_refund": {
|
||||
"charge": ValidationRule(
|
||||
name="charge",
|
||||
type=str,
|
||||
min_length=1,
|
||||
pattern=r"^ch_[a-zA-Z0-9]+$"
|
||||
),
|
||||
"amount": ValidationRule(
|
||||
name="amount",
|
||||
type=int,
|
||||
required=False,
|
||||
min_value=50,
|
||||
custom_validator=lambda x: x % 1 == 0
|
||||
),
|
||||
"reason": ValidationRule(
|
||||
name="reason",
|
||||
type=str,
|
||||
required=False,
|
||||
allowed_values=["duplicate", "fraudulent", "requested_by_customer"]
|
||||
)
|
||||
},
|
||||
"create_payment_method": {
|
||||
"type": ValidationRule(
|
||||
name="type",
|
||||
type=str,
|
||||
allowed_values=["card", "bank_account"]
|
||||
),
|
||||
"card": ValidationRule(
|
||||
name="card",
|
||||
type=dict,
|
||||
custom_validator=lambda x: all(k in x for k in ["number", "exp_month", "exp_year"])
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
super().__init__(schemas)
|
||||
|
||||
|
||||
class ERPValidator(SchemaValidator):
|
||||
"""Validator for ERP operations"""
|
||||
|
||||
def __init__(self):
|
||||
schemas = {
|
||||
"create_customer": {
|
||||
"name": ValidationRule(
|
||||
name="name",
|
||||
type=str,
|
||||
min_length=1,
|
||||
max_length=100
|
||||
),
|
||||
"email": ValidationRule(
|
||||
name="email",
|
||||
type=str,
|
||||
pattern=r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
),
|
||||
"phone": ValidationRule(
|
||||
name="phone",
|
||||
type=str,
|
||||
required=False,
|
||||
pattern=r"^\+?[1-9]\d{1,14}$"
|
||||
),
|
||||
"address": ValidationRule(
|
||||
name="address",
|
||||
type=dict,
|
||||
required=False
|
||||
)
|
||||
},
|
||||
"create_order": {
|
||||
"customer_id": ValidationRule(
|
||||
name="customer_id",
|
||||
type=str,
|
||||
min_length=1
|
||||
),
|
||||
"items": ValidationRule(
|
||||
name="items",
|
||||
type=list,
|
||||
min_length=1,
|
||||
custom_validator=lambda x: all(isinstance(i, dict) and "product_id" in i and "quantity" in i for i in x)
|
||||
),
|
||||
"currency": ValidationRule(
|
||||
name="currency",
|
||||
type=str,
|
||||
pattern=r"^[A-Z]{3}$"
|
||||
)
|
||||
},
|
||||
"sync_data": {
|
||||
"entity_type": ValidationRule(
|
||||
name="entity_type",
|
||||
type=str,
|
||||
allowed_values=["customers", "orders", "products", "invoices"]
|
||||
),
|
||||
"since": ValidationRule(
|
||||
name="since",
|
||||
type=str,
|
||||
required=False,
|
||||
pattern=r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z$"
|
||||
),
|
||||
"limit": ValidationRule(
|
||||
name="limit",
|
||||
type=int,
|
||||
required=False,
|
||||
min_value=1,
|
||||
max_value=1000
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
super().__init__(schemas)
|
||||
|
||||
|
||||
class CompositeValidator(BaseValidator):
|
||||
"""Combines multiple validators"""
|
||||
|
||||
def __init__(self, validators: List[BaseValidator]):
|
||||
self.validators = validators
|
||||
|
||||
async def validate(self, operation: str, data: Dict[str, Any]) -> bool:
|
||||
"""Run all validators"""
|
||||
errors = []
|
||||
|
||||
for validator in self.validators:
|
||||
try:
|
||||
await validator.validate(operation, data)
|
||||
except ValidationError as e:
|
||||
errors.append(str(e))
|
||||
|
||||
if errors:
|
||||
raise ValidationError(f"Validation failed: {'; '.join(errors)}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Common validation functions
|
||||
def validate_email(email: str) -> bool:
|
||||
"""Validate email address"""
|
||||
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
return re.match(pattern, email) is not None
|
||||
|
||||
|
||||
def validate_phone(phone: str) -> bool:
|
||||
"""Validate phone number (E.164 format)"""
|
||||
pattern = r"^\+?[1-9]\d{1,14}$"
|
||||
return re.match(pattern, phone) is not None
|
||||
|
||||
|
||||
def validate_amount(amount: int) -> bool:
|
||||
"""Validate amount in cents"""
|
||||
return amount > 0 and amount % 1 == 0
|
||||
|
||||
|
||||
def validate_currency(currency: str) -> bool:
|
||||
"""Validate currency code"""
|
||||
return len(currency) == 3 and currency.isupper()
|
||||
|
||||
|
||||
def validate_timestamp(timestamp: str) -> bool:
|
||||
"""Validate ISO 8601 timestamp"""
|
||||
try:
|
||||
datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
309
enterprise-connectors/python-sdk/aitbc_enterprise/webhooks.py
Normal file
309
enterprise-connectors/python-sdk/aitbc_enterprise/webhooks.py
Normal file
@ -0,0 +1,309 @@
|
||||
"""
|
||||
Webhook handling for AITBC Enterprise Connectors
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, Callable, List, Awaitable
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .exceptions import WebhookError
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebhookEvent:
|
||||
"""Webhook event representation"""
|
||||
id: str
|
||||
type: str
|
||||
source: str
|
||||
timestamp: datetime
|
||||
data: Dict[str, Any]
|
||||
signature: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": self.type,
|
||||
"source": self.source,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"data": self.data,
|
||||
"signature": self.signature
|
||||
}
|
||||
|
||||
|
||||
class WebhookHandler:
|
||||
"""Handles webhook processing and verification"""
|
||||
|
||||
def __init__(self, secret: str = None):
|
||||
self.secret = secret
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
# Event handlers
|
||||
self._handlers: Dict[str, List[Callable]] = {}
|
||||
|
||||
# Processing state
|
||||
self._processing = False
|
||||
self._queue: asyncio.Queue = None
|
||||
self._worker_task = None
|
||||
|
||||
async def setup(self, endpoint: str, secret: str = None):
|
||||
"""Setup webhook handler"""
|
||||
if secret:
|
||||
self.secret = secret
|
||||
|
||||
# Initialize queue and worker
|
||||
self._queue = asyncio.Queue(maxsize=1000)
|
||||
self._worker_task = asyncio.create_task(self._process_queue())
|
||||
|
||||
self.logger.info(f"Webhook handler setup for endpoint: {endpoint}")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup webhook handler"""
|
||||
if self._worker_task:
|
||||
self._worker_task.cancel()
|
||||
try:
|
||||
await self._worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.logger.info("Webhook handler cleaned up")
|
||||
|
||||
def add_handler(self, event_type: str, handler: Callable[[WebhookEvent], Awaitable[None]]):
|
||||
"""Add handler for specific event type"""
|
||||
if event_type not in self._handlers:
|
||||
self._handlers[event_type] = []
|
||||
self._handlers[event_type].append(handler)
|
||||
|
||||
def remove_handler(self, event_type: str, handler: Callable):
|
||||
"""Remove handler for specific event type"""
|
||||
if event_type in self._handlers:
|
||||
try:
|
||||
self._handlers[event_type].remove(handler)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async def verify(self, payload: bytes, signature: str, algorithm: str = "sha256") -> bool:
|
||||
"""Verify webhook signature"""
|
||||
if not self.secret:
|
||||
self.logger.warning("No webhook secret configured, skipping verification")
|
||||
return True
|
||||
|
||||
try:
|
||||
expected_signature = hmac.new(
|
||||
self.secret.encode(),
|
||||
payload,
|
||||
getattr(hashlib, algorithm)
|
||||
).hexdigest()
|
||||
|
||||
# Compare signatures securely
|
||||
return hmac.compare_digest(expected_signature, signature)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Webhook verification failed: {e}")
|
||||
return False
|
||||
|
||||
async def handle(self, payload: bytes, signature: str = None) -> Dict[str, Any]:
|
||||
"""Handle incoming webhook"""
|
||||
try:
|
||||
# Parse payload
|
||||
data = json.loads(payload.decode())
|
||||
|
||||
# Create event
|
||||
event = WebhookEvent(
|
||||
id=data.get("id", f"evt_{int(datetime.utcnow().timestamp())}"),
|
||||
type=data.get("type", "unknown"),
|
||||
source=data.get("source", "unknown"),
|
||||
timestamp=datetime.fromisoformat(data.get("timestamp", datetime.utcnow().isoformat())),
|
||||
data=data.get("data", {}),
|
||||
signature=signature
|
||||
)
|
||||
|
||||
# Verify signature if provided
|
||||
if signature and not await self.verify(payload, signature):
|
||||
raise WebhookError("Invalid webhook signature")
|
||||
|
||||
# Queue for processing
|
||||
if self._queue:
|
||||
await self._queue.put(event)
|
||||
return {
|
||||
"status": "queued",
|
||||
"event_id": event.id
|
||||
}
|
||||
else:
|
||||
# Process immediately
|
||||
result = await self._process_event(event)
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise WebhookError(f"Invalid JSON payload: {e}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Webhook handling failed: {e}")
|
||||
raise WebhookError(f"Processing failed: {e}")
|
||||
|
||||
async def _process_queue(self):
|
||||
"""Process webhook events from queue"""
|
||||
while True:
|
||||
try:
|
||||
event = await self._queue.get()
|
||||
await self._process_event(event)
|
||||
self._queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing webhook event: {e}")
|
||||
|
||||
async def _process_event(self, event: WebhookEvent) -> Dict[str, Any]:
|
||||
"""Process a single webhook event"""
|
||||
try:
|
||||
self.logger.debug(f"Processing webhook event: {event.type}")
|
||||
|
||||
# Get handlers for event type
|
||||
handlers = self._handlers.get(event.type, [])
|
||||
|
||||
# Also check for wildcard handlers
|
||||
wildcard_handlers = self._handlers.get("*", [])
|
||||
handlers.extend(wildcard_handlers)
|
||||
|
||||
if not handlers:
|
||||
self.logger.warning(f"No handlers for event type: {event.type}")
|
||||
return {
|
||||
"status": "ignored",
|
||||
"event_id": event.id,
|
||||
"message": "No handlers registered"
|
||||
}
|
||||
|
||||
# Execute handlers
|
||||
tasks = []
|
||||
for handler in handlers:
|
||||
tasks.append(handler(event))
|
||||
|
||||
# Wait for all handlers to complete
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check for errors
|
||||
errors = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
errors.append(str(result))
|
||||
self.logger.error(f"Handler {i} failed: {result}")
|
||||
|
||||
return {
|
||||
"status": "processed" if not errors else "partial",
|
||||
"event_id": event.id,
|
||||
"handlers_count": len(handlers),
|
||||
"errors_count": len(errors),
|
||||
"errors": errors if errors else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process webhook event: {e}")
|
||||
return {
|
||||
"status": "failed",
|
||||
"event_id": event.id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
class StripeWebhookHandler(WebhookHandler):
|
||||
"""Stripe-specific webhook handler"""
|
||||
|
||||
def __init__(self, secret: str):
|
||||
super().__init__(secret)
|
||||
self._setup_default_handlers()
|
||||
|
||||
def _setup_default_handlers(self):
|
||||
"""Setup default Stripe event handlers"""
|
||||
self.add_handler("charge.succeeded", self._handle_charge_succeeded)
|
||||
self.add_handler("charge.failed", self._handle_charge_failed)
|
||||
self.add_handler("payment_method.attached", self._handle_payment_method_attached)
|
||||
self.add_handler("invoice.payment_succeeded", self._handle_invoice_succeeded)
|
||||
|
||||
async def verify(self, payload: bytes, signature: str) -> bool:
|
||||
"""Verify Stripe webhook signature"""
|
||||
try:
|
||||
import stripe
|
||||
|
||||
stripe.WebhookSignature.verify_header(
|
||||
payload,
|
||||
signature,
|
||||
self.secret,
|
||||
300 # 5 minutes tolerance
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Stripe webhook verification failed: {e}")
|
||||
return False
|
||||
|
||||
async def _handle_charge_succeeded(self, event: WebhookEvent):
|
||||
"""Handle successful charge"""
|
||||
charge = event.data.get("object", {})
|
||||
self.logger.info(f"Charge succeeded: {charge.get('id')} - ${charge.get('amount', 0) / 100:.2f}")
|
||||
|
||||
async def _handle_charge_failed(self, event: WebhookEvent):
|
||||
"""Handle failed charge"""
|
||||
charge = event.data.get("object", {})
|
||||
self.logger.warning(f"Charge failed: {charge.get('id')} - {charge.get('failure_message')}")
|
||||
|
||||
async def _handle_payment_method_attached(self, event: WebhookEvent):
|
||||
"""Handle payment method attachment"""
|
||||
pm = event.data.get("object", {})
|
||||
self.logger.info(f"Payment method attached: {pm.get('id')} - {pm.get('type')}")
|
||||
|
||||
async def _handle_invoice_succeeded(self, event: WebhookEvent):
|
||||
"""Handle successful invoice payment"""
|
||||
invoice = event.data.get("object", {})
|
||||
self.logger.info(f"Invoice paid: {invoice.get('id')} - ${invoice.get('amount_paid', 0) / 100:.2f}")
|
||||
|
||||
|
||||
class WebhookServer:
|
||||
"""Simple webhook server for testing"""
|
||||
|
||||
def __init__(self, handler: WebhookHandler, port: int = 8080):
|
||||
self.handler = handler
|
||||
self.port = port
|
||||
self.server = None
|
||||
self.logger = __import__('logging').getLogger(f"aitbc.{self.__class__.__name__}")
|
||||
|
||||
async def start(self):
|
||||
"""Start webhook server"""
|
||||
from aiohttp import web
|
||||
|
||||
async def handle_webhook(request):
|
||||
# Get signature from header
|
||||
signature = request.headers.get("Stripe-Signature") or request.headers.get("X-Signature")
|
||||
|
||||
# Read payload
|
||||
payload = await request.read()
|
||||
|
||||
try:
|
||||
# Handle webhook
|
||||
result = await self.handler.handle(payload, signature)
|
||||
return web.json_response(result)
|
||||
except WebhookError as e:
|
||||
return web.json_response(
|
||||
{"error": str(e)},
|
||||
status=400
|
||||
)
|
||||
|
||||
# Create app
|
||||
app = web.Application()
|
||||
app.router.add_post("/webhook", handle_webhook)
|
||||
|
||||
# Start server
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "localhost", self.port)
|
||||
await site.start()
|
||||
|
||||
self.server = runner
|
||||
self.logger.info(f"Webhook server started on port {self.port}")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop webhook server"""
|
||||
if self.server:
|
||||
await self.server.cleanup()
|
||||
self.logger.info("Webhook server stopped")
|
||||
Reference in New Issue
Block a user