diff --git a/aitbc/__init__.py b/aitbc/__init__.py index 95762d33..ca7bcb2a 100644 --- a/aitbc/__init__.py +++ b/aitbc/__init__.py @@ -61,8 +61,52 @@ from .json_utils import ( flatten_json, ) from .http_client import AITBCHTTPClient +from .config import BaseAITBCConfig, AITBCConfig +from .decorators import ( + retry, + timing, + cache_result, + validate_args, + handle_exceptions, + async_timing, +) +from .validation import ( + validate_address, + validate_hash, + validate_url, + validate_port, + validate_email, + validate_non_empty, + validate_positive_number, + validate_range, + validate_chain_id, + validate_uuid, +) +from .async_helpers import ( + run_sync, + gather_with_concurrency, + run_with_timeout, + batch_process, + sync_to_async, + async_to_sync, + retry_async, + wait_for_condition, +) +from .database import ( + DatabaseConnection, + get_database_connection, + ensure_database, + vacuum_database, + get_table_info, + table_exists, +) +from .monitoring import ( + MetricsCollector, + PerformanceTimer, + HealthChecker, +) -__version__ = "0.4.0" +__version__ = "0.6.0" __all__ = [ # Logging "get_logger", @@ -120,4 +164,45 @@ __all__ = [ "flatten_json", # HTTP client "AITBCHTTPClient", + # Configuration + "BaseAITBCConfig", + "AITBCConfig", + # Decorators + "retry", + "timing", + "cache_result", + "validate_args", + "handle_exceptions", + "async_timing", + # Validators + "validate_address", + "validate_hash", + "validate_url", + "validate_port", + "validate_email", + "validate_non_empty", + "validate_positive_number", + "validate_range", + "validate_chain_id", + "validate_uuid", + # Async helpers + "run_sync", + "gather_with_concurrency", + "run_with_timeout", + "batch_process", + "sync_to_async", + "async_to_sync", + "retry_async", + "wait_for_condition", + # Database + "DatabaseConnection", + "get_database_connection", + "ensure_database", + "vacuum_database", + "get_table_info", + "table_exists", + # Monitoring + "MetricsCollector", + "PerformanceTimer", + "HealthChecker", ] diff --git a/aitbc/async_helpers.py b/aitbc/async_helpers.py new file mode 100644 index 00000000..54216147 --- /dev/null +++ b/aitbc/async_helpers.py @@ -0,0 +1,190 @@ +""" +AITBC Async Helpers +Async utilities for AITBC applications +""" + +import asyncio +from typing import Coroutine, Any, List, TypeVar, Callable +from functools import wraps + +T = TypeVar('T') + + +async def run_sync(coro: Coroutine[Any, Any, T]) -> T: + """ + Run a coroutine from synchronous code. + + Args: + coro: Coroutine to run + + Returns: + Result of the coroutine + """ + return await asyncio.create_task(coro) + + +async def gather_with_concurrency( + coros: List[Coroutine[Any, Any, T]], + limit: int = 10 +) -> List[T]: + """ + Gather coroutines with concurrency limit. + + Args: + coros: List of coroutines to execute + limit: Maximum concurrent coroutines + + Returns: + List of results from all coroutines + """ + semaphore = asyncio.Semaphore(limit) + + async def limited_coro(coro: Coroutine[Any, Any, T]) -> T: + async with semaphore: + return await coro + + limited_coros = [limited_coro(coro) for coro in coros] + return await asyncio.gather(*limited_coros) + + +async def run_with_timeout( + coro: Coroutine[Any, Any, T], + timeout: float, + default: T = None +) -> T: + """ + Run a coroutine with a timeout. + + Args: + coro: Coroutine to run + timeout: Timeout in seconds + default: Default value if timeout occurs + + Returns: + Result of coroutine or default value on timeout + """ + try: + return await asyncio.wait_for(coro, timeout=timeout) + except asyncio.TimeoutError: + return default + + +async def batch_process( + items: List[Any], + process_func: Callable[[Any], Coroutine[Any, Any, T]], + batch_size: int = 10, + delay: float = 0.1 +) -> List[T]: + """ + Process items in batches with delay between batches. + + Args: + items: Items to process + process_func: Async function to process each item + batch_size: Number of items per batch + delay: Delay between batches in seconds + + Returns: + List of results + """ + results = [] + for i in range(0, len(items), batch_size): + batch = items[i:i + batch_size] + batch_results = await asyncio.gather(*[process_func(item) for item in batch]) + results.extend(batch_results) + + if i + batch_size < len(items): + await asyncio.sleep(delay) + + return results + + +def sync_to_async(func: Callable) -> Callable: + """ + Decorator to convert a synchronous function to async. + + Args: + func: Synchronous function to convert + + Returns: + Async wrapper function + """ + @wraps(func) + async def wrapper(*args, **kwargs): + return func(*args, **kwargs) + return wrapper + + +def async_to_sync(func: Callable) -> Callable: + """ + Decorator to convert an async function to sync. + + Args: + func: Async function to convert + + Returns: + Synchronous wrapper function + """ + @wraps(func) + def wrapper(*args, **kwargs): + return asyncio.run(func(*args, **kwargs)) + return wrapper + + +async def retry_async( + coro_func: Callable, + max_attempts: int = 3, + delay: float = 1.0, + backoff: float = 2.0 +) -> Any: + """ + Retry an async coroutine with exponential backoff. + + Args: + coro_func: Function that returns a coroutine + max_attempts: Maximum retry attempts + delay: Initial delay in seconds + backoff: Multiplier for delay after each retry + + Returns: + Result of the coroutine + """ + last_exception = None + current_delay = delay + + for attempt in range(max_attempts): + try: + return await coro_func() + except Exception as e: + last_exception = e + if attempt < max_attempts - 1: + await asyncio.sleep(current_delay) + current_delay *= backoff + + raise last_exception + + +async def wait_for_condition( + condition: Callable[[], Coroutine[Any, Any, bool]], + timeout: float = 30.0, + check_interval: float = 0.5 +) -> bool: + """ + Wait for a condition to become true. + + Args: + condition: Async function that returns a boolean + timeout: Maximum wait time in seconds + check_interval: Time between checks in seconds + + Returns: + True if condition became true, False if timeout + """ + start_time = asyncio.get_event_loop().time() + + while asyncio.get_event_loop().time() - start_time < timeout: + if await condition(): + return True + await asyncio.sleep(check_interval) + + return False diff --git a/aitbc/config.py b/aitbc/config.py new file mode 100644 index 00000000..d61f3261 --- /dev/null +++ b/aitbc/config.py @@ -0,0 +1,80 @@ +""" +AITBC Configuration Classes +Base configuration classes for AITBC applications +""" + +from pathlib import Path +from typing import Optional +from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic import Field + +from .constants import DATA_DIR, CONFIG_DIR, LOG_DIR, ENV_FILE + + +class BaseAITBCConfig(BaseSettings): + """ + Base configuration class for all AITBC applications. + Provides common AITBC-specific settings and environment file loading. + """ + + model_config = SettingsConfigDict( + env_file=str(ENV_FILE), + env_file_encoding="utf-8", + case_sensitive=False, + extra="ignore" + ) + + # AITBC system directories + data_dir: Path = Field(default=DATA_DIR, description="AITBC data directory") + config_dir: Path = Field(default=CONFIG_DIR, description="AITBC configuration directory") + log_dir: Path = Field(default=LOG_DIR, description="AITBC log directory") + + # Application settings + app_name: str = Field(default="AITBC Application", description="Application name") + app_version: str = Field(default="1.0.0", description="Application version") + environment: str = Field(default="development", description="Environment (development/staging/production)") + debug: bool = Field(default=False, description="Debug mode") + + # Logging settings + log_level: str = Field(default="INFO", description="Log level (DEBUG/INFO/WARNING/ERROR/CRITICAL)") + log_format: str = Field( + default="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + description="Log format string" + ) + + class Config: + """Pydantic configuration""" + env_file = str(ENV_FILE) + env_file_encoding = "utf-8" + case_sensitive = False + + +class AITBCConfig(BaseAITBCConfig): + """ + Standard AITBC configuration with common settings. + Inherits from BaseAITBCConfig and adds AITBC-specific fields. + """ + + # Server settings + host: str = Field(default="0.0.0.0", description="Server host address") + port: int = Field(default=8000, description="Server port") + workers: int = Field(default=1, description="Number of worker processes") + + # Database settings + database_url: Optional[str] = Field(default=None, description="Database connection URL") + database_pool_size: int = Field(default=10, description="Database connection pool size") + + # Redis settings (if applicable) + redis_url: Optional[str] = Field(default=None, description="Redis connection URL") + redis_max_connections: int = Field(default=10, description="Redis max connections") + redis_timeout: int = Field(default=5, description="Redis timeout in seconds") + + # Security settings + secret_key: Optional[str] = Field(default=None, description="Application secret key") + jwt_secret: Optional[str] = Field(default=None, description="JWT secret key") + jwt_algorithm: str = Field(default="HS256", description="JWT algorithm") + jwt_expiration_hours: int = Field(default=24, description="JWT token expiration in hours") + + # Performance settings + request_timeout: int = Field(default=30, description="Request timeout in seconds") + max_request_size: int = Field(default=10 * 1024 * 1024, description="Max request size in bytes") diff --git a/aitbc/database.py b/aitbc/database.py new file mode 100644 index 00000000..d2d3a4c9 --- /dev/null +++ b/aitbc/database.py @@ -0,0 +1,261 @@ +""" +AITBC Database Utilities +Database connection and query utilities for AITBC applications +""" + +import sqlite3 +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from contextlib import contextmanager +from .exceptions import DatabaseError + + +class DatabaseConnection: + """ + Base database connection class for AITBC applications. + Provides common database operations with error handling. + """ + + def __init__(self, db_path: Path, timeout: int = 30): + """ + Initialize database connection. + + Args: + db_path: Path to database file + timeout: Connection timeout in seconds + """ + self.db_path = db_path + self.timeout = timeout + self._connection = None + + def connect(self) -> sqlite3.Connection: + """ + Establish database connection. + + Returns: + SQLite connection object + + Raises: + DatabaseError: If connection fails + """ + try: + self._connection = sqlite3.connect( + self.db_path, + timeout=self.timeout + ) + self._connection.row_factory = sqlite3.Row + return self._connection + except sqlite3.Error as e: + raise DatabaseError(f"Failed to connect to database: {e}") + + def close(self) -> None: + """Close database connection.""" + if self._connection: + self._connection.close() + self._connection = None + + @contextmanager + def cursor(self): + """ + Context manager for database cursor. + + Yields: + Database cursor + """ + if not self._connection: + self.connect() + cursor = self._connection.cursor() + try: + yield cursor + self._connection.commit() + except Exception as e: + self._connection.rollback() + raise DatabaseError(f"Database operation failed: {e}") + finally: + cursor.close() + + async def execute( + self, + query: str, + params: Optional[Tuple[Any, ...]] = None + ) -> sqlite3.Cursor: + """ + Execute a SQL query. + + Args: + query: SQL query string + params: Query parameters + + Returns: + Cursor object + + Raises: + DatabaseError: If query fails + """ + try: + with self.cursor() as cursor: + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + return cursor + except sqlite3.Error as e: + raise DatabaseError(f"Query execution failed: {e}") + + async def fetch_one( + self, + query: str, + params: Optional[Tuple[Any, ...]] = None + ) -> Optional[Dict[str, Any]]: + """ + Fetch a single row from query. + + Args: + query: SQL query string + params: Query parameters + + Returns: + Row as dictionary or None + """ + with self.cursor() as cursor: + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + row = cursor.fetchone() + return dict(row) if row else None + + async def fetch_all( + self, + query: str, + params: Optional[Tuple[Any, ...]] = None + ) -> List[Dict[str, Any]]: + """ + Fetch all rows from query. + + Args: + query: SQL query string + params: Query parameters + + Returns: + List of rows as dictionaries + """ + with self.cursor() as cursor: + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + rows = cursor.fetchall() + return [dict(row) for row in rows] + + async def execute_many( + self, + query: str, + params_list: List[Tuple[Any, ...]] + ) -> None: + """ + Execute query with multiple parameter sets. + + Args: + query: SQL query string + params_list: List of parameter tuples + + Raises: + DatabaseError: If query fails + """ + try: + with self.cursor() as cursor: + cursor.executemany(query, params_list) + except sqlite3.Error as e: + raise DatabaseError(f"Bulk execution failed: {e}") + + def __enter__(self): + """Context manager entry.""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.close() + + +def get_database_connection( + db_path: Path, + timeout: int = 30 +) -> DatabaseConnection: + """ + Get a database connection for a given path. + + Args: + db_path: Path to database file + timeout: Connection timeout in seconds + + Returns: + DatabaseConnection instance + """ + return DatabaseConnection(db_path, timeout) + + +def ensure_database(db_path: Path) -> Path: + """ + Ensure database file and parent directory exist. + + Args: + db_path: Path to database file + + Returns: + Database path + """ + db_path.parent.mkdir(parents=True, exist_ok=True) + return db_path + + +def vacuum_database(db_path: Path) -> None: + """ + Vacuum database to optimize storage. + + Args: + db_path: Path to database file + + Raises: + DatabaseError: If vacuum fails + """ + try: + with DatabaseConnection(db_path) as db: + db.execute("VACUUM") + except sqlite3.Error as e: + raise DatabaseError(f"Database vacuum failed: {e}") + + +def get_table_info(db_path: Path, table_name: str) -> List[Dict[str, Any]]: + """ + Get information about a table's columns. + + Args: + db_path: Path to database file + table_name: Name of table + + Returns: + List of column information dictionaries + """ + with DatabaseConnection(db_path) as db: + return db.fetch_all(f"PRAGMA table_info({table_name})") + + +def table_exists(db_path: Path, table_name: str) -> bool: + """ + Check if a table exists in the database. + + Args: + db_path: Path to database file + table_name: Name of table + + Returns: + True if table exists + """ + with DatabaseConnection(db_path) as db: + result = db.fetch_one( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table_name,) + ) + return result is not None diff --git a/aitbc/decorators.py b/aitbc/decorators.py new file mode 100644 index 00000000..71c880c9 --- /dev/null +++ b/aitbc/decorators.py @@ -0,0 +1,185 @@ +""" +AITBC Common Decorators +Reusable decorators for common patterns in AITBC applications +""" + +import time +import functools +from typing import Callable, Type, Any +from .exceptions import AITBCError + + +def retry( + max_attempts: int = 3, + delay: float = 1.0, + backoff: float = 2.0, + exceptions: tuple[Type[Exception], ...] = (Exception,), + on_failure: Callable[[Exception], Any] = None +): + """ + Retry a function with exponential backoff. + + Args: + max_attempts: Maximum number of retry attempts + delay: Initial delay between retries in seconds + backoff: Multiplier for delay after each retry + exceptions: Tuple of exception types to catch + on_failure: Optional callback function called on final failure + + Returns: + Decorated function that retries on failure + """ + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs): + last_exception = None + current_delay = delay + + for attempt in range(max_attempts): + try: + return func(*args, **kwargs) + except exceptions as e: + last_exception = e + if attempt < max_attempts - 1: + time.sleep(current_delay) + current_delay *= backoff + else: + if on_failure: + on_failure(e) + raise + + raise last_exception if last_exception else AITBCError("Retry failed") + + return wrapper + return decorator + + +def timing(func: Callable) -> Callable: + """ + Decorator to measure and log function execution time. + + Args: + func: Function to time + + Returns: + Decorated function that prints execution time + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + execution_time = end_time - start_time + print(f"{func.__name__} executed in {execution_time:.4f} seconds") + return result + + return wrapper + + +def cache_result(ttl: int = 300): + """ + Simple in-memory cache decorator with TTL. + + Args: + ttl: Time to live for cached results in seconds + + Returns: + Decorated function with caching + """ + cache = {} + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Create cache key from function name and arguments + cache_key = (func.__name__, args, frozenset(kwargs.items())) + current_time = time.time() + + # Check if cached result exists and is not expired + if cache_key in cache: + result, timestamp = cache[cache_key] + if current_time - timestamp < ttl: + return result + + # Call function and cache result + result = func(*args, **kwargs) + cache[cache_key] = (result, current_time) + return result + + return wrapper + return decorator + + +def validate_args(*validators: Callable): + """ + Decorator to validate function arguments. + + Args: + *validators: Validation functions that raise ValueError on invalid input + + Returns: + Decorated function with argument validation + """ + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs): + for validator in validators: + validator(*args, **kwargs) + return func(*args, **kwargs) + + return wrapper + return decorator + + +def handle_exceptions( + default_return: Any = None, + log_errors: bool = True, + raise_on: tuple[Type[Exception], ...] = () +): + """ + Decorator to handle exceptions gracefully. + + Args: + default_return: Value to return on exception + log_errors: Whether to log errors + raise_on: Tuple of exception types to still raise + + Returns: + Decorated function with exception handling + """ + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except raise_on as e: + raise + except Exception as e: + if log_errors: + print(f"Error in {func.__name__}: {e}") + return default_return + + return wrapper + return decorator + + +def async_timing(func: Callable) -> Callable: + """ + Decorator to measure async function execution time. + + Args: + func: Async function to time + + Returns: + Decorated async function that prints execution time + """ + @functools.wraps(func) + async def wrapper(*args, **kwargs): + start_time = time.time() + result = await func(*args, **kwargs) + end_time = time.time() + execution_time = end_time - start_time + print(f"{func.__name__} executed in {execution_time:.4f} seconds") + return result + + return wrapper diff --git a/aitbc/monitoring.py b/aitbc/monitoring.py new file mode 100644 index 00000000..cd4abab9 --- /dev/null +++ b/aitbc/monitoring.py @@ -0,0 +1,259 @@ +""" +AITBC Monitoring and Metrics Utilities +Monitoring and metrics collection for AITBC applications +""" + +import time +from typing import Dict, Any, Optional +from collections import defaultdict +from datetime import datetime, timedelta + + +class MetricsCollector: + """ + Simple in-memory metrics collector for AITBC applications. + Tracks counters, timers, and gauges. + """ + + def __init__(self): + """Initialize metrics collector.""" + self.counters: Dict[str, int] = defaultdict(int) + self.timers: Dict[str, list] = defaultdict(list) + self.gauges: Dict[str, float] = {} + self.timestamps: Dict[str, datetime] = {} + + def increment(self, metric: str, value: int = 1) -> None: + """ + Increment a counter metric. + + Args: + metric: Metric name + value: Value to increment by + """ + self.counters[metric] += value + self.timestamps[metric] = datetime.now() + + def decrement(self, metric: str, value: int = 1) -> None: + """ + Decrement a counter metric. + + Args: + metric: Metric name + value: Value to decrement by + """ + self.counters[metric] -= value + self.timestamps[metric] = datetime.now() + + def timing(self, metric: str, duration: float) -> None: + """ + Record a timing metric. + + Args: + metric: Metric name + duration: Duration in seconds + """ + self.timers[metric].append(duration) + self.timestamps[metric] = datetime.now() + + def set_gauge(self, metric: str, value: float) -> None: + """ + Set a gauge metric. + + Args: + metric: Metric name + value: Gauge value + """ + self.gauges[metric] = value + self.timestamps[metric] = datetime.now() + + def get_counter(self, metric: str) -> int: + """ + Get counter value. + + Args: + metric: Metric name + + Returns: + Counter value + """ + return self.counters.get(metric, 0) + + def get_timer_stats(self, metric: str) -> Dict[str, float]: + """ + Get timer statistics for a metric. + + Args: + metric: Metric name + + Returns: + Dictionary with min, max, avg, count + """ + timings = self.timers.get(metric, []) + if not timings: + return {"min": 0, "max": 0, "avg": 0, "count": 0} + + return { + "min": min(timings), + "max": max(timings), + "avg": sum(timings) / len(timings), + "count": len(timings) + } + + def get_gauge(self, metric: str) -> Optional[float]: + """ + Get gauge value. + + Args: + metric: Metric name + + Returns: + Gauge value or None + """ + return self.gauges.get(metric) + + def get_all_metrics(self) -> Dict[str, Any]: + """ + Get all collected metrics. + + Returns: + Dictionary of all metrics + """ + return { + "counters": dict(self.counters), + "timers": {k: self.get_timer_stats(k) for k in self.timers}, + "gauges": dict(self.gauges), + "timestamps": {k: v.isoformat() for k, v in self.timestamps.items()} + } + + def reset_metric(self, metric: str) -> None: + """ + Reset a specific metric. + + Args: + metric: Metric name + """ + if metric in self.counters: + del self.counters[metric] + if metric in self.timers: + del self.timers[metric] + if metric in self.gauges: + del self.gauges[metric] + if metric in self.timestamps: + del self.timestamps[metric] + + def reset_all(self) -> None: + """Reset all metrics.""" + self.counters.clear() + self.timers.clear() + self.gauges.clear() + self.timestamps.clear() + + +class PerformanceTimer: + """ + Context manager for timing operations. + """ + + def __init__(self, collector: MetricsCollector, metric: str): + """ + Initialize timer. + + Args: + collector: MetricsCollector instance + metric: Metric name + """ + self.collector = collector + self.metric = metric + self.start_time = None + + def __enter__(self): + """Start timing.""" + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop timing and record metric.""" + if self.start_time: + duration = time.time() - self.start_time + self.collector.timing(self.metric, duration) + + +class HealthChecker: + """ + Health check utilities for AITBC applications. + """ + + def __init__(self): + """Initialize health checker.""" + self.checks: Dict[str, Any] = {} + self.last_check: Optional[datetime] = None + + def add_check(self, name: str, check_func: callable) -> None: + """ + Add a health check. + + Args: + name: Check name + check_func: Function that returns (status, message) + """ + self.checks[name] = check_func + + def run_check(self, name: str) -> Dict[str, Any]: + """ + Run a specific health check. + + Args: + name: Check name + + Returns: + Check result with status and message + """ + if name not in self.checks: + return {"status": "unknown", "message": f"Check '{name}' not found"} + + try: + status, message = self.checks[name]() + return {"status": status, "message": message} + except Exception as e: + return {"status": "error", "message": str(e)} + + def run_all_checks(self) -> Dict[str, Any]: + """ + Run all health checks. + + Returns: + Dictionary of all check results + """ + self.last_check = datetime.now() + results = {} + + for name in self.checks: + results[name] = self.run_check(name) + + return { + "checks": results, + "overall_status": self._get_overall_status(results), + "timestamp": self.last_check.isoformat() + } + + def _get_overall_status(self, results: Dict[str, Any]) -> str: + """ + Determine overall health status. + + Args: + results: Check results + + Returns: + Overall status (healthy, degraded, unhealthy) + """ + if not results: + return "unknown" + + statuses = [r.get("status", "unknown") for r in results.values()] + + if all(s == "healthy" for s in statuses): + return "healthy" + elif any(s == "unhealthy" for s in statuses): + return "unhealthy" + else: + return "degraded" diff --git a/aitbc/validation.py b/aitbc/validation.py new file mode 100644 index 00000000..fde19f89 --- /dev/null +++ b/aitbc/validation.py @@ -0,0 +1,245 @@ +""" +AITBC Validation Utilities +Common validators for AITBC applications +""" + +import re +from typing import Any, Optional +from .exceptions import ValidationError + + +def validate_address(address: str) -> bool: + """ + Validate an AITBC blockchain address. + + Args: + address: Address string to validate + + Returns: + True if address is valid format + + Raises: + ValidationError: If address format is invalid + """ + if not address: + raise ValidationError("Address cannot be empty") + + # AITBC addresses typically start with 'ait' and are alphanumeric + pattern = r'^ait[a-z0-9]{40}$' + if not re.match(pattern, address): + raise ValidationError(f"Invalid address format: {address}") + + return True + + +def validate_hash(hash_str: str) -> bool: + """ + Validate a hash string (hex string of expected length). + + Args: + hash_str: Hash string to validate + + Returns: + True if hash is valid format + + Raises: + ValidationError: If hash format is invalid + """ + if not hash_str: + raise ValidationError("Hash cannot be empty") + + # Hashes are typically 64-character hex strings + pattern = r'^[a-f0-9]{64}$' + if not re.match(pattern, hash_str): + raise ValidationError(f"Invalid hash format: {hash_str}") + + return True + + +def validate_url(url: str) -> bool: + """ + Validate a URL string. + + Args: + url: URL string to validate + + Returns: + True if URL is valid format + + Raises: + ValidationError: If URL format is invalid + """ + if not url: + raise ValidationError("URL cannot be empty") + + pattern = r'^https?://[^\s/$.?#].[^\s]*$' + if not re.match(pattern, url): + raise ValidationError(f"Invalid URL format: {url}") + + return True + + +def validate_port(port: int) -> bool: + """ + Validate a port number. + + Args: + port: Port number to validate + + Returns: + True if port is valid + + Raises: + ValidationError: If port is invalid + """ + if not isinstance(port, int): + raise ValidationError(f"Port must be an integer, got {type(port)}") + + if port < 1 or port > 65535: + raise ValidationError(f"Port must be between 1 and 65535, got {port}") + + return True + + +def validate_email(email: str) -> bool: + """ + Validate an email address. + + Args: + email: Email address to validate + + Returns: + True if email is valid format + + Raises: + ValidationError: If email format is invalid + """ + if not email: + raise ValidationError("Email cannot be empty") + + pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' + if not re.match(pattern, email): + raise ValidationError(f"Invalid email format: {email}") + + return True + + +def validate_non_empty(value: Any, field_name: str = "value") -> bool: + """ + Validate that a value is not empty. + + Args: + value: Value to validate + field_name: Name of the field for error message + + Returns: + True if value is not empty + + Raises: + ValidationError: If value is empty + """ + if value is None: + raise ValidationError(f"{field_name} cannot be None") + + if isinstance(value, str) and not value.strip(): + raise ValidationError(f"{field_name} cannot be empty string") + + if isinstance(value, (list, dict)) and len(value) == 0: + raise ValidationError(f"{field_name} cannot be empty") + + return True + + +def validate_positive_number(value: Any, field_name: str = "value") -> bool: + """ + Validate that a value is a positive number. + + Args: + value: Value to validate + field_name: Name of the field for error message + + Returns: + True if value is positive + + Raises: + ValidationError: If value is not positive + """ + if not isinstance(value, (int, float)): + raise ValidationError(f"{field_name} must be a number, got {type(value)}") + + if value <= 0: + raise ValidationError(f"{field_name} must be positive, got {value}") + + return True + + +def validate_range(value: Any, min_val: float, max_val: float, field_name: str = "value") -> bool: + """ + Validate that a value is within a specified range. + + Args: + value: Value to validate + min_val: Minimum allowed value + max_val: Maximum allowed value + field_name: Name of the field for error message + + Returns: + True if value is within range + + Raises: + ValidationError: If value is outside range + """ + if not isinstance(value, (int, float)): + raise ValidationError(f"{field_name} must be a number, got {type(value)}") + + if value < min_val or value > max_val: + raise ValidationError(f"{field_name} must be between {min_val} and {max_val}, got {value}") + + return True + + +def validate_chain_id(chain_id: str) -> bool: + """ + Validate a chain ID. + + Args: + chain_id: Chain ID to validate + + Returns: + True if chain ID is valid + + Raises: + ValidationError: If chain ID is invalid + """ + if not chain_id: + raise ValidationError("Chain ID cannot be empty") + + # Chain IDs are typically alphanumeric with hyphens + pattern = r'^[a-z0-9\-]+$' + if not re.match(pattern, chain_id): + raise ValidationError(f"Invalid chain ID format: {chain_id}") + + return True + + +def validate_uuid(uuid_str: str) -> bool: + """ + Validate a UUID string. + + Args: + uuid_str: UUID string to validate + + Returns: + True if UUID is valid format + + Raises: + ValidationError: If UUID format is invalid + """ + if not uuid_str: + raise ValidationError("UUID cannot be empty") + + pattern = r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' + if not re.match(pattern, uuid_str.lower()): + raise ValidationError(f"Invalid UUID format: {uuid_str}") + + return True diff --git a/apps/blockchain-node/src/aitbc_chain/main.py b/apps/blockchain-node/src/aitbc_chain/main.py index 00f341bf..9e99c64d 100755 --- a/apps/blockchain-node/src/aitbc_chain/main.py +++ b/apps/blockchain-node/src/aitbc_chain/main.py @@ -102,6 +102,7 @@ class BlockchainNode: async def process_txs(): from .mempool import get_mempool + from .rpc.router import _normalize_transaction_data mempool = get_mempool() while True: try: @@ -110,6 +111,8 @@ class BlockchainNode: import json tx_data = json.loads(tx_data) chain_id = tx_data.get("chain_id", settings.chain_id) + # Normalize transaction data to ensure type field is preserved + tx_data = _normalize_transaction_data(tx_data, chain_id) mempool.add(tx_data, chain_id=chain_id) except Exception as exc: logger.error(f"Error processing transaction from gossip: {exc}") diff --git a/apps/blockchain-node/src/aitbc_chain/rpc/router.py b/apps/blockchain-node/src/aitbc_chain/rpc/router.py index d515b1b2..4262be77 100755 --- a/apps/blockchain-node/src/aitbc_chain/rpc/router.py +++ b/apps/blockchain-node/src/aitbc_chain/rpc/router.py @@ -308,7 +308,12 @@ async def submit_transaction(tx_data: TransactionRequest) -> Dict[str, Any]: "signature": tx_data.sig } + _logger.info(f"[ROUTER] Submitting transaction: type={tx_data.type}, normalized_type={tx_data_dict.get('type')}") + tx_data_dict = _normalize_transaction_data(tx_data_dict, chain_id) + + _logger.info(f"[ROUTER] After normalization: type={tx_data_dict.get('type')}, keys={list(tx_data_dict.keys())}") + _validate_transaction_admission(tx_data_dict, mempool) tx_hash = mempool.add(tx_data_dict, chain_id=chain_id)