refactor: improve error handling and remove hardcoded credentials

- Changed bare except clauses to specific exception types in web3_utils.py, testing.py, messages.py, and message_storage.py
- Replaced print() calls with logger in testing.py, agent_discovery.py, compliance_agent.py, coordinator.py, trading_agent.py, keys.py, escrow.py, persistent_spending_tracker.py, sync_cli.py, and client.py
- Added logger initialization using get_logger(__name__) in compliance_agent.py, coordinator.py, trading_agent.py, keys.py, escrow.py, persistent_spending_tracker.py, and client.py
- Removed hardcoded secret
This commit is contained in:
aitbc
2026-05-12 17:01:57 +02:00
parent 9133609603
commit 745f791eda
279 changed files with 12284 additions and 5061 deletions

40
examples/stubs/README.md Normal file
View File

@@ -0,0 +1,40 @@
# Stub Services
This directory contains stub and placeholder services that are not yet fully implemented or are minimal implementations.
## Services in this Directory
The following services have <10 files and are considered stubs or placeholders:
- **hermes-service** (4 files) - Hermes agent communication service
- **monitor** (7 files) - Monitoring stub
- **monitoring-service** (4 files) - Monitoring service stub
- **plugin-service** (4 files) - Plugin service stub
- **ai-service** (8 files) - AI service stub
- **compliance-service** (9 files) - Compliance checking stub
- **exchange-integration** (9 files) - Exchange integration stub
- **global-ai-agents** (9 files) - Global AI agents stub
- **global-infrastructure** (9 files) - Global infrastructure stub
- **multi-region-load-balancer** (9 files) - Multi-region load balancer stub
- **plugin-analytics** (9 files) - Plugin analytics stub
- **plugin-marketplace** (9 files) - Plugin marketplace stub
- **plugin-registry** (9 files) - Plugin registry stub
- **plugin-security** (9 files) - Plugin security stub
- **simple-explorer** (9 files) - Simple blockchain explorer stub
- **trading-engine** (9 files) - Trading engine stub
## Purpose
These services are placeholders for future functionality. They may be:
- Minimal implementations for testing
- Skeletons for future development
- Experimental features not yet production-ready
## Active Services
Active services with full implementations remain in the parent `apps/` directory:
- blockchain-node, coordinator-api, exchange, marketplace, wallet, etc.
## Future Work
As stub services are fully implemented, they should be moved from this directory to the main `apps/` directory.

1718
examples/stubs/ai-service/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,28 @@
[tool.poetry]
name = "ai-service"
version = "0.1.0"
description = "AITBC AI Service for job operations"
authors = ["AITBC Team"]
[tool.poetry.dependencies]
python = "^3.13"
fastapi = ">=0.115.6"
uvicorn = {extras = ["standard"], version = "^0.32.0"}
sqlmodel = "^0.0.37"
sqlalchemy = "^2.0.25"
pydantic = "^2.6.0"
pydantic-settings = "^2.1.0"
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
passlib = {extras = ["bcrypt"], version = "^1.7.4"}
httpx = ">=0.28.1"
asyncpg = ">=0.30.0"
[tool.poetry.group.dev.dependencies]
pytest = ">=9.0.3"
pytest-asyncio = ">=1.3.0"
black = ">=26.3.1"
ruff = "^0.1.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1,68 @@
"""Domain models for AI job operations."""
from __future__ import annotations
from datetime import datetime, timezone
from enum import StrEnum
from uuid import uuid4
from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel
class JobState(StrEnum):
"""Job execution states."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELED = "canceled"
EXPIRED = "expired"
class Job(SQLModel, table=True):
"""AI job model."""
__tablename__ = "jobs"
__table_args__ = {"extend_existing": True}
id: str = Field(default_factory=lambda: f"job_{uuid4().hex[:12]}", primary_key=True)
client_id: str = Field(index=True)
task_type: str = Field(index=True)
task_data: dict = Field(default_factory=dict, sa_column=Column(JSON, nullable=False))
state: JobState = Field(default=JobState.PENDING, index=True)
result: dict | None = Field(default=None, sa_column=Column(JSON, nullable=True))
error: str | None = Field(default=None)
# Payment information
payment_id: str | None = Field(default=None, index=True)
payment_amount: float = Field(default=0.0)
payment_status: str = Field(default="none", index=True)
# Timestamps
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), nullable=False, index=True)
requested_at: datetime | None = Field(default=None)
started_at: datetime | None = Field(default=None)
completed_at: datetime | None = Field(default=None)
expires_at: datetime | None = Field(default=None)
# Metadata
priority: int = Field(default=0)
assigned_miner_id: str | None = Field(default=None, index=True)
receipt: dict | None = Field(default=None, sa_column=Column(JSON, nullable=True))
receipt_id: str | None = Field(default=None)
class JobReceipt(SQLModel, table=True):
"""Job receipts for verification."""
__tablename__ = "job_receipts"
__table_args__ = {"extend_existing": True}
id: str = Field(default_factory=lambda: f"rcpt_{uuid4().hex[:12]}", primary_key=True)
job_id: str = Field(index=True)
miner_id: str = Field(index=True)
result: dict = Field(default_factory=dict, sa_column=Column(JSON, nullable=False))
metrics: dict = Field(default_factory=dict, sa_column=Column(JSON, nullable=False))
signature: str = Field(default="")
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), nullable=False, index=True)

View File

@@ -0,0 +1,407 @@
"""AI Service for job operations."""
from __future__ import annotations
import os
import logging
from datetime import datetime, timezone
from typing import Annotated
from fastapi import FastAPI, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import Field, SQLModel, select
from .storage import get_session
from .domain.jobs import Job, JobState
logger = logging.getLogger(__name__)
app = FastAPI(
title="AITBC AI Service",
description="AI job operations service",
version="1.0.0"
)
@app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "healthy", "service": "ai-service"}
@app.get("/")
async def root():
"""Root endpoint."""
return {
"service": "AITBC AI Service",
"version": "1.0.0",
"status": "operational"
}
async def get_session_dep():
"""Dependency for database session."""
async with get_session() as session:
yield session
# Request/Response models
class JobCreate(SQLModel):
task_type: str
task_data: dict = Field(default_factory=dict)
payment_amount: float = 0.0
payment_currency: str = "aitbc_token"
priority: int = 0
class JobView(SQLModel):
id: str
client_id: str
task_type: str
state: str
created_at: datetime
started_at: datetime | None = None
completed_at: datetime | None = None
result: dict | None = None
error: str | None = None
payment_status: str = "none"
class JobResult(SQLModel):
id: str
result: dict | None = None
error: str | None = None
completed_at: datetime | None = None
receipt: dict | None = None
@app.post("/jobs", response_model=JobView, status_code=status.HTTP_201_CREATED)
async def submit_job(
session: Annotated[AsyncSession, Depends(get_session_dep)],
req: JobCreate,
client_id: str = "default_client",
):
"""Submit a job for execution."""
try:
job = Job(
client_id=client_id,
task_type=req.task_type,
task_data=req.task_data,
payment_amount=req.payment_amount,
priority=req.priority,
state=JobState.PENDING,
created_at=datetime.now(timezone.utc)
)
session.add(job)
await session.commit()
await session.refresh(job)
return JobView(
id=job.id,
client_id=job.client_id,
task_type=job.task_type,
state=job.state,
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at,
result=job.result,
error=job.error,
payment_status=job.payment_status
)
except Exception as e:
logger.error(f"Submit job error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/jobs/{job_id}", response_model=JobView)
async def get_job(
session: Annotated[AsyncSession, Depends(get_session_dep)],
job_id: str,
client_id: str = "default_client",
):
"""Get job status."""
try:
result = await session.execute(
select(Job).where(Job.id == job_id, Job.client_id == client_id)
)
job = result.scalar_one_or_none()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
return JobView(
id=job.id,
client_id=job.client_id,
task_type=job.task_type,
state=job.state,
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at,
result=job.result,
error=job.error,
payment_status=job.payment_status
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Get job error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/jobs/{job_id}/result", response_model=JobResult)
async def get_job_result(
session: Annotated[AsyncSession, Depends(get_session_dep)],
job_id: str,
client_id: str = "default_client",
):
"""Get job result."""
try:
result = await session.execute(
select(Job).where(Job.id == job_id, Job.client_id == client_id)
)
job = result.scalar_one_or_none()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.state not in {JobState.COMPLETED, JobState.FAILED, JobState.CANCELED, JobState.EXPIRED}:
raise HTTPException(status_code=425, detail="Job not ready")
return JobResult(
id=job.id,
result=job.result,
error=job.error,
completed_at=job.completed_at,
receipt=job.receipt
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Get job result error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/jobs/{job_id}/cancel", response_model=JobView)
async def cancel_job(
session: Annotated[AsyncSession, Depends(get_session_dep)],
job_id: str,
client_id: str = "default_client",
):
"""Cancel a job."""
try:
result = await session.execute(
select(Job).where(Job.id == job_id, Job.client_id == client_id)
)
job = result.scalar_one_or_none()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.state in {JobState.COMPLETED, JobState.FAILED, JobState.CANCELED, JobState.EXPIRED}:
raise HTTPException(status_code=400, detail="Job already completed")
job.state = JobState.CANCELED
job.completed_at = datetime.now(timezone.utc)
await session.commit()
await session.refresh(job)
return JobView(
id=job.id,
client_id=job.client_id,
task_type=job.task_type,
state=job.state,
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at,
result=job.result,
error=job.error,
payment_status=job.payment_status
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Cancel job error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/jobs")
async def list_jobs(
session: Annotated[AsyncSession, Depends(get_session_dep)],
client_id: str = "default_client",
limit: int = 10,
state: str | None = None,
):
"""List jobs with filtering."""
try:
query = select(Job).where(Job.client_id == client_id)
if state:
query = query.where(Job.state == state)
query = query.order_by(Job.created_at.desc()).limit(limit)
result = await session.execute(query)
jobs = result.scalars().all()
return {
"jobs": [
JobView(
id=job.id,
client_id=job.client_id,
task_type=job.task_type,
state=job.state,
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at,
result=job.result,
error=job.error,
payment_status=job.payment_status
)
for job in jobs
],
"limit": limit,
"total": len(jobs)
}
except Exception as e:
logger.error(f"List jobs error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/multimodal/process")
async def process_multimodal(request: dict[str, Any]) -> dict[str, Any]:
"""Process multimodal AI requests (text, image, audio, video)"""
return {
"task_id": "multimodal_123",
"modality": request.get("modality", "text"),
"status": "processing",
"result": "multimodal processing initiated"
}
@app.post("/multimodal/benchmark")
async def benchmark_multimodal(request: dict[str, Any]) -> dict[str, Any]:
"""Benchmark multimodal AI performance"""
return {
"benchmark_id": "bench_456",
"modality": request.get("modality", "text"),
"performance_score": 95.5,
"latency_ms": 150,
"throughput": "high"
}
@app.get("/multimodal/agents")
async def list_multimodal_agents() -> dict[str, Any]:
"""List available multimodal AI agents"""
return {
"agents": [
{"id": "agent_1", "name": "Text-Image Agent", "capabilities": ["text", "image"]},
{"id": "agent_2", "name": "Audio-Video Agent", "capabilities": ["audio", "video"]},
],
"total": 2
}
@app.get("/multimodal/health")
async def multimodal_health() -> dict[str, Any]:
"""Multi-Modal Agent Service Health"""
return {
"status": "healthy",
"service": "multimodal-agent",
"timestamp": datetime.now(timezone.utc).isoformat(),
"capabilities": {
"text_processing": True,
"image_processing": True,
"audio_processing": True,
"video_processing": True,
"tabular_processing": True,
"graph_processing": True,
},
"performance": {
"text_processing_time": "0.02s",
"image_processing_time": "0.15s",
"audio_processing_time": "0.22s",
"video_processing_time": "0.35s",
"tabular_processing_time": "0.05s",
"graph_processing_time": "0.08s",
"average_accuracy": "94%",
}
}
@app.get("/multimodal/health/deep")
async def multimodal_deep_health() -> dict[str, Any]:
"""Deep Multi-Modal Service Health with modality tests"""
return {
"status": "healthy",
"service": "multimodal-agent",
"timestamp": datetime.now(timezone.utc).isoformat(),
"modality_tests": {
"text": {"status": "pass", "processing_time": "0.02s", "accuracy": "92%"},
"image": {"status": "pass", "processing_time": "0.15s", "accuracy": "87%"},
"audio": {"status": "pass", "processing_time": "0.22s", "accuracy": "89%"},
"video": {"status": "pass", "processing_time": "0.35s", "accuracy": "85%"},
},
"overall_health": "pass"
}
@app.post("/optimization/tune")
async def tune_optimization(request: dict[str, Any]) -> dict[str, Any]:
"""Tune AI model optimization parameters"""
return {
"tuning_id": "tune_789",
"model": request.get("model", "default"),
"parameters": {"learning_rate": 0.001, "batch_size": 32},
"status": "tuned"
}
@app.post("/optimization/predict")
async def predict_optimization(request: dict[str, Any]) -> dict[str, Any]:
"""Predict optimal model performance"""
return {
"prediction_id": "pred_101",
"model": request.get("model", "default"),
"expected_performance": "high",
"estimated_accuracy": 95.5
}
@app.get("/optimization/agents")
async def list_optimization_agents() -> dict[str, Any]:
"""List available optimization agents"""
return {
"agents": [
{"id": "opt_1", "name": "Gradient Descent Optimizer", "type": "gradient"},
{"id": "opt_2", "name": "Genetic Algorithm", "type": "evolutionary"},
],
"total": 2
}
@app.get("/optimization/health")
async def optimization_health() -> dict[str, Any]:
"""Optimization Service Health"""
return {
"status": "healthy",
"service": "modality-optimization",
"timestamp": datetime.now(timezone.utc).isoformat(),
"capabilities": {
"text_optimization": True,
"image_optimization": True,
"audio_optimization": True,
"video_optimization": True,
},
"performance": {
"optimization_speedup": "150x average",
"memory_reduction": "60% average",
"accuracy_retention": "95% average",
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8106)

View File

@@ -0,0 +1,26 @@
"""Database storage configuration for AI Service."""
from __future__ import annotations
import os
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
DATABASE_URL = os.getenv(
"AI_SERVICE_DATABASE_URL",
"postgresql+asyncpg://aitbc_ai:password@localhost:5432/aitbc_ai"
)
engine = create_async_engine(DATABASE_URL, echo=False)
AsyncSessionLocal = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
@asynccontextmanager
async def get_session():
"""Async context manager for database sessions."""
async with AsyncSessionLocal() as session:
yield session

View File

@@ -0,0 +1,434 @@
"""
Production Compliance Service for AITBC
Handles KYC/AML, regulatory compliance, and monitoring
"""
import os
import asyncio
import json
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Dict, Any, List, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from aitbc import get_logger
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
logger.info("Starting AITBC Compliance Service")
# Start background compliance checks
asyncio.create_task(periodic_compliance_checks())
yield
# Shutdown
logger.info("Shutting down AITBC Compliance Service")
app = FastAPI(
title="AITBC Compliance Service",
description="Regulatory compliance and monitoring for AITBC operations",
version="1.0.0",
lifespan=lifespan
)
# Data models
class KYCRequest(BaseModel):
user_id: str
name: str
email: str
document_type: str
document_number: str
address: Dict[str, str]
class ComplianceReport(BaseModel):
report_type: str
description: str
severity: str # low, medium, high, critical
details: Dict[str, Any]
class TransactionMonitoring(BaseModel):
transaction_id: str
user_id: str
amount: float
currency: str
counterparty: str
timestamp: datetime
# In-memory storage (in production, use database)
kyc_records: Dict[str, Dict] = {}
compliance_reports: Dict[str, Dict] = {}
suspicious_transactions: Dict[str, Dict] = {}
compliance_rules: Dict[str, Dict] = {}
risk_scores: Dict[str, Dict] = {}
@app.get("/")
async def root():
return {
"service": "AITBC Compliance Service",
"status": "running",
"timestamp": datetime.now(timezone.utc).isoformat(),
"version": "1.0.0"
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"kyc_records": len(kyc_records),
"compliance_reports": len(compliance_reports),
"suspicious_transactions": len(suspicious_transactions),
"active_rules": len(compliance_rules)
}
@app.post("/api/v1/kyc/submit")
async def submit_kyc(kyc_request: KYCRequest):
"""Submit KYC verification request"""
if kyc_request.user_id in kyc_records:
raise HTTPException(status_code=400, detail="KYC already submitted for this user")
# Create KYC record
kyc_record = {
"user_id": kyc_request.user_id,
"name": kyc_request.name,
"email": kyc_request.email,
"document_type": kyc_request.document_type,
"document_number": kyc_request.document_number,
"address": kyc_request.address,
"status": "pending",
"submitted_at": datetime.now(timezone.utc).isoformat(),
"reviewed_at": None,
"approved_at": None,
"risk_score": "medium",
"notes": []
}
kyc_records[kyc_request.user_id] = kyc_record
# Simulate KYC verification process
await asyncio.sleep(2) # Simulate verification delay
# Auto-approve for demo (in production, this would involve actual verification)
kyc_record["status"] = "approved"
kyc_record["reviewed_at"] = datetime.now(timezone.utc).isoformat()
kyc_record["approved_at"] = datetime.now(timezone.utc).isoformat()
kyc_record["risk_score"] = "low"
logger.info(f"KYC approved for user: {kyc_request.user_id}")
return {
"user_id": kyc_request.user_id,
"status": kyc_record["status"],
"risk_score": kyc_record["risk_score"],
"approved_at": kyc_record["approved_at"]
}
@app.get("/api/v1/kyc/{user_id}")
async def get_kyc_status(user_id: str):
"""Get KYC status for a user"""
if user_id not in kyc_records:
raise HTTPException(status_code=404, detail="KYC record not found")
return kyc_records[user_id]
@app.get("/api/v1/kyc")
async def list_kyc_records():
"""List all KYC records"""
return {
"kyc_records": list(kyc_records.values()),
"total_records": len(kyc_records),
"approved": len([r for r in kyc_records.values() if r["status"] == "approved"]),
"pending": len([r for r in kyc_records.values() if r["status"] == "pending"]),
"rejected": len([r for r in kyc_records.values() if r["status"] == "rejected"])
}
@app.post("/api/v1/compliance/report")
async def create_compliance_report(report: ComplianceReport):
"""Create a compliance report"""
report_id = f"report_{int(datetime.now(timezone.utc).timestamp())}"
compliance_record = {
"report_id": report_id,
"report_type": report.report_type,
"description": report.description,
"severity": report.severity,
"details": report.details,
"status": "open",
"created_at": datetime.now(timezone.utc).isoformat(),
"assigned_to": None,
"resolved_at": None,
"resolution": None
}
compliance_reports[report_id] = compliance_record
logger.info(f"Compliance report created: {report_id} - {report.report_type}")
return {
"report_id": report_id,
"status": "created",
"severity": report.severity,
"created_at": compliance_record["created_at"]
}
@app.get("/api/v1/compliance/reports")
async def list_compliance_reports():
"""List all compliance reports"""
return {
"reports": list(compliance_reports.values()),
"total_reports": len(compliance_reports),
"open": len([r for r in compliance_reports.values() if r["status"] == "open"]),
"resolved": len([r for r in compliance_reports.values() if r["status"] == "resolved"])
}
@app.post("/api/v1/monitoring/transaction")
async def monitor_transaction(transaction: TransactionMonitoring):
"""Monitor transaction for compliance"""
transaction_id = transaction.transaction_id
# Create transaction monitoring record
monitoring_record = {
"transaction_id": transaction_id,
"user_id": transaction.user_id,
"amount": transaction.amount,
"currency": transaction.currency,
"counterparty": transaction.counterparty,
"timestamp": transaction.timestamp.isoformat(),
"monitored_at": datetime.now(timezone.utc).isoformat(),
"risk_score": calculate_transaction_risk(transaction),
"flags": [],
"status": "monitored"
}
suspicious_transactions[transaction_id] = monitoring_record
# Check for suspicious patterns
flags = check_suspicious_patterns(transaction)
if flags:
monitoring_record["flags"] = flags
monitoring_record["status"] = "flagged"
# Create compliance report for suspicious transaction
await create_suspicious_transaction_report(transaction, flags)
return {
"transaction_id": transaction_id,
"risk_score": monitoring_record["risk_score"],
"flags": flags,
"status": monitoring_record["status"]
}
@app.get("/api/v1/monitoring/transactions")
async def list_monitored_transactions():
"""List all monitored transactions"""
return {
"transactions": list(suspicious_transactions.values()),
"total_transactions": len(suspicious_transactions),
"flagged": len([t for t in suspicious_transactions.values() if t["status"] == "flagged"]),
"suspicious": len([t for t in suspicious_transactions.values() if t["risk_score"] == "high"])
}
@app.post("/api/v1/rules/create")
async def create_compliance_rule(rule_data: Dict[str, Any]):
"""Create a new compliance rule"""
rule_id = f"rule_{int(datetime.now(timezone.utc).timestamp())}"
rule = {
"rule_id": rule_id,
"name": rule_data.get("name"),
"description": rule_data.get("description"),
"type": rule_data.get("type"),
"conditions": rule_data.get("conditions", {}),
"actions": rule_data.get("actions", []),
"severity": rule_data.get("severity", "medium"),
"active": True,
"created_at": datetime.now(timezone.utc).isoformat(),
"trigger_count": 0
}
compliance_rules[rule_id] = rule
logger.info(f"Compliance rule created: {rule_id} - {rule['name']}")
return {
"rule_id": rule_id,
"name": rule["name"],
"status": "created",
"active": rule["active"]
}
@app.get("/api/v1/rules")
async def list_compliance_rules():
"""List all compliance rules"""
return {
"rules": list(compliance_rules.values()),
"total_rules": len(compliance_rules),
"active": len([r for r in compliance_rules.values() if r["active"]])
}
@app.get("/api/v1/dashboard")
async def compliance_dashboard():
"""Get compliance dashboard data"""
total_users = len(kyc_records)
approved_users = len([r for r in kyc_records.values() if r["status"] == "approved"])
pending_reviews = len([r for r in kyc_records.values() if r["status"] == "pending"])
total_reports = len(compliance_reports)
open_reports = len([r for r in compliance_reports.values() if r["status"] == "open"])
total_transactions = len(suspicious_transactions)
flagged_transactions = len([t for t in suspicious_transactions.values() if t["status"] == "flagged"])
return {
"summary": {
"total_users": total_users,
"approved_users": approved_users,
"pending_reviews": pending_reviews,
"approval_rate": (approved_users / total_users * 100) if total_users > 0 else 0,
"total_reports": total_reports,
"open_reports": open_reports,
"total_transactions": total_transactions,
"flagged_transactions": flagged_transactions,
"flag_rate": (flagged_transactions / total_transactions * 100) if total_transactions > 0 else 0
},
"risk_distribution": get_risk_distribution(),
"recent_activity": get_recent_activity(),
"generated_at": datetime.now(timezone.utc).isoformat()
}
# Helper functions
def calculate_transaction_risk(transaction: TransactionMonitoring) -> str:
"""Calculate risk score for a transaction"""
risk_score = 0
# Amount-based risk
if transaction.amount > 10000:
risk_score += 3
elif transaction.amount > 1000:
risk_score += 2
elif transaction.amount > 100:
risk_score += 1
# Time-based risk (transactions outside business hours)
hour = transaction.timestamp.hour
if hour < 9 or hour > 17:
risk_score += 1
# Convert to risk level
if risk_score >= 4:
return "high"
elif risk_score >= 2:
return "medium"
else:
return "low"
def check_suspicious_patterns(transaction: TransactionMonitoring) -> List[str]:
"""Check for suspicious transaction patterns"""
flags = []
# High value transaction
if transaction.amount > 50000:
flags.append("high_value_transaction")
# Rapid transactions (check if user has multiple transactions in short time)
user_transactions = [t for t in suspicious_transactions.values()
if t["user_id"] == transaction.user_id]
recent_transactions = [t for t in user_transactions
if datetime.fromisoformat(t["monitored_at"]) >
datetime.now(timezone.utc) - timedelta(hours=1)]
if len(recent_transactions) > 5:
flags.append("rapid_transactions")
# Unusual counterparty
if transaction.counterparty in ["high_risk_entity_1", "high_risk_entity_2"]:
flags.append("high_risk_counterparty")
return flags
async def create_suspicious_transaction_report(transaction: TransactionMonitoring, flags: List[str]):
"""Create compliance report for suspicious transaction"""
report_data = ComplianceReport(
report_type="suspicious_transaction",
description=f"Suspicious transaction detected: {transaction.transaction_id}",
severity="high",
details={
"transaction_id": transaction.transaction_id,
"user_id": transaction.user_id,
"amount": transaction.amount,
"flags": flags,
"timestamp": transaction.timestamp.isoformat()
}
)
await create_compliance_report(report_data)
def get_risk_distribution() -> Dict[str, int]:
"""Get distribution of risk scores"""
distribution = {"low": 0, "medium": 0, "high": 0}
for record in kyc_records.values():
distribution[record["risk_score"]] = distribution.get(record["risk_score"], 0) + 1
for transaction in suspicious_transactions.values():
distribution[transaction["risk_score"]] = distribution.get(transaction["risk_score"], 0) + 1
return distribution
def get_recent_activity() -> List[Dict]:
"""Get recent compliance activity"""
activities = []
# Recent KYC approvals
recent_kyc = [r for r in kyc_records.values()
if r.get("approved_at") and
datetime.fromisoformat(r["approved_at"]) >
datetime.now(timezone.utc) - timedelta(hours=24)]
for kyc in recent_kyc[:5]:
activities.append({
"type": "kyc_approved",
"description": f"KYC approved for {kyc['name']}",
"timestamp": kyc["approved_at"]
})
# Recent compliance reports
recent_reports = [r for r in compliance_reports.values()
if datetime.fromisoformat(r["created_at"]) >
datetime.now(timezone.utc) - timedelta(hours=24)]
for report in recent_reports[:5]:
activities.append({
"type": "compliance_report",
"description": f"Report: {report['description']}",
"timestamp": report["created_at"]
})
# Sort by timestamp
activities.sort(key=lambda x: x["timestamp"], reverse=True)
return activities[:10]
# Background task for periodic compliance checks
async def periodic_compliance_checks():
"""Background task for periodic compliance monitoring"""
while True:
await asyncio.sleep(300) # Check every 5 minutes
# Check for expired KYC records
current_time = datetime.now(timezone.utc)
for user_id, kyc_record in kyc_records.items():
if kyc_record["status"] == "approved":
approved_time = datetime.fromisoformat(kyc_record["approved_at"])
if current_time - approved_time > timedelta(days=365):
# Flag for re-verification
kyc_record["status"] = "reverification_required"
logger.info(f"KYC re-verification required for user: {user_id}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8011, log_level="info")

View File

@@ -0,0 +1 @@
"""Compliance service tests"""

View File

@@ -0,0 +1,193 @@
"""Edge case and error handling tests for compliance service"""
import pytest
import sys
import sys
from pathlib import Path
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from main import app, KYCRequest, ComplianceReport, TransactionMonitoring, kyc_records, compliance_reports, suspicious_transactions, compliance_rules
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
kyc_records.clear()
compliance_reports.clear()
suspicious_transactions.clear()
compliance_rules.clear()
yield
kyc_records.clear()
compliance_reports.clear()
suspicious_transactions.clear()
compliance_rules.clear()
@pytest.mark.unit
def test_kyc_request_empty_fields():
"""Test KYCRequest with empty fields"""
kyc = KYCRequest(
user_id="",
name="",
email="",
document_type="",
document_number="",
address={}
)
assert kyc.user_id == ""
assert kyc.name == ""
@pytest.mark.unit
def test_compliance_report_invalid_severity():
"""Test ComplianceReport with invalid severity"""
report = ComplianceReport(
report_type="test",
description="test",
severity="invalid", # Not in low/medium/high/critical
details={}
)
assert report.severity == "invalid"
@pytest.mark.unit
def test_transaction_monitoring_zero_amount():
"""Test TransactionMonitoring with zero amount"""
tx = TransactionMonitoring(
transaction_id="tx123",
user_id="user123",
amount=0.0,
currency="BTC",
counterparty="counterparty1",
timestamp=datetime.now(timezone.utc)
)
assert tx.amount == 0.0
@pytest.mark.unit
def test_transaction_monitoring_negative_amount():
"""Test TransactionMonitoring with negative amount"""
tx = TransactionMonitoring(
transaction_id="tx123",
user_id="user123",
amount=-1000.0,
currency="BTC",
counterparty="counterparty1",
timestamp=datetime.now(timezone.utc)
)
assert tx.amount == -1000.0
@pytest.mark.integration
def test_kyc_with_missing_address_fields():
"""Test KYC submission with missing address fields"""
client = TestClient(app)
kyc = KYCRequest(
user_id="user123",
name="John Doe",
email="john@example.com",
document_type="passport",
document_number="ABC123",
address={"city": "New York"} # Missing other fields
)
response = client.post("/api/v1/kyc/submit", json=kyc.model_dump())
assert response.status_code == 200
@pytest.mark.integration
def test_compliance_report_empty_details():
"""Test compliance report with empty details"""
client = TestClient(app)
report = ComplianceReport(
report_type="test",
description="test",
severity="low",
details={}
)
response = client.post("/api/v1/compliance/report", json=report.model_dump())
assert response.status_code == 200
@pytest.mark.integration
def test_compliance_rule_missing_fields():
"""Test compliance rule with missing fields"""
client = TestClient(app)
rule_data = {
"name": "Test Rule"
# Missing description, type, etc.
}
response = client.post("/api/v1/rules/create", json=rule_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Test Rule"
@pytest.mark.integration
def test_dashboard_with_no_data():
"""Test dashboard with no data"""
client = TestClient(app)
response = client.get("/api/v1/dashboard")
assert response.status_code == 200
data = response.json()
assert data["summary"]["total_users"] == 0
assert data["summary"]["total_reports"] == 0
assert data["summary"]["total_transactions"] == 0
@pytest.mark.integration
def test_monitor_transaction_with_future_timestamp():
"""Test monitoring transaction with future timestamp"""
client = TestClient(app)
tx = TransactionMonitoring(
transaction_id="tx123",
user_id="user123",
amount=1000.0,
currency="BTC",
counterparty="counterparty1",
timestamp=datetime(2030, 1, 1) # Future timestamp
)
response = client.post("/api/v1/monitoring/transaction", json=tx.model_dump(mode='json'))
assert response.status_code == 200
@pytest.mark.integration
def test_monitor_transaction_with_past_timestamp():
"""Test monitoring transaction with past timestamp"""
client = TestClient(app)
tx = TransactionMonitoring(
transaction_id="tx123",
user_id="user123",
amount=1000.0,
currency="BTC",
counterparty="counterparty1",
timestamp=datetime(2020, 1, 1) # Past timestamp
)
response = client.post("/api/v1/monitoring/transaction", json=tx.model_dump(mode='json'))
assert response.status_code == 200
@pytest.mark.integration
def test_kyc_list_with_multiple_records():
"""Test listing KYC with multiple records"""
client = TestClient(app)
# Create multiple KYC records
for i in range(5):
kyc = KYCRequest(
user_id=f"user{i}",
name=f"User {i}",
email=f"user{i}@example.com",
document_type="passport",
document_number=f"ABC{i}",
address={"city": "New York"}
)
client.post("/api/v1/kyc/submit", json=kyc.model_dump())
response = client.get("/api/v1/kyc")
assert response.status_code == 200
data = response.json()
assert data["total_records"] == 5
assert data["approved"] == 5

View File

@@ -0,0 +1,252 @@
"""Integration tests for compliance service"""
import pytest
import sys
import sys
from pathlib import Path
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from main import app, KYCRequest, ComplianceReport, TransactionMonitoring, kyc_records, compliance_reports, suspicious_transactions, compliance_rules
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
kyc_records.clear()
compliance_reports.clear()
suspicious_transactions.clear()
compliance_rules.clear()
yield
kyc_records.clear()
compliance_reports.clear()
suspicious_transactions.clear()
compliance_rules.clear()
@pytest.mark.integration
def test_root_endpoint():
"""Test root endpoint"""
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["service"] == "AITBC Compliance Service"
assert data["status"] == "running"
@pytest.mark.integration
def test_health_check_endpoint():
"""Test health check endpoint"""
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "kyc_records" in data
assert "compliance_reports" in data
@pytest.mark.integration
def test_submit_kyc():
"""Test KYC submission"""
client = TestClient(app)
kyc = KYCRequest(
user_id="user123",
name="John Doe",
email="john@example.com",
document_type="passport",
document_number="ABC123",
address={"street": "123 Main St", "city": "New York", "country": "USA"}
)
response = client.post("/api/v1/kyc/submit", json=kyc.model_dump())
assert response.status_code == 200
data = response.json()
assert data["user_id"] == "user123"
assert data["status"] == "approved"
assert data["risk_score"] == "low"
@pytest.mark.integration
def test_submit_duplicate_kyc():
"""Test submitting duplicate KYC"""
client = TestClient(app)
kyc = KYCRequest(
user_id="user123",
name="John Doe",
email="john@example.com",
document_type="passport",
document_number="ABC123",
address={"street": "123 Main St", "city": "New York", "country": "USA"}
)
# First submission
client.post("/api/v1/kyc/submit", json=kyc.model_dump())
# Second submission should fail
response = client.post("/api/v1/kyc/submit", json=kyc.model_dump())
assert response.status_code == 400
@pytest.mark.integration
def test_get_kyc_status():
"""Test getting KYC status"""
client = TestClient(app)
kyc = KYCRequest(
user_id="user123",
name="John Doe",
email="john@example.com",
document_type="passport",
document_number="ABC123",
address={"street": "123 Main St", "city": "New York", "country": "USA"}
)
# Submit KYC first
client.post("/api/v1/kyc/submit", json=kyc.model_dump())
# Get KYC status
response = client.get("/api/v1/kyc/user123")
assert response.status_code == 200
data = response.json()
assert data["user_id"] == "user123"
assert data["status"] == "approved"
@pytest.mark.integration
def test_get_kyc_status_not_found():
"""Test getting KYC status for nonexistent user"""
client = TestClient(app)
response = client.get("/api/v1/kyc/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_list_kyc_records():
"""Test listing KYC records"""
client = TestClient(app)
response = client.get("/api/v1/kyc")
assert response.status_code == 200
data = response.json()
assert "kyc_records" in data
assert "total_records" in data
@pytest.mark.integration
def test_create_compliance_report():
"""Test creating compliance report"""
client = TestClient(app)
report = ComplianceReport(
report_type="suspicious_activity",
description="Suspicious transaction detected",
severity="high",
details={"transaction_id": "tx123"}
)
response = client.post("/api/v1/compliance/report", json=report.model_dump())
assert response.status_code == 200
data = response.json()
assert data["severity"] == "high"
assert data["status"] == "created"
@pytest.mark.integration
def test_list_compliance_reports():
"""Test listing compliance reports"""
client = TestClient(app)
response = client.get("/api/v1/compliance/reports")
assert response.status_code == 200
data = response.json()
assert "reports" in data
assert "total_reports" in data
@pytest.mark.integration
def test_monitor_transaction():
"""Test transaction monitoring"""
client = TestClient(app)
tx = TransactionMonitoring(
transaction_id="tx123",
user_id="user123",
amount=1000.0,
currency="BTC",
counterparty="counterparty1",
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/monitoring/transaction", json=tx.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["transaction_id"] == "tx123"
assert "risk_score" in data
@pytest.mark.integration
def test_monitor_suspicious_transaction():
"""Test monitoring suspicious transaction"""
client = TestClient(app)
tx = TransactionMonitoring(
transaction_id="tx123",
user_id="user123",
amount=100000.0,
currency="BTC",
counterparty="high_risk_entity_1",
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/monitoring/transaction", json=tx.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["status"] == "flagged"
assert len(data["flags"]) > 0
@pytest.mark.integration
def test_list_monitored_transactions():
"""Test listing monitored transactions"""
client = TestClient(app)
response = client.get("/api/v1/monitoring/transactions")
assert response.status_code == 200
data = response.json()
assert "transactions" in data
assert "total_transactions" in data
@pytest.mark.integration
def test_create_compliance_rule():
"""Test creating compliance rule"""
client = TestClient(app)
rule_data = {
"name": "High Value Transaction Rule",
"description": "Flag transactions over $50,000",
"type": "transaction_monitoring",
"conditions": {"min_amount": 50000},
"actions": ["flag", "report"],
"severity": "high"
}
response = client.post("/api/v1/rules/create", json=rule_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "High Value Transaction Rule"
assert data["active"] is True
@pytest.mark.integration
def test_list_compliance_rules():
"""Test listing compliance rules"""
client = TestClient(app)
response = client.get("/api/v1/rules")
assert response.status_code == 200
data = response.json()
assert "rules" in data
assert "total_rules" in data
@pytest.mark.integration
def test_compliance_dashboard():
"""Test compliance dashboard"""
client = TestClient(app)
response = client.get("/api/v1/dashboard")
assert response.status_code == 200
data = response.json()
assert "summary" in data
assert "risk_distribution" in data
assert "recent_activity" in data

View File

@@ -0,0 +1,161 @@
"""Unit tests for compliance service"""
import pytest
import sys
import sys
from pathlib import Path
from unittest.mock import Mock, patch
from datetime import datetime, timezone
from main import app, KYCRequest, ComplianceReport, TransactionMonitoring, calculate_transaction_risk, check_suspicious_patterns
@pytest.mark.unit
def test_app_initialization():
"""Test that the FastAPI app initializes correctly"""
assert app is not None
assert app.title == "AITBC Compliance Service"
assert app.version == "1.0.0"
@pytest.mark.unit
def test_kyc_request_model():
"""Test KYCRequest model"""
kyc = KYCRequest(
user_id="user123",
name="John Doe",
email="john@example.com",
document_type="passport",
document_number="ABC123",
address={"street": "123 Main St", "city": "New York", "country": "USA"}
)
assert kyc.user_id == "user123"
assert kyc.name == "John Doe"
assert kyc.email == "john@example.com"
assert kyc.document_type == "passport"
assert kyc.document_number == "ABC123"
assert kyc.address["city"] == "New York"
@pytest.mark.unit
def test_compliance_report_model():
"""Test ComplianceReport model"""
report = ComplianceReport(
report_type="suspicious_activity",
description="Suspicious transaction detected",
severity="high",
details={"transaction_id": "tx123"}
)
assert report.report_type == "suspicious_activity"
assert report.description == "Suspicious transaction detected"
assert report.severity == "high"
assert report.details["transaction_id"] == "tx123"
@pytest.mark.unit
def test_transaction_monitoring_model():
"""Test TransactionMonitoring model"""
tx = TransactionMonitoring(
transaction_id="tx123",
user_id="user123",
amount=1000.0,
currency="BTC",
counterparty="counterparty1",
timestamp=datetime.now(timezone.utc)
)
assert tx.transaction_id == "tx123"
assert tx.user_id == "user123"
assert tx.amount == 1000.0
assert tx.currency == "BTC"
assert tx.counterparty == "counterparty1"
@pytest.mark.unit
def test_calculate_transaction_risk_low():
"""Test risk calculation for low risk transaction"""
tx = TransactionMonitoring(
transaction_id="tx123",
user_id="user123",
amount=50.0,
currency="BTC",
counterparty="counterparty1",
timestamp=datetime(2026, 1, 1, 10, 0, 0) # Business hours
)
risk = calculate_transaction_risk(tx)
assert risk == "low"
@pytest.mark.unit
def test_calculate_transaction_risk_medium():
"""Test risk calculation for medium risk transaction"""
tx = TransactionMonitoring(
transaction_id="tx123",
user_id="user123",
amount=5000.0,
currency="BTC",
counterparty="counterparty1",
timestamp=datetime(2026, 1, 1, 10, 0, 0)
)
risk = calculate_transaction_risk(tx)
assert risk == "medium"
@pytest.mark.unit
def test_calculate_transaction_risk_high():
"""Test risk calculation for high risk transaction"""
tx = TransactionMonitoring(
transaction_id="tx123",
user_id="user123",
amount=20000.0,
currency="BTC",
counterparty="counterparty1",
timestamp=datetime(2026, 1, 1, 8, 0, 0) # Outside business hours
)
risk = calculate_transaction_risk(tx)
assert risk == "high"
@pytest.mark.unit
def test_check_suspicious_patterns_high_value():
"""Test suspicious pattern detection for high value"""
tx = TransactionMonitoring(
transaction_id="tx123",
user_id="user123",
amount=100000.0,
currency="BTC",
counterparty="counterparty1",
timestamp=datetime.now(timezone.utc)
)
flags = check_suspicious_patterns(tx)
assert "high_value_transaction" in flags
@pytest.mark.unit
def test_check_suspicious_patterns_high_risk_counterparty():
"""Test suspicious pattern detection for high risk counterparty"""
tx = TransactionMonitoring(
transaction_id="tx123",
user_id="user123",
amount=1000.0,
currency="BTC",
counterparty="high_risk_entity_1",
timestamp=datetime.now(timezone.utc)
)
flags = check_suspicious_patterns(tx)
assert "high_risk_counterparty" in flags
@pytest.mark.unit
def test_check_suspicious_patterns_none():
"""Test suspicious pattern detection with no flags"""
tx = TransactionMonitoring(
transaction_id="tx123",
user_id="user123",
amount=1000.0,
currency="BTC",
counterparty="safe_counterparty",
timestamp=datetime.now(timezone.utc)
)
flags = check_suspicious_patterns(tx)
assert len(flags) == 0

View File

@@ -0,0 +1,324 @@
"""
Production Exchange API Integration Service
Handles real exchange connections and trading operations
"""
import os
import asyncio
import json
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, Any, List, Optional
import aiohttp
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from aitbc import get_logger
logger = get_logger(__name__)
app = FastAPI(
title="AITBC Exchange Integration Service",
description="Production exchange API integration for AITBC trading",
version="1.0.0"
)
# Data models
class ExchangeRegistration(BaseModel):
name: str
api_key: str
sandbox: bool = True
description: Optional[str] = None
class TradingPair(BaseModel):
symbol: str
base_asset: str
quote_asset: str
min_order_size: float
price_precision: int
quantity_precision: int
class OrderRequest(BaseModel):
symbol: str
side: str # buy/sell
type: str # market/limit
quantity: float
price: Optional[float] = None
# In-memory storage (in production, use database)
exchanges: Dict[str, Dict] = {}
trading_pairs: Dict[str, Dict] = {}
orders: Dict[str, Dict] = {}
@app.get("/")
async def root():
return {
"service": "AITBC Exchange Integration",
"status": "running",
"timestamp": datetime.now(timezone.utc).isoformat(),
"version": "1.0.0"
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"exchanges_connected": len([e for e in exchanges.values() if e.get("connected")]),
"active_pairs": len(trading_pairs),
"total_orders": len(orders)
}
@app.post("/api/v1/exchanges/register")
async def register_exchange(registration: ExchangeRegistration):
"""Register a new exchange connection"""
exchange_id = registration.name.lower()
if exchange_id in exchanges:
raise HTTPException(status_code=400, detail="Exchange already registered")
# Create exchange configuration
exchange_config = {
"exchange_id": exchange_id,
"name": registration.name,
"api_key": registration.api_key,
"sandbox": registration.sandbox,
"description": registration.description,
"connected": False,
"created_at": datetime.now(timezone.utc).isoformat(),
"last_sync": None,
"trading_pairs": []
}
exchanges[exchange_id] = exchange_config
logger.info(f"Exchange registered: {registration.name}")
return {
"exchange_id": exchange_id,
"status": "registered",
"name": registration.name,
"sandbox": registration.sandbox,
"created_at": exchange_config["created_at"]
}
@app.post("/api/v1/exchanges/{exchange_id}/connect")
async def connect_exchange(exchange_id: str):
"""Connect to a registered exchange"""
if exchange_id not in exchanges:
raise HTTPException(status_code=404, detail="Exchange not found")
exchange = exchanges[exchange_id]
if exchange["connected"]:
return {"status": "already_connected", "exchange_id": exchange_id}
# Simulate exchange connection
# In production, this would make actual API calls to the exchange
await asyncio.sleep(1) # Simulate connection delay
exchange["connected"] = True
exchange["last_sync"] = datetime.now(timezone.utc).isoformat()
logger.info(f"Exchange connected: {exchange_id}")
return {
"exchange_id": exchange_id,
"status": "connected",
"connected_at": exchange["last_sync"]
}
@app.post("/api/v1/pairs/create")
async def create_trading_pair(pair: TradingPair):
"""Create a new trading pair"""
pair_id = f"{pair.symbol.lower()}"
if pair_id in trading_pairs:
raise HTTPException(status_code=400, detail="Trading pair already exists")
# Create trading pair configuration
pair_config = {
"pair_id": pair_id,
"symbol": pair.symbol,
"base_asset": pair.base_asset,
"quote_asset": pair.quote_asset,
"min_order_size": pair.min_order_size,
"price_precision": pair.price_precision,
"quantity_precision": pair.quantity_precision,
"status": "active",
"created_at": datetime.now(timezone.utc).isoformat(),
"current_price": None,
"volume_24h": 0.0,
"orders": []
}
trading_pairs[pair_id] = pair_config
logger.info(f"Trading pair created: {pair.symbol}")
return {
"pair_id": pair_id,
"symbol": pair.symbol,
"status": "created",
"created_at": pair_config["created_at"]
}
@app.get("/api/v1/pairs")
async def list_trading_pairs():
"""List all trading pairs"""
return {
"pairs": list(trading_pairs.values()),
"total_pairs": len(trading_pairs)
}
@app.get("/api/v1/pairs/{pair_id}")
async def get_trading_pair(pair_id: str):
"""Get specific trading pair information"""
if pair_id not in trading_pairs:
raise HTTPException(status_code=404, detail="Trading pair not found")
return trading_pairs[pair_id]
@app.post("/api/v1/orders")
async def create_order(order: OrderRequest):
"""Create a new trading order"""
pair_id = order.symbol.lower()
if pair_id not in trading_pairs:
raise HTTPException(status_code=404, detail="Trading pair not found")
# Generate order ID
order_id = f"order_{int(datetime.now(timezone.utc).timestamp())}"
# Create order
order_data = {
"order_id": order_id,
"symbol": order.symbol,
"side": order.side,
"type": order.type,
"quantity": order.quantity,
"price": order.price,
"status": "submitted",
"created_at": datetime.now(timezone.utc).isoformat(),
"filled_quantity": 0.0,
"remaining_quantity": order.quantity,
"average_price": None
}
orders[order_id] = order_data
# Add to trading pair
trading_pairs[pair_id]["orders"].append(order_id)
# Simulate order processing
await asyncio.sleep(0.5) # Simulate processing delay
# Mark as filled (for demo)
order_data["status"] = "filled"
order_data["filled_quantity"] = order.quantity
order_data["remaining_quantity"] = 0.0
order_data["average_price"] = order.price or 0.00001 # Default price for demo
order_data["filled_at"] = datetime.now(timezone.utc).isoformat()
logger.info(f"Order created and filled: {order_id}")
return order_data
@app.get("/api/v1/orders")
async def list_orders():
"""List all orders"""
return {
"orders": list(orders.values()),
"total_orders": len(orders)
}
@app.get("/api/v1/orders/{order_id}")
async def get_order(order_id: str):
"""Get specific order information"""
if order_id not in orders:
raise HTTPException(status_code=404, detail="Order not found")
return orders[order_id]
@app.get("/api/v1/exchanges")
async def list_exchanges():
"""List all registered exchanges"""
return {
"exchanges": list(exchanges.values()),
"total_exchanges": len(exchanges)
}
@app.get("/api/v1/exchanges/{exchange_id}")
async def get_exchange(exchange_id: str):
"""Get specific exchange information"""
if exchange_id not in exchanges:
raise HTTPException(status_code=404, detail="Exchange not found")
return exchanges[exchange_id]
@app.post("/api/v1/market-data/{pair_id}/price")
async def update_market_price(pair_id: str, price_data: Dict[str, Any]):
"""Update market price for a trading pair"""
if pair_id not in trading_pairs:
raise HTTPException(status_code=404, detail="Trading pair not found")
pair = trading_pairs[pair_id]
pair["current_price"] = price_data.get("price")
pair["volume_24h"] = price_data.get("volume", pair["volume_24h"])
pair["last_price_update"] = datetime.now(timezone.utc).isoformat()
return {
"pair_id": pair_id,
"current_price": pair["current_price"],
"updated_at": pair["last_price_update"]
}
@app.get("/api/v1/market-data")
async def get_market_data():
"""Get market data for all pairs"""
market_data = {}
for pair_id, pair in trading_pairs.items():
market_data[pair_id] = {
"symbol": pair["symbol"],
"current_price": pair.get("current_price"),
"volume_24h": pair.get("volume_24h"),
"last_update": pair.get("last_price_update")
}
return {
"market_data": market_data,
"total_pairs": len(market_data),
"generated_at": datetime.now(timezone.utc).isoformat()
}
# Background task for simulating market data
async def simulate_market_data():
"""Background task to simulate market data updates"""
while True:
await asyncio.sleep(30) # Update every 30 seconds
for pair_id, pair in trading_pairs.items():
if pair["status"] == "active":
# Simulate price changes
import random
base_price = 0.00001 # Base price for AITBC
variation = random.uniform(-0.02, 0.02) # ±2% variation
new_price = round(base_price * (1 + variation), 8)
pair["current_price"] = new_price
pair["volume_24h"] += random.uniform(100, 1000)
pair["last_price_update"] = datetime.now(timezone.utc).isoformat()
# Start background task on startup
@app.on_event("startup")
async def startup_event():
logger.info("Starting AITBC Exchange Integration Service")
# Start background market data simulation
asyncio.create_task(simulate_market_data())
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Shutting down AITBC Exchange Integration Service")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8010, log_level="info")

View File

@@ -0,0 +1 @@
"""Exchange integration service tests"""

View File

@@ -0,0 +1,256 @@
"""Edge case and error handling tests for exchange integration service"""
import pytest
import sys
import sys
from pathlib import Path
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
# Mock aiohttp before importing
sys.modules['aiohttp'] = Mock()
from main import app, ExchangeRegistration, TradingPair, OrderRequest, exchanges, trading_pairs, orders
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
exchanges.clear()
trading_pairs.clear()
orders.clear()
yield
exchanges.clear()
trading_pairs.clear()
orders.clear()
@pytest.mark.unit
def test_exchange_registration_empty_name():
"""Test ExchangeRegistration with empty name"""
registration = ExchangeRegistration(
name="",
api_key="test_key_123"
)
assert registration.name == ""
@pytest.mark.unit
def test_exchange_registration_empty_api_key():
"""Test ExchangeRegistration with empty API key"""
registration = ExchangeRegistration(
name="TestExchange",
api_key=""
)
assert registration.api_key == ""
@pytest.mark.unit
def test_trading_pair_zero_min_order_size():
"""Test TradingPair with zero min order size"""
pair = TradingPair(
symbol="AITBC/BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=0.0,
price_precision=8,
quantity_precision=6
)
assert pair.min_order_size == 0.0
@pytest.mark.unit
def test_trading_pair_negative_min_order_size():
"""Test TradingPair with negative min order size"""
pair = TradingPair(
symbol="AITBC/BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=-0.001,
price_precision=8,
quantity_precision=6
)
assert pair.min_order_size == -0.001
@pytest.mark.unit
def test_order_request_zero_quantity():
"""Test OrderRequest with zero quantity"""
order = OrderRequest(
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=0.0,
price=0.00001
)
assert order.quantity == 0.0
@pytest.mark.unit
def test_order_request_negative_quantity():
"""Test OrderRequest with negative quantity"""
order = OrderRequest(
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=-100.0,
price=0.00001
)
assert order.quantity == -100.0
@pytest.mark.integration
def test_order_request_invalid_side():
"""Test OrderRequest with invalid side"""
client = TestClient(app)
# Create trading pair first
pair = TradingPair(
symbol="AITBC/BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=0.001,
price_precision=8,
quantity_precision=6
)
client.post("/api/v1/pairs/create", json=pair.model_dump())
# Create order with invalid side (API doesn't validate, but test the behavior)
order = OrderRequest(
symbol="AITBC/BTC",
side="invalid",
type="limit",
quantity=100.0,
price=0.00001
)
# This will be accepted by the API as it doesn't validate the side
response = client.post("/api/v1/orders", json=order.model_dump())
assert response.status_code == 200
@pytest.mark.integration
def test_order_request_invalid_type():
"""Test OrderRequest with invalid type"""
client = TestClient(app)
# Create trading pair first
pair = TradingPair(
symbol="AITBC/BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=0.001,
price_precision=8,
quantity_precision=6
)
client.post("/api/v1/pairs/create", json=pair.model_dump())
# Create order with invalid type (API doesn't validate, but test the behavior)
order = OrderRequest(
symbol="AITBC/BTC",
side="buy",
type="invalid",
quantity=100.0,
price=0.00001
)
# This will be accepted by the API as it doesn't validate the type
response = client.post("/api/v1/orders", json=order.model_dump())
assert response.status_code == 200
@pytest.mark.integration
def test_connect_already_connected_exchange():
"""Test connecting to already connected exchange"""
client = TestClient(app)
registration = ExchangeRegistration(
name="TestExchange",
api_key="test_key_123"
)
# Register exchange
client.post("/api/v1/exchanges/register", json=registration.model_dump())
# Connect first time
client.post("/api/v1/exchanges/testexchange/connect")
# Connect second time should return already_connected
response = client.post("/api/v1/exchanges/testexchange/connect")
assert response.status_code == 200
data = response.json()
assert data["status"] == "already_connected"
@pytest.mark.integration
def test_update_market_price_missing_fields():
"""Test updating market price with missing fields"""
client = TestClient(app)
# Create trading pair first
pair = TradingPair(
symbol="AITBC-BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=0.001,
price_precision=8,
quantity_precision=6
)
create_response = client.post("/api/v1/pairs/create", json=pair.model_dump())
assert create_response.status_code == 200
# Update with missing price
price_data = {"volume": 50000.0}
response = client.post("/api/v1/market-data/aitbc-btc/price", json=price_data)
assert response.status_code == 200
data = response.json()
# Should use None for missing price
assert data["current_price"] is None
@pytest.mark.integration
def test_update_market_price_zero_price():
"""Test updating market price with zero price"""
client = TestClient(app)
# Create trading pair first
pair = TradingPair(
symbol="AITBC-BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=0.001,
price_precision=8,
quantity_precision=6
)
create_response = client.post("/api/v1/pairs/create", json=pair.model_dump())
assert create_response.status_code == 200
# Update with zero price
price_data = {"price": 0.0}
response = client.post("/api/v1/market-data/aitbc-btc/price", json=price_data)
assert response.status_code == 200
data = response.json()
assert data["current_price"] == 0.0
@pytest.mark.integration
def test_update_market_price_negative_price():
"""Test updating market price with negative price"""
client = TestClient(app)
# Create trading pair first
pair = TradingPair(
symbol="AITBC-BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=0.001,
price_precision=8,
quantity_precision=6
)
create_response = client.post("/api/v1/pairs/create", json=pair.model_dump())
assert create_response.status_code == 200
# Update with negative price
price_data = {"price": -0.00001}
response = client.post("/api/v1/market-data/aitbc-btc/price", json=price_data)
assert response.status_code == 200
data = response.json()
assert data["current_price"] == -0.00001

View File

@@ -0,0 +1,378 @@
"""Integration tests for exchange integration service"""
import pytest
import sys
import sys
from pathlib import Path
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
# Mock aiohttp before importing
sys.modules['aiohttp'] = Mock()
from main import app, ExchangeRegistration, TradingPair, OrderRequest, exchanges, trading_pairs, orders
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
exchanges.clear()
trading_pairs.clear()
orders.clear()
yield
exchanges.clear()
trading_pairs.clear()
orders.clear()
@pytest.mark.integration
def test_root_endpoint():
"""Test root endpoint"""
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["service"] == "AITBC Exchange Integration"
assert data["status"] == "running"
@pytest.mark.integration
def test_health_check_endpoint():
"""Test health check endpoint"""
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "exchanges_connected" in data
assert "active_pairs" in data
assert "total_orders" in data
@pytest.mark.integration
def test_register_exchange():
"""Test exchange registration"""
client = TestClient(app)
registration = ExchangeRegistration(
name="TestExchange",
api_key="test_key_123",
sandbox=True
)
response = client.post("/api/v1/exchanges/register", json=registration.model_dump())
assert response.status_code == 200
data = response.json()
assert data["exchange_id"] == "testexchange"
assert data["status"] == "registered"
assert data["name"] == "TestExchange"
@pytest.mark.integration
def test_register_duplicate_exchange():
"""Test registering duplicate exchange"""
client = TestClient(app)
registration = ExchangeRegistration(
name="TestExchange",
api_key="test_key_123"
)
# First registration
client.post("/api/v1/exchanges/register", json=registration.model_dump())
# Second registration should fail
response = client.post("/api/v1/exchanges/register", json=registration.model_dump())
assert response.status_code == 400
@pytest.mark.integration
def test_connect_exchange():
"""Test connecting to exchange"""
client = TestClient(app)
registration = ExchangeRegistration(
name="TestExchange",
api_key="test_key_123"
)
# Register exchange first
client.post("/api/v1/exchanges/register", json=registration.model_dump())
# Connect to exchange
response = client.post("/api/v1/exchanges/testexchange/connect")
assert response.status_code == 200
data = response.json()
assert data["exchange_id"] == "testexchange"
assert data["status"] == "connected"
@pytest.mark.integration
def test_connect_nonexistent_exchange():
"""Test connecting to nonexistent exchange"""
client = TestClient(app)
response = client.post("/api/v1/exchanges/nonexistent/connect")
assert response.status_code == 404
@pytest.mark.integration
def test_create_trading_pair():
"""Test creating trading pair"""
client = TestClient(app)
pair = TradingPair(
symbol="AITBC/BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=0.001,
price_precision=8,
quantity_precision=6
)
response = client.post("/api/v1/pairs/create", json=pair.model_dump())
assert response.status_code == 200
data = response.json()
assert data["pair_id"] == "aitbc/btc"
assert data["symbol"] == "AITBC/BTC"
assert data["status"] == "created"
@pytest.mark.integration
def test_create_duplicate_trading_pair():
"""Test creating duplicate trading pair"""
client = TestClient(app)
pair = TradingPair(
symbol="AITBC/BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=0.001,
price_precision=8,
quantity_precision=6
)
# First creation
client.post("/api/v1/pairs/create", json=pair.model_dump())
# Second creation should fail
response = client.post("/api/v1/pairs/create", json=pair.model_dump())
assert response.status_code == 400
@pytest.mark.integration
def test_list_trading_pairs():
"""Test listing trading pairs"""
client = TestClient(app)
response = client.get("/api/v1/pairs")
assert response.status_code == 200
data = response.json()
assert "pairs" in data
assert "total_pairs" in data
@pytest.mark.integration
def test_get_trading_pair():
"""Test getting specific trading pair"""
client = TestClient(app)
pair = TradingPair(
symbol="AITBC-BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=0.001,
price_precision=8,
quantity_precision=6
)
# Create pair first
client.post("/api/v1/pairs/create", json=pair.model_dump())
# Get pair with lowercase symbol as pair_id
response = client.get("/api/v1/pairs/aitbc-btc")
assert response.status_code == 200
data = response.json()
assert data["symbol"] == "AITBC-BTC"
@pytest.mark.integration
def test_get_nonexistent_trading_pair():
"""Test getting nonexistent trading pair"""
client = TestClient(app)
response = client.get("/api/v1/pairs/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_create_order():
"""Test creating order"""
client = TestClient(app)
# Create trading pair first
pair = TradingPair(
symbol="AITBC/BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=0.001,
price_precision=8,
quantity_precision=6
)
client.post("/api/v1/pairs/create", json=pair.model_dump())
# Create order
order = OrderRequest(
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001
)
response = client.post("/api/v1/orders", json=order.model_dump())
assert response.status_code == 200
data = response.json()
assert data["symbol"] == "AITBC/BTC"
assert data["side"] == "buy"
assert data["status"] == "filled"
assert data["filled_quantity"] == 100.0
@pytest.mark.integration
def test_create_order_nonexistent_pair():
"""Test creating order for nonexistent pair"""
client = TestClient(app)
order = OrderRequest(
symbol="NONEXISTENT/BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001
)
response = client.post("/api/v1/orders", json=order.model_dump())
assert response.status_code == 404
@pytest.mark.integration
def test_list_orders():
"""Test listing orders"""
client = TestClient(app)
response = client.get("/api/v1/orders")
assert response.status_code == 200
data = response.json()
assert "orders" in data
assert "total_orders" in data
@pytest.mark.integration
def test_get_order():
"""Test getting specific order"""
client = TestClient(app)
# Create trading pair first
pair = TradingPair(
symbol="AITBC/BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=0.001,
price_precision=8,
quantity_precision=6
)
client.post("/api/v1/pairs/create", json=pair.model_dump())
# Create order
order = OrderRequest(
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001
)
create_response = client.post("/api/v1/orders", json=order.model_dump())
order_id = create_response.json()["order_id"]
# Get order
response = client.get(f"/api/v1/orders/{order_id}")
assert response.status_code == 200
data = response.json()
assert data["order_id"] == order_id
@pytest.mark.integration
def test_get_nonexistent_order():
"""Test getting nonexistent order"""
client = TestClient(app)
response = client.get("/api/v1/orders/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_list_exchanges():
"""Test listing exchanges"""
client = TestClient(app)
response = client.get("/api/v1/exchanges")
assert response.status_code == 200
data = response.json()
assert "exchanges" in data
assert "total_exchanges" in data
@pytest.mark.integration
def test_get_exchange():
"""Test getting specific exchange"""
client = TestClient(app)
registration = ExchangeRegistration(
name="TestExchange",
api_key="test_key_123"
)
# Register exchange first
client.post("/api/v1/exchanges/register", json=registration.model_dump())
# Get exchange
response = client.get("/api/v1/exchanges/testexchange")
assert response.status_code == 200
data = response.json()
assert data["exchange_id"] == "testexchange"
@pytest.mark.integration
def test_get_nonexistent_exchange():
"""Test getting nonexistent exchange"""
client = TestClient(app)
response = client.get("/api/v1/exchanges/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_update_market_price():
"""Test updating market price"""
client = TestClient(app)
# Create trading pair first
pair = TradingPair(
symbol="AITBC-BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=0.001,
price_precision=8,
quantity_precision=6
)
client.post("/api/v1/pairs/create", json=pair.model_dump())
# Update price
price_data = {"price": 0.000015, "volume": 50000.0}
response = client.post("/api/v1/market-data/aitbc-btc/price", json=price_data)
assert response.status_code == 200
data = response.json()
assert data["current_price"] == 0.000015
@pytest.mark.integration
def test_update_price_nonexistent_pair():
"""Test updating price for nonexistent pair"""
client = TestClient(app)
price_data = {"price": 0.000015}
response = client.post("/api/v1/market-data/nonexistent/price", json=price_data)
assert response.status_code == 404
@pytest.mark.integration
def test_get_market_data():
"""Test getting market data"""
client = TestClient(app)
response = client.get("/api/v1/market-data")
assert response.status_code == 200
data = response.json()
assert "market_data" in data
assert "total_pairs" in data

View File

@@ -0,0 +1,101 @@
"""Unit tests for exchange integration service"""
import pytest
import sys
import sys
from pathlib import Path
from unittest.mock import Mock, patch
# Mock aiohttp before importing
sys.modules['aiohttp'] = Mock()
from main import app, ExchangeRegistration, TradingPair, OrderRequest
@pytest.mark.unit
def test_app_initialization():
"""Test that the FastAPI app initializes correctly"""
assert app is not None
assert app.title == "AITBC Exchange Integration Service"
assert app.version == "1.0.0"
@pytest.mark.unit
def test_exchange_registration_model():
"""Test ExchangeRegistration model"""
registration = ExchangeRegistration(
name="TestExchange",
api_key="test_key_123",
sandbox=True,
description="Test exchange"
)
assert registration.name == "TestExchange"
assert registration.api_key == "test_key_123"
assert registration.sandbox is True
assert registration.description == "Test exchange"
@pytest.mark.unit
def test_exchange_registration_defaults():
"""Test ExchangeRegistration default values"""
registration = ExchangeRegistration(
name="TestExchange",
api_key="test_key_123"
)
assert registration.name == "TestExchange"
assert registration.api_key == "test_key_123"
assert registration.sandbox is True
assert registration.description is None
@pytest.mark.unit
def test_trading_pair_model():
"""Test TradingPair model"""
pair = TradingPair(
symbol="AITBC/BTC",
base_asset="AITBC",
quote_asset="BTC",
min_order_size=0.001,
price_precision=8,
quantity_precision=6
)
assert pair.symbol == "AITBC/BTC"
assert pair.base_asset == "AITBC"
assert pair.quote_asset == "BTC"
assert pair.min_order_size == 0.001
assert pair.price_precision == 8
assert pair.quantity_precision == 6
@pytest.mark.unit
def test_order_request_model():
"""Test OrderRequest model"""
order = OrderRequest(
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001
)
assert order.symbol == "AITBC/BTC"
assert order.side == "buy"
assert order.type == "limit"
assert order.quantity == 100.0
assert order.price == 0.00001
@pytest.mark.unit
def test_order_request_market_order():
"""Test OrderRequest for market order"""
order = OrderRequest(
symbol="AITBC/BTC",
side="sell",
type="market",
quantity=50.0
)
assert order.symbol == "AITBC/BTC"
assert order.side == "sell"
assert order.type == "market"
assert order.quantity == 50.0
assert order.price is None

View File

@@ -0,0 +1,662 @@
"""
Global AI Agent Communication Service for AITBC
Handles cross-chain and cross-region AI agent communication with global optimization
"""
import os
import asyncio
import json
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Dict, Any, List, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from aitbc import get_logger
logger = get_logger(__name__)
app = FastAPI(
title="AITBC Global AI Agent Communication Service",
description="Global AI agent communication and collaboration platform",
version="1.0.0"
)
# Data models
class Agent(BaseModel):
agent_id: str
name: str
type: str # ai, blockchain, oracle, market_maker, etc.
region: str
capabilities: List[str]
status: str # active, inactive, busy
languages: List[str] # Languages the agent can communicate in
specialization: str
performance_score: float
class AgentMessage(BaseModel):
message_id: str
sender_id: str
recipient_id: Optional[str] # None for broadcast
message_type: str # request, response, collaboration, data_share
content: Dict[str, Any]
priority: str # low, medium, high, critical
language: str
timestamp: datetime
encryption_key: Optional[str] = None
class CollaborationSession(BaseModel):
session_id: str
participants: List[str]
session_type: str # task_force, research, trading, governance
objective: str
created_at: datetime
expires_at: datetime
status: str # active, completed, expired
class AgentPerformance(BaseModel):
agent_id: str
timestamp: datetime
tasks_completed: int
response_time_ms: float
accuracy_score: float
collaboration_score: float
resource_usage: Dict[str, float]
# In-memory storage (in production, use database)
global_agents: Dict[str, Dict] = {}
agent_messages: Dict[str, List[Dict]] = {}
collaboration_sessions: Dict[str, Dict] = {}
agent_performance: Dict[str, List[Dict]] = {}
global_network_stats: Dict[str, Any] = {}
@app.get("/")
async def root():
return {
"service": "AITBC Global AI Agent Communication Service",
"status": "running",
"timestamp": datetime.now(timezone.utc).isoformat(),
"version": "1.0.0"
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"total_agents": len(global_agents),
"active_agents": len([a for a in global_agents.values() if a["status"] == "active"]),
"active_sessions": len([s for s in collaboration_sessions.values() if s["status"] == "active"]),
"total_messages": sum(len(messages) for messages in agent_messages.values())
}
@app.post("/api/v1/agents/register")
async def register_agent(agent: Agent):
"""Register a new AI agent in the global network"""
if agent.agent_id in global_agents:
raise HTTPException(status_code=400, detail="Agent already registered")
# Create agent record
agent_record = {
"agent_id": agent.agent_id,
"name": agent.name,
"type": agent.type,
"region": agent.region,
"capabilities": agent.capabilities,
"status": agent.status,
"languages": agent.languages,
"specialization": agent.specialization,
"performance_score": agent.performance_score,
"created_at": datetime.now(timezone.utc).isoformat(),
"last_active": datetime.now(timezone.utc).isoformat(),
"total_messages_sent": 0,
"total_messages_received": 0,
"collaborations_participated": 0,
"tasks_completed": 0,
"reputation_score": 5.0,
"network_connections": []
}
global_agents[agent.agent_id] = agent_record
agent_messages[agent.agent_id] = []
logger.info(f"Agent registered: {agent.name} ({agent.agent_id}) in {agent.region}")
return {
"agent_id": agent.agent_id,
"status": "registered",
"name": agent.name,
"region": agent.region,
"created_at": agent_record["created_at"]
}
@app.get("/api/v1/agents")
async def list_agents(region: Optional[str] = None,
agent_type: Optional[str] = None,
status: Optional[str] = None):
"""List all agents with filtering"""
agents = list(global_agents.values())
# Apply filters
if region:
agents = [a for a in agents if a["region"] == region]
if agent_type:
agents = [a for a in agents if a["type"] == agent_type]
if status:
agents = [a for a in agents if a["status"] == status]
return {
"agents": agents,
"total_agents": len(agents),
"filters": {
"region": region,
"agent_type": agent_type,
"status": status
}
}
@app.get("/api/v1/agents/{agent_id}")
async def get_agent(agent_id: str):
"""Get detailed agent information"""
if agent_id not in global_agents:
raise HTTPException(status_code=404, detail="Agent not found")
agent = global_agents[agent_id].copy()
# Add recent messages
agent["recent_messages"] = agent_messages.get(agent_id, [])[-10:]
# Add performance metrics
agent["performance_metrics"] = agent_performance.get(agent_id, [])
return agent
@app.post("/api/v1/messages/send")
async def send_message(message: AgentMessage):
"""Send a message from one agent to another or broadcast"""
# Validate sender
if message.sender_id not in global_agents:
raise HTTPException(status_code=400, detail="Sender agent not found")
# Create message record
message_record = {
"message_id": message.message_id,
"sender_id": message.sender_id,
"recipient_id": message.recipient_id,
"message_type": message.message_type,
"content": message.content,
"priority": message.priority,
"language": message.language,
"timestamp": message.timestamp.isoformat(),
"encryption_key": message.encryption_key,
"status": "delivered",
"delivered_at": datetime.now(timezone.utc).isoformat(),
"read_at": None
}
# Handle broadcast
if message.recipient_id is None:
# Broadcast to all active agents
for agent_id in global_agents:
if agent_id != message.sender_id and global_agents[agent_id]["status"] == "active":
if agent_id not in agent_messages:
agent_messages[agent_id] = []
agent_messages[agent_id].append(message_record.copy())
# Update sender stats
global_agents[message.sender_id]["total_messages_sent"] += len(global_agents) - 1
logger.info(f"Broadcast message sent from {message.sender_id} to all agents")
else:
# Direct message
if message.recipient_id not in global_agents:
raise HTTPException(status_code=400, detail="Recipient agent not found")
if message.recipient_id not in agent_messages:
agent_messages[message.recipient_id] = []
agent_messages[message.recipient_id].append(message_record)
# Update stats
global_agents[message.sender_id]["total_messages_sent"] += 1
global_agents[message.recipient_id]["total_messages_received"] += 1
logger.info(f"Message sent from {message.sender_id} to {message.recipient_id}")
return {
"message_id": message.message_id,
"status": "delivered",
"delivered_at": message_record["delivered_at"]
}
@app.get("/api/v1/messages/{agent_id}")
async def get_agent_messages(agent_id: str, limit: int = 50):
"""Get messages for an agent"""
if agent_id not in global_agents:
raise HTTPException(status_code=404, detail="Agent not found")
messages = agent_messages.get(agent_id, [])
# Sort by timestamp (most recent first)
messages.sort(key=lambda x: x["timestamp"], reverse=True)
return {
"agent_id": agent_id,
"messages": messages[:limit],
"total_messages": len(messages),
"unread_count": len([m for m in messages if m.get("read_at") is None])
}
@app.post("/api/v1/collaborations/create")
async def create_collaboration(session: CollaborationSession):
"""Create a new collaboration session"""
# Validate participants
for participant_id in session.participants:
if participant_id not in global_agents:
raise HTTPException(status_code=400, detail=f"Participant {participant_id} not found")
# Create collaboration session
session_record = {
"session_id": session.session_id,
"participants": session.participants,
"session_type": session.session_type,
"objective": session.objective,
"created_at": session.created_at.isoformat(),
"expires_at": session.expires_at.isoformat(),
"status": session.status,
"messages": [],
"shared_resources": {},
"task_progress": {},
"outcome": None
}
collaboration_sessions[session.session_id] = session_record
# Update participant stats
for participant_id in session.participants:
global_agents[participant_id]["collaborations_participated"] += 1
# Notify participants
notification = {
"type": "collaboration_invite",
"session_id": session.session_id,
"objective": session.objective,
"participants": session.participants
}
for participant_id in session.participants:
message_record = {
"message_id": f"collab_{int(datetime.now(timezone.utc).timestamp())}",
"sender_id": "system",
"recipient_id": participant_id,
"message_type": "notification",
"content": notification,
"priority": "medium",
"language": "english",
"timestamp": datetime.now(timezone.utc).isoformat(),
"status": "delivered",
"delivered_at": datetime.now(timezone.utc).isoformat()
}
if participant_id not in agent_messages:
agent_messages[participant_id] = []
agent_messages[participant_id].append(message_record)
logger.info(f"Collaboration session created: {session.session_id} with {len(session.participants)} participants")
return {
"session_id": session.session_id,
"status": "created",
"participants": session.participants,
"objective": session.objective,
"created_at": session_record["created_at"]
}
@app.get("/api/v1/collaborations/{session_id}")
async def get_collaboration(session_id: str):
"""Get collaboration session details"""
if session_id not in collaboration_sessions:
raise HTTPException(status_code=404, detail="Collaboration session not found")
return collaboration_sessions[session_id]
@app.post("/api/v1/collaborations/{session_id}/message")
async def send_collaboration_message(session_id: str, sender_id: str, content: Dict[str, Any]):
"""Send a message within a collaboration session"""
if session_id not in collaboration_sessions:
raise HTTPException(status_code=404, detail="Collaboration session not found")
if sender_id not in collaboration_sessions[session_id]["participants"]:
raise HTTPException(status_code=400, detail="Sender not a participant in this session")
# Create collaboration message
message_record = {
"message_id": f"collab_msg_{int(datetime.now(timezone.utc).timestamp())}",
"sender_id": sender_id,
"session_id": session_id,
"content": content,
"timestamp": datetime.now(timezone.utc).isoformat(),
"type": "collaboration_message"
}
collaboration_sessions[session_id]["messages"].append(message_record)
# Notify all participants
for participant_id in collaboration_sessions[session_id]["participants"]:
if participant_id != sender_id:
notification = {
"type": "collaboration_message",
"session_id": session_id,
"sender_id": sender_id,
"content": content
}
msg_record = {
"message_id": f"notif_{int(datetime.now(timezone.utc).timestamp())}",
"sender_id": "system",
"recipient_id": participant_id,
"message_type": "notification",
"content": notification,
"priority": "medium",
"language": "english",
"timestamp": datetime.now(timezone.utc).isoformat(),
"status": "delivered",
"delivered_at": datetime.now(timezone.utc).isoformat()
}
if participant_id not in agent_messages:
agent_messages[participant_id] = []
agent_messages[participant_id].append(msg_record)
return {
"message_id": message_record["message_id"],
"status": "delivered",
"timestamp": message_record["timestamp"]
}
@app.post("/api/v1/performance/record")
async def record_agent_performance(performance: AgentPerformance):
"""Record performance metrics for an agent"""
if performance.agent_id not in global_agents:
raise HTTPException(status_code=404, detail="Agent not found")
# Create performance record
performance_record = {
"performance_id": f"perf_{int(datetime.now(timezone.utc).timestamp())}",
"agent_id": performance.agent_id,
"timestamp": performance.timestamp.isoformat(),
"tasks_completed": performance.tasks_completed,
"response_time_ms": performance.response_time_ms,
"accuracy_score": performance.accuracy_score,
"collaboration_score": performance.collaboration_score,
"resource_usage": performance.resource_usage
}
if performance.agent_id not in agent_performance:
agent_performance[performance.agent_id] = []
agent_performance[performance.agent_id].append(performance_record)
# Update agent's performance score
recent_performances = agent_performance[performance.agent_id][-10:] # Last 10 records
if recent_performances:
avg_accuracy = sum(p["accuracy_score"] for p in recent_performances) / len(recent_performances)
avg_collaboration = sum(p["collaboration_score"] for p in recent_performances) / len(recent_performances)
# Update overall performance score
new_score = (avg_accuracy * 0.6 + avg_collaboration * 0.4)
global_agents[performance.agent_id]["performance_score"] = round(new_score, 2)
# Update tasks completed
global_agents[performance.agent_id]["tasks_completed"] += performance.tasks_completed
return {
"performance_id": performance_record["performance_id"],
"status": "recorded",
"updated_performance_score": global_agents[performance.agent_id]["performance_score"]
}
@app.get("/api/v1/performance/{agent_id}")
async def get_agent_performance(agent_id: str, hours: int = 24):
"""Get performance metrics for an agent"""
if agent_id not in global_agents:
raise HTTPException(status_code=404, detail="Agent not found")
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
performance_records = agent_performance.get(agent_id, [])
recent_performance = [
p for p in performance_records
if datetime.fromisoformat(p["timestamp"]) > cutoff_time
]
# Calculate statistics
if recent_performance:
avg_response_time = sum(p["response_time_ms"] for p in recent_performance) / len(recent_performance)
avg_accuracy = sum(p["accuracy_score"] for p in recent_performance) / len(recent_performance)
avg_collaboration = sum(p["collaboration_score"] for p in recent_performance) / len(recent_performance)
total_tasks = sum(p["tasks_completed"] for p in recent_performance)
else:
avg_response_time = avg_accuracy = avg_collaboration = total_tasks = 0.0
return {
"agent_id": agent_id,
"period_hours": hours,
"performance_records": recent_performance,
"statistics": {
"average_response_time_ms": round(avg_response_time, 2),
"average_accuracy_score": round(avg_accuracy, 3),
"average_collaboration_score": round(avg_collaboration, 3),
"total_tasks_completed": int(total_tasks),
"total_records": len(recent_performance)
},
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/network/dashboard")
async def get_network_dashboard():
"""Get global AI agent network dashboard"""
# Calculate network statistics
total_agents = len(global_agents)
active_agents = len([a for a in global_agents.values() if a["status"] == "active"])
# Agent type distribution
type_distribution = {}
for agent in global_agents.values():
agent_type = agent["type"]
type_distribution[agent_type] = type_distribution.get(agent_type, 0) + 1
# Regional distribution
region_distribution = {}
for agent in global_agents.values():
region = agent["region"]
region_distribution[region] = region_distribution.get(region, 0) + 1
# Performance summary
performance_scores = [a["performance_score"] for a in global_agents.values()]
avg_performance = sum(performance_scores) / len(performance_scores) if performance_scores else 0.0
# Recent activity
recent_messages = 0
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=1)
for messages in agent_messages.values():
recent_messages += len([m for m in messages if datetime.fromisoformat(m["timestamp"]) > cutoff_time])
return {
"dashboard": {
"network_overview": {
"total_agents": total_agents,
"active_agents": active_agents,
"agent_utilization": round((active_agents / total_agents * 100) if total_agents > 0 else 0, 2),
"average_performance_score": round(avg_performance, 3)
},
"agent_distribution": {
"by_type": type_distribution,
"by_region": region_distribution
},
"collaborations": {
"total_sessions": len(collaboration_sessions),
"active_sessions": len([s for s in collaboration_sessions.values() if s["status"] == "active"]),
"total_participants": sum(len(s["participants"]) for s in collaboration_sessions.values())
},
"activity": {
"recent_messages_hour": recent_messages,
"total_messages_sent": sum(a["total_messages_sent"] for a in global_agents.values()),
"total_tasks_completed": sum(a["tasks_completed"] for a in global_agents.values())
}
},
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/network/optimize")
async def optimize_network():
"""Optimize global agent network performance"""
optimization_results = {
"recommendations": [],
"actions_taken": [],
"performance_improvements": {}
}
# Find underperforming agents
for agent_id, agent in global_agents.items():
if agent["performance_score"] < 3.0 and agent["status"] == "active":
optimization_results["recommendations"].append({
"type": "agent_performance",
"agent_id": agent_id,
"issue": "Low performance score",
"recommendation": "Consider agent retraining or resource allocation"
})
# Find overloaded regions
region_load = {}
for agent in global_agents.values():
if agent["status"] == "active":
region = agent["region"]
region_load[region] = region_load.get(region, 0) + 1
total_capacity = len(global_agents)
for region, load in region_load.items():
if load > total_capacity * 0.4: # More than 40% of agents in one region
optimization_results["recommendations"].append({
"type": "regional_balance",
"region": region,
"issue": "Agent concentration imbalance",
"recommendation": "Redistribute agents to other regions"
})
# Find inactive agents with good performance
for agent_id, agent in global_agents.items():
if agent["status"] == "inactive" and agent["performance_score"] > 4.0:
optimization_results["actions_taken"].append({
"type": "agent_activation",
"agent_id": agent_id,
"action": "Activated high-performing inactive agent"
})
agent["status"] = "active"
return {
"optimization_results": optimization_results,
"generated_at": datetime.now(timezone.utc).isoformat()
}
# Background task for network monitoring
async def network_monitoring_task():
"""Background task for global network monitoring"""
while True:
await asyncio.sleep(300) # Monitor every 5 minutes
# Update network statistics
global_network_stats["last_update"] = datetime.now(timezone.utc).isoformat()
global_network_stats["total_agents"] = len(global_agents)
global_network_stats["active_agents"] = len([a for a in global_agents.values() if a["status"] == "active"])
# Check for expired collaboration sessions
current_time = datetime.now(timezone.utc)
for session_id, session in collaboration_sessions.items():
if datetime.fromisoformat(session["expires_at"]) < current_time and session["status"] == "active":
session["status"] = "expired"
logger.info(f"Collaboration session expired: {session_id}")
# Clean up old messages (older than 7 days)
cutoff_time = current_time - timedelta(days=7)
for agent_id in agent_messages:
agent_messages[agent_id] = [
m for m in agent_messages[agent_id]
if datetime.fromisoformat(m["timestamp"]) > cutoff_time
]
# Initialize with some default AI agents
@app.on_event("startup")
async def startup_event():
logger.info("Starting AITBC Global AI Agent Communication Service")
# Initialize default AI agents
default_agents = [
{
"agent_id": "ai-trader-001",
"name": "AlphaTrader",
"type": "trading",
"region": "us-east-1",
"capabilities": ["market_analysis", "trading", "risk_management"],
"status": "active",
"languages": ["english", "chinese", "japanese", "spanish"],
"specialization": "cryptocurrency_trading",
"performance_score": 4.7
},
{
"agent_id": "ai-oracle-001",
"name": "OraclePro",
"type": "oracle",
"region": "eu-west-1",
"capabilities": ["price_feeds", "data_analysis", "prediction"],
"status": "active",
"languages": ["english", "german", "french"],
"specialization": "price_discovery",
"performance_score": 4.9
},
{
"agent_id": "ai-research-001",
"name": "ResearchNova",
"type": "research",
"region": "ap-southeast-1",
"capabilities": ["data_analysis", "pattern_recognition", "reporting"],
"status": "active",
"languages": ["english", "chinese", "korean"],
"specialization": "blockchain_research",
"performance_score": 4.5
}
]
for agent_data in default_agents:
agent = Agent(**agent_data)
agent_record = {
"agent_id": agent.agent_id,
"name": agent.name,
"type": agent.type,
"region": agent.region,
"capabilities": agent.capabilities,
"status": agent.status,
"languages": agent.languages,
"specialization": agent.specialization,
"performance_score": agent.performance_score,
"created_at": datetime.now(timezone.utc).isoformat(),
"last_active": datetime.now(timezone.utc).isoformat(),
"total_messages_sent": 0,
"total_messages_received": 0,
"collaborations_participated": 0,
"tasks_completed": 0,
"reputation_score": 5.0,
"network_connections": []
}
global_agents[agent.agent_id] = agent_record
agent_messages[agent.agent_id] = []
# Start network monitoring
asyncio.create_task(network_monitoring_task())
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Shutting down AITBC Global AI Agent Communication Service")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8018, log_level="info")

View File

@@ -0,0 +1 @@
"""Global AI agents service tests"""

View File

@@ -0,0 +1,186 @@
"""Edge case and error handling tests for global AI agents service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime, timezone, timedelta
from main import app, Agent, AgentMessage, CollaborationSession, AgentPerformance, global_agents, agent_messages, collaboration_sessions, agent_performance
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
global_agents.clear()
agent_messages.clear()
collaboration_sessions.clear()
agent_performance.clear()
yield
global_agents.clear()
agent_messages.clear()
collaboration_sessions.clear()
agent_performance.clear()
@pytest.mark.unit
def test_agent_empty_name():
"""Test Agent with empty name"""
agent = Agent(
agent_id="agent_123",
name="",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
assert agent.name == ""
@pytest.mark.unit
def test_agent_negative_performance_score():
"""Test Agent with negative performance score"""
agent = Agent(
agent_id="agent_123",
name="Test Agent",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=-4.5
)
assert agent.performance_score == -4.5
@pytest.mark.unit
def test_agent_performance_out_of_range_score():
"""Test AgentPerformance with out of range scores"""
performance = AgentPerformance(
agent_id="agent_123",
timestamp=datetime.now(timezone.utc),
tasks_completed=10,
response_time_ms=50.5,
accuracy_score=2.0,
collaboration_score=2.0,
resource_usage={}
)
assert performance.accuracy_score == 2.0
assert performance.collaboration_score == 2.0
@pytest.mark.unit
def test_agent_message_empty_content():
"""Test AgentMessage with empty content"""
message = AgentMessage(
message_id="msg_123",
sender_id="agent_123",
recipient_id="agent_456",
message_type="request",
content={},
priority="high",
language="english",
timestamp=datetime.now(timezone.utc)
)
assert message.content == {}
@pytest.mark.integration
def test_list_agents_with_no_agents():
"""Test listing agents when no agents exist"""
client = TestClient(app)
response = client.get("/api/v1/agents")
assert response.status_code == 200
data = response.json()
assert data["total_agents"] == 0
@pytest.mark.integration
def test_get_agent_messages_agent_not_found():
"""Test getting messages for nonexistent agent"""
client = TestClient(app)
response = client.get("/api/v1/messages/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_get_collaboration_not_found():
"""Test getting nonexistent collaboration session"""
client = TestClient(app)
response = client.get("/api/v1/collaborations/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_send_collaboration_message_session_not_found():
"""Test sending message to nonexistent collaboration session"""
client = TestClient(app)
response = client.post("/api/v1/collaborations/nonexistent/message", params={"sender_id": "agent_123"}, json={"content": "test"})
assert response.status_code == 404
@pytest.mark.integration
def test_send_collaboration_message_sender_not_participant():
"""Test sending message from non-participant"""
client = TestClient(app)
# Register agent and create collaboration
agent = Agent(
agent_id="agent_123",
name="Agent 1",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent.model_dump())
session = CollaborationSession(
session_id="session_123",
participants=["agent_123"],
session_type="research",
objective="Research task",
created_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
status="active"
)
client.post("/api/v1/collaborations/create", json=session.model_dump(mode='json'))
response = client.post("/api/v1/collaborations/session_123/message", params={"sender_id": "nonexistent"}, json={"content": "test"})
assert response.status_code == 400
@pytest.mark.integration
def test_get_agent_performance_agent_not_found():
"""Test getting performance for nonexistent agent"""
client = TestClient(app)
response = client.get("/api/v1/performance/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_dashboard_with_no_data():
"""Test dashboard with no data"""
client = TestClient(app)
response = client.get("/api/v1/network/dashboard")
assert response.status_code == 200
data = response.json()
assert data["dashboard"]["network_overview"]["total_agents"] == 0
@pytest.mark.integration
def test_optimize_network_with_no_agents():
"""Test network optimization with no agents"""
client = TestClient(app)
response = client.get("/api/v1/network/optimize")
assert response.status_code == 200
data = response.json()
assert "optimization_results" in data

View File

@@ -0,0 +1,590 @@
"""Integration tests for global AI agents service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime, timezone, timedelta
from main import app, Agent, AgentMessage, CollaborationSession, AgentPerformance, global_agents, agent_messages, collaboration_sessions, agent_performance
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
global_agents.clear()
agent_messages.clear()
collaboration_sessions.clear()
agent_performance.clear()
yield
global_agents.clear()
agent_messages.clear()
collaboration_sessions.clear()
agent_performance.clear()
@pytest.mark.integration
def test_root_endpoint():
"""Test root endpoint"""
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["service"] == "AITBC Global AI Agent Communication Service"
assert data["status"] == "running"
@pytest.mark.integration
def test_health_check_endpoint():
"""Test health check endpoint"""
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "total_agents" in data
@pytest.mark.integration
def test_register_agent():
"""Test registering a new agent"""
client = TestClient(app)
agent = Agent(
agent_id="agent_123",
name="Test Agent",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
response = client.post("/api/v1/agents/register", json=agent.model_dump())
assert response.status_code == 200
data = response.json()
assert data["agent_id"] == "agent_123"
assert data["status"] == "registered"
@pytest.mark.integration
def test_register_duplicate_agent():
"""Test registering duplicate agent"""
client = TestClient(app)
agent = Agent(
agent_id="agent_123",
name="Test Agent",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent.model_dump())
response = client.post("/api/v1/agents/register", json=agent.model_dump())
assert response.status_code == 400
@pytest.mark.integration
def test_list_agents():
"""Test listing all agents"""
client = TestClient(app)
response = client.get("/api/v1/agents")
assert response.status_code == 200
data = response.json()
assert "agents" in data
assert "total_agents" in data
@pytest.mark.integration
def test_list_agents_with_filters():
"""Test listing agents with filters"""
client = TestClient(app)
# Register an agent first
agent = Agent(
agent_id="agent_123",
name="Test Agent",
type="trading",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent.model_dump())
response = client.get("/api/v1/agents?region=us-east-1&type=trading&status=active")
assert response.status_code == 200
data = response.json()
assert "filters" in data
@pytest.mark.integration
def test_get_agent():
"""Test getting specific agent"""
client = TestClient(app)
agent = Agent(
agent_id="agent_123",
name="Test Agent",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent.model_dump())
response = client.get("/api/v1/agents/agent_123")
assert response.status_code == 200
data = response.json()
assert data["agent_id"] == "agent_123"
@pytest.mark.integration
def test_get_agent_not_found():
"""Test getting nonexistent agent"""
client = TestClient(app)
response = client.get("/api/v1/agents/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_send_direct_message():
"""Test sending direct message"""
client = TestClient(app)
# Register two agents
agent1 = Agent(
agent_id="agent_123",
name="Agent 1",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
agent2 = Agent(
agent_id="agent_456",
name="Agent 2",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent1.model_dump())
client.post("/api/v1/agents/register", json=agent2.model_dump())
message = AgentMessage(
message_id="msg_123",
sender_id="agent_123",
recipient_id="agent_456",
message_type="request",
content={"data": "test"},
priority="high",
language="english",
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/messages/send", json=message.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["message_id"] == "msg_123"
assert data["status"] == "delivered"
@pytest.mark.integration
def test_send_broadcast_message():
"""Test sending broadcast message"""
client = TestClient(app)
# Register two agents
agent1 = Agent(
agent_id="agent_123",
name="Agent 1",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
agent2 = Agent(
agent_id="agent_456",
name="Agent 2",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent1.model_dump())
client.post("/api/v1/agents/register", json=agent2.model_dump())
message = AgentMessage(
message_id="msg_123",
sender_id="agent_123",
recipient_id=None,
message_type="broadcast",
content={"data": "test"},
priority="medium",
language="english",
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/messages/send", json=message.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["message_id"] == "msg_123"
@pytest.mark.integration
def test_send_message_sender_not_found():
"""Test sending message with nonexistent sender"""
client = TestClient(app)
message = AgentMessage(
message_id="msg_123",
sender_id="nonexistent",
recipient_id="agent_456",
message_type="request",
content={"data": "test"},
priority="high",
language="english",
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/messages/send", json=message.model_dump(mode='json'))
assert response.status_code == 400
@pytest.mark.integration
def test_send_message_recipient_not_found():
"""Test sending message with nonexistent recipient"""
client = TestClient(app)
agent = Agent(
agent_id="agent_123",
name="Agent 1",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent.model_dump())
message = AgentMessage(
message_id="msg_123",
sender_id="agent_123",
recipient_id="nonexistent",
message_type="request",
content={"data": "test"},
priority="high",
language="english",
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/messages/send", json=message.model_dump(mode='json'))
assert response.status_code == 400
@pytest.mark.integration
def test_get_agent_messages():
"""Test getting agent messages"""
client = TestClient(app)
agent = Agent(
agent_id="agent_123",
name="Agent 1",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent.model_dump())
response = client.get("/api/v1/messages/agent_123")
assert response.status_code == 200
data = response.json()
assert data["agent_id"] == "agent_123"
@pytest.mark.integration
def test_get_agent_messages_with_limit():
"""Test getting agent messages with limit parameter"""
client = TestClient(app)
agent = Agent(
agent_id="agent_123",
name="Agent 1",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent.model_dump())
response = client.get("/api/v1/messages/agent_123?limit=10")
assert response.status_code == 200
@pytest.mark.integration
def test_create_collaboration():
"""Test creating collaboration session"""
client = TestClient(app)
# Register two agents
agent1 = Agent(
agent_id="agent_123",
name="Agent 1",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
agent2 = Agent(
agent_id="agent_456",
name="Agent 2",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent1.model_dump())
client.post("/api/v1/agents/register", json=agent2.model_dump())
session = CollaborationSession(
session_id="session_123",
participants=["agent_123", "agent_456"],
session_type="task_force",
objective="Complete task",
created_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
status="active"
)
response = client.post("/api/v1/collaborations/create", json=session.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["session_id"] == "session_123"
@pytest.mark.integration
def test_create_collaboration_participant_not_found():
"""Test creating collaboration with nonexistent participant"""
client = TestClient(app)
session = CollaborationSession(
session_id="session_123",
participants=["nonexistent"],
session_type="task_force",
objective="Complete task",
created_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
status="active"
)
response = client.post("/api/v1/collaborations/create", json=session.model_dump(mode='json'))
assert response.status_code == 400
@pytest.mark.integration
def test_get_collaboration():
"""Test getting collaboration session"""
client = TestClient(app)
# Register agents and create collaboration
agent = Agent(
agent_id="agent_123",
name="Agent 1",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent.model_dump())
session = CollaborationSession(
session_id="session_123",
participants=["agent_123"],
session_type="research",
objective="Research task",
created_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
status="active"
)
client.post("/api/v1/collaborations/create", json=session.model_dump(mode='json'))
response = client.get("/api/v1/collaborations/session_123")
assert response.status_code == 200
data = response.json()
assert data["session_id"] == "session_123"
@pytest.mark.integration
def test_send_collaboration_message():
"""Test sending message within collaboration session"""
client = TestClient(app)
# Register agent and create collaboration
agent = Agent(
agent_id="agent_123",
name="Agent 1",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent.model_dump())
session = CollaborationSession(
session_id="session_123",
participants=["agent_123"],
session_type="research",
objective="Research task",
created_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
status="active"
)
client.post("/api/v1/collaborations/create", json=session.model_dump(mode='json'))
response = client.post("/api/v1/collaborations/session_123/message", params={"sender_id": "agent_123"}, json={"content": "test message"})
assert response.status_code == 200
data = response.json()
assert data["status"] == "delivered"
@pytest.mark.integration
def test_record_agent_performance():
"""Test recording agent performance"""
client = TestClient(app)
agent = Agent(
agent_id="agent_123",
name="Agent 1",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent.model_dump())
performance = AgentPerformance(
agent_id="agent_123",
timestamp=datetime.now(timezone.utc),
tasks_completed=10,
response_time_ms=50.5,
accuracy_score=0.95,
collaboration_score=0.9,
resource_usage={"cpu": 50.0}
)
response = client.post("/api/v1/performance/record", json=performance.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["performance_id"]
assert data["status"] == "recorded"
@pytest.mark.integration
def test_record_performance_agent_not_found():
"""Test recording performance for nonexistent agent"""
client = TestClient(app)
performance = AgentPerformance(
agent_id="nonexistent",
timestamp=datetime.now(timezone.utc),
tasks_completed=10,
response_time_ms=50.5,
accuracy_score=0.95,
collaboration_score=0.9,
resource_usage={}
)
response = client.post("/api/v1/performance/record", json=performance.model_dump(mode='json'))
assert response.status_code == 404
@pytest.mark.integration
def test_get_agent_performance():
"""Test getting agent performance"""
client = TestClient(app)
agent = Agent(
agent_id="agent_123",
name="Agent 1",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent.model_dump())
response = client.get("/api/v1/performance/agent_123")
assert response.status_code == 200
data = response.json()
assert data["agent_id"] == "agent_123"
@pytest.mark.integration
def test_get_agent_performance_hours_parameter():
"""Test getting agent performance with custom hours parameter"""
client = TestClient(app)
agent = Agent(
agent_id="agent_123",
name="Agent 1",
type="ai",
region="us-east-1",
capabilities=["trading"],
status="active",
languages=["english"],
specialization="trading",
performance_score=4.5
)
client.post("/api/v1/agents/register", json=agent.model_dump())
response = client.get("/api/v1/performance/agent_123?hours=12")
assert response.status_code == 200
data = response.json()
assert data["period_hours"] == 12
@pytest.mark.integration
def test_get_network_dashboard():
"""Test getting network dashboard"""
client = TestClient(app)
response = client.get("/api/v1/network/dashboard")
assert response.status_code == 200
data = response.json()
assert "dashboard" in data
@pytest.mark.integration
def test_optimize_network():
"""Test network optimization"""
client = TestClient(app)
response = client.get("/api/v1/network/optimize")
assert response.status_code == 200
data = response.json()
assert "optimization_results" in data

View File

@@ -0,0 +1,158 @@
"""Unit tests for global AI agents service"""
import pytest
import sys
import sys
from pathlib import Path
from datetime import datetime, timezone
from main import app, Agent, AgentMessage, CollaborationSession, AgentPerformance
@pytest.mark.unit
def test_app_initialization():
"""Test that the FastAPI app initializes correctly"""
assert app is not None
assert app.title == "AITBC Global AI Agent Communication Service"
assert app.version == "1.0.0"
@pytest.mark.unit
def test_agent_model():
"""Test Agent model"""
agent = Agent(
agent_id="agent_123",
name="Test Agent",
type="ai",
region="us-east-1",
capabilities=["trading", "analysis"],
status="active",
languages=["english", "chinese"],
specialization="trading",
performance_score=4.5
)
assert agent.agent_id == "agent_123"
assert agent.name == "Test Agent"
assert agent.type == "ai"
assert agent.status == "active"
assert agent.performance_score == 4.5
@pytest.mark.unit
def test_agent_empty_capabilities():
"""Test Agent with empty capabilities"""
agent = Agent(
agent_id="agent_123",
name="Test Agent",
type="ai",
region="us-east-1",
capabilities=[],
status="active",
languages=["english"],
specialization="general",
performance_score=4.5
)
assert agent.capabilities == []
@pytest.mark.unit
def test_agent_message_model():
"""Test AgentMessage model"""
message = AgentMessage(
message_id="msg_123",
sender_id="agent_123",
recipient_id="agent_456",
message_type="request",
content={"data": "test"},
priority="high",
language="english",
timestamp=datetime.now(timezone.utc)
)
assert message.message_id == "msg_123"
assert message.sender_id == "agent_123"
assert message.recipient_id == "agent_456"
assert message.message_type == "request"
assert message.priority == "high"
@pytest.mark.unit
def test_agent_message_broadcast():
"""Test AgentMessage with None recipient (broadcast)"""
message = AgentMessage(
message_id="msg_123",
sender_id="agent_123",
recipient_id=None,
message_type="broadcast",
content={"data": "test"},
priority="medium",
language="english",
timestamp=datetime.now(timezone.utc)
)
assert message.recipient_id is None
@pytest.mark.unit
def test_collaboration_session_model():
"""Test CollaborationSession model"""
session = CollaborationSession(
session_id="session_123",
participants=["agent_123", "agent_456"],
session_type="task_force",
objective="Complete trading task",
created_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc),
status="active"
)
assert session.session_id == "session_123"
assert session.participants == ["agent_123", "agent_456"]
assert session.session_type == "task_force"
@pytest.mark.unit
def test_collaboration_session_empty_participants():
"""Test CollaborationSession with empty participants"""
session = CollaborationSession(
session_id="session_123",
participants=[],
session_type="research",
objective="Research task",
created_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc),
status="active"
)
assert session.participants == []
@pytest.mark.unit
def test_agent_performance_model():
"""Test AgentPerformance model"""
performance = AgentPerformance(
agent_id="agent_123",
timestamp=datetime.now(timezone.utc),
tasks_completed=10,
response_time_ms=50.5,
accuracy_score=0.95,
collaboration_score=0.9,
resource_usage={"cpu": 50.0, "memory": 60.0}
)
assert performance.agent_id == "agent_123"
assert performance.tasks_completed == 10
assert performance.response_time_ms == 50.5
assert performance.accuracy_score == 0.95
@pytest.mark.unit
def test_agent_performance_negative_values():
"""Test AgentPerformance with negative values"""
performance = AgentPerformance(
agent_id="agent_123",
timestamp=datetime.now(timezone.utc),
tasks_completed=-10,
response_time_ms=-50.5,
accuracy_score=-0.95,
collaboration_score=-0.9,
resource_usage={}
)
assert performance.tasks_completed == -10
assert performance.response_time_ms == -50.5

View File

@@ -0,0 +1,602 @@
"""
Global Infrastructure Deployment Service for AITBC
Handles multi-region deployment, load balancing, and global optimization
"""
import os
import asyncio
import json
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Dict, Any, List, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from aitbc import get_logger
logger = get_logger(__name__)
app = FastAPI(
title="AITBC Global Infrastructure Service",
description="Global infrastructure deployment and multi-region optimization",
version="1.0.0"
)
# Data models
class Region(BaseModel):
region_id: str
name: str
location: str
endpoint: str
status: str # active, inactive, maintenance
capacity: int
current_load: float
latency_ms: float
compliance_level: str
class GlobalDeployment(BaseModel):
deployment_id: str
service_name: str
target_regions: List[str]
configuration: Dict[str, Any]
deployment_strategy: str # blue_green, canary, rolling
health_checks: List[str]
class LoadBalancer(BaseModel):
balancer_id: str
name: str
algorithm: str # round_robin, weighted, least_connections
target_regions: List[str]
health_check_interval: int
failover_threshold: int
class PerformanceMetrics(BaseModel):
region_id: str
timestamp: datetime
cpu_usage: float
memory_usage: float
network_io: float
disk_io: float
active_connections: int
response_time_ms: float
# In-memory storage (in production, use database)
global_regions: Dict[str, Dict] = {}
deployments: Dict[str, Dict] = {}
load_balancers: Dict[str, Dict] = {}
performance_metrics: Dict[str, List[Dict]] = {}
compliance_data: Dict[str, Dict] = {}
global_monitoring: Dict[str, Dict] = {}
@app.get("/")
async def root():
return {
"service": "AITBC Global Infrastructure Service",
"status": "running",
"timestamp": datetime.now(timezone.utc).isoformat(),
"version": "1.0.0"
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"total_regions": len(global_regions),
"active_regions": len([r for r in global_regions.values() if r["status"] == "active"]),
"total_deployments": len(deployments),
"active_load_balancers": len([lb for lb in load_balancers.values() if lb["status"] == "active"])
}
@app.post("/api/v1/regions/register")
async def register_region(region: Region):
"""Register a new global region"""
if region.region_id in global_regions:
raise HTTPException(status_code=400, detail="Region already registered")
# Create region record
region_record = {
"region_id": region.region_id,
"name": region.name,
"location": region.location,
"endpoint": region.endpoint,
"status": region.status,
"capacity": region.capacity,
"current_load": region.current_load,
"latency_ms": region.latency_ms,
"compliance_level": region.compliance_level,
"created_at": datetime.now(timezone.utc).isoformat(),
"last_health_check": None,
"services_deployed": [],
"performance_history": []
}
global_regions[region.region_id] = region_record
logger.info(f"Region registered: {region.name} ({region.region_id})")
return {
"region_id": region.region_id,
"status": "registered",
"name": region.name,
"created_at": region_record["created_at"]
}
@app.get("/api/v1/regions")
async def list_regions():
"""List all registered regions"""
return {
"regions": list(global_regions.values()),
"total_regions": len(global_regions),
"active_regions": len([r for r in global_regions.values() if r["status"] == "active"])
}
@app.get("/api/v1/regions/{region_id}")
async def get_region(region_id: str):
"""Get detailed region information"""
if region_id not in global_regions:
raise HTTPException(status_code=404, detail="Region not found")
region = global_regions[region_id].copy()
# Add performance metrics
region["performance_metrics"] = performance_metrics.get(region_id, [])
# Add compliance data
region["compliance_data"] = compliance_data.get(region_id, {})
return region
@app.post("/api/v1/deployments/create")
async def create_deployment(deployment: GlobalDeployment):
"""Create a new global deployment"""
deployment_id = f"deploy_{int(datetime.now(timezone.utc).timestamp())}"
# Validate target regions
for region_id in deployment.target_regions:
if region_id not in global_regions:
raise HTTPException(status_code=400, detail=f"Region {region_id} not found")
# Create deployment record
deployment_record = {
"deployment_id": deployment_id,
"service_name": deployment.service_name,
"target_regions": deployment.target_regions,
"configuration": deployment.configuration,
"deployment_strategy": deployment.deployment_strategy,
"health_checks": deployment.health_checks,
"status": "pending",
"created_at": datetime.now(timezone.utc).isoformat(),
"started_at": None,
"completed_at": None,
"deployment_progress": {},
"rollback_available": False
}
deployments[deployment_id] = deployment_record
# Start async deployment
asyncio.create_task(execute_deployment(deployment_id))
logger.info(f"Deployment created: {deployment_id} for {deployment.service_name}")
return {
"deployment_id": deployment_id,
"status": "pending",
"service_name": deployment.service_name,
"target_regions": deployment.target_regions,
"created_at": deployment_record["created_at"]
}
@app.get("/api/v1/deployments/{deployment_id}")
async def get_deployment(deployment_id: str):
"""Get deployment status and details"""
if deployment_id not in deployments:
raise HTTPException(status_code=404, detail="Deployment not found")
return deployments[deployment_id]
@app.get("/api/v1/deployments")
async def list_deployments(status: Optional[str] = None):
"""List all deployments"""
deployment_list = list(deployments.values())
if status:
deployment_list = [d for d in deployment_list if d["status"] == status]
# Sort by creation date (most recent first)
deployment_list.sort(key=lambda x: x["created_at"], reverse=True)
return {
"deployments": deployment_list,
"total_deployments": len(deployment_list),
"status_filter": status
}
@app.post("/api/v1/load-balancers/create")
async def create_load_balancer(balancer: LoadBalancer):
"""Create a new load balancer"""
balancer_id = f"lb_{int(datetime.now(timezone.utc).timestamp())}"
# Validate target regions
for region_id in balancer.target_regions:
if region_id not in global_regions:
raise HTTPException(status_code=400, detail=f"Region {region_id} not found")
# Create load balancer record
balancer_record = {
"balancer_id": balancer_id,
"name": balancer.name,
"algorithm": balancer.algorithm,
"target_regions": balancer.target_regions,
"health_check_interval": balancer.health_check_interval,
"failover_threshold": balancer.failover_threshold,
"status": "active",
"created_at": datetime.now(timezone.utc).isoformat(),
"current_weights": {region_id: 1.0 for region_id in balancer.target_regions},
"health_status": {region_id: "healthy" for region_id in balancer.target_regions},
"total_requests": 0,
"failed_requests": 0
}
load_balancers[balancer_id] = balancer_record
# Start health checking
asyncio.create_task(start_health_monitoring(balancer_id))
logger.info(f"Load balancer created: {balancer_id} - {balancer.name}")
return {
"balancer_id": balancer_id,
"status": "active",
"name": balancer.name,
"algorithm": balancer.algorithm,
"created_at": balancer_record["created_at"]
}
@app.get("/api/v1/load-balancers")
async def list_load_balancers():
"""List all load balancers"""
return {
"load_balancers": list(load_balancers.values()),
"total_balancers": len(load_balancers),
"active_balancers": len([lb for lb in load_balancers.values() if lb["status"] == "active"])
}
@app.post("/api/v1/performance/metrics")
async def record_performance_metrics(metrics: PerformanceMetrics):
"""Record performance metrics for a region"""
metrics_record = {
"metrics_id": f"metrics_{int(datetime.now(timezone.utc).timestamp())}",
"region_id": metrics.region_id,
"timestamp": metrics.timestamp.isoformat(),
"cpu_usage": metrics.cpu_usage,
"memory_usage": metrics.memory_usage,
"network_io": metrics.network_io,
"disk_io": metrics.disk_io,
"active_connections": metrics.active_connections,
"response_time_ms": metrics.response_time_ms
}
if metrics.region_id not in performance_metrics:
performance_metrics[metrics.region_id] = []
performance_metrics[metrics.region_id].append(metrics_record)
# Keep only last 1000 records per region
if len(performance_metrics[metrics.region_id]) > 1000:
performance_metrics[metrics.region_id] = performance_metrics[metrics.region_id][-1000:]
# Update region performance history
if metrics.region_id in global_regions:
global_regions[metrics.region_id]["performance_history"].append({
"timestamp": metrics.timestamp.isoformat(),
"cpu_usage": metrics.cpu_usage,
"memory_usage": metrics.memory_usage,
"response_time_ms": metrics.response_time_ms
})
# Keep only last 100 records
if len(global_regions[metrics.region_id]["performance_history"]) > 100:
global_regions[metrics.region_id]["performance_history"] = global_regions[metrics.region_id]["performance_history"][-100:]
return {
"metrics_id": metrics_record["metrics_id"],
"status": "recorded",
"timestamp": metrics_record["timestamp"]
}
@app.get("/api/v1/performance/{region_id}")
async def get_region_performance(region_id: str, hours: int = 24):
"""Get performance metrics for a region"""
if region_id not in performance_metrics:
raise HTTPException(status_code=404, detail="No performance data for region")
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
recent_metrics = [
m for m in performance_metrics[region_id]
if datetime.fromisoformat(m["timestamp"]) > cutoff_time
]
# Calculate statistics
if recent_metrics:
avg_cpu = sum(m["cpu_usage"] for m in recent_metrics) / len(recent_metrics)
avg_memory = sum(m["memory_usage"] for m in recent_metrics) / len(recent_metrics)
avg_response_time = sum(m["response_time_ms"] for m in recent_metrics) / len(recent_metrics)
else:
avg_cpu = avg_memory = avg_response_time = 0.0
return {
"region_id": region_id,
"period_hours": hours,
"metrics": recent_metrics,
"statistics": {
"average_cpu_usage": round(avg_cpu, 2),
"average_memory_usage": round(avg_memory, 2),
"average_response_time_ms": round(avg_response_time, 2),
"total_samples": len(recent_metrics)
},
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/compliance/{region_id}")
async def get_region_compliance(region_id: str):
"""Get compliance information for a region"""
if region_id not in global_regions:
raise HTTPException(status_code=404, detail="Region not found")
# Mock compliance data (in production, this would be from actual compliance systems)
compliance_info = {
"region_id": region_id,
"region_name": global_regions[region_id]["name"],
"compliance_level": global_regions[region_id]["compliance_level"],
"certifications": ["SOC2", "ISO27001", "GDPR"],
"data_residency": "compliant",
"last_audit": (datetime.now(timezone.utc) - timedelta(days=90)).isoformat(),
"next_audit": (datetime.now(timezone.utc) + timedelta(days=275)).isoformat(),
"regulations": ["GDPR", "CCPA", "PDPA"],
"data_protection": "end-to-end-encryption",
"access_controls": "role-based-access",
"audit_logging": "enabled"
}
return compliance_info
@app.get("/api/v1/global/dashboard")
async def get_global_dashboard():
"""Get global infrastructure dashboard"""
# Calculate global statistics
total_capacity = sum(r["capacity"] for r in global_regions.values())
total_load = sum(r["current_load"] for r in global_regions.values())
avg_latency = sum(r["latency_ms"] for r in global_regions.values()) / len(global_regions) if global_regions else 0
# Deployment statistics
deployment_stats = {
"total": len(deployments),
"pending": len([d for d in deployments.values() if d["status"] == "pending"]),
"in_progress": len([d for d in deployments.values() if d["status"] == "in_progress"]),
"completed": len([d for d in deployments.values() if d["status"] == "completed"]),
"failed": len([d for d in deployments.values() if d["status"] == "failed"])
}
# Performance summary
performance_summary = {}
for region_id, metrics_list in performance_metrics.items():
if metrics_list:
latest_metrics = metrics_list[-1]
performance_summary[region_id] = {
"cpu_usage": latest_metrics["cpu_usage"],
"memory_usage": latest_metrics["memory_usage"],
"response_time_ms": latest_metrics["response_time_ms"],
"active_connections": latest_metrics["active_connections"]
}
return {
"dashboard": {
"infrastructure": {
"total_regions": len(global_regions),
"active_regions": len([r for r in global_regions.values() if r["status"] == "active"]),
"total_capacity": total_capacity,
"current_load": total_load,
"utilization_percentage": round((total_load / total_capacity * 100) if total_capacity > 0 else 0, 2),
"average_latency_ms": round(avg_latency, 2)
},
"deployments": deployment_stats,
"load_balancers": {
"total": len(load_balancers),
"active": len([lb for lb in load_balancers.values() if lb["status"] == "active"])
},
"performance": performance_summary,
"compliance": {
"compliant_regions": len([r for r in global_regions.values() if r["compliance_level"] == "full"]),
"partial_compliance": len([r for r in global_regions.values() if r["compliance_level"] == "partial"])
}
},
"generated_at": datetime.now(timezone.utc).isoformat()
}
# Core deployment and load balancing functions
async def execute_deployment(deployment_id: str):
"""Execute a global deployment"""
deployment = deployments[deployment_id]
deployment["status"] = "in_progress"
deployment["started_at"] = datetime.now(timezone.utc).isoformat()
try:
for region_id in deployment["target_regions"]:
deployment["deployment_progress"][region_id] = {
"status": "deploying",
"started_at": datetime.now(timezone.utc).isoformat(),
"progress": 0
}
# Simulate deployment process
await simulate_deployment_step(region_id, deployment_id)
deployment["deployment_progress"][region_id].update({
"status": "completed",
"completed_at": datetime.now(timezone.utc).isoformat(),
"progress": 100
})
# Update region services
if region_id in global_regions:
if deployment["service_name"] not in global_regions[region_id]["services_deployed"]:
global_regions[region_id]["services_deployed"].append(deployment["service_name"])
deployment["status"] = "completed"
deployment["completed_at"] = datetime.now(timezone.utc).isoformat()
logger.info(f"Deployment completed: {deployment_id}")
except Exception as e:
deployment["status"] = "failed"
deployment["error"] = str(e)
logger.error(f"Deployment failed: {deployment_id} - {str(e)}")
async def simulate_deployment_step(region_id: str, deployment_id: str):
"""Simulate deployment step for demo"""
deployment = deployments[deployment_id]
# Simulate deployment progress
for progress in range(0, 101, 10):
if region_id in deployment["deployment_progress"]:
deployment["deployment_progress"][region_id]["progress"] = progress
await asyncio.sleep(0.1) # Simulate work
async def start_health_monitoring(balancer_id: str):
"""Start health monitoring for a load balancer"""
balancer = load_balancers[balancer_id]
while balancer["status"] == "active":
try:
# Check health of target regions
for region_id in balancer["target_regions"]:
if region_id in global_regions:
region = global_regions[region_id]
# Simulate health check (in production, this would be actual health checks)
is_healthy = region["status"] == "active" and region["current_load"] < region["capacity"] * 0.9
balancer["health_status"][region_id] = "healthy" if is_healthy else "unhealthy"
# Update load balancer weights based on performance
update_load_balancer_weights(balancer_id)
await asyncio.sleep(balancer["health_check_interval"])
except Exception as e:
logger.error(f"Health monitoring error for {balancer_id}: {str(e)}")
await asyncio.sleep(10)
def update_load_balancer_weights(balancer_id: str):
"""Update load balancer weights based on region performance"""
balancer = load_balancers[balancer_id]
if balancer["algorithm"] == "weighted":
# Calculate weights based on capacity and current load
for region_id in balancer["target_regions"]:
if region_id in global_regions:
region = global_regions[region_id]
# Weight based on available capacity
available_capacity = region["capacity"] - region["current_load"]
total_available = sum(
global_regions[r]["capacity"] - global_regions[r]["current_load"]
for r in balancer["target_regions"]
if r in global_regions
)
if total_available > 0:
weight = available_capacity / total_available
balancer["current_weights"][region_id] = round(weight, 3)
# Background task for global monitoring
async def global_monitoring_task():
"""Background task for global infrastructure monitoring"""
while True:
await asyncio.sleep(60) # Monitor every minute
# Update global monitoring data
global_monitoring["last_update"] = datetime.now(timezone.utc).isoformat()
global_monitoring["total_requests"] = sum(lb.get("total_requests", 0) for lb in load_balancers.values())
global_monitoring["failed_requests"] = sum(lb.get("failed_requests", 0) for lb in load_balancers.values())
# Check for regions that need attention
for region_id, region in global_regions.items():
if region["current_load"] > region["capacity"] * 0.8:
logger.warning(f"High load detected in region {region_id}: {region['current_load']}/{region['capacity']}")
if region["latency_ms"] > 500:
logger.warning(f"High latency detected in region {region_id}: {region['latency_ms']}ms")
# Initialize with some default regions
@app.on_event("startup")
async def startup_event():
logger.info("Starting AITBC Global Infrastructure Service")
# Initialize default regions
default_regions = [
{
"region_id": "us-east-1",
"name": "US East (N. Virginia)",
"location": "North America",
"endpoint": "https://us-east-1.api.aitbc.dev",
"status": "active",
"capacity": 10000,
"current_load": 3500,
"latency_ms": 45,
"compliance_level": "full"
},
{
"region_id": "eu-west-1",
"name": "EU West (Ireland)",
"location": "Europe",
"endpoint": "https://eu-west-1.api.aitbc.dev",
"status": "active",
"capacity": 8000,
"current_load": 2800,
"latency_ms": 38,
"compliance_level": "full"
},
{
"region_id": "ap-southeast-1",
"name": "AP Southeast (Singapore)",
"location": "Asia Pacific",
"endpoint": "https://ap-southeast-1.api.aitbc.dev",
"status": "active",
"capacity": 6000,
"current_load": 2200,
"latency_ms": 62,
"compliance_level": "partial"
}
]
for region_data in default_regions:
region = Region(**region_data)
region_record = {
"region_id": region.region_id,
"name": region.name,
"location": region.location,
"endpoint": region.endpoint,
"status": region.status,
"capacity": region.capacity,
"current_load": region.current_load,
"latency_ms": region.latency_ms,
"compliance_level": region.compliance_level,
"created_at": datetime.now(timezone.utc).isoformat(),
"last_health_check": None,
"services_deployed": [],
"performance_history": []
}
global_regions[region.region_id] = region_record
# Start global monitoring
asyncio.create_task(global_monitoring_task())
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Shutting down AITBC Global Infrastructure Service")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8017, log_level="info")

View File

@@ -0,0 +1 @@
"""Global infrastructure service tests"""

View File

@@ -0,0 +1,195 @@
"""Edge case and error handling tests for global infrastructure service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from main import app, Region, GlobalDeployment, LoadBalancer, PerformanceMetrics, global_regions, deployments, load_balancers, performance_metrics
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
global_regions.clear()
deployments.clear()
load_balancers.clear()
performance_metrics.clear()
yield
global_regions.clear()
deployments.clear()
load_balancers.clear()
performance_metrics.clear()
@pytest.mark.unit
def test_region_negative_capacity():
"""Test Region with negative capacity"""
region = Region(
region_id="us-west-1",
name="US West",
location="North America",
endpoint="https://us-west-1.api.aitbc.dev",
status="active",
capacity=-1000,
current_load=-500,
latency_ms=-50,
compliance_level="full"
)
assert region.capacity == -1000
assert region.current_load == -500
@pytest.mark.unit
def test_region_empty_name():
"""Test Region with empty name"""
region = Region(
region_id="us-west-1",
name="",
location="North America",
endpoint="https://us-west-1.api.aitbc.dev",
status="active",
capacity=8000,
current_load=2000,
latency_ms=50,
compliance_level="full"
)
assert region.name == ""
@pytest.mark.unit
def test_deployment_empty_target_regions():
"""Test GlobalDeployment with empty target regions"""
deployment = GlobalDeployment(
deployment_id="deploy_123",
service_name="test-service",
target_regions=[],
configuration={},
deployment_strategy="blue_green",
health_checks=[]
)
assert deployment.target_regions == []
@pytest.mark.unit
def test_load_balancer_negative_health_check_interval():
"""Test LoadBalancer with negative health check interval"""
balancer = LoadBalancer(
balancer_id="lb_123",
name="Main LB",
algorithm="round_robin",
target_regions=["us-east-1"],
health_check_interval=-30,
failover_threshold=3
)
assert balancer.health_check_interval == -30
@pytest.mark.unit
def test_performance_metrics_negative_values():
"""Test PerformanceMetrics with negative values"""
metrics = PerformanceMetrics(
region_id="us-east-1",
timestamp=datetime.now(timezone.utc),
cpu_usage=-50.5,
memory_usage=-60.2,
network_io=-1000.5,
disk_io=-500.3,
active_connections=-100,
response_time_ms=-45.2
)
assert metrics.cpu_usage == -50.5
assert metrics.active_connections == -100
@pytest.mark.integration
def test_list_regions_with_no_regions():
"""Test listing regions when no regions exist"""
client = TestClient(app)
response = client.get("/api/v1/regions")
assert response.status_code == 200
data = response.json()
assert data["total_regions"] == 0
@pytest.mark.integration
def test_list_deployments_with_no_deployments():
"""Test listing deployments when no deployments exist"""
client = TestClient(app)
response = client.get("/api/v1/deployments")
assert response.status_code == 200
data = response.json()
assert data["total_deployments"] == 0
@pytest.mark.integration
def test_list_load_balancers_with_no_balancers():
"""Test listing load balancers when no balancers exist"""
client = TestClient(app)
response = client.get("/api/v1/load-balancers")
assert response.status_code == 200
data = response.json()
assert data["total_balancers"] == 0
@pytest.mark.integration
def test_get_deployment_not_found():
"""Test getting nonexistent deployment"""
client = TestClient(app)
response = client.get("/api/v1/deployments/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_get_region_performance_no_data():
"""Test getting region performance when no data exists"""
client = TestClient(app)
response = client.get("/api/v1/performance/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_get_region_compliance_nonexistent():
"""Test getting compliance for nonexistent region"""
client = TestClient(app)
response = client.get("/api/v1/compliance/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_create_load_balancer_nonexistent_region():
"""Test creating load balancer with nonexistent region"""
client = TestClient(app)
balancer = LoadBalancer(
balancer_id="lb_123",
name="Main LB",
algorithm="round_robin",
target_regions=["nonexistent"],
health_check_interval=30,
failover_threshold=3
)
response = client.post("/api/v1/load-balancers/create", json=balancer.model_dump())
assert response.status_code == 400
@pytest.mark.integration
def test_list_deployments_with_status_filter():
"""Test listing deployments with status filter"""
client = TestClient(app)
response = client.get("/api/v1/deployments?status=pending")
assert response.status_code == 200
data = response.json()
assert "status_filter" in data
@pytest.mark.integration
def test_global_dashboard_with_no_data():
"""Test global dashboard with no data"""
client = TestClient(app)
response = client.get("/api/v1/global/dashboard")
assert response.status_code == 200
data = response.json()
assert data["dashboard"]["infrastructure"]["total_regions"] == 0

View File

@@ -0,0 +1,353 @@
"""Integration tests for global infrastructure service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from main import app, Region, GlobalDeployment, LoadBalancer, PerformanceMetrics, global_regions, deployments, load_balancers, performance_metrics
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
global_regions.clear()
deployments.clear()
load_balancers.clear()
performance_metrics.clear()
yield
global_regions.clear()
deployments.clear()
load_balancers.clear()
performance_metrics.clear()
@pytest.mark.integration
def test_root_endpoint():
"""Test root endpoint"""
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["service"] == "AITBC Global Infrastructure Service"
assert data["status"] == "running"
@pytest.mark.integration
def test_health_check_endpoint():
"""Test health check endpoint"""
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "total_regions" in data
assert "active_regions" in data
@pytest.mark.integration
def test_register_region():
"""Test registering a new region"""
client = TestClient(app)
region = Region(
region_id="us-west-1",
name="US West",
location="North America",
endpoint="https://us-west-1.api.aitbc.dev",
status="active",
capacity=8000,
current_load=2000,
latency_ms=50,
compliance_level="full"
)
response = client.post("/api/v1/regions/register", json=region.model_dump())
assert response.status_code == 200
data = response.json()
assert data["region_id"] == "us-west-1"
assert data["status"] == "registered"
@pytest.mark.integration
def test_register_duplicate_region():
"""Test registering duplicate region"""
client = TestClient(app)
region = Region(
region_id="us-west-1",
name="US West",
location="North America",
endpoint="https://us-west-1.api.aitbc.dev",
status="active",
capacity=8000,
current_load=2000,
latency_ms=50,
compliance_level="full"
)
client.post("/api/v1/regions/register", json=region.model_dump())
response = client.post("/api/v1/regions/register", json=region.model_dump())
assert response.status_code == 400
@pytest.mark.integration
def test_list_regions():
"""Test listing all regions"""
client = TestClient(app)
response = client.get("/api/v1/regions")
assert response.status_code == 200
data = response.json()
assert "regions" in data
assert "total_regions" in data
@pytest.mark.integration
def test_get_region():
"""Test getting specific region"""
client = TestClient(app)
region = Region(
region_id="us-west-1",
name="US West",
location="North America",
endpoint="https://us-west-1.api.aitbc.dev",
status="active",
capacity=8000,
current_load=2000,
latency_ms=50,
compliance_level="full"
)
client.post("/api/v1/regions/register", json=region.model_dump())
response = client.get("/api/v1/regions/us-west-1")
assert response.status_code == 200
data = response.json()
assert data["region_id"] == "us-west-1"
@pytest.mark.integration
def test_get_region_not_found():
"""Test getting nonexistent region"""
client = TestClient(app)
response = client.get("/api/v1/regions/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_create_deployment():
"""Test creating a deployment"""
client = TestClient(app)
# Register region first
region = Region(
region_id="us-west-1",
name="US West",
location="North America",
endpoint="https://us-west-1.api.aitbc.dev",
status="active",
capacity=8000,
current_load=2000,
latency_ms=50,
compliance_level="full"
)
client.post("/api/v1/regions/register", json=region.model_dump())
deployment = GlobalDeployment(
deployment_id="deploy_123",
service_name="test-service",
target_regions=["us-west-1"],
configuration={"replicas": 3},
deployment_strategy="blue_green",
health_checks=["/health"]
)
response = client.post("/api/v1/deployments/create", json=deployment.model_dump())
assert response.status_code == 200
data = response.json()
assert data["deployment_id"]
assert data["status"] == "pending"
@pytest.mark.integration
def test_create_deployment_nonexistent_region():
"""Test creating deployment with nonexistent region"""
client = TestClient(app)
deployment = GlobalDeployment(
deployment_id="deploy_123",
service_name="test-service",
target_regions=["nonexistent"],
configuration={"replicas": 3},
deployment_strategy="blue_green",
health_checks=["/health"]
)
response = client.post("/api/v1/deployments/create", json=deployment.model_dump())
assert response.status_code == 400
@pytest.mark.integration
def test_get_deployment():
"""Test getting deployment details"""
client = TestClient(app)
# Register region first
region = Region(
region_id="us-west-1",
name="US West",
location="North America",
endpoint="https://us-west-1.api.aitbc.dev",
status="active",
capacity=8000,
current_load=2000,
latency_ms=50,
compliance_level="full"
)
client.post("/api/v1/regions/register", json=region.model_dump())
deployment = GlobalDeployment(
deployment_id="deploy_123",
service_name="test-service",
target_regions=["us-west-1"],
configuration={"replicas": 3},
deployment_strategy="blue_green",
health_checks=["/health"]
)
create_response = client.post("/api/v1/deployments/create", json=deployment.model_dump())
deployment_id = create_response.json()["deployment_id"]
response = client.get(f"/api/v1/deployments/{deployment_id}")
assert response.status_code == 200
data = response.json()
assert data["deployment_id"] == deployment_id
@pytest.mark.integration
def test_list_deployments():
"""Test listing all deployments"""
client = TestClient(app)
response = client.get("/api/v1/deployments")
assert response.status_code == 200
data = response.json()
assert "deployments" in data
assert "total_deployments" in data
@pytest.mark.integration
def test_create_load_balancer():
"""Test creating a load balancer"""
client = TestClient(app)
# Register region first
region = Region(
region_id="us-west-1",
name="US West",
location="North America",
endpoint="https://us-west-1.api.aitbc.dev",
status="active",
capacity=8000,
current_load=2000,
latency_ms=50,
compliance_level="full"
)
client.post("/api/v1/regions/register", json=region.model_dump())
balancer = LoadBalancer(
balancer_id="lb_123",
name="Main LB",
algorithm="round_robin",
target_regions=["us-west-1"],
health_check_interval=30,
failover_threshold=3
)
response = client.post("/api/v1/load-balancers/create", json=balancer.model_dump())
assert response.status_code == 200
data = response.json()
assert data["balancer_id"]
assert data["status"] == "active"
@pytest.mark.integration
def test_list_load_balancers():
"""Test listing all load balancers"""
client = TestClient(app)
response = client.get("/api/v1/load-balancers")
assert response.status_code == 200
data = response.json()
assert "load_balancers" in data
assert "total_balancers" in data
@pytest.mark.integration
def test_record_performance_metrics():
"""Test recording performance metrics"""
client = TestClient(app)
metrics = PerformanceMetrics(
region_id="us-west-1",
timestamp=datetime.now(timezone.utc),
cpu_usage=50.5,
memory_usage=60.2,
network_io=1000.5,
disk_io=500.3,
active_connections=100,
response_time_ms=45.2
)
response = client.post("/api/v1/performance/metrics", json=metrics.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["metrics_id"]
assert data["status"] == "recorded"
@pytest.mark.integration
def test_get_region_performance():
"""Test getting region performance metrics"""
client = TestClient(app)
# Record metrics first
metrics = PerformanceMetrics(
region_id="us-west-1",
timestamp=datetime.now(timezone.utc),
cpu_usage=50.5,
memory_usage=60.2,
network_io=1000.5,
disk_io=500.3,
active_connections=100,
response_time_ms=45.2
)
client.post("/api/v1/performance/metrics", json=metrics.model_dump(mode='json'))
response = client.get("/api/v1/performance/us-west-1")
assert response.status_code == 200
data = response.json()
assert data["region_id"] == "us-west-1"
assert "statistics" in data
@pytest.mark.integration
def test_get_region_compliance():
"""Test getting region compliance information"""
client = TestClient(app)
# Register region first
region = Region(
region_id="us-west-1",
name="US West",
location="North America",
endpoint="https://us-west-1.api.aitbc.dev",
status="active",
capacity=8000,
current_load=2000,
latency_ms=50,
compliance_level="full"
)
client.post("/api/v1/regions/register", json=region.model_dump())
response = client.get("/api/v1/compliance/us-west-1")
assert response.status_code == 200
data = response.json()
assert data["region_id"] == "us-west-1"
assert "compliance_level" in data
@pytest.mark.integration
def test_get_global_dashboard():
"""Test getting global dashboard"""
client = TestClient(app)
response = client.get("/api/v1/global/dashboard")
assert response.status_code == 200
data = response.json()
assert "dashboard" in data
assert "infrastructure" in data["dashboard"]

View File

@@ -0,0 +1,93 @@
"""Unit tests for global infrastructure service"""
import pytest
import sys
import sys
from pathlib import Path
from datetime import datetime, timezone
from main import app, Region, GlobalDeployment, LoadBalancer, PerformanceMetrics
@pytest.mark.unit
def test_app_initialization():
"""Test that the FastAPI app initializes correctly"""
assert app is not None
assert app.title == "AITBC Global Infrastructure Service"
assert app.version == "1.0.0"
@pytest.mark.unit
def test_region_model():
"""Test Region model"""
region = Region(
region_id="us-east-1",
name="US East",
location="North America",
endpoint="https://us-east-1.api.aitbc.dev",
status="active",
capacity=10000,
current_load=3500,
latency_ms=45,
compliance_level="full"
)
assert region.region_id == "us-east-1"
assert region.name == "US East"
assert region.status == "active"
assert region.capacity == 10000
assert region.compliance_level == "full"
@pytest.mark.unit
def test_global_deployment_model():
"""Test GlobalDeployment model"""
deployment = GlobalDeployment(
deployment_id="deploy_123",
service_name="test-service",
target_regions=["us-east-1", "eu-west-1"],
configuration={"replicas": 3},
deployment_strategy="blue_green",
health_checks=["/health", "/ready"]
)
assert deployment.deployment_id == "deploy_123"
assert deployment.service_name == "test-service"
assert deployment.target_regions == ["us-east-1", "eu-west-1"]
assert deployment.deployment_strategy == "blue_green"
@pytest.mark.unit
def test_load_balancer_model():
"""Test LoadBalancer model"""
balancer = LoadBalancer(
balancer_id="lb_123",
name="Main LB",
algorithm="round_robin",
target_regions=["us-east-1", "eu-west-1"],
health_check_interval=30,
failover_threshold=3
)
assert balancer.balancer_id == "lb_123"
assert balancer.name == "Main LB"
assert balancer.algorithm == "round_robin"
assert balancer.health_check_interval == 30
@pytest.mark.unit
def test_performance_metrics_model():
"""Test PerformanceMetrics model"""
metrics = PerformanceMetrics(
region_id="us-east-1",
timestamp=datetime.now(timezone.utc),
cpu_usage=50.5,
memory_usage=60.2,
network_io=1000.5,
disk_io=500.3,
active_connections=100,
response_time_ms=45.2
)
assert metrics.region_id == "us-east-1"
assert metrics.cpu_usage == 50.5
assert metrics.memory_usage == 60.2
assert metrics.active_connections == 100
assert metrics.response_time_ms == 45.2

1292
examples/stubs/hermes-service/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,25 @@
[tool.poetry]
name = "hermes-service"
version = "0.1.0"
description = "AITBC hermes Service for agent orchestration and edge computing"
authors = ["AITBC Team"]
[tool.poetry.dependencies]
python = "^3.13"
fastapi = ">=0.115.6"
uvicorn = {extras = ["standard"], version = "^0.32.0"}
sqlmodel = "^0.0.37"
sqlalchemy = "^2.0.25"
pydantic = "^2.6.0"
pydantic-settings = "^2.1.0"
httpx = ">=0.28.1"
[tool.poetry.group.dev.dependencies]
pytest = ">=9.0.3"
pytest-asyncio = ">=1.3.0"
black = ">=26.3.1"
ruff = "^0.1.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1,122 @@
"""Hermes Service for agent orchestration and edge computing."""
from __future__ import annotations
import os
import logging
from typing import Any
from fastapi import FastAPI
logger = logging.getLogger(__name__)
app = FastAPI(
title="AITBC Hermes Service",
description="Agent orchestration and edge computing service",
version="1.0.0"
)
@app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "healthy", "service": "hermes-service"}
@app.get("/")
async def root():
"""Root endpoint."""
return {
"service": "AITBC Hermes Service",
"version": "1.0.0",
"status": "operational"
}
@app.post("/routing/skill")
async def route_agent_skill(request: dict[str, Any]) -> dict[str, Any]:
"""Sophisticated agent skill routing"""
return {
"selected_agent": "agent_default",
"routing_strategy": "performance_based",
"expected_performance": "high",
"estimated_cost": 0.5
}
@app.post("/offloading/intelligent")
async def intelligent_job_offloading(request: dict[str, Any]) -> dict[str, Any]:
"""Intelligent job offloading strategies"""
return {
"should_offload": True,
"job_size": "medium",
"cost_analysis": {"estimated_cost": 10.0, "savings": 5.0},
"performance_prediction": "improved",
"fallback_mechanism": "local_execution"
}
@app.post("/collaboration/coordinate")
async def coordinate_agent_collaboration(request: dict[str, Any]) -> dict[str, Any]:
"""Agent collaboration and coordination"""
return {
"coordination_method": "consensus",
"selected_coordinator": "agent_coordinator_1",
"consensus_reached": True,
"task_distribution": {"agent_1": 0.5, "agent_2": 0.5},
"estimated_completion_time": 300
}
@app.post("/execution/hybrid-optimize")
async def optimize_hybrid_execution(request: dict[str, Any]) -> dict[str, Any]:
"""Hybrid execution optimization"""
return {
"execution_mode": "hybrid",
"strategy": "cost_performance_balance",
"resource_allocation": {"cpu": 50, "memory": 40, "gpu": 10},
"performance_tuning": "optimized",
"expected_improvement": "30%"
}
@app.post("/edge/deploy")
async def deploy_to_edge(request: dict[str, Any]) -> dict[str, Any]:
"""Deploy agent to edge computing infrastructure"""
return {
"deployment_id": "deployment_123",
"agent_id": request.get("agent_id", "unknown"),
"edge_locations": request.get("edge_locations", []),
"deployment_results": {"success": True, "deployed_count": 3},
"status": "completed"
}
@app.post("/edge/coordinate")
async def coordinate_edge_to_cloud(request: dict[str, Any]) -> dict[str, Any]:
"""Coordinate edge-to-cloud agent operations"""
return {
"coordination_id": "coord_456",
"edge_deployment_id": request.get("edge_deployment_id", "unknown"),
"synchronization": "enabled",
"load_balancing": "round_robin",
"failover": "active",
"status": "active"
}
@app.post("/ecosystem/develop")
async def develop_hermes_ecosystem(request: dict[str, Any]) -> dict[str, Any]:
"""Build comprehensive Hermes ecosystem"""
return {
"ecosystem_id": "eco_789",
"developer_tools": ["cli", "sdk", "dashboard"],
"marketplace": "enabled",
"community": "growing",
"partnerships": ["partner_1", "partner_2"],
"status": "developing"
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8108)

View File

@@ -0,0 +1,48 @@
#!/usr/bin/env python3
"""
AITBC Monitor Service
"""
import time
import json
from pathlib import Path
import psutil
from aitbc import get_logger, DATA_DIR
def main():
logger = get_logger('aitbc-monitor')
while True:
try:
# System stats
cpu_percent = psutil.cpu_percent()
memory_percent = psutil.virtual_memory().percent
logger.info(f'System: CPU {cpu_percent}%, Memory {memory_percent}%')
# Blockchain stats
blockchain_file = DATA_DIR / 'data/blockchain/aitbc/blockchain.json'
if blockchain_file.exists():
with open(blockchain_file, 'r') as f:
data = json.load(f)
logger.info(f'Blockchain: {len(data.get("blocks", []))} blocks')
# Marketplace stats
marketplace_dir = DATA_DIR / 'data/marketplace'
if marketplace_dir.exists():
listings_file = marketplace_dir / 'gpu_listings.json'
if listings_file.exists():
with open(listings_file, 'r') as f:
listings = json.load(f)
logger.info(f'Marketplace: {len(listings)} GPU listings')
time.sleep(30)
except (json.JSONDecodeError, FileNotFoundError, PermissionError, IOError) as e:
logger.error(f'Monitoring error: {type(e).__name__}: {e}')
time.sleep(60)
except psutil.Error as e:
logger.error(f'System monitoring error: {type(e).__name__}: {e}')
time.sleep(60)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1 @@
"""Monitor service tests"""

View File

@@ -0,0 +1,216 @@
"""Edge case and error handling tests for monitor service"""
import sys
import pytest
import sys
from unittest.mock import Mock, patch, MagicMock, mock_open
from pathlib import Path
import json
# Create a proper psutil mock with Error exception class
class PsutilError(Exception):
pass
mock_psutil = MagicMock()
mock_psutil.cpu_percent = Mock(return_value=45.5)
mock_psutil.virtual_memory = Mock(return_value=MagicMock(percent=60.2))
mock_psutil.Error = PsutilError
sys.modules['psutil'] = mock_psutil
import monitor
@pytest.mark.unit
def test_json_decode_error_handling():
"""Test JSON decode error is handled correctly"""
with patch('monitor.logging') as mock_logging, \
patch('monitor.time.sleep', side_effect=[None, KeyboardInterrupt]), \
patch('monitor.Path') as mock_path, \
patch('builtins.open', mock_open(read_data='invalid json{')):
# Mock blockchain file exists
blockchain_path = Mock()
blockchain_path.exists.return_value = True
marketplace_path = Mock()
marketplace_path.exists.return_value = False
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
logger = mock_logging.getLogger.return_value
mock_logging.basicConfig.return_value = None
try:
monitor.main()
except KeyboardInterrupt:
pass
# Verify error was logged
error_calls = [call for call in logger.error.call_args_list if 'JSONDecodeError' in str(call)]
assert len(error_calls) > 0
@pytest.mark.unit
def test_file_not_found_error_handling():
"""Test FileNotFoundError is handled correctly"""
with patch('monitor.logging') as mock_logging, \
patch('monitor.time.sleep', side_effect=[None, KeyboardInterrupt]), \
patch('monitor.Path') as mock_path, \
patch('builtins.open', side_effect=FileNotFoundError("File not found")):
# Mock blockchain file exists
blockchain_path = Mock()
blockchain_path.exists.return_value = True
marketplace_path = Mock()
marketplace_path.exists.return_value = False
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
logger = mock_logging.getLogger.return_value
mock_logging.basicConfig.return_value = None
try:
monitor.main()
except KeyboardInterrupt:
pass
# Verify error was logged
error_calls = [call for call in logger.error.call_args_list if 'FileNotFoundError' in str(call)]
assert len(error_calls) > 0
@pytest.mark.unit
def test_permission_error_handling():
"""Test PermissionError is handled correctly"""
with patch('monitor.logging') as mock_logging, \
patch('monitor.time.sleep', side_effect=[None, KeyboardInterrupt]), \
patch('monitor.Path') as mock_path, \
patch('builtins.open', side_effect=PermissionError("Permission denied")):
# Mock blockchain file exists
blockchain_path = Mock()
blockchain_path.exists.return_value = True
marketplace_path = Mock()
marketplace_path.exists.return_value = False
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
logger = mock_logging.getLogger.return_value
mock_logging.basicConfig.return_value = None
try:
monitor.main()
except KeyboardInterrupt:
pass
# Verify error was logged
error_calls = [call for call in logger.error.call_args_list if 'PermissionError' in str(call)]
assert len(error_calls) > 0
@pytest.mark.unit
def test_io_error_handling():
"""Test IOError is handled correctly"""
with patch('monitor.logging') as mock_logging, \
patch('monitor.time.sleep', side_effect=[None, KeyboardInterrupt]), \
patch('monitor.Path') as mock_path, \
patch('builtins.open', side_effect=IOError("I/O error")):
# Mock blockchain file exists
blockchain_path = Mock()
blockchain_path.exists.return_value = True
marketplace_path = Mock()
marketplace_path.exists.return_value = False
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
logger = mock_logging.getLogger.return_value
mock_logging.basicConfig.return_value = None
try:
monitor.main()
except KeyboardInterrupt:
pass
# Verify error was logged
error_calls = [call for call in logger.error.call_args_list if 'IOError' in str(call) or 'OSError' in str(call)]
assert len(error_calls) > 0
@pytest.mark.unit
def test_psutil_error_handling():
"""Test psutil.Error is handled correctly"""
with patch('monitor.logging') as mock_logging, \
patch('monitor.time.sleep', side_effect=[None, KeyboardInterrupt]), \
patch('monitor.psutil.cpu_percent', side_effect=PsutilError("psutil error")):
logger = mock_logging.getLogger.return_value
mock_logging.basicConfig.return_value = None
try:
monitor.main()
except KeyboardInterrupt:
pass
# Verify error was logged
error_calls = [call for call in logger.error.call_args_list if 'psutil error' in str(call)]
assert len(error_calls) > 0
@pytest.mark.unit
def test_empty_blocks_array():
"""Test handling of empty blocks array in blockchain data"""
with patch('monitor.logging') as mock_logging, \
patch('monitor.time.sleep', side_effect=KeyboardInterrupt), \
patch('monitor.Path') as mock_path, \
patch('builtins.open', mock_open(read_data='{"blocks": []}')):
# Mock blockchain file exists
blockchain_path = Mock()
blockchain_path.exists.return_value = True
marketplace_path = Mock()
marketplace_path.exists.return_value = False
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
logger = mock_logging.getLogger.return_value
mock_logging.basicConfig.return_value = None
try:
monitor.main()
except KeyboardInterrupt:
pass
# Verify blockchain stats were logged with 0 blocks
blockchain_calls = [call for call in logger.info.call_args_list if 'Blockchain' in str(call)]
assert len(blockchain_calls) > 0
assert '0 blocks' in str(blockchain_calls[0])
@pytest.mark.unit
def test_missing_blocks_key():
"""Test handling of missing blocks key in blockchain data"""
with patch('monitor.logging') as mock_logging, \
patch('monitor.time.sleep', side_effect=KeyboardInterrupt), \
patch('monitor.Path') as mock_path, \
patch('builtins.open', mock_open(read_data='{"height": 100}')):
# Mock blockchain file exists
blockchain_path = Mock()
blockchain_path.exists.return_value = True
marketplace_path = Mock()
marketplace_path.exists.return_value = False
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
logger = mock_logging.getLogger.return_value
mock_logging.basicConfig.return_value = None
try:
monitor.main()
except KeyboardInterrupt:
pass
# Verify blockchain stats were logged with 0 blocks (default)
blockchain_calls = [call for call in logger.info.call_args_list if 'Blockchain' in str(call)]
assert len(blockchain_calls) > 0
assert '0 blocks' in str(blockchain_calls[0])

View File

@@ -0,0 +1,108 @@
"""Unit tests for monitor service"""
import sys
import pytest
import sys
from unittest.mock import Mock, patch, MagicMock, mock_open
from pathlib import Path
import json
# Create a proper psutil mock with Error exception class
class PsutilError(Exception):
pass
mock_psutil = MagicMock()
mock_psutil.cpu_percent = Mock(return_value=45.5)
mock_psutil.virtual_memory = Mock(return_value=MagicMock(percent=60.2))
mock_psutil.Error = PsutilError
sys.modules['psutil'] = mock_psutil
import monitor
@pytest.mark.unit
def test_main_system_stats_logging():
"""Test that system stats are logged correctly"""
with patch('monitor.logging') as mock_logging, \
patch('monitor.time.sleep', side_effect=KeyboardInterrupt), \
patch('monitor.Path') as mock_path:
mock_path.return_value.exists.return_value = False
logger = mock_logging.getLogger.return_value
mock_logging.basicConfig.return_value = None
try:
monitor.main()
except KeyboardInterrupt:
pass
# Verify system stats were logged
assert logger.info.call_count >= 1
system_call = logger.info.call_args_list[0]
assert 'CPU 45.5%' in str(system_call)
assert 'Memory 60.2%' in str(system_call)
@pytest.mark.unit
def test_main_blockchain_stats_logging():
"""Test that blockchain stats are logged when file exists"""
with patch('monitor.logging') as mock_logging, \
patch('monitor.time.sleep', side_effect=KeyboardInterrupt), \
patch('monitor.Path') as mock_path, \
patch('builtins.open', mock_open(read_data='{"blocks": [{"height": 1}, {"height": 2}]}')):
# Mock blockchain file exists
blockchain_path = Mock()
blockchain_path.exists.return_value = True
marketplace_path = Mock()
marketplace_path.exists.return_value = False
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
logger = mock_logging.getLogger.return_value
mock_logging.basicConfig.return_value = None
try:
monitor.main()
except KeyboardInterrupt:
pass
# Verify blockchain stats were logged
blockchain_calls = [call for call in logger.info.call_args_list if 'Blockchain' in str(call)]
assert len(blockchain_calls) > 0
assert '2 blocks' in str(blockchain_calls[0])
@pytest.mark.unit
def test_main_marketplace_stats_logging():
"""Test that marketplace stats are logged when file exists"""
with patch('monitor.logging') as mock_logging, \
patch('monitor.time.sleep', side_effect=KeyboardInterrupt), \
patch('monitor.Path') as mock_path, \
patch('builtins.open', mock_open(read_data='[{"id": 1, "gpu": "rtx3080"}, {"id": 2, "gpu": "rtx3090"}]')):
# Mock blockchain file doesn't exist, marketplace does
blockchain_path = Mock()
blockchain_path.exists.return_value = False
marketplace_path = Mock()
marketplace_path.exists.return_value = True
listings_file = Mock()
listings_file.exists.return_value = True
listings_file.__truediv__ = Mock(return_value=listings_file)
marketplace_path.__truediv__ = Mock(return_value=listings_file)
mock_path.side_effect = lambda x: listings_file if 'gpu_listings' in str(x) else (marketplace_path if 'marketplace' in str(x) else blockchain_path)
logger = mock_logging.getLogger.return_value
mock_logging.basicConfig.return_value = None
try:
monitor.main()
except KeyboardInterrupt:
pass
# Verify marketplace stats were logged
marketplace_calls = [call for call in logger.info.call_args_list if 'Marketplace' in str(call)]
assert len(marketplace_calls) > 0
assert '2 GPU listings' in str(marketplace_calls[0])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,26 @@
[tool.poetry]
name = "monitoring-service"
version = "0.1.0"
description = "AITBC Monitoring Service for system health and metrics"
authors = ["AITBC Team"]
[tool.poetry.dependencies]
python = "^3.13"
fastapi = ">=0.115.6"
uvicorn = {extras = ["standard"], version = "^0.32.0"}
sqlmodel = "^0.0.37"
sqlalchemy = "^2.0.25"
pydantic = "^2.6.0"
pydantic-settings = "^2.1.0"
httpx = ">=0.28.1"
psutil = "^7.2.0"
[tool.poetry.group.dev.dependencies]
pytest = ">=9.0.3"
pytest-asyncio = ">=1.3.0"
black = ">=26.3.1"
ruff = "^0.1.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1,266 @@
"""Monitoring Service for system health and metrics."""
from __future__ import annotations
import os
import logging
import asyncio
from datetime import datetime, timezone
from typing import Any
from fastapi import FastAPI
import httpx
logger = logging.getLogger(__name__)
app = FastAPI(
title="AITBC Monitoring Service",
description="System health and metrics monitoring service",
version="1.0.0"
)
# Service endpoints configuration
SERVICES = {
"gpu": {
"name": "GPU Service",
"port": 8101,
"url": "http://localhost:8101",
"description": "GPU marketplace and miner operations",
},
"marketplace": {
"name": "Marketplace Service",
"port": 8102,
"url": "http://localhost:8102",
"description": "Marketplace transactions",
},
"trading": {
"name": "Trading Service",
"port": 8104,
"url": "http://localhost:8104",
"description": "Trading and explorer operations",
},
"governance": {
"name": "Governance Service",
"port": 8105,
"url": "http://localhost:8105",
"description": "Governance transactions",
},
"ai": {
"name": "AI Service",
"port": 8106,
"url": "http://localhost:8106",
"description": "AI job operations",
},
}
@app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "healthy", "service": "monitoring-service"}
@app.get("/")
async def root():
"""Root endpoint."""
return {
"service": "AITBC Monitoring Service",
"version": "1.0.0",
"status": "operational"
}
@app.get("/dashboard")
async def monitoring_dashboard() -> dict[str, Any]:
"""
Unified monitoring dashboard for all services
"""
try:
# Collect health data from all services
health_data = await collect_all_health_data()
# Calculate overall metrics
overall_metrics = calculate_overall_metrics(health_data)
dashboard_data = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"overall_status": overall_metrics["overall_status"],
"services": health_data,
"metrics": overall_metrics,
"summary": {
"total_services": len(SERVICES),
"healthy_services": len([s for s in health_data.values() if s.get("status") == "healthy"]),
"degraded_services": len([s for s in health_data.values() if s.get("status") == "degraded"]),
"unhealthy_services": len([s for s in health_data.values() if s.get("status") == "unhealthy"]),
"last_updated": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC"),
},
}
logger.info("Monitoring dashboard data collected successfully")
return dashboard_data
except Exception as e:
logger.error(f"Failed to generate monitoring dashboard: {e}")
return {
"error": "Failed to generate dashboard",
"timestamp": datetime.now(timezone.utc).isoformat(),
"services": SERVICES,
"overall_status": "error",
"summary": {
"total_services": len(SERVICES),
"healthy_services": 0,
"degraded_services": 0,
"unhealthy_services": len(SERVICES),
"last_updated": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC"),
},
}
@app.get("/dashboard/summary")
async def services_summary() -> dict[str, Any]:
"""
Quick summary of all services status
"""
try:
health_data = await collect_all_health_data()
summary = {"timestamp": datetime.now(timezone.utc).isoformat(), "services": {}}
for service_id, service_info in SERVICES.items():
health = health_data.get(service_id, {})
summary["services"][service_id] = {
"name": service_info["name"],
"port": service_info["port"],
"status": health.get("status", "unknown"),
"description": service_info["description"],
"last_check": health.get("timestamp"),
}
return summary
except Exception as e:
logger.error(f"Failed to generate services summary: {e}")
return {"error": "Failed to generate summary", "timestamp": datetime.now(timezone.utc).isoformat()}
@app.get("/dashboard/metrics")
async def system_metrics() -> dict[str, Any]:
"""
System-wide performance metrics
"""
try:
import psutil
# System metrics
cpu_percent = psutil.cpu_percent(interval=1)
memory = psutil.virtual_memory()
disk = psutil.disk_usage("/")
# Network metrics
network = psutil.net_io_counters()
metrics = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"system": {
"cpu_percent": cpu_percent,
"cpu_count": psutil.cpu_count(),
"memory_percent": memory.percent,
"memory_total_gb": round(memory.total / (1024**3), 2),
"memory_available_gb": round(memory.available / (1024**3), 2),
"disk_percent": disk.percent,
"disk_total_gb": round(disk.total / (1024**3), 2),
"disk_free_gb": round(disk.free / (1024**3), 2),
},
"network": {
"bytes_sent": network.bytes_sent,
"bytes_recv": network.bytes_recv,
"packets_sent": network.packets_sent,
"packets_recv": network.packets_recv,
},
"services": {
"total_services": len(SERVICES),
"service_names": list(SERVICES.keys()),
},
}
return metrics
except Exception as e:
logger.error(f"Failed to collect system metrics: {e}")
return {"error": "Failed to collect metrics", "timestamp": datetime.now(timezone.utc).isoformat()}
async def collect_all_health_data() -> dict[str, Any]:
"""Collect health data from all services"""
health_data = {}
tasks = []
for service_id, service_info in SERVICES.items():
task = check_service_health(service_id, service_info)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, (service_id, service_info) in enumerate(SERVICES.items()):
result = results[i]
if isinstance(result, Exception):
health_data[service_id] = {
"status": "unhealthy",
"error": str(result),
"timestamp": datetime.now(timezone.utc).isoformat(),
}
else:
health_data[service_id] = result
return health_data
async def check_service_health(service_name: str, service_config: dict[str, Any]) -> dict[str, Any]:
"""
Check health status of a specific service
"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
health_url = f"{service_config['url']}/health"
response = await client.get(health_url)
return {
"status": "healthy",
"response_time": 0.1,
"last_check": datetime.now(timezone.utc).isoformat(),
"details": response.json(),
}
except Exception as e:
logger.warning(f"Service {service_name} health check failed: {e}")
return {
"status": "unhealthy",
"error": str(e),
"last_check": datetime.now(timezone.utc).isoformat(),
}
def calculate_overall_metrics(health_data: dict[str, Any]) -> dict[str, Any]:
"""Calculate overall system metrics from health data"""
status_counts = {"healthy": 0, "degraded": 0, "unhealthy": 0, "unknown": 0}
for service_health in health_data.values():
status = service_health.get("status", "unknown")
status_counts[status] = status_counts.get(status, 0) + 1
# Determine overall status
if status_counts["unhealthy"] > 0:
overall_status = "unhealthy"
elif status_counts["degraded"] > 0:
overall_status = "degraded"
else:
overall_status = "healthy"
return {
"overall_status": overall_status,
"status_counts": status_counts,
"health_percentage": (status_counts["healthy"] / len(health_data)) * 100 if health_data else 0,
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8107)

View File

@@ -0,0 +1,697 @@
"""
Multi-Region Load Balancing Service for AITBC
Handles intelligent load distribution across global regions
"""
import os
import asyncio
import json
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Dict, Any, List, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from aitbc import get_logger
logger = get_logger(__name__)
app = FastAPI(
title="AITBC Multi-Region Load Balancer",
description="Intelligent load balancing across global regions",
version="1.0.0"
)
# Data models
class LoadBalancingRule(BaseModel):
rule_id: str
name: str
algorithm: str # weighted_round_robin, least_connections, geographic, performance_based
target_regions: List[str]
weights: Dict[str, float] # Region weights
health_check_path: str
failover_enabled: bool
session_affinity: bool
class RegionHealth(BaseModel):
region_id: str
status: str # healthy, unhealthy, degraded
response_time_ms: float
success_rate: float
active_connections: int
last_check: datetime
class LoadBalancingMetrics(BaseModel):
balancer_id: str
timestamp: datetime
total_requests: int
requests_per_region: Dict[str, int]
average_response_time: float
error_rate: float
throughput: float
class GeographicRule(BaseModel):
rule_id: str
source_regions: List[str]
target_regions: List[str]
priority: int # Lower number = higher priority
latency_threshold_ms: float
# In-memory storage (in production, use database)
load_balancing_rules: Dict[str, Dict] = {}
region_health_status: Dict[str, RegionHealth] = {}
balancing_metrics: Dict[str, List[Dict]] = {}
geographic_rules: Dict[str, Dict] = {}
session_affinity_data: Dict[str, Dict] = {}
@app.get("/")
async def root():
return {
"service": "AITBC Multi-Region Load Balancer",
"status": "running",
"timestamp": datetime.now(timezone.utc).isoformat(),
"version": "1.0.0"
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"total_rules": len(load_balancing_rules),
"active_rules": len([r for r in load_balancing_rules.values() if r["status"] == "active"]),
"monitored_regions": len(region_health_status),
"healthy_regions": len([r for r in region_health_status.values() if r.status == "healthy"])
}
@app.post("/api/v1/rules/create")
async def create_load_balancing_rule(rule: LoadBalancingRule):
"""Create a new load balancing rule"""
if rule.rule_id in load_balancing_rules:
raise HTTPException(status_code=400, detail="Load balancing rule already exists")
# Create rule record
rule_record = {
"rule_id": rule.rule_id,
"name": rule.name,
"algorithm": rule.algorithm,
"target_regions": rule.target_regions,
"weights": rule.weights,
"health_check_path": rule.health_check_path,
"failover_enabled": rule.failover_enabled,
"session_affinity": rule.session_affinity,
"status": "active",
"created_at": datetime.now(timezone.utc).isoformat(),
"total_requests": 0,
"failed_requests": 0,
"last_updated": datetime.now(timezone.utc).isoformat()
}
load_balancing_rules[rule.rule_id] = rule_record
# Start health monitoring for target regions
asyncio.create_task(start_health_monitoring(rule.rule_id))
logger.info(f"Load balancing rule created: {rule.name} ({rule.rule_id})")
return {
"rule_id": rule.rule_id,
"status": "created",
"name": rule.name,
"algorithm": rule.algorithm,
"created_at": rule_record["created_at"]
}
@app.get("/api/v1/rules")
async def list_load_balancing_rules():
"""List all load balancing rules"""
return {
"rules": list(load_balancing_rules.values()),
"total_rules": len(load_balancing_rules),
"active_rules": len([r for r in load_balancing_rules.values() if r["status"] == "active"])
}
@app.get("/api/v1/rules/{rule_id}")
async def get_load_balancing_rule(rule_id: str):
"""Get detailed load balancing rule information"""
if rule_id not in load_balancing_rules:
raise HTTPException(status_code=404, detail="Load balancing rule not found")
rule = load_balancing_rules[rule_id].copy()
# Add region health status
rule["region_health"] = {
region_id: region_health_status.get(region_id)
for region_id in rule["target_regions"]
if region_id in region_health_status
}
# Add performance metrics
rule["performance_metrics"] = balancing_metrics.get(rule_id, [])
return rule
@app.post("/api/v1/rules/{rule_id}/update-weights")
async def update_rule_weights(rule_id: str, weights: Dict[str, float]):
"""Update weights for a load balancing rule"""
if rule_id not in load_balancing_rules:
raise HTTPException(status_code=404, detail="Load balancing rule not found")
rule = load_balancing_rules[rule_id]
# Validate weights
total_weight = sum(weights.values())
if total_weight == 0:
raise HTTPException(status_code=400, detail="Total weight cannot be zero")
# Normalize weights
normalized_weights = {k: v / total_weight for k, v in weights.items()}
# Update rule weights
rule["weights"] = normalized_weights
rule["last_updated"] = datetime.now(timezone.utc).isoformat()
logger.info(f"Weights updated for rule {rule_id}: {normalized_weights}")
return {
"rule_id": rule_id,
"new_weights": normalized_weights,
"updated_at": rule["last_updated"]
}
@app.post("/api/v1/health/register")
async def register_region_health(health: RegionHealth):
"""Register or update health status for a region"""
region_health_status[health.region_id] = health
# Update load balancing rules that use this region
for rule_id, rule in load_balancing_rules.items():
if health.region_id in rule["target_regions"]:
# Update rule based on health status
if health.status == "unhealthy" and rule["failover_enabled"]:
logger.warning(f"Region {health.region_id} unhealthy, enabling failover for rule {rule_id}")
enable_failover(rule_id, health.region_id)
return {
"region_id": health.region_id,
"status": health.status,
"registered_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/health")
async def get_all_region_health():
"""Get health status for all monitored regions"""
return {
"region_health": {
region_id: health.dict()
for region_id, health in region_health_status.items()
},
"total_regions": len(region_health_status),
"healthy_regions": len([r for r in region_health_status.values() if r.status == "healthy"]),
"unhealthy_regions": len([r for r in region_health_status.values() if r.status == "unhealthy"]),
"degraded_regions": len([r for r in region_health_status.values() if r.status == "degraded"])
}
@app.post("/api/v1/geographic-rules/create")
async def create_geographic_rule(rule: GeographicRule):
"""Create a geographic routing rule"""
if rule.rule_id in geographic_rules:
raise HTTPException(status_code=400, detail="Geographic rule already exists")
# Create geographic rule record
rule_record = {
"rule_id": rule.rule_id,
"source_regions": rule.source_regions,
"target_regions": rule.target_regions,
"priority": rule.priority,
"latency_threshold_ms": rule.latency_threshold_ms,
"status": "active",
"created_at": datetime.now(timezone.utc).isoformat(),
"usage_count": 0
}
geographic_rules[rule.rule_id] = rule_record
logger.info(f"Geographic rule created: {rule.rule_id}")
return {
"rule_id": rule.rule_id,
"status": "created",
"priority": rule.priority,
"created_at": rule_record["created_at"]
}
@app.get("/api/v1/route/{client_region}")
async def get_optimal_region(client_region: str, rule_id: Optional[str] = None):
"""Get optimal target region for a client region"""
if rule_id and rule_id not in load_balancing_rules:
raise HTTPException(status_code=404, detail="Load balancing rule not found")
# Find optimal region based on rules
if rule_id:
optimal_region = select_region_by_algorithm(rule_id, client_region)
else:
optimal_region = select_region_geographically(client_region)
return {
"client_region": client_region,
"optimal_region": optimal_region,
"rule_id": rule_id,
"selection_reason": get_selection_reason(optimal_region, client_region, rule_id),
"timestamp": datetime.now(timezone.utc).isoformat()
}
@app.post("/api/v1/metrics/record")
async def record_balancing_metrics(metrics: LoadBalancingMetrics):
"""Record load balancing performance metrics"""
metrics_record = {
"metrics_id": f"metrics_{int(datetime.now(timezone.utc).timestamp())}",
"balancer_id": metrics.balancer_id,
"timestamp": metrics.timestamp.isoformat(),
"total_requests": metrics.total_requests,
"requests_per_region": metrics.requests_per_region,
"average_response_time": metrics.average_response_time,
"error_rate": metrics.error_rate,
"throughput": metrics.throughput
}
if metrics.balancer_id not in balancing_metrics:
balancing_metrics[metrics.balancer_id] = []
balancing_metrics[metrics.balancer_id].append(metrics_record)
# Keep only last 1000 records per balancer
if len(balancing_metrics[metrics.balancer_id]) > 1000:
balancing_metrics[metrics.balancer_id] = balancing_metrics[metrics.balancer_id][-1000:]
return {
"metrics_id": metrics_record["metrics_id"],
"status": "recorded",
"timestamp": metrics_record["timestamp"]
}
@app.get("/api/v1/metrics/{rule_id}")
async def get_balancing_metrics(rule_id: str, hours: int = 24):
"""Get performance metrics for a load balancing rule"""
if rule_id not in load_balancing_rules:
raise HTTPException(status_code=404, detail="Load balancing rule not found")
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
recent_metrics = [
m for m in balancing_metrics.get(rule_id, [])
if datetime.fromisoformat(m["timestamp"]) > cutoff_time
]
# Calculate statistics
if recent_metrics:
avg_response_time = sum(m["average_response_time"] for m in recent_metrics) / len(recent_metrics)
avg_error_rate = sum(m["error_rate"] for m in recent_metrics) / len(recent_metrics)
avg_throughput = sum(m["throughput"] for m in recent_metrics) / len(recent_metrics)
total_requests = sum(m["total_requests"] for m in recent_metrics)
else:
avg_response_time = avg_error_rate = avg_throughput = total_requests = 0.0
return {
"rule_id": rule_id,
"period_hours": hours,
"metrics": recent_metrics,
"statistics": {
"average_response_time_ms": round(avg_response_time, 3),
"average_error_rate": round(avg_error_rate, 4),
"average_throughput": round(avg_throughput, 2),
"total_requests": int(total_requests),
"total_samples": len(recent_metrics)
},
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/dashboard")
async def get_load_balancing_dashboard():
"""Get comprehensive load balancing dashboard"""
# Calculate overall statistics
total_rules = len(load_balancing_rules)
active_rules = len([r for r in load_balancing_rules.values() if r["status"] == "active"])
# Region health summary
health_summary = {
"total_regions": len(region_health_status),
"healthy": len([r for r in region_health_status.values() if r.status == "healthy"]),
"unhealthy": len([r for r in region_health_status.values() if r.status == "unhealthy"]),
"degraded": len([r for r in region_health_status.values() if r.status == "degraded"])
}
# Performance summary
performance_summary = {}
for rule_id, metrics_list in balancing_metrics.items():
if metrics_list:
latest_metrics = metrics_list[-1]
performance_summary[rule_id] = {
"total_requests": latest_metrics["total_requests"],
"average_response_time": latest_metrics["average_response_time"],
"error_rate": latest_metrics["error_rate"],
"throughput": latest_metrics["throughput"]
}
# Algorithm distribution
algorithm_distribution = {}
for rule in load_balancing_rules.values():
algorithm = rule["algorithm"]
algorithm_distribution[algorithm] = algorithm_distribution.get(algorithm, 0) + 1
return {
"dashboard": {
"overview": {
"total_rules": total_rules,
"active_rules": active_rules,
"geographic_rules": len(geographic_rules),
"algorithm_distribution": algorithm_distribution
},
"region_health": health_summary,
"performance": performance_summary,
"recent_activity": get_recent_activity()
},
"generated_at": datetime.now(timezone.utc).isoformat()
}
# Core load balancing functions
def select_region_by_algorithm(rule_id: str, client_region: str) -> Optional[str]:
"""Select optimal region based on load balancing algorithm"""
if rule_id not in load_balancing_rules:
return None
rule = load_balancing_rules[rule_id]
algorithm = rule["algorithm"]
target_regions = rule["target_regions"]
# Filter healthy regions
healthy_regions = [
region for region in target_regions
if region in region_health_status and region_health_status[region].status == "healthy"
]
if not healthy_regions:
# Fallback to any region if no healthy ones
healthy_regions = target_regions
if algorithm == "weighted_round_robin":
return select_weighted_round_robin(rule_id, healthy_regions)
elif algorithm == "least_connections":
return select_least_connections(healthy_regions)
elif algorithm == "geographic":
return select_geographic_optimal(client_region, healthy_regions)
elif algorithm == "performance_based":
return select_performance_optimal(healthy_regions)
else:
return healthy_regions[0] if healthy_regions else None
def select_weighted_round_robin(rule_id: str, regions: List[str]) -> str:
"""Select region using weighted round robin"""
rule = load_balancing_rules[rule_id]
weights = rule["weights"]
# Filter weights for available regions
available_weights = {r: weights.get(r, 1.0) for r in regions if r in weights}
if not available_weights:
return regions[0]
# Simple weighted selection (in production, use proper round robin state)
import random
total_weight = sum(available_weights.values())
rand_val = random.uniform(0, total_weight)
current_weight = 0
for region, weight in available_weights.items():
current_weight += weight
if rand_val <= current_weight:
return region
return list(available_weights.keys())[-1]
def select_least_connections(regions: List[str]) -> str:
"""Select region with least connections"""
min_connections = float('inf')
optimal_region = None
for region in regions:
if region in region_health_status:
connections = region_health_status[region].active_connections
if connections < min_connections:
min_connections = connections
optimal_region = region
return optimal_region or regions[0]
def select_geographic_optimal(client_region: str, target_regions: List[str]) -> str:
"""Select region based on geographic proximity"""
# Simplified geographic mapping (in production, use actual geographic data)
geographic_proximity = {
"us-east": ["us-east-1", "us-west-1"],
"us-west": ["us-west-1", "us-east-1"],
"europe": ["eu-west-1", "eu-central-1"],
"asia": ["ap-southeast-1", "ap-northeast-1"]
}
# Find closest regions
for geo_area, close_regions in geographic_proximity.items():
if client_region.lower() in geo_area.lower():
for close_region in close_regions:
if close_region in target_regions:
return close_region
# Fallback to first healthy region
return target_regions[0]
def select_performance_optimal(regions: List[str]) -> str:
"""Select region with best performance"""
best_region = None
best_score = float('inf')
for region in regions:
if region in region_health_status:
health = region_health_status[region]
# Calculate performance score (lower is better)
score = health.response_time_ms * (1 - health.success_rate)
if score < best_score:
best_score = score
best_region = region
return best_region or regions[0]
def select_region_geographically(client_region: str) -> Optional[str]:
"""Select region based on geographic rules"""
# Apply geographic rules
applicable_rules = [
rule for rule in geographic_rules.values()
if client_region in rule["source_regions"] and rule["status"] == "active"
]
# Sort by priority (lower number = higher priority)
applicable_rules.sort(key=lambda x: x["priority"])
for rule in applicable_rules:
# Find best target region based on latency
best_target = None
best_latency = float('inf')
for target_region in rule["target_regions"]:
if target_region in region_health_status:
latency = region_health_status[target_region].response_time_ms
if latency < best_latency and latency < rule["latency_threshold_ms"]:
best_latency = latency
best_target = target_region
if best_target:
rule["usage_count"] += 1
return best_target
# Fallback to any healthy region
healthy_regions = [
region for region, health in region_health_status.items()
if health.status == "healthy"
]
return healthy_regions[0] if healthy_regions else None
def get_selection_reason(region: str, client_region: str, rule_id: Optional[str]) -> str:
"""Get reason for region selection"""
if rule_id and rule_id in load_balancing_rules:
rule = load_balancing_rules[rule_id]
return f"Selected by {rule['algorithm']} algorithm using rule {rule['name']}"
else:
return f"Selected based on geographic proximity from {client_region}"
def enable_failover(rule_id: str, unhealthy_region: str):
"""Enable failover for unhealthy region"""
rule = load_balancing_rules[rule_id]
# Remove unhealthy region from rotation temporarily
if unhealthy_region in rule["target_regions"]:
rule["target_regions"].remove(unhealthy_region)
logger.warning(f"Region {unhealthy_region} removed from load balancing rule {rule_id}")
def get_recent_activity() -> List[Dict]:
"""Get recent load balancing activity"""
activity = []
# Recent health changes
for region_id, health in region_health_status.items():
if (datetime.now(timezone.utc) - health.last_check).total_seconds() < 3600: # Last hour
activity.append({
"type": "health_check",
"region": region_id,
"status": health.status,
"timestamp": health.last_check.isoformat()
})
# Recent rule updates
for rule_id, rule in load_balancing_rules.items():
if (datetime.now(timezone.utc) - datetime.fromisoformat(rule["last_updated"])).total_seconds() < 3600:
activity.append({
"type": "rule_update",
"rule_id": rule_id,
"name": rule["name"],
"timestamp": rule["last_updated"]
})
# Sort by timestamp (most recent first)
activity.sort(key=lambda x: x["timestamp"], reverse=True)
return activity[:20]
# Background task for health monitoring
async def start_health_monitoring(rule_id: str):
"""Start health monitoring for a load balancing rule"""
rule = load_balancing_rules[rule_id]
while rule["status"] == "active":
try:
# Check health of all target regions
for region_id in rule["target_regions"]:
await check_region_health(region_id)
await asyncio.sleep(30) # Check every 30 seconds
except Exception as e:
logger.error(f"Health monitoring error for rule {rule_id}: {str(e)}")
await asyncio.sleep(10)
async def check_region_health(region_id: str):
"""Check health of a specific region"""
# Simulate health check (in production, this would be actual health checks)
import random
# Simulate health metrics
response_time = random.uniform(20, 200)
success_rate = random.uniform(0.95, 1.0)
active_connections = random.randint(100, 1000)
# Determine health status
if response_time < 100 and success_rate > 0.99:
status = "healthy"
elif response_time < 200 and success_rate > 0.95:
status = "degraded"
else:
status = "unhealthy"
health = RegionHealth(
region_id=region_id,
status=status,
response_time_ms=response_time,
success_rate=success_rate,
active_connections=active_connections,
last_check=datetime.now(timezone.utc)
)
region_health_status[region_id] = health
# Initialize with some default rules
@app.on_event("startup")
async def startup_event():
logger.info("Starting AITBC Multi-Region Load Balancer")
# Initialize default load balancing rules
default_rules = [
{
"rule_id": "global-web-rule",
"name": "Global Web Load Balancer",
"algorithm": "weighted_round_robin",
"target_regions": ["us-east-1", "eu-west-1", "ap-southeast-1"],
"weights": {"us-east-1": 0.4, "eu-west-1": 0.35, "ap-southeast-1": 0.25},
"health_check_path": "/health",
"failover_enabled": True,
"session_affinity": False
},
{
"rule_id": "api-performance-rule",
"name": "API Performance Optimizer",
"algorithm": "performance_based",
"target_regions": ["us-east-1", "eu-west-1"],
"weights": {"us-east-1": 0.5, "eu-west-1": 0.5},
"health_check_path": "/api/health",
"failover_enabled": True,
"session_affinity": True
}
]
for rule_data in default_rules:
rule = LoadBalancingRule(**rule_data)
rule_record = {
"rule_id": rule.rule_id,
"name": rule.name,
"algorithm": rule.algorithm,
"target_regions": rule.target_regions,
"weights": rule.weights,
"health_check_path": rule.health_check_path,
"failover_enabled": rule.failover_enabled,
"session_affinity": rule.session_affinity,
"status": "active",
"created_at": datetime.now(timezone.utc).isoformat(),
"total_requests": 0,
"failed_requests": 0,
"last_updated": datetime.now(timezone.utc).isoformat()
}
load_balancing_rules[rule.rule_id] = rule_record
# Start health monitoring
asyncio.create_task(start_health_monitoring(rule.rule_id))
# Initialize default geographic rules
default_geo_rules = [
{
"rule_id": "us-to-us",
"source_regions": ["us-east", "us-west", "north-america"],
"target_regions": ["us-east-1", "us-west-1"],
"priority": 1,
"latency_threshold_ms": 50
},
{
"rule_id": "eu-to-eu",
"source_regions": ["europe", "eu-west", "eu-central"],
"target_regions": ["eu-west-1", "eu-central-1"],
"priority": 1,
"latency_threshold_ms": 30
}
]
for geo_rule_data in default_geo_rules:
geo_rule = GeographicRule(**geo_rule_data)
geo_rule_record = {
"rule_id": geo_rule.rule_id,
"source_regions": geo_rule.source_regions,
"target_regions": geo_rule.target_regions,
"priority": geo_rule.priority,
"latency_threshold_ms": geo_rule.latency_threshold_ms,
"status": "active",
"created_at": datetime.now(timezone.utc).isoformat(),
"usage_count": 0
}
geographic_rules[geo_rule.rule_id] = geo_rule_record
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Shutting down AITBC Multi-Region Load Balancer")
if __name__ == "__main__":
import uvicorn
import os
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8019, log_level="info")

View File

@@ -0,0 +1 @@
"""Multi-region load balancer service tests"""

View File

@@ -0,0 +1,199 @@
"""Edge case and error handling tests for multi-region load balancer service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from main import app, LoadBalancingRule, RegionHealth, LoadBalancingMetrics, GeographicRule, load_balancing_rules, region_health_status, balancing_metrics, geographic_rules
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
load_balancing_rules.clear()
region_health_status.clear()
balancing_metrics.clear()
geographic_rules.clear()
yield
load_balancing_rules.clear()
region_health_status.clear()
balancing_metrics.clear()
geographic_rules.clear()
@pytest.mark.unit
def test_load_balancing_rule_empty_target_regions():
"""Test LoadBalancingRule with empty target regions"""
rule = LoadBalancingRule(
rule_id="rule_123",
name="Test Rule",
algorithm="round_robin",
target_regions=[],
weights={},
health_check_path="/health",
failover_enabled=False,
session_affinity=False
)
assert rule.target_regions == []
@pytest.mark.unit
def test_region_health_negative_success_rate():
"""Test RegionHealth with negative success rate"""
health = RegionHealth(
region_id="us-east-1",
status="healthy",
response_time_ms=45.5,
success_rate=-0.5,
active_connections=100,
last_check=datetime.now(timezone.utc)
)
assert health.success_rate == -0.5
@pytest.mark.unit
def test_region_health_negative_connections():
"""Test RegionHealth with negative connections"""
health = RegionHealth(
region_id="us-east-1",
status="healthy",
response_time_ms=45.5,
success_rate=0.99,
active_connections=-100,
last_check=datetime.now(timezone.utc)
)
assert health.active_connections == -100
@pytest.mark.unit
def test_load_balancing_metrics_negative_requests():
"""Test LoadBalancingMetrics with negative requests"""
metrics = LoadBalancingMetrics(
balancer_id="lb_123",
timestamp=datetime.now(timezone.utc),
total_requests=-1000,
requests_per_region={},
average_response_time=50.5,
error_rate=0.001,
throughput=100.0
)
assert metrics.total_requests == -1000
@pytest.mark.unit
def test_load_balancing_metrics_negative_response_time():
"""Test LoadBalancingMetrics with negative response time"""
metrics = LoadBalancingMetrics(
balancer_id="lb_123",
timestamp=datetime.now(timezone.utc),
total_requests=1000,
requests_per_region={},
average_response_time=-50.5,
error_rate=0.001,
throughput=100.0
)
assert metrics.average_response_time == -50.5
@pytest.mark.unit
def test_geographic_rule_empty_source_regions():
"""Test GeographicRule with empty source regions"""
rule = GeographicRule(
rule_id="geo_123",
source_regions=[],
target_regions=["us-east-1"],
priority=1,
latency_threshold_ms=50.0
)
assert rule.source_regions == []
@pytest.mark.unit
def test_geographic_rule_negative_priority():
"""Test GeographicRule with negative priority"""
rule = GeographicRule(
rule_id="geo_123",
source_regions=["us-east"],
target_regions=["us-east-1"],
priority=-5,
latency_threshold_ms=50.0
)
assert rule.priority == -5
@pytest.mark.unit
def test_geographic_rule_negative_latency_threshold():
"""Test GeographicRule with negative latency threshold"""
rule = GeographicRule(
rule_id="geo_123",
source_regions=["us-east"],
target_regions=["us-east-1"],
priority=1,
latency_threshold_ms=-50.0
)
assert rule.latency_threshold_ms == -50.0
@pytest.mark.integration
def test_list_rules_with_no_rules():
"""Test listing rules when no rules exist"""
client = TestClient(app)
response = client.get("/api/v1/rules")
assert response.status_code == 200
data = response.json()
assert data["total_rules"] == 0
@pytest.mark.integration
def test_get_region_health_with_no_regions():
"""Test getting region health when no regions exist"""
client = TestClient(app)
response = client.get("/api/v1/health")
assert response.status_code == 200
data = response.json()
assert data["total_regions"] == 0
@pytest.mark.integration
def test_get_balancing_metrics_hours_parameter():
"""Test getting balancing metrics with custom hours parameter"""
client = TestClient(app)
# Create a rule first
rule = LoadBalancingRule(
rule_id="rule_123",
name="Test Rule",
algorithm="weighted_round_robin",
target_regions=["us-east-1"],
weights={"us-east-1": 1.0},
health_check_path="/health",
failover_enabled=True,
session_affinity=False
)
client.post("/api/v1/rules/create", json=rule.model_dump())
response = client.get("/api/v1/metrics/rule_123?hours=12")
assert response.status_code == 200
data = response.json()
assert data["period_hours"] == 12
@pytest.mark.integration
def test_get_optimal_region_nonexistent_rule():
"""Test getting optimal region with nonexistent rule"""
client = TestClient(app)
response = client.get("/api/v1/route/us-east?rule_id=nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_dashboard_with_no_data():
"""Test dashboard with no data"""
client = TestClient(app)
response = client.get("/api/v1/dashboard")
assert response.status_code == 200
data = response.json()
assert data["dashboard"]["overview"]["total_rules"] == 0

View File

@@ -0,0 +1,341 @@
"""Integration tests for multi-region load balancer service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from main import app, LoadBalancingRule, RegionHealth, LoadBalancingMetrics, GeographicRule, load_balancing_rules, region_health_status, balancing_metrics, geographic_rules
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
load_balancing_rules.clear()
region_health_status.clear()
balancing_metrics.clear()
geographic_rules.clear()
yield
load_balancing_rules.clear()
region_health_status.clear()
balancing_metrics.clear()
geographic_rules.clear()
@pytest.mark.integration
def test_root_endpoint():
"""Test root endpoint"""
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["service"] == "AITBC Multi-Region Load Balancer"
assert data["status"] == "running"
@pytest.mark.integration
def test_health_check_endpoint():
"""Test health check endpoint"""
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "total_rules" in data
@pytest.mark.integration
def test_create_load_balancing_rule():
"""Test creating a load balancing rule"""
client = TestClient(app)
rule = LoadBalancingRule(
rule_id="rule_123",
name="Test Rule",
algorithm="weighted_round_robin",
target_regions=["us-east-1"],
weights={"us-east-1": 1.0},
health_check_path="/health",
failover_enabled=True,
session_affinity=False
)
response = client.post("/api/v1/rules/create", json=rule.model_dump())
assert response.status_code == 200
data = response.json()
assert data["rule_id"] == "rule_123"
assert data["status"] == "created"
@pytest.mark.integration
def test_create_duplicate_rule():
"""Test creating duplicate load balancing rule"""
client = TestClient(app)
rule = LoadBalancingRule(
rule_id="rule_123",
name="Test Rule",
algorithm="weighted_round_robin",
target_regions=["us-east-1"],
weights={"us-east-1": 1.0},
health_check_path="/health",
failover_enabled=True,
session_affinity=False
)
client.post("/api/v1/rules/create", json=rule.model_dump())
response = client.post("/api/v1/rules/create", json=rule.model_dump())
assert response.status_code == 400
@pytest.mark.integration
def test_list_load_balancing_rules():
"""Test listing load balancing rules"""
client = TestClient(app)
response = client.get("/api/v1/rules")
assert response.status_code == 200
data = response.json()
assert "rules" in data
assert "total_rules" in data
@pytest.mark.integration
def test_get_load_balancing_rule():
"""Test getting specific load balancing rule"""
client = TestClient(app)
rule = LoadBalancingRule(
rule_id="rule_123",
name="Test Rule",
algorithm="weighted_round_robin",
target_regions=["us-east-1"],
weights={"us-east-1": 1.0},
health_check_path="/health",
failover_enabled=True,
session_affinity=False
)
client.post("/api/v1/rules/create", json=rule.model_dump())
response = client.get("/api/v1/rules/rule_123")
assert response.status_code == 200
data = response.json()
assert data["rule_id"] == "rule_123"
@pytest.mark.integration
def test_get_load_balancing_rule_not_found():
"""Test getting nonexistent load balancing rule"""
client = TestClient(app)
response = client.get("/api/v1/rules/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_update_rule_weights():
"""Test updating rule weights"""
client = TestClient(app)
rule = LoadBalancingRule(
rule_id="rule_123",
name="Test Rule",
algorithm="weighted_round_robin",
target_regions=["us-east-1", "eu-west-1"],
weights={"us-east-1": 0.5, "eu-west-1": 0.5},
health_check_path="/health",
failover_enabled=True,
session_affinity=False
)
client.post("/api/v1/rules/create", json=rule.model_dump())
new_weights = {"us-east-1": 0.7, "eu-west-1": 0.3}
response = client.post("/api/v1/rules/rule_123/update-weights", json=new_weights)
assert response.status_code == 200
data = response.json()
assert data["rule_id"] == "rule_123"
assert "new_weights" in data
@pytest.mark.integration
def test_update_rule_weights_not_found():
"""Test updating weights for nonexistent rule"""
client = TestClient(app)
new_weights = {"us-east-1": 1.0}
response = client.post("/api/v1/rules/nonexistent/update-weights", json=new_weights)
assert response.status_code == 404
@pytest.mark.integration
def test_update_rule_weights_zero_total():
"""Test updating weights with zero total"""
client = TestClient(app)
rule = LoadBalancingRule(
rule_id="rule_123",
name="Test Rule",
algorithm="weighted_round_robin",
target_regions=["us-east-1"],
weights={"us-east-1": 1.0},
health_check_path="/health",
failover_enabled=True,
session_affinity=False
)
client.post("/api/v1/rules/create", json=rule.model_dump())
new_weights = {"us-east-1": 0.0}
response = client.post("/api/v1/rules/rule_123/update-weights", json=new_weights)
assert response.status_code == 400
@pytest.mark.integration
def test_register_region_health():
"""Test registering region health"""
client = TestClient(app)
health = RegionHealth(
region_id="us-east-1",
status="healthy",
response_time_ms=45.5,
success_rate=0.99,
active_connections=100,
last_check=datetime.now(timezone.utc)
)
response = client.post("/api/v1/health/register", json=health.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["region_id"] == "us-east-1"
@pytest.mark.integration
def test_get_all_region_health():
"""Test getting all region health"""
client = TestClient(app)
response = client.get("/api/v1/health")
assert response.status_code == 200
data = response.json()
assert "region_health" in data
@pytest.mark.integration
def test_create_geographic_rule():
"""Test creating geographic rule"""
client = TestClient(app)
rule = GeographicRule(
rule_id="geo_123",
source_regions=["us-east"],
target_regions=["us-east-1"],
priority=1,
latency_threshold_ms=50.0
)
response = client.post("/api/v1/geographic-rules/create", json=rule.model_dump())
assert response.status_code == 200
data = response.json()
assert data["rule_id"] == "geo_123"
assert data["status"] == "created"
@pytest.mark.integration
def test_create_duplicate_geographic_rule():
"""Test creating duplicate geographic rule"""
client = TestClient(app)
rule = GeographicRule(
rule_id="geo_123",
source_regions=["us-east"],
target_regions=["us-east-1"],
priority=1,
latency_threshold_ms=50.0
)
client.post("/api/v1/geographic-rules/create", json=rule.model_dump())
response = client.post("/api/v1/geographic-rules/create", json=rule.model_dump())
assert response.status_code == 400
@pytest.mark.integration
def test_get_optimal_region():
"""Test getting optimal region"""
client = TestClient(app)
response = client.get("/api/v1/route/us-east")
assert response.status_code == 200
data = response.json()
assert "client_region" in data
assert "optimal_region" in data
@pytest.mark.integration
def test_get_optimal_region_with_rule():
"""Test getting optimal region with specific rule"""
client = TestClient(app)
# Create a rule first
rule = LoadBalancingRule(
rule_id="rule_123",
name="Test Rule",
algorithm="weighted_round_robin",
target_regions=["us-east-1"],
weights={"us-east-1": 1.0},
health_check_path="/health",
failover_enabled=True,
session_affinity=False
)
client.post("/api/v1/rules/create", json=rule.model_dump())
response = client.get("/api/v1/route/us-east?rule_id=rule_123")
assert response.status_code == 200
data = response.json()
assert data["rule_id"] == "rule_123"
@pytest.mark.integration
def test_record_balancing_metrics():
"""Test recording balancing metrics"""
client = TestClient(app)
metrics = LoadBalancingMetrics(
balancer_id="lb_123",
timestamp=datetime.now(timezone.utc),
total_requests=1000,
requests_per_region={"us-east-1": 500},
average_response_time=50.5,
error_rate=0.001,
throughput=100.0
)
response = client.post("/api/v1/metrics/record", json=metrics.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["metrics_id"]
assert data["status"] == "recorded"
@pytest.mark.integration
def test_get_balancing_metrics():
"""Test getting balancing metrics"""
client = TestClient(app)
# Create a rule first
rule = LoadBalancingRule(
rule_id="rule_123",
name="Test Rule",
algorithm="weighted_round_robin",
target_regions=["us-east-1"],
weights={"us-east-1": 1.0},
health_check_path="/health",
failover_enabled=True,
session_affinity=False
)
client.post("/api/v1/rules/create", json=rule.model_dump())
response = client.get("/api/v1/metrics/rule_123")
assert response.status_code == 200
data = response.json()
assert data["rule_id"] == "rule_123"
@pytest.mark.integration
def test_get_balancing_metrics_not_found():
"""Test getting metrics for nonexistent rule"""
client = TestClient(app)
response = client.get("/api/v1/metrics/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_get_load_balancing_dashboard():
"""Test getting load balancing dashboard"""
client = TestClient(app)
response = client.get("/api/v1/dashboard")
assert response.status_code == 200
data = response.json()
assert "dashboard" in data

View File

@@ -0,0 +1,120 @@
"""Unit tests for multi-region load balancer service"""
import pytest
import sys
import sys
from pathlib import Path
from datetime import datetime, timezone
from main import app, LoadBalancingRule, RegionHealth, LoadBalancingMetrics, GeographicRule
@pytest.mark.unit
def test_app_initialization():
"""Test that the FastAPI app initializes correctly"""
assert app is not None
assert app.title == "AITBC Multi-Region Load Balancer"
assert app.version == "1.0.0"
@pytest.mark.unit
def test_load_balancing_rule_model():
"""Test LoadBalancingRule model"""
rule = LoadBalancingRule(
rule_id="rule_123",
name="Test Rule",
algorithm="weighted_round_robin",
target_regions=["us-east-1", "eu-west-1"],
weights={"us-east-1": 0.5, "eu-west-1": 0.5},
health_check_path="/health",
failover_enabled=True,
session_affinity=False
)
assert rule.rule_id == "rule_123"
assert rule.name == "Test Rule"
assert rule.algorithm == "weighted_round_robin"
assert rule.failover_enabled is True
assert rule.session_affinity is False
@pytest.mark.unit
def test_region_health_model():
"""Test RegionHealth model"""
health = RegionHealth(
region_id="us-east-1",
status="healthy",
response_time_ms=45.5,
success_rate=0.99,
active_connections=100,
last_check=datetime.now(timezone.utc)
)
assert health.region_id == "us-east-1"
assert health.status == "healthy"
assert health.response_time_ms == 45.5
assert health.success_rate == 0.99
assert health.active_connections == 100
@pytest.mark.unit
def test_load_balancing_metrics_model():
"""Test LoadBalancingMetrics model"""
metrics = LoadBalancingMetrics(
balancer_id="lb_123",
timestamp=datetime.now(timezone.utc),
total_requests=1000,
requests_per_region={"us-east-1": 500, "eu-west-1": 500},
average_response_time=50.5,
error_rate=0.001,
throughput=100.0
)
assert metrics.balancer_id == "lb_123"
assert metrics.total_requests == 1000
assert metrics.average_response_time == 50.5
assert metrics.error_rate == 0.001
@pytest.mark.unit
def test_geographic_rule_model():
"""Test GeographicRule model"""
rule = GeographicRule(
rule_id="geo_123",
source_regions=["us-east", "us-west"],
target_regions=["us-east-1", "us-west-1"],
priority=1,
latency_threshold_ms=50.0
)
assert rule.rule_id == "geo_123"
assert rule.source_regions == ["us-east", "us-west"]
assert rule.priority == 1
assert rule.latency_threshold_ms == 50.0
@pytest.mark.unit
def test_load_balancing_rule_empty_weights():
"""Test LoadBalancingRule with empty weights"""
rule = LoadBalancingRule(
rule_id="rule_123",
name="Test Rule",
algorithm="round_robin",
target_regions=["us-east-1"],
weights={},
health_check_path="/health",
failover_enabled=False,
session_affinity=False
)
assert rule.weights == {}
@pytest.mark.unit
def test_region_health_negative_response_time():
"""Test RegionHealth with negative response time"""
health = RegionHealth(
region_id="us-east-1",
status="healthy",
response_time_ms=-45.5,
success_rate=0.99,
active_connections=100,
last_check=datetime.now(timezone.utc)
)
assert health.response_time_ms == -45.5

View File

@@ -0,0 +1,655 @@
"""
Plugin Analytics and Usage Tracking Service for AITBC
Handles plugin analytics, usage tracking, and performance monitoring
"""
import os
import asyncio
import json
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Dict, Any, List, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from aitbc import get_logger
logger = get_logger(__name__)
app = FastAPI(
title="AITBC Plugin Analytics Service",
description="Plugin analytics, usage tracking, and performance monitoring",
version="1.0.0"
)
# Data models
class PluginUsage(BaseModel):
plugin_id: str
user_id: str
action: str # install, uninstall, enable, disable, use
timestamp: datetime
metadata: Dict[str, Any] = {}
class PluginPerformance(BaseModel):
plugin_id: str
version: str
cpu_usage: float
memory_usage: float
response_time: float
error_rate: float
uptime: float
timestamp: datetime
class PluginRating(BaseModel):
plugin_id: str
user_id: str
rating: int # 1-5
review: Optional[str] = None
timestamp: datetime
class PluginEvent(BaseModel):
event_type: str
plugin_id: str
user_id: Optional[str] = None
data: Dict[str, Any] = {}
timestamp: datetime
# In-memory storage (in production, use database)
plugin_usage_data: Dict[str, List[Dict]] = {}
plugin_performance_data: Dict[str, List[Dict]] = {}
plugin_ratings: Dict[str, List[Dict]] = {}
plugin_events: Dict[str, List[Dict]] = {}
analytics_cache: Dict[str, Dict] = {}
usage_trends: Dict[str, Dict] = {}
@app.get("/")
async def root():
return {
"service": "AITBC Plugin Analytics Service",
"status": "running",
"timestamp": datetime.now(timezone.utc).isoformat(),
"version": "1.0.0"
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"total_usage_records": sum(len(data) for data in plugin_usage_data.values()),
"total_performance_records": sum(len(data) for data in plugin_performance_data.values()),
"total_ratings": sum(len(data) for data in plugin_ratings.values()),
"total_events": sum(len(data) for data in plugin_events.values()),
"cache_size": len(analytics_cache)
}
@app.post("/api/v1/analytics/usage")
async def record_plugin_usage(usage: PluginUsage):
"""Record plugin usage event"""
usage_record = {
"usage_id": f"usage_{int(datetime.now(timezone.utc).timestamp())}",
"plugin_id": usage.plugin_id,
"user_id": usage.user_id,
"action": usage.action,
"timestamp": usage.timestamp.isoformat(),
"metadata": usage.metadata
}
if usage.plugin_id not in plugin_usage_data:
plugin_usage_data[usage.plugin_id] = []
plugin_usage_data[usage.plugin_id].append(usage_record)
# Update usage trends
update_usage_trends(usage.plugin_id, usage.action, usage.timestamp)
logger.info(f"Usage recorded: {usage.plugin_id} - {usage.action} by {usage.user_id}")
return {
"usage_id": usage_record["usage_id"],
"status": "recorded",
"timestamp": usage_record["timestamp"]
}
@app.post("/api/v1/analytics/performance")
async def record_plugin_performance(performance: PluginPerformance):
"""Record plugin performance metrics"""
performance_record = {
"performance_id": f"perf_{int(datetime.now(timezone.utc).timestamp())}",
"plugin_id": performance.plugin_id,
"version": performance.version,
"cpu_usage": performance.cpu_usage,
"memory_usage": performance.memory_usage,
"response_time": performance.response_time,
"error_rate": performance.error_rate,
"uptime": performance.uptime,
"timestamp": performance.timestamp.isoformat()
}
if performance.plugin_id not in plugin_performance_data:
plugin_performance_data[performance.plugin_id] = []
plugin_performance_data[performance.plugin_id].append(performance_record)
logger.info(f"Performance recorded: {performance.plugin_id} - CPU: {performance.cpu_usage}%, Memory: {performance.memory_usage}%")
return {
"performance_id": performance_record["performance_id"],
"status": "recorded",
"timestamp": performance_record["timestamp"]
}
@app.post("/api/v1/analytics/rating")
async def record_plugin_rating(rating: PluginRating):
"""Record plugin rating and review"""
rating_record = {
"rating_id": f"rating_{int(datetime.now(timezone.utc).timestamp())}",
"plugin_id": rating.plugin_id,
"user_id": rating.user_id,
"rating": rating.rating,
"review": rating.review,
"timestamp": rating.timestamp.isoformat()
}
if rating.plugin_id not in plugin_ratings:
plugin_ratings[rating.plugin_id] = []
plugin_ratings[rating.plugin_id].append(rating_record)
logger.info(f"Rating recorded: {rating.plugin_id} - {rating.rating} stars by {rating.user_id}")
return {
"rating_id": rating_record["rating_id"],
"status": "recorded",
"timestamp": rating_record["timestamp"]
}
@app.post("/api/v1/analytics/event")
async def record_plugin_event(event: PluginEvent):
"""Record generic plugin event"""
event_record = {
"event_id": f"event_{int(datetime.now(timezone.utc).timestamp())}",
"event_type": event.event_type,
"plugin_id": event.plugin_id,
"user_id": event.user_id,
"data": event.data,
"timestamp": event.timestamp.isoformat()
}
if event.plugin_id not in plugin_events:
plugin_events[event.plugin_id] = []
plugin_events[event.plugin_id].append(event_record)
logger.info(f"Event recorded: {event.event_type} for {event.plugin_id}")
return {
"event_id": event_record["event_id"],
"status": "recorded",
"timestamp": event_record["timestamp"]
}
@app.get("/api/v1/analytics/usage/{plugin_id}")
async def get_plugin_usage(plugin_id: str, days: int = 30):
"""Get usage analytics for a specific plugin"""
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
usage_records = plugin_usage_data.get(plugin_id, [])
recent_usage = [r for r in usage_records
if datetime.fromisoformat(r["timestamp"]) > cutoff_date]
# Calculate usage statistics
usage_stats = calculate_usage_statistics(recent_usage)
return {
"plugin_id": plugin_id,
"period_days": days,
"usage_statistics": usage_stats,
"total_records": len(recent_usage),
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/analytics/performance/{plugin_id}")
async def get_plugin_performance(plugin_id: str, hours: int = 24):
"""Get performance analytics for a specific plugin"""
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
performance_records = plugin_performance_data.get(plugin_id, [])
recent_performance = [r for r in performance_records
if datetime.fromisoformat(r["timestamp"]) > cutoff_time]
# Calculate performance statistics
performance_stats = calculate_performance_statistics(recent_performance)
return {
"plugin_id": plugin_id,
"period_hours": hours,
"performance_statistics": performance_stats,
"total_records": len(recent_performance),
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/analytics/ratings/{plugin_id}")
async def get_plugin_ratings(plugin_id: str):
"""Get ratings and reviews for a specific plugin"""
rating_records = plugin_ratings.get(plugin_id, [])
# Calculate rating statistics
rating_stats = calculate_rating_statistics(rating_records)
return {
"plugin_id": plugin_id,
"rating_statistics": rating_stats,
"total_ratings": len(rating_records),
"recent_ratings": rating_records[-10:], # Last 10 ratings
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/analytics/dashboard")
async def get_analytics_dashboard():
"""Get comprehensive analytics dashboard"""
dashboard_data = {
"overview": get_overview_statistics(),
"trending_plugins": get_trending_plugins(),
"usage_trends": get_global_usage_trends(),
"performance_summary": get_performance_summary(),
"rating_summary": get_rating_summary(),
"recent_events": get_recent_events()
}
return {
"dashboard": dashboard_data,
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/analytics/trends")
async def get_usage_trends(plugin_id: Optional[str] = None, days: int = 30):
"""Get usage trends data"""
if plugin_id:
return get_plugin_trends(plugin_id, days)
else:
return get_global_usage_trends(days)
@app.get("/api/v1/analytics/reports")
async def generate_analytics_report(report_type: str, plugin_id: Optional[str] = None):
"""Generate various analytics reports"""
if report_type == "usage":
return generate_usage_report(plugin_id)
elif report_type == "performance":
return generate_performance_report(plugin_id)
elif report_type == "ratings":
return generate_ratings_report(plugin_id)
elif report_type == "summary":
return generate_summary_report(plugin_id)
else:
raise HTTPException(status_code=400, detail="Invalid report type")
# Analytics calculation functions
def calculate_usage_statistics(usage_records: List[Dict]) -> Dict[str, Any]:
"""Calculate usage statistics from usage records"""
if not usage_records:
return {
"total_actions": 0,
"unique_users": 0,
"action_distribution": {},
"daily_usage": {}
}
# Basic statistics
total_actions = len(usage_records)
unique_users = len(set(r["user_id"] for r in usage_records))
# Action distribution
action_counts = {}
for record in usage_records:
action = record["action"]
action_counts[action] = action_counts.get(action, 0) + 1
# Daily usage
daily_usage = {}
for record in usage_records:
date = datetime.fromisoformat(record["timestamp"]).date().isoformat()
daily_usage[date] = daily_usage.get(date, 0) + 1
return {
"total_actions": total_actions,
"unique_users": unique_users,
"action_distribution": action_counts,
"daily_usage": daily_usage,
"most_common_action": max(action_counts.items(), key=lambda x: x[1])[0] if action_counts else None
}
def calculate_performance_statistics(performance_records: List[Dict]) -> Dict[str, Any]:
"""Calculate performance statistics from performance records"""
if not performance_records:
return {
"avg_cpu_usage": 0.0,
"avg_memory_usage": 0.0,
"avg_response_time": 0.0,
"avg_error_rate": 0.0,
"avg_uptime": 0.0
}
# Calculate averages
cpu_usage = sum(r["cpu_usage"] for r in performance_records) / len(performance_records)
memory_usage = sum(r["memory_usage"] for r in performance_records) / len(performance_records)
response_time = sum(r["response_time"] for r in performance_records) / len(performance_records)
error_rate = sum(r["error_rate"] for r in performance_records) / len(performance_records)
uptime = sum(r["uptime"] for r in performance_records) / len(performance_records)
# Calculate min/max
min_cpu = min(r["cpu_usage"] for r in performance_records)
max_cpu = max(r["cpu_usage"] for r in performance_records)
return {
"avg_cpu_usage": round(cpu_usage, 2),
"avg_memory_usage": round(memory_usage, 2),
"avg_response_time": round(response_time, 3),
"avg_error_rate": round(error_rate, 4),
"avg_uptime": round(uptime, 2),
"min_cpu_usage": round(min_cpu, 2),
"max_cpu_usage": round(max_cpu, 2),
"total_samples": len(performance_records)
}
def calculate_rating_statistics(rating_records: List[Dict]) -> Dict[str, Any]:
"""Calculate rating statistics from rating records"""
if not rating_records:
return {
"average_rating": 0.0,
"total_ratings": 0,
"rating_distribution": {1: 0, 2: 0, 3: 0, 4: 0, 5: 0}
}
# Calculate average rating
total_rating = sum(r["rating"] for r in rating_records)
average_rating = total_rating / len(rating_records)
# Rating distribution
rating_distribution = {1: 0, 2: 0, 3: 0, 4: 0, 5: 0}
for record in rating_records:
rating_distribution[record["rating"]] += 1
return {
"average_rating": round(average_rating, 2),
"total_ratings": len(rating_records),
"rating_distribution": rating_distribution,
"latest_rating": rating_records[-1]["rating"] if rating_records else 0
}
def update_usage_trends(plugin_id: str, action: str, timestamp: datetime):
"""Update usage trends data"""
if plugin_id not in usage_trends:
usage_trends[plugin_id] = {
"daily": {},
"weekly": {},
"monthly": {}
}
# Update daily trends
date_key = timestamp.date().isoformat()
if date_key not in usage_trends[plugin_id]["daily"]:
usage_trends[plugin_id]["daily"][date_key] = {}
usage_trends[plugin_id]["daily"][date_key][action] = usage_trends[plugin_id]["daily"][date_key].get(action, 0) + 1
def get_overview_statistics() -> Dict[str, Any]:
"""Get overview statistics for all plugins"""
total_plugins = len(set(plugin_usage_data.keys()) | set(plugin_performance_data.keys()) | set(plugin_ratings.keys()))
total_usage = sum(len(data) for data in plugin_usage_data.values())
total_ratings = sum(len(data) for data in plugin_ratings.values())
# Calculate active plugins (plugins with usage in last 7 days)
cutoff_date = datetime.now(timezone.utc) - timedelta(days=7)
active_plugins = 0
for plugin_id, usage_records in plugin_usage_data.items():
recent_usage = [r for r in usage_records
if datetime.fromisoformat(r["timestamp"]) > cutoff_date]
if recent_usage:
active_plugins += 1
return {
"total_plugins": total_plugins,
"active_plugins": active_plugins,
"total_usage_events": total_usage,
"total_ratings": total_ratings,
"average_ratings_per_plugin": round(total_ratings / total_plugins, 2) if total_plugins > 0 else 0
}
def get_trending_plugins(limit: int = 10) -> List[Dict]:
"""Get trending plugins based on recent usage"""
cutoff_date = datetime.now(timezone.utc) - timedelta(days=7)
plugin_scores = []
for plugin_id, usage_records in plugin_usage_data.items():
recent_usage = [r for r in usage_records
if datetime.fromisoformat(r["timestamp"]) > cutoff_date]
if recent_usage:
# Calculate trend score (simplified)
score = len(recent_usage) + len(set(r["user_id"] for r in recent_usage))
plugin_scores.append({
"plugin_id": plugin_id,
"trend_score": score,
"recent_usage": len(recent_usage),
"unique_users": len(set(r["user_id"] for r in recent_usage))
})
# Sort by trend score
plugin_scores.sort(key=lambda x: x["trend_score"], reverse=True)
return plugin_scores[:limit]
def get_global_usage_trends(days: int = 30) -> Dict[str, Any]:
"""Get global usage trends"""
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
global_trends = {}
for plugin_id, usage_records in plugin_usage_data.items():
recent_usage = [r for r in usage_records
if datetime.fromisoformat(r["timestamp"]) > cutoff_date]
if recent_usage:
daily_counts = {}
for record in recent_usage:
date = datetime.fromisoformat(record["timestamp"]).date().isoformat()
daily_counts[date] = daily_counts.get(date, 0) + 1
global_trends[plugin_id] = daily_counts
return {
"trends": global_trends,
"period_days": days,
"total_plugins": len(global_trends)
}
def get_performance_summary() -> Dict[str, Any]:
"""Get performance summary for all plugins"""
all_performance = []
for plugin_id, performance_records in plugin_performance_data.items():
if performance_records:
latest_record = performance_records[-1]
all_performance.append({
"plugin_id": plugin_id,
"cpu_usage": latest_record["cpu_usage"],
"memory_usage": latest_record["memory_usage"],
"response_time": latest_record["response_time"],
"error_rate": latest_record["error_rate"]
})
# Calculate averages
if all_performance:
avg_cpu = sum(p["cpu_usage"] for p in all_performance) / len(all_performance)
avg_memory = sum(p["memory_usage"] for p in all_performance) / len(all_performance)
avg_response = sum(p["response_time"] for p in all_performance) / len(all_performance)
avg_error = sum(p["error_rate"] for p in all_performance) / len(all_performance)
else:
avg_cpu = avg_memory = avg_response = avg_error = 0.0
return {
"total_plugins": len(all_performance),
"average_cpu_usage": round(avg_cpu, 2),
"average_memory_usage": round(avg_memory, 2),
"average_response_time": round(avg_response, 3),
"average_error_rate": round(avg_error, 4),
"top_cpu_users": sorted(all_performance, key=lambda x: x["cpu_usage"], reverse=True)[:5]
}
def get_rating_summary() -> Dict[str, Any]:
"""Get rating summary for all plugins"""
all_ratings = []
for plugin_id, rating_records in plugin_ratings.items():
if rating_records:
avg_rating = sum(r["rating"] for r in rating_records) / len(rating_records)
all_ratings.append({
"plugin_id": plugin_id,
"average_rating": round(avg_rating, 2),
"total_ratings": len(rating_records)
})
# Sort by rating
all_ratings.sort(key=lambda x: x["average_rating"], reverse=True)
return {
"total_plugins": len(all_ratings),
"top_rated": all_ratings[:10],
"average_rating_all": round(sum(r["average_rating"] for r in all_ratings) / len(all_ratings), 2) if all_ratings else 0.0
}
def get_recent_events(limit: int = 20) -> List[Dict]:
"""Get recent plugin events"""
all_events = []
for plugin_id, events in plugin_events.items():
for event in events:
all_events.append({
"plugin_id": plugin_id,
"event_type": event["event_type"],
"timestamp": event["timestamp"],
"user_id": event.get("user_id")
})
# Sort by timestamp (most recent first)
all_events.sort(key=lambda x: x["timestamp"], reverse=True)
return all_events[:limit]
def get_plugin_trends(plugin_id: str, days: int) -> Dict[str, Any]:
"""Get trends for a specific plugin"""
plugin_trends = usage_trends.get(plugin_id, {})
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
date_key = cutoff_date.date().isoformat()
return {
"plugin_id": plugin_id,
"trends": plugin_trends,
"period_days": days,
"generated_at": datetime.now(timezone.utc).isoformat()
}
# Report generation functions
def generate_usage_report(plugin_id: Optional[str] = None) -> Dict[str, Any]:
"""Generate usage report"""
if plugin_id:
return get_plugin_usage(plugin_id, days=30)
else:
return get_global_usage_trends(days=30)
def generate_performance_report(plugin_id: Optional[str] = None) -> Dict[str, Any]:
"""Generate performance report"""
if plugin_id:
return get_plugin_performance(plugin_id, hours=24)
else:
return get_performance_summary()
def generate_ratings_report(plugin_id: Optional[str] = None) -> Dict[str, Any]:
"""Generate ratings report"""
if plugin_id:
return get_plugin_ratings(plugin_id)
else:
return get_rating_summary()
def generate_summary_report(plugin_id: Optional[str] = None) -> Dict[str, Any]:
"""Generate comprehensive summary report"""
if plugin_id:
return {
"plugin_id": plugin_id,
"usage": get_plugin_usage(plugin_id, days=30),
"performance": get_plugin_performance(plugin_id, hours=24),
"ratings": get_plugin_ratings(plugin_id),
"generated_at": datetime.now(timezone.utc).isoformat()
}
else:
return get_analytics_dashboard()
# Background task for analytics processing
async def process_analytics():
"""Background task to process analytics data"""
while True:
await asyncio.sleep(3600) # Process every hour
# Update analytics cache
update_analytics_cache()
# Clean old data (older than 90 days)
cleanup_old_data()
logger.info("Analytics processing completed")
def update_analytics_cache():
"""Update analytics cache with frequently accessed data"""
# Cache trending plugins
analytics_cache["trending_plugins"] = get_trending_plugins()
# Cache overview statistics
analytics_cache["overview"] = get_overview_statistics()
# Cache performance summary
analytics_cache["performance_summary"] = get_performance_summary()
def cleanup_old_data():
"""Clean up old analytics data"""
cutoff_date = datetime.now(timezone.utc) - timedelta(days=90)
cutoff_iso = cutoff_date.isoformat()
# Clean usage data
for plugin_id in plugin_usage_data:
plugin_usage_data[plugin_id] = [
r for r in plugin_usage_data[plugin_id]
if r["timestamp"] > cutoff_iso
]
# Clean performance data
for plugin_id in plugin_performance_data:
plugin_performance_data[plugin_id] = [
r for r in plugin_performance_data[plugin_id]
if r["timestamp"] > cutoff_iso
]
# Clean events data
for plugin_id in plugin_events:
plugin_events[plugin_id] = [
r for r in plugin_events[plugin_id]
if r["timestamp"] > cutoff_iso
]
@app.on_event("startup")
async def startup_event():
logger.info("Starting AITBC Plugin Analytics Service")
# Initialize analytics cache
update_analytics_cache()
# Start analytics processing
asyncio.create_task(process_analytics())
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Shutting down AITBC Plugin Analytics Service")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8016, log_level="info")

View File

@@ -0,0 +1 @@
"""Plugin analytics service tests"""

View File

@@ -0,0 +1,168 @@
"""Edge case and error handling tests for plugin analytics service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from main import app, PluginUsage, PluginPerformance, PluginRating, PluginEvent, plugin_usage_data, plugin_performance_data, plugin_ratings, plugin_events
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
plugin_usage_data.clear()
plugin_performance_data.clear()
plugin_ratings.clear()
plugin_events.clear()
yield
plugin_usage_data.clear()
plugin_performance_data.clear()
plugin_ratings.clear()
plugin_events.clear()
@pytest.mark.unit
def test_plugin_usage_empty_plugin_id():
"""Test PluginUsage with empty plugin_id"""
usage = PluginUsage(
plugin_id="",
user_id="user_123",
action="install",
timestamp=datetime.now(timezone.utc)
)
assert usage.plugin_id == ""
@pytest.mark.unit
def test_plugin_performance_negative_values():
"""Test PluginPerformance with negative values"""
perf = PluginPerformance(
plugin_id="plugin_123",
version="1.0.0",
cpu_usage=-10.0,
memory_usage=-5.0,
response_time=-0.1,
error_rate=-0.01,
uptime=-50.0,
timestamp=datetime.now(timezone.utc)
)
assert perf.cpu_usage == -10.0
assert perf.memory_usage == -5.0
@pytest.mark.unit
def test_plugin_rating_out_of_range():
"""Test PluginRating with out of range rating"""
rating = PluginRating(
plugin_id="plugin_123",
user_id="user_123",
rating=10,
timestamp=datetime.now(timezone.utc)
)
assert rating.rating == 10
@pytest.mark.unit
def test_plugin_rating_zero():
"""Test PluginRating with zero rating"""
rating = PluginRating(
plugin_id="plugin_123",
user_id="user_123",
rating=0,
timestamp=datetime.now(timezone.utc)
)
assert rating.rating == 0
@pytest.mark.integration
def test_get_plugin_usage_no_data():
"""Test getting plugin usage when no data exists"""
client = TestClient(app)
response = client.get("/api/v1/analytics/usage/nonexistent")
assert response.status_code == 200
data = response.json()
assert data["total_records"] == 0
@pytest.mark.integration
def test_get_plugin_performance_no_data():
"""Test getting plugin performance when no data exists"""
client = TestClient(app)
response = client.get("/api/v1/analytics/performance/nonexistent")
assert response.status_code == 200
data = response.json()
assert data["total_records"] == 0
@pytest.mark.integration
def test_get_plugin_ratings_no_data():
"""Test getting plugin ratings when no data exists"""
client = TestClient(app)
response = client.get("/api/v1/analytics/ratings/nonexistent")
assert response.status_code == 200
data = response.json()
assert data["total_ratings"] == 0
@pytest.mark.integration
def test_dashboard_with_no_data():
"""Test dashboard with no data"""
client = TestClient(app)
response = client.get("/api/v1/analytics/dashboard")
assert response.status_code == 200
data = response.json()
assert data["dashboard"]["overview"]["total_plugins"] == 0
@pytest.mark.integration
def test_record_multiple_usage_events():
"""Test recording multiple usage events for same plugin"""
client = TestClient(app)
for i in range(5):
usage = PluginUsage(
plugin_id="plugin_123",
user_id=f"user_{i}",
action="use",
timestamp=datetime.now(timezone.utc)
)
client.post("/api/v1/analytics/usage", json=usage.model_dump(mode='json'))
response = client.get("/api/v1/analytics/usage/plugin_123")
assert response.status_code == 200
data = response.json()
assert data["total_records"] == 5
@pytest.mark.integration
def test_usage_trends_days_parameter():
"""Test usage trends with custom days parameter"""
client = TestClient(app)
response = client.get("/api/v1/analytics/trends?days=7")
assert response.status_code == 200
data = response.json()
assert "trends" in data
@pytest.mark.integration
def test_get_plugin_usage_days_parameter():
"""Test getting plugin usage with custom days parameter"""
client = TestClient(app)
response = client.get("/api/v1/analytics/usage/plugin_123?days=7")
assert response.status_code == 200
data = response.json()
assert data["period_days"] == 7
@pytest.mark.integration
def test_get_plugin_performance_hours_parameter():
"""Test getting plugin performance with custom hours parameter"""
client = TestClient(app)
response = client.get("/api/v1/analytics/performance/plugin_123?hours=12")
assert response.status_code == 200
data = response.json()
assert data["period_hours"] == 12

View File

@@ -0,0 +1,253 @@
"""Integration tests for plugin analytics service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from main import app, PluginUsage, PluginPerformance, PluginRating, PluginEvent, plugin_usage_data, plugin_performance_data, plugin_ratings, plugin_events
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
plugin_usage_data.clear()
plugin_performance_data.clear()
plugin_ratings.clear()
plugin_events.clear()
yield
plugin_usage_data.clear()
plugin_performance_data.clear()
plugin_ratings.clear()
plugin_events.clear()
@pytest.mark.integration
def test_root_endpoint():
"""Test root endpoint"""
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["service"] == "AITBC Plugin Analytics Service"
assert data["status"] == "running"
@pytest.mark.integration
def test_health_check_endpoint():
"""Test health check endpoint"""
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "total_usage_records" in data
assert "total_performance_records" in data
@pytest.mark.integration
def test_record_plugin_usage():
"""Test recording plugin usage"""
client = TestClient(app)
usage = PluginUsage(
plugin_id="plugin_123",
user_id="user_123",
action="install",
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/analytics/usage", json=usage.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["usage_id"]
assert data["status"] == "recorded"
@pytest.mark.integration
def test_record_plugin_performance():
"""Test recording plugin performance"""
client = TestClient(app)
perf = PluginPerformance(
plugin_id="plugin_123",
version="1.0.0",
cpu_usage=50.5,
memory_usage=30.2,
response_time=0.123,
error_rate=0.001,
uptime=99.9,
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/analytics/performance", json=perf.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["performance_id"]
assert data["status"] == "recorded"
@pytest.mark.integration
def test_record_plugin_rating():
"""Test recording plugin rating"""
client = TestClient(app)
rating = PluginRating(
plugin_id="plugin_123",
user_id="user_123",
rating=5,
review="Great plugin!",
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/analytics/rating", json=rating.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["rating_id"]
assert data["status"] == "recorded"
@pytest.mark.integration
def test_record_plugin_event():
"""Test recording plugin event"""
client = TestClient(app)
event = PluginEvent(
event_type="error",
plugin_id="plugin_123",
user_id="user_123",
data={"error": "timeout"},
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/analytics/event", json=event.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["event_id"]
assert data["status"] == "recorded"
@pytest.mark.integration
def test_get_plugin_usage():
"""Test getting plugin usage analytics"""
client = TestClient(app)
# Record usage first
usage = PluginUsage(
plugin_id="plugin_123",
user_id="user_123",
action="install",
timestamp=datetime.now(timezone.utc)
)
client.post("/api/v1/analytics/usage", json=usage.model_dump(mode='json'))
response = client.get("/api/v1/analytics/usage/plugin_123")
assert response.status_code == 200
data = response.json()
assert data["plugin_id"] == "plugin_123"
assert "usage_statistics" in data
@pytest.mark.integration
def test_get_plugin_performance():
"""Test getting plugin performance analytics"""
client = TestClient(app)
# Record performance first
perf = PluginPerformance(
plugin_id="plugin_123",
version="1.0.0",
cpu_usage=50.5,
memory_usage=30.2,
response_time=0.123,
error_rate=0.001,
uptime=99.9,
timestamp=datetime.now(timezone.utc)
)
client.post("/api/v1/analytics/performance", json=perf.model_dump(mode='json'))
response = client.get("/api/v1/analytics/performance/plugin_123")
assert response.status_code == 200
data = response.json()
assert data["plugin_id"] == "plugin_123"
assert "performance_statistics" in data
@pytest.mark.integration
def test_get_plugin_ratings():
"""Test getting plugin ratings"""
client = TestClient(app)
# Record rating first
rating = PluginRating(
plugin_id="plugin_123",
user_id="user_123",
rating=5,
timestamp=datetime.now(timezone.utc)
)
client.post("/api/v1/analytics/rating", json=rating.model_dump(mode='json'))
response = client.get("/api/v1/analytics/ratings/plugin_123")
assert response.status_code == 200
data = response.json()
assert data["plugin_id"] == "plugin_123"
assert "rating_statistics" in data
@pytest.mark.integration
def test_get_analytics_dashboard():
"""Test getting analytics dashboard"""
client = TestClient(app)
response = client.get("/api/v1/analytics/dashboard")
assert response.status_code == 200
data = response.json()
assert "dashboard" in data
assert "overview" in data["dashboard"]
assert "trending_plugins" in data["dashboard"]
@pytest.mark.integration
def test_get_usage_trends():
"""Test getting usage trends"""
client = TestClient(app)
response = client.get("/api/v1/analytics/trends")
assert response.status_code == 200
data = response.json()
assert "trends" in data
@pytest.mark.integration
def test_get_usage_trends_plugin_specific():
"""Test getting usage trends for specific plugin"""
client = TestClient(app)
response = client.get("/api/v1/analytics/trends?plugin_id=plugin_123")
assert response.status_code == 200
data = response.json()
assert "plugin_id" in data
@pytest.mark.integration
def test_generate_analytics_report_usage():
"""Test generating usage report"""
client = TestClient(app)
response = client.get("/api/v1/analytics/reports?report_type=usage")
assert response.status_code == 200
data = response.json()
@pytest.mark.integration
def test_generate_analytics_report_performance():
"""Test generating performance report"""
client = TestClient(app)
response = client.get("/api/v1/analytics/reports?report_type=performance")
assert response.status_code == 200
data = response.json()
@pytest.mark.integration
def test_generate_analytics_report_ratings():
"""Test generating ratings report"""
client = TestClient(app)
response = client.get("/api/v1/analytics/reports?report_type=ratings")
assert response.status_code == 200
data = response.json()
@pytest.mark.integration
def test_generate_analytics_report_invalid():
"""Test generating analytics report with invalid type"""
client = TestClient(app)
response = client.get("/api/v1/analytics/reports?report_type=invalid")
assert response.status_code == 400

View File

@@ -0,0 +1,123 @@
"""Unit tests for plugin analytics service"""
import pytest
import sys
import sys
from pathlib import Path
from datetime import datetime, timezone
from main import app, PluginUsage, PluginPerformance, PluginRating, PluginEvent
@pytest.mark.unit
def test_app_initialization():
"""Test that the FastAPI app initializes correctly"""
assert app is not None
assert app.title == "AITBC Plugin Analytics Service"
assert app.version == "1.0.0"
@pytest.mark.unit
def test_plugin_usage_model():
"""Test PluginUsage model"""
usage = PluginUsage(
plugin_id="plugin_123",
user_id="user_123",
action="install",
timestamp=datetime.now(timezone.utc),
metadata={"source": "marketplace"}
)
assert usage.plugin_id == "plugin_123"
assert usage.user_id == "user_123"
assert usage.action == "install"
assert usage.metadata == {"source": "marketplace"}
@pytest.mark.unit
def test_plugin_usage_defaults():
"""Test PluginUsage with default metadata"""
usage = PluginUsage(
plugin_id="plugin_123",
user_id="user_123",
action="use",
timestamp=datetime.now(timezone.utc)
)
assert usage.metadata == {}
@pytest.mark.unit
def test_plugin_performance_model():
"""Test PluginPerformance model"""
perf = PluginPerformance(
plugin_id="plugin_123",
version="1.0.0",
cpu_usage=50.5,
memory_usage=30.2,
response_time=0.123,
error_rate=0.001,
uptime=99.9,
timestamp=datetime.now(timezone.utc)
)
assert perf.plugin_id == "plugin_123"
assert perf.version == "1.0.0"
assert perf.cpu_usage == 50.5
assert perf.memory_usage == 30.2
assert perf.response_time == 0.123
assert perf.error_rate == 0.001
assert perf.uptime == 99.9
@pytest.mark.unit
def test_plugin_rating_model():
"""Test PluginRating model"""
rating = PluginRating(
plugin_id="plugin_123",
user_id="user_123",
rating=5,
review="Great plugin!",
timestamp=datetime.now(timezone.utc)
)
assert rating.plugin_id == "plugin_123"
assert rating.rating == 5
assert rating.review == "Great plugin!"
@pytest.mark.unit
def test_plugin_rating_defaults():
"""Test PluginRating with default review"""
rating = PluginRating(
plugin_id="plugin_123",
user_id="user_123",
rating=4,
timestamp=datetime.now(timezone.utc)
)
assert rating.review is None
@pytest.mark.unit
def test_plugin_event_model():
"""Test PluginEvent model"""
event = PluginEvent(
event_type="error",
plugin_id="plugin_123",
user_id="user_123",
data={"error": "timeout"},
timestamp=datetime.now(timezone.utc)
)
assert event.event_type == "error"
assert event.plugin_id == "plugin_123"
assert event.user_id == "user_123"
assert event.data == {"error": "timeout"}
@pytest.mark.unit
def test_plugin_event_defaults():
"""Test PluginEvent with default values"""
event = PluginEvent(
event_type="info",
plugin_id="plugin_123",
timestamp=datetime.now(timezone.utc)
)
assert event.user_id is None
assert event.data == {}

View File

@@ -0,0 +1,604 @@
"""
Plugin Marketplace Frontend Service for AITBC
Provides web interface and marketplace functionality for plugins
"""
import os
import asyncio
import json
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Dict, Any, List, Optional
from fastapi import FastAPI, HTTPException, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from aitbc import get_logger
logger = get_logger(__name__)
app = FastAPI(
title="AITBC Plugin Marketplace",
description="Plugin marketplace frontend and community features",
version="1.0.0"
)
# Data models
class MarketplaceReview(BaseModel):
plugin_id: str
user_id: str
rating: int # 1-5 stars
title: str
content: str
pros: List[str] = []
cons: List[str] = []
class PluginPurchase(BaseModel):
plugin_id: str
user_id: str
price: float
payment_method: str
class DeveloperApplication(BaseModel):
developer_name: str
email: str
company: Optional[str] = None
experience: str
portfolio_url: Optional[str] = None
github_username: Optional[str] = None
description: str
# In-memory storage (in production, use database)
marketplace_data: Dict[str, Dict] = {}
reviews: Dict[str, List[Dict]] = {}
purchases: Dict[str, List[Dict]] = {}
developer_applications: Dict[str, Dict] = {}
verified_developers: Dict[str, Dict] = {}
revenue_sharing: Dict[str, Dict] = {}
# Static files and templates
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/")
async def marketplace_home(request: Request):
"""Marketplace homepage"""
return templates.TemplateResponse("index.html", {
"request": request,
"featured_plugins": get_featured_plugins(),
"popular_plugins": get_popular_plugins(),
"recent_plugins": get_recent_plugins(),
"categories": get_categories(),
"stats": get_marketplace_stats()
})
@app.get("/plugins")
async def plugins_page(request: Request):
"""Plugin listing page"""
return templates.TemplateResponse("plugins.html", {
"request": request,
"plugins": get_all_plugins(),
"categories": get_categories(),
"tags": get_all_tags()
})
@app.get("/plugins/{plugin_id}")
async def plugin_detail(request: Request, plugin_id: str):
"""Individual plugin detail page"""
plugin = get_plugin_details(plugin_id)
if not plugin:
raise HTTPException(status_code=404, detail="Plugin not found")
return templates.TemplateResponse("plugin_detail.html", {
"request": request,
"plugin": plugin,
"reviews": get_plugin_reviews(plugin_id),
"related_plugins": get_related_plugins(plugin_id)
})
@app.get("/developers")
async def developers_page(request: Request):
"""Developer portal page"""
return templates.TemplateResponse("developers.html", {
"request": request,
"verified_developers": get_verified_developers(),
"developer_stats": get_developer_stats()
})
@app.get("/submit")
async def submit_plugin_page(request: Request):
"""Plugin submission page"""
return templates.TemplateResponse("submit.html", {
"request": request,
"categories": get_categories(),
"guidelines": get_submission_guidelines()
})
# API endpoints
@app.get("/api/v1/marketplace/featured")
async def get_featured_plugins_api():
"""Get featured plugins for marketplace"""
return {
"featured_plugins": get_featured_plugins(),
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/marketplace/popular")
async def get_popular_plugins_api(limit: int = 12):
"""Get popular plugins"""
return {
"popular_plugins": get_popular_plugins(limit),
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/marketplace/recent")
async def get_recent_plugins_api(limit: int = 12):
"""Get recently added plugins"""
return {
"recent_plugins": get_recent_plugins(limit),
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/marketplace/stats")
async def get_marketplace_stats_api():
"""Get marketplace statistics"""
return {
"stats": get_marketplace_stats(),
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.post("/api/v1/reviews")
async def create_review(review: MarketplaceReview):
"""Create a plugin review"""
review_id = f"review_{int(datetime.now(timezone.utc).timestamp())}"
review_record = {
"review_id": review_id,
"plugin_id": review.plugin_id,
"user_id": review.user_id,
"rating": review.rating,
"title": review.title,
"content": review.content,
"pros": review.pros,
"cons": review.cons,
"helpful_votes": 0,
"created_at": datetime.now(timezone.utc).isoformat(),
"verified_purchase": False
}
if review.plugin_id not in reviews:
reviews[review.plugin_id] = []
reviews[review.plugin_id].append(review_record)
logger.info(f"Review created for plugin {review.plugin_id}: {review.rating} stars")
return {
"review_id": review_id,
"status": "created",
"rating": review.rating,
"created_at": review_record["created_at"]
}
@app.get("/api/v1/reviews/{plugin_id}")
async def get_plugin_reviews_api(plugin_id: str):
"""Get all reviews for a plugin"""
plugin_reviews = reviews.get(plugin_id, [])
# Calculate average rating
if plugin_reviews:
avg_rating = sum(r["rating"] for r in plugin_reviews) / len(plugin_reviews)
else:
avg_rating = 0.0
return {
"plugin_id": plugin_id,
"reviews": plugin_reviews,
"total_reviews": len(plugin_reviews),
"average_rating": avg_rating,
"rating_distribution": get_rating_distribution(plugin_reviews)
}
@app.post("/api/v1/purchases")
async def create_purchase(purchase: PluginPurchase):
"""Create a plugin purchase"""
purchase_id = f"purchase_{int(datetime.now(timezone.utc).timestamp())}"
purchase_record = {
"purchase_id": purchase_id,
"plugin_id": purchase.plugin_id,
"user_id": purchase.user_id,
"price": purchase.price,
"payment_method": purchase.payment_method,
"status": "completed",
"created_at": datetime.now(timezone.utc).isoformat(),
"refund_deadline": (datetime.now(timezone.utc) + timedelta(days=30)).isoformat()
}
if purchase.plugin_id not in purchases:
purchases[purchase.plugin_id] = []
purchases[purchase.plugin_id].append(purchase_record)
# Update revenue sharing
update_revenue_sharing(purchase.plugin_id, purchase.price)
logger.info(f"Purchase created for plugin {purchase.plugin_id}: ${purchase.price}")
return {
"purchase_id": purchase_id,
"status": "completed",
"price": purchase.price,
"created_at": purchase_record["created_at"]
}
@app.post("/api/v1/developers/apply")
async def apply_developer(application: DeveloperApplication):
"""Apply to become a verified developer"""
application_id = f"dev_app_{int(datetime.now(timezone.utc).timestamp())}"
application_record = {
"application_id": application_id,
"developer_name": application.developer_name,
"email": application.email,
"company": application.company,
"experience": application.experience,
"portfolio_url": application.portfolio_url,
"github_username": application.github_username,
"description": application.description,
"status": "pending",
"submitted_at": datetime.now(timezone.utc).isoformat(),
"reviewed_at": None,
"reviewer_notes": None
}
developer_applications[application_id] = application_record
logger.info(f"Developer application submitted: {application.developer_name}")
return {
"application_id": application_id,
"status": "pending",
"submitted_at": application_record["submitted_at"]
}
@app.get("/api/v1/developers/verified")
async def get_verified_developers_api():
"""Get list of verified developers"""
return {
"verified_developers": get_verified_developers(),
"total_developers": len(verified_developers),
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/revenue/{developer_id}")
async def get_developer_revenue(developer_id: str):
"""Get revenue information for a developer"""
developer_revenue = revenue_sharing.get(developer_id, {
"total_revenue": 0.0,
"plugin_revenue": {},
"monthly_revenue": {},
"last_updated": datetime.now(timezone.utc).isoformat()
})
return developer_revenue
# Helper functions
def get_featured_plugins() -> List[Dict]:
"""Get featured plugins"""
# In production, this would be based on editorial selection or algorithm
featured_plugins = []
# Mock data for demo
featured_plugins = [
{
"plugin_id": "ai_trading_bot",
"name": "AI Trading Bot",
"description": "Advanced AI-powered trading automation",
"author": "AITBC Labs",
"category": "ai",
"rating": 4.8,
"downloads": 15420,
"price": 99.99,
"featured": True
},
{
"plugin_id": "blockchain_analyzer",
"name": "Blockchain Analyzer",
"description": "Comprehensive blockchain analytics and monitoring",
"author": "CryptoTools",
"category": "blockchain",
"rating": 4.6,
"downloads": 12350,
"price": 149.99,
"featured": True
}
]
return featured_plugins
def get_popular_plugins(limit: int = 12) -> List[Dict]:
"""Get popular plugins"""
# Mock data for demo
popular_plugins = [
{
"plugin_id": "cli_enhancer",
"name": "CLI Enhancer",
"description": "Enhanced CLI commands and shortcuts",
"author": "DevTools",
"category": "cli",
"rating": 4.7,
"downloads": 8920,
"price": 29.99
},
{
"plugin_id": "web_dashboard",
"name": "Web Dashboard",
"description": "Beautiful web dashboard for AITBC",
"author": "WebCraft",
"category": "web",
"rating": 4.5,
"downloads": 7650,
"price": 79.99
}
]
return popular_plugins[:limit]
def get_recent_plugins(limit: int = 12) -> List[Dict]:
"""Get recently added plugins"""
# Mock data for demo
recent_plugins = [
{
"plugin_id": "security_scanner",
"name": "Security Scanner",
"description": "Advanced security vulnerability scanner",
"author": "SecureDev",
"category": "security",
"rating": 4.9,
"downloads": 2340,
"price": 199.99,
"created_at": (datetime.now(timezone.utc) - timedelta(days=3)).isoformat()
},
{
"plugin_id": "performance_monitor",
"name": "Performance Monitor",
"description": "Real-time performance monitoring and alerts",
"author": "PerfTools",
"category": "monitoring",
"rating": 4.4,
"downloads": 1890,
"price": 59.99,
"created_at": (datetime.now(timezone.utc) - timedelta(days=7)).isoformat()
}
]
return recent_plugins[:limit]
def get_categories() -> List[Dict]:
"""Get plugin categories"""
categories = [
{"name": "ai", "display_name": "AI & Machine Learning", "count": 45},
{"name": "blockchain", "display_name": "Blockchain", "count": 32},
{"name": "cli", "display_name": "CLI Tools", "count": 28},
{"name": "web", "display_name": "Web & UI", "count": 24},
{"name": "security", "display_name": "Security", "count": 18},
{"name": "monitoring", "display_name": "Monitoring", "count": 15}
]
return categories
def get_all_plugins() -> List[Dict]:
"""Get all plugins"""
# Mock data for demo
all_plugins = get_featured_plugins() + get_popular_plugins() + get_recent_plugins()
return all_plugins
def get_all_tags() -> List[str]:
"""Get all plugin tags"""
tags = ["automation", "trading", "analytics", "security", "monitoring", "dashboard", "cli", "ai", "blockchain", "web"]
return tags
def get_plugin_details(plugin_id: str) -> Optional[Dict]:
"""Get detailed plugin information"""
# Mock data for demo
plugins = {
"ai_trading_bot": {
"plugin_id": "ai_trading_bot",
"name": "AI Trading Bot",
"description": "Advanced AI-powered trading automation with machine learning algorithms for optimal trading strategies",
"author": "AITBC Labs",
"category": "ai",
"tags": ["automation", "trading", "ai", "machine-learning"],
"rating": 4.8,
"downloads": 15420,
"price": 99.99,
"version": "2.1.0",
"last_updated": (datetime.now(timezone.utc) - timedelta(days=15)).isoformat(),
"repository_url": "https://github.com/aitbc-labs/ai-trading-bot",
"homepage_url": "https://aitbc-trading-bot.com",
"license": "MIT",
"screenshots": [
"/static/screenshots/trading-bot-1.png",
"/static/screenshots/trading-bot-2.png"
],
"changelog": "Added new ML models, improved performance, bug fixes",
"compatibility": ["v1.0.0+", "v2.0.0+"]
}
}
return plugins.get(plugin_id)
def get_plugin_reviews(plugin_id: str) -> List[Dict]:
"""Get reviews for a plugin"""
# Mock data for demo
mock_reviews = [
{
"review_id": "review_1",
"user_id": "user123",
"rating": 5,
"title": "Excellent Trading Bot",
"content": "This plugin has transformed my trading strategy. Highly recommended!",
"pros": ["Easy to use", "Great performance", "Good documentation"],
"cons": ["Initial setup complexity"],
"helpful_votes": 23,
"created_at": (datetime.now(timezone.utc) - timedelta(days=10)).isoformat()
},
{
"review_id": "review_2",
"user_id": "user456",
"rating": 4,
"title": "Good but needs improvements",
"content": "Solid plugin with room for improvement in the UI.",
"pros": ["Powerful features", "Good support"],
"cons": ["UI could be better", "Learning curve"],
"helpful_votes": 15,
"created_at": (datetime.now(timezone.utc) - timedelta(days=25)).isoformat()
}
]
return mock_reviews
def get_related_plugins(plugin_id: str) -> List[Dict]:
"""Get related plugins"""
# Mock data for demo
related_plugins = [
{
"plugin_id": "market_analyzer",
"name": "Market Analyzer",
"description": "Advanced market analysis tools",
"rating": 4.6,
"price": 79.99
},
{
"plugin_id": "risk_manager",
"name": "Risk Manager",
"description": "Comprehensive risk management system",
"rating": 4.5,
"price": 89.99
}
]
return related_plugins
def get_verified_developers() -> List[Dict]:
"""Get verified developers"""
# Mock data for demo
verified_devs = [
{
"developer_id": "aitbc_labs",
"name": "AITBC Labs",
"description": "Official AITBC development team",
"plugins_count": 12,
"total_downloads": 45680,
"verified_since": "2025-01-15",
"avatar": "/static/avatars/aitbc-labs.png"
},
{
"developer_id": "crypto_tools",
"name": "CryptoTools",
"description": "Professional blockchain tools provider",
"plugins_count": 8,
"total_downloads": 23450,
"verified_since": "2025-03-01",
"avatar": "/static/avatars/crypto-tools.png"
}
]
return verified_devs
def get_developer_stats() -> Dict:
"""Get developer statistics"""
return {
"total_developers": 156,
"verified_developers": 23,
"total_revenue_paid": 1250000.00,
"active_developers": 89
}
def get_submission_guidelines() -> Dict:
"""Get plugin submission guidelines"""
return {
"requirements": [
"Plugin must be compatible with AITBC v2.0+",
"Code must be open source with appropriate license",
"Comprehensive documentation required",
"Security scan must pass",
"Unit tests with 80%+ coverage"
],
"process": [
"Submit plugin for review",
"Security and quality assessment",
"Community review period",
"Final approval and publication"
],
"benefits": [
"Revenue sharing (70% to developer)",
"Featured placement opportunities",
"Developer support and resources",
"Community recognition"
]
}
def get_marketplace_stats() -> Dict:
"""Get marketplace statistics"""
return {
"total_plugins": 234,
"total_developers": 156,
"total_downloads": 1256780,
"total_revenue": 2345678.90,
"active_users": 45678,
"featured_plugins": 12,
"categories": 8
}
def get_rating_distribution(reviews: List[Dict]) -> Dict:
"""Get rating distribution"""
distribution = {1: 0, 2: 0, 3: 0, 4: 0, 5: 0}
for review in reviews:
distribution[review["rating"]] += 1
return distribution
def update_revenue_sharing(plugin_id: str, price: float):
"""Update revenue sharing records"""
# Mock implementation - in production, this would calculate actual revenue sharing
developer_share = price * 0.7 # 70% to developer
platform_share = price * 0.3 # 30% to platform
# Update records (simplified for demo)
if "revenue_sharing" not in revenue_sharing:
revenue_sharing["revenue_sharing"] = {
"total_revenue": 0.0,
"developer_revenue": 0.0,
"platform_revenue": 0.0
}
revenue_sharing["revenue_sharing"]["total_revenue"] += price
revenue_sharing["revenue_sharing"]["developer_revenue"] += developer_share
revenue_sharing["revenue_sharing"]["platform_revenue"] += platform_share
# Background task for marketplace analytics
async def update_marketplace_analytics():
"""Background task to update marketplace analytics"""
while True:
await asyncio.sleep(3600) # Update every hour
# Update trending plugins
# Update revenue calculations
# Update user engagement metrics
logger.info("Marketplace analytics updated")
@app.on_event("startup")
async def startup_event():
logger.info("Starting AITBC Plugin Marketplace")
# Start analytics processing
asyncio.create_task(update_marketplace_analytics())
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Shutting down AITBC Plugin Marketplace")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8014, log_level="info")

View File

@@ -0,0 +1 @@
"""Plugin marketplace service tests"""

View File

@@ -0,0 +1,176 @@
"""Edge case and error handling tests for plugin marketplace service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from main import app, MarketplaceReview, PluginPurchase, DeveloperApplication, reviews, purchases, developer_applications
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
reviews.clear()
purchases.clear()
developer_applications.clear()
yield
reviews.clear()
purchases.clear()
developer_applications.clear()
@pytest.mark.unit
def test_marketplace_review_out_of_range_rating():
"""Test MarketplaceReview with out of range rating"""
review = MarketplaceReview(
plugin_id="plugin_123",
user_id="user_123",
rating=10,
title="Great plugin",
content="Excellent"
)
assert review.rating == 10
@pytest.mark.unit
def test_marketplace_review_zero_rating():
"""Test MarketplaceReview with zero rating"""
review = MarketplaceReview(
plugin_id="plugin_123",
user_id="user_123",
rating=0,
title="Bad plugin",
content="Poor"
)
assert review.rating == 0
@pytest.mark.unit
def test_marketplace_review_negative_rating():
"""Test MarketplaceReview with negative rating"""
review = MarketplaceReview(
plugin_id="plugin_123",
user_id="user_123",
rating=-5,
title="Terrible",
content="Worst"
)
assert review.rating == -5
@pytest.mark.unit
def test_marketplace_review_empty_fields():
"""Test MarketplaceReview with empty fields"""
review = MarketplaceReview(
plugin_id="",
user_id="",
rating=3,
title="",
content=""
)
assert review.plugin_id == ""
assert review.title == ""
@pytest.mark.unit
def test_plugin_purchase_zero_price():
"""Test PluginPurchase with zero price"""
purchase = PluginPurchase(
plugin_id="plugin_123",
user_id="user_123",
price=0.0,
payment_method="free"
)
assert purchase.price == 0.0
@pytest.mark.unit
def test_developer_application_empty_fields():
"""Test DeveloperApplication with empty fields"""
application = DeveloperApplication(
developer_name="",
email="",
experience="",
description=""
)
assert application.developer_name == ""
assert application.email == ""
@pytest.mark.integration
def test_get_popular_plugins_with_limit():
"""Test getting popular plugins with limit parameter"""
client = TestClient(app)
response = client.get("/api/v1/marketplace/popular?limit=5")
assert response.status_code == 200
data = response.json()
assert "popular_plugins" in data
@pytest.mark.integration
def test_get_recent_plugins_with_limit():
"""Test getting recent plugins with limit parameter"""
client = TestClient(app)
response = client.get("/api/v1/marketplace/recent?limit=5")
assert response.status_code == 200
data = response.json()
assert "recent_plugins" in data
@pytest.mark.integration
def test_create_multiple_reviews():
"""Test creating multiple reviews for same plugin"""
client = TestClient(app)
for i in range(3):
review = MarketplaceReview(
plugin_id="plugin_123",
user_id=f"user_{i}",
rating=5,
title="Great",
content="Excellent"
)
client.post("/api/v1/reviews", json=review.model_dump())
response = client.get("/api/v1/reviews/plugin_123")
assert response.status_code == 200
data = response.json()
assert data["total_reviews"] == 3
@pytest.mark.integration
def test_create_multiple_purchases():
"""Test creating multiple purchases for same plugin"""
client = TestClient(app)
for i in range(3):
purchase = PluginPurchase(
plugin_id="plugin_123",
user_id=f"user_{i}",
price=99.99,
payment_method="credit_card"
)
client.post("/api/v1/purchases", json=purchase.model_dump())
response = client.get("/api/v1/revenue/revenue_sharing")
assert response.status_code == 200
@pytest.mark.integration
def test_developer_application_with_company():
"""Test developer application with company"""
client = TestClient(app)
application = DeveloperApplication(
developer_name="Dev Name",
email="dev@example.com",
company="Dev Corp",
experience="5 years",
description="Experienced"
)
response = client.post("/api/v1/developers/apply", json=application.model_dump())
assert response.status_code == 200
data = response.json()
assert data["application_id"]

View File

@@ -0,0 +1,165 @@
"""Integration tests for plugin marketplace service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from main import app, MarketplaceReview, PluginPurchase, DeveloperApplication, reviews, purchases, developer_applications
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
reviews.clear()
purchases.clear()
developer_applications.clear()
yield
reviews.clear()
purchases.clear()
developer_applications.clear()
@pytest.mark.integration
def test_get_featured_plugins_api():
"""Test getting featured plugins API"""
client = TestClient(app)
response = client.get("/api/v1/marketplace/featured")
assert response.status_code == 200
data = response.json()
assert "featured_plugins" in data
@pytest.mark.integration
def test_get_popular_plugins_api():
"""Test getting popular plugins API"""
client = TestClient(app)
response = client.get("/api/v1/marketplace/popular")
assert response.status_code == 200
data = response.json()
assert "popular_plugins" in data
@pytest.mark.integration
def test_get_recent_plugins_api():
"""Test getting recent plugins API"""
client = TestClient(app)
response = client.get("/api/v1/marketplace/recent")
assert response.status_code == 200
data = response.json()
assert "recent_plugins" in data
@pytest.mark.integration
def test_get_marketplace_stats_api():
"""Test getting marketplace stats API"""
client = TestClient(app)
response = client.get("/api/v1/marketplace/stats")
assert response.status_code == 200
data = response.json()
assert "stats" in data
@pytest.mark.integration
def test_create_review():
"""Test creating a review"""
client = TestClient(app)
review = MarketplaceReview(
plugin_id="plugin_123",
user_id="user_123",
rating=5,
title="Great plugin",
content="Excellent functionality"
)
response = client.post("/api/v1/reviews", json=review.model_dump())
assert response.status_code == 200
data = response.json()
assert data["review_id"]
assert data["status"] == "created"
@pytest.mark.integration
def test_get_plugin_reviews_api():
"""Test getting plugin reviews API"""
client = TestClient(app)
# Create a review first
review = MarketplaceReview(
plugin_id="plugin_123",
user_id="user_123",
rating=5,
title="Great plugin",
content="Excellent functionality"
)
client.post("/api/v1/reviews", json=review.model_dump())
response = client.get("/api/v1/reviews/plugin_123")
assert response.status_code == 200
data = response.json()
assert data["plugin_id"] == "plugin_123"
assert "reviews" in data
@pytest.mark.integration
def test_get_plugin_reviews_no_reviews():
"""Test getting plugin reviews when no reviews exist"""
client = TestClient(app)
response = client.get("/api/v1/reviews/nonexistent")
assert response.status_code == 200
data = response.json()
assert data["total_reviews"] == 0
@pytest.mark.integration
def test_create_purchase():
"""Test creating a purchase"""
client = TestClient(app)
purchase = PluginPurchase(
plugin_id="plugin_123",
user_id="user_123",
price=99.99,
payment_method="credit_card"
)
response = client.post("/api/v1/purchases", json=purchase.model_dump())
assert response.status_code == 200
data = response.json()
assert data["purchase_id"]
assert data["status"] == "completed"
@pytest.mark.integration
def test_apply_developer():
"""Test applying to become a developer"""
client = TestClient(app)
application = DeveloperApplication(
developer_name="Dev Name",
email="dev@example.com",
experience="5 years",
description="Experienced developer"
)
response = client.post("/api/v1/developers/apply", json=application.model_dump())
assert response.status_code == 200
data = response.json()
assert data["application_id"]
assert data["status"] == "pending"
@pytest.mark.integration
def test_get_verified_developers_api():
"""Test getting verified developers API"""
client = TestClient(app)
response = client.get("/api/v1/developers/verified")
assert response.status_code == 200
data = response.json()
assert "verified_developers" in data
@pytest.mark.integration
def test_get_developer_revenue():
"""Test getting developer revenue"""
client = TestClient(app)
response = client.get("/api/v1/revenue/dev_123")
assert response.status_code == 200
data = response.json()
assert "total_revenue" in data

View File

@@ -0,0 +1,108 @@
"""Unit tests for plugin marketplace service"""
import pytest
import sys
import sys
from pathlib import Path
from main import app, MarketplaceReview, PluginPurchase, DeveloperApplication
@pytest.mark.unit
def test_app_initialization():
"""Test that the FastAPI app initializes correctly"""
assert app is not None
assert app.title == "AITBC Plugin Marketplace"
assert app.version == "1.0.0"
@pytest.mark.unit
def test_marketplace_review_model():
"""Test MarketplaceReview model"""
review = MarketplaceReview(
plugin_id="plugin_123",
user_id="user_123",
rating=5,
title="Great plugin",
content="Excellent functionality",
pros=["Easy to use", "Fast"],
cons=["Learning curve"]
)
assert review.plugin_id == "plugin_123"
assert review.rating == 5
assert review.title == "Great plugin"
assert review.pros == ["Easy to use", "Fast"]
assert review.cons == ["Learning curve"]
@pytest.mark.unit
def test_marketplace_review_defaults():
"""Test MarketplaceReview with default values"""
review = MarketplaceReview(
plugin_id="plugin_123",
user_id="user_123",
rating=4,
title="Good plugin",
content="Nice functionality"
)
assert review.pros == []
assert review.cons == []
@pytest.mark.unit
def test_plugin_purchase_model():
"""Test PluginPurchase model"""
purchase = PluginPurchase(
plugin_id="plugin_123",
user_id="user_123",
price=99.99,
payment_method="credit_card"
)
assert purchase.plugin_id == "plugin_123"
assert purchase.price == 99.99
assert purchase.payment_method == "credit_card"
@pytest.mark.unit
def test_plugin_purchase_negative_price():
"""Test PluginPurchase with negative price"""
purchase = PluginPurchase(
plugin_id="plugin_123",
user_id="user_123",
price=-99.99,
payment_method="credit_card"
)
assert purchase.price == -99.99
@pytest.mark.unit
def test_developer_application_model():
"""Test DeveloperApplication model"""
application = DeveloperApplication(
developer_name="Dev Name",
email="dev@example.com",
company="Dev Corp",
experience="5 years",
portfolio_url="https://portfolio.com",
github_username="devuser",
description="Experienced developer"
)
assert application.developer_name == "Dev Name"
assert application.email == "dev@example.com"
assert application.company == "Dev Corp"
assert application.github_username == "devuser"
@pytest.mark.unit
def test_developer_application_defaults():
"""Test DeveloperApplication with optional fields"""
application = DeveloperApplication(
developer_name="Dev Name",
email="dev@example.com",
experience="3 years",
description="New developer"
)
assert application.company is None
assert application.portfolio_url is None
assert application.github_username is None

View File

@@ -0,0 +1,485 @@
"""
Production Plugin Registry Service for AITBC
Handles plugin registration, discovery, versioning, and security validation
"""
import os
import asyncio
import json
import hashlib
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Dict, Any, List, Optional
from fastapi import FastAPI, HTTPException, UploadFile, File
from pydantic import BaseModel
from aitbc import get_logger
logger = get_logger(__name__)
app = FastAPI(
title="AITBC Plugin Registry",
description="Production plugin registry for AITBC ecosystem",
version="1.0.0"
)
# Data models
class PluginRegistration(BaseModel):
name: str
version: str
description: str
author: str
category: str
tags: List[str]
repository_url: str
homepage_url: Optional[str] = None
license: str
dependencies: List[str] = []
aitbc_version: str
plugin_type: str # cli, blockchain, ai, web, etc.
class PluginVersion(BaseModel):
version: str
changelog: str
download_url: str
checksum: str
aitbc_compatibility: List[str]
release_date: datetime
class SecurityScan(BaseModel):
scan_id: str
plugin_id: str
version: str
scan_date: datetime
vulnerabilities: List[Dict[str, Any]]
risk_score: str # low, medium, high, critical
passed: bool
# In-memory storage (in production, use database)
plugins: Dict[str, Dict] = {}
plugin_versions: Dict[str, List[Dict]] = {}
security_scans: Dict[str, Dict] = {}
analytics: Dict[str, Dict] = {}
downloads: Dict[str, List[Dict]] = {}
@app.get("/")
async def root():
return {
"service": "AITBC Plugin Registry",
"status": "running",
"timestamp": datetime.now(timezone.utc).isoformat(),
"version": "1.0.0"
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"total_plugins": len(plugins),
"total_versions": sum(len(versions) for versions in plugin_versions.values()),
"security_scans": len(security_scans),
"downloads_today": len([d for downloads_list in downloads.values()
for d in downloads_list
if datetime.fromisoformat(d["timestamp"]).date() == datetime.now(timezone.utc).date()])
}
@app.post("/api/v1/plugins/register")
async def register_plugin(plugin: PluginRegistration):
"""Register a new plugin"""
plugin_id = f"{plugin.name.lower().replace(' ', '_')}"
if plugin_id in plugins:
raise HTTPException(status_code=400, detail="Plugin already registered")
# Create plugin record
plugin_record = {
"plugin_id": plugin_id,
"name": plugin.name,
"description": plugin.description,
"author": plugin.author,
"category": plugin.category,
"tags": plugin.tags,
"repository_url": plugin.repository_url,
"homepage_url": plugin.homepage_url,
"license": plugin.license,
"dependencies": plugin.dependencies,
"aitbc_version": plugin.aitbc_version,
"plugin_type": plugin.plugin_type,
"status": "active",
"created_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
"verified": False,
"featured": False,
"download_count": 0,
"rating": 0.0,
"rating_count": 0,
"latest_version": plugin.version
}
plugins[plugin_id] = plugin_record
plugin_versions[plugin_id] = []
# Initialize analytics
analytics[plugin_id] = {
"downloads": [],
"views": [],
"ratings": [],
"daily_stats": {}
}
logger.info(f"Plugin registered: {plugin.name}")
return {
"plugin_id": plugin_id,
"status": "registered",
"name": plugin.name,
"created_at": plugin_record["created_at"]
}
@app.post("/api/v1/plugins/{plugin_id}/versions")
async def add_plugin_version(plugin_id: str, version: PluginVersion):
"""Add a new version to an existing plugin"""
if plugin_id not in plugins:
raise HTTPException(status_code=404, detail="Plugin not found")
# Check if version already exists
for existing_version in plugin_versions[plugin_id]:
if existing_version["version"] == version.version:
raise HTTPException(status_code=400, detail="Version already exists")
# Create version record
version_record = {
"version_id": f"{plugin_id}_v_{version.version}",
"plugin_id": plugin_id,
"version": version.version,
"changelog": version.changelog,
"download_url": version.download_url,
"checksum": version.checksum,
"aitbc_compatibility": version.aitbc_compatibility,
"release_date": version.release_date.isoformat(),
"downloads": 0,
"security_scan_passed": False,
"created_at": datetime.now(timezone.utc).isoformat()
}
plugin_versions[plugin_id].append(version_record)
# Update plugin's latest version
plugins[plugin_id]["latest_version"] = version.version
plugins[plugin_id]["updated_at"] = datetime.now(timezone.utc).isoformat()
# Sort versions by version number (semantic versioning)
plugin_versions[plugin_id].sort(key=lambda x: x["version"], reverse=True)
logger.info(f"Version added to plugin {plugin_id}: {version.version}")
return {
"plugin_id": plugin_id,
"version": version.version,
"status": "added",
"created_at": version_record["created_at"]
}
@app.get("/api/v1/plugins")
async def list_plugins(category: Optional[str] = None, tag: Optional[str] = None,
search: Optional[str] = None, sort_by: str = "created_at"):
"""List all plugins with filtering and sorting"""
filtered_plugins = []
for plugin in plugins.values():
# Apply filters
if category and plugin["category"] != category:
continue
if tag and tag not in plugin["tags"]:
continue
if search and search.lower() not in plugin["name"].lower() and search.lower() not in plugin["description"].lower():
continue
filtered_plugins.append(plugin.copy())
# Sort plugins
if sort_by == "created_at":
filtered_plugins.sort(key=lambda x: x["created_at"], reverse=True)
elif sort_by == "updated_at":
filtered_plugins.sort(key=lambda x: x["updated_at"], reverse=True)
elif sort_by == "name":
filtered_plugins.sort(key=lambda x: x["name"])
elif sort_by == "downloads":
filtered_plugins.sort(key=lambda x: x["download_count"], reverse=True)
elif sort_by == "rating":
filtered_plugins.sort(key=lambda x: x["rating"], reverse=True)
return {
"plugins": filtered_plugins,
"total_plugins": len(filtered_plugins),
"filters": {
"category": category,
"tag": tag,
"search": search,
"sort_by": sort_by
}
}
@app.get("/api/v1/plugins/{plugin_id}")
async def get_plugin(plugin_id: str):
"""Get detailed plugin information"""
if plugin_id not in plugins:
raise HTTPException(status_code=404, detail="Plugin not found")
plugin = plugins[plugin_id].copy()
# Add version information
plugin["versions"] = plugin_versions.get(plugin_id, [])
# Add analytics
plugin_analytics = analytics.get(plugin_id, {})
plugin["analytics"] = {
"total_downloads": len(plugin_analytics.get("downloads", [])),
"total_views": len(plugin_analytics.get("views", [])),
"average_rating": sum(plugin_analytics.get("ratings", [])) / len(plugin_analytics.get("ratings", [])) if plugin_analytics.get("ratings") else 0.0,
"rating_count": len(plugin_analytics.get("ratings", []))
}
return plugin
@app.get("/api/v1/plugins/{plugin_id}/versions")
async def get_plugin_versions(plugin_id: str):
"""Get all versions of a plugin"""
if plugin_id not in plugins:
raise HTTPException(status_code=404, detail="Plugin not found")
return {
"plugin_id": plugin_id,
"versions": plugin_versions.get(plugin_id, []),
"total_versions": len(plugin_versions.get(plugin_id, []))
}
@app.get("/api/v1/plugins/{plugin_id}/download/{version}")
async def download_plugin(plugin_id: str, version: str):
"""Download a specific plugin version"""
if plugin_id not in plugins:
raise HTTPException(status_code=404, detail="Plugin not found")
# Find the version
version_record = None
for v in plugin_versions.get(plugin_id, []):
if v["version"] == version:
version_record = v
break
if not version_record:
raise HTTPException(status_code=404, detail="Version not found")
# Record download
download_record = {
"version": version,
"timestamp": datetime.now(timezone.utc).isoformat(),
"ip_address": "client_ip", # In production, get actual IP
"user_agent": "user_agent" # In production, get actual user agent
}
if plugin_id not in downloads:
downloads[plugin_id] = []
downloads[plugin_id].append(download_record)
# Update analytics
if plugin_id not in analytics:
analytics[plugin_id] = {"downloads": [], "views": [], "ratings": []}
analytics[plugin_id]["downloads"].append(datetime.now(timezone.utc).timestamp())
# Update plugin download count
plugins[plugin_id]["download_count"] += 1
version_record["downloads"] += 1
# In production, this would return the actual file
return {
"plugin_id": plugin_id,
"version": version,
"download_url": version_record["download_url"],
"checksum": version_record["checksum"],
"download_count": version_record["downloads"]
}
@app.post("/api/v1/plugins/{plugin_id}/security-scan")
async def create_security_scan(plugin_id: str, scan: SecurityScan):
"""Create a security scan record for a plugin version"""
if plugin_id not in plugins:
raise HTTPException(status_code=404, detail="Plugin not found")
# Verify version exists
version_exists = any(v["version"] == scan.version for v in plugin_versions.get(plugin_id, []))
if not version_exists:
raise HTTPException(status_code=404, detail="Version not found")
# Create security scan record
security_scans[scan.scan_id] = {
"scan_id": scan.scan_id,
"plugin_id": plugin_id,
"version": scan.version,
"scan_date": scan.scan_date.isoformat(),
"vulnerabilities": scan.vulnerabilities,
"risk_score": scan.risk_score,
"passed": scan.passed,
"created_at": datetime.now(timezone.utc).isoformat()
}
# Update version security status
for version_record in plugin_versions.get(plugin_id, []):
if version_record["version"] == scan.version:
version_record["security_scan_passed"] = scan.passed
break
logger.info(f"Security scan created for {plugin_id} v{scan.version}: {scan.risk_score}")
return {
"scan_id": scan.scan_id,
"plugin_id": plugin_id,
"version": scan.version,
"risk_score": scan.risk_score,
"passed": scan.passed,
"scan_date": scan.scan_date.isoformat()
}
@app.get("/api/v1/plugins/{plugin_id}/security")
async def get_plugin_security(plugin_id: str):
"""Get security information for a plugin"""
if plugin_id not in plugins:
raise HTTPException(status_code=404, detail="Plugin not found")
plugin_scans = []
for scan_id, scan in security_scans.items():
if scan["plugin_id"] == plugin_id:
plugin_scans.append(scan)
# Sort by scan date
plugin_scans.sort(key=lambda x: x["scan_date"], reverse=True)
return {
"plugin_id": plugin_id,
"security_scans": plugin_scans,
"total_scans": len(plugin_scans),
"latest_scan": plugin_scans[0] if plugin_scans else None
}
@app.get("/api/v1/categories")
async def get_categories():
"""Get all plugin categories"""
categories = {}
for plugin in plugins.values():
category = plugin["category"]
if category not in categories:
categories[category] = {
"name": category,
"plugin_count": 0,
"description": f"Plugins in {category} category"
}
categories[category]["plugin_count"] += 1
return {
"categories": list(categories.values()),
"total_categories": len(categories)
}
@app.get("/api/v1/tags")
async def get_tags():
"""Get all plugin tags"""
tag_counts = {}
for plugin in plugins.values():
for tag in plugin["tags"]:
tag_counts[tag] = tag_counts.get(tag, 0) + 1
return {
"tags": [{"tag": tag, "count": count} for tag, count in sorted(tag_counts.items(), key=lambda x: x[1], reverse=True)],
"total_tags": len(tag_counts)
}
@app.get("/api/v1/analytics/popular")
async def get_popular_plugins(limit: int = 10):
"""Get most popular plugins by downloads"""
popular_plugins = sorted(plugins.values(), key=lambda x: x["download_count"], reverse=True)[:limit]
return {
"popular_plugins": popular_plugins,
"limit": limit,
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/analytics/recent")
async def get_recent_plugins(limit: int = 10):
"""Get recently updated plugins"""
recent_plugins = sorted(plugins.values(), key=lambda x: x["updated_at"], reverse=True)[:limit]
return {
"recent_plugins": recent_plugins,
"limit": limit,
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/analytics/dashboard")
async def get_analytics_dashboard():
"""Get registry analytics dashboard"""
total_plugins = len(plugins)
total_versions = sum(len(versions) for versions in plugin_versions.values())
total_downloads = sum(plugin["download_count"] for plugin in plugins.values())
# Category distribution
category_stats = {}
for plugin in plugins.values():
category = plugin["category"]
category_stats[category] = category_stats.get(category, 0) + 1
# Recent activity
recent_downloads = 0
today = datetime.now(timezone.utc).date()
for download_list in downloads.values():
recent_downloads += len([d for d in download_list
if datetime.fromisoformat(d["timestamp"]).date() == today])
return {
"dashboard": {
"total_plugins": total_plugins,
"total_versions": total_versions,
"total_downloads": total_downloads,
"recent_downloads_today": recent_downloads,
"categories": category_stats,
"security_scans": len(security_scans),
"passed_scans": len([s for s in security_scans.values() if s["passed"]])
},
"generated_at": datetime.now(timezone.utc).isoformat()
}
# Background task for analytics processing
async def process_analytics():
"""Background task to process analytics data"""
while True:
await asyncio.sleep(3600) # Process every hour
# Update daily statistics
current_date = datetime.now(timezone.utc).date()
for plugin_id, plugin_analytics in analytics.items():
daily_key = current_date.isoformat()
if daily_key not in plugin_analytics["daily_stats"]:
plugin_analytics["daily_stats"][daily_key] = {
"downloads": len([d for d in plugin_analytics.get("downloads", [])
if datetime.fromtimestamp(d).date() == current_date]),
"views": len([v for v in plugin_analytics.get("views", [])
if datetime.fromtimestamp(v).date() == current_date]),
"ratings": len([r for r in plugin_analytics.get("ratings", [])
if datetime.fromtimestamp(r).date() == current_date])
}
@app.on_event("startup")
async def startup_event():
logger.info("Starting AITBC Plugin Registry")
# Start analytics processing
asyncio.create_task(process_analytics())
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Shutting down AITBC Plugin Registry")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8013, log_level="info")

View File

@@ -0,0 +1 @@
"""Plugin registry service tests"""

View File

@@ -0,0 +1,317 @@
"""Edge case and error handling tests for plugin registry service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from main import app, PluginRegistration, PluginVersion, SecurityScan, plugins, plugin_versions, security_scans, analytics, downloads
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
plugins.clear()
plugin_versions.clear()
security_scans.clear()
analytics.clear()
downloads.clear()
yield
plugins.clear()
plugin_versions.clear()
security_scans.clear()
analytics.clear()
downloads.clear()
@pytest.mark.unit
def test_plugin_registration_empty_name():
"""Test PluginRegistration with empty name"""
plugin = PluginRegistration(
name="",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
assert plugin.name == ""
@pytest.mark.unit
def test_plugin_registration_empty_tags():
"""Test PluginRegistration with empty tags"""
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
assert plugin.tags == []
@pytest.mark.unit
def test_plugin_version_empty_changelog():
"""Test PluginVersion with empty changelog"""
version = PluginVersion(
version="1.0.0",
changelog="",
download_url="https://github.com/test/plugin/archive/v1.0.0.tar.gz",
checksum="abc123",
aitbc_compatibility=["1.0.0"],
release_date=datetime.now(timezone.utc)
)
assert version.changelog == ""
@pytest.mark.unit
def test_security_scan_empty_vulnerabilities():
"""Test SecurityScan with empty vulnerabilities"""
scan = SecurityScan(
scan_id="scan_123",
plugin_id="test_plugin",
version="1.0.0",
scan_date=datetime.now(timezone.utc),
vulnerabilities=[],
risk_score="low",
passed=True
)
assert scan.vulnerabilities == []
@pytest.mark.integration
def test_add_version_nonexistent_plugin():
"""Test adding version to nonexistent plugin"""
client = TestClient(app)
version = PluginVersion(
version="1.0.0",
changelog="Initial release",
download_url="https://github.com/test/plugin/archive/v1.0.0.tar.gz",
checksum="abc123",
aitbc_compatibility=["1.0.0"],
release_date=datetime.now(timezone.utc)
)
response = client.post("/api/v1/plugins/nonexistent/versions", json=version.model_dump(mode='json'))
assert response.status_code == 404
@pytest.mark.integration
def test_download_nonexistent_plugin():
"""Test downloading nonexistent plugin"""
client = TestClient(app)
response = client.get("/api/v1/plugins/nonexistent/download/1.0.0")
assert response.status_code == 404
@pytest.mark.integration
def test_download_nonexistent_version():
"""Test downloading nonexistent version"""
client = TestClient(app)
# Register plugin first
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
client.post("/api/v1/plugins/register", json=plugin.model_dump())
# Try to download nonexistent version
response = client.get("/api/v1/plugins/test_plugin/download/2.0.0")
assert response.status_code == 404
@pytest.mark.integration
def test_security_scan_nonexistent_plugin():
"""Test creating security scan for nonexistent plugin"""
client = TestClient(app)
scan = SecurityScan(
scan_id="scan_123",
plugin_id="nonexistent",
version="1.0.0",
scan_date=datetime.now(timezone.utc),
vulnerabilities=[],
risk_score="low",
passed=True
)
response = client.post("/api/v1/plugins/nonexistent/security-scan", json=scan.model_dump(mode='json'))
assert response.status_code == 404
@pytest.mark.integration
def test_security_scan_nonexistent_version():
"""Test creating security scan for nonexistent version"""
client = TestClient(app)
# Register plugin first
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
client.post("/api/v1/plugins/register", json=plugin.model_dump())
# Create security scan for nonexistent version
scan = SecurityScan(
scan_id="scan_123",
plugin_id="test_plugin",
version="2.0.0",
scan_date=datetime.now(timezone.utc),
vulnerabilities=[],
risk_score="low",
passed=True
)
response = client.post("/api/v1/plugins/test_plugin/security-scan", json=scan.model_dump(mode='json'))
assert response.status_code == 404
@pytest.mark.integration
def test_list_plugins_with_filters():
"""Test listing plugins with filters"""
client = TestClient(app)
# Register multiple plugins
plugin1 = PluginRegistration(
name="Test Plugin 1",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=["test"],
repository_url="https://github.com/test/plugin1",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
client.post("/api/v1/plugins/register", json=plugin1.model_dump())
plugin2 = PluginRegistration(
name="Production Plugin",
version="1.0.0",
description="A production plugin",
author="Test Author",
category="production",
tags=["prod"],
repository_url="https://github.com/test/plugin2",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="web"
)
client.post("/api/v1/plugins/register", json=plugin2.model_dump())
# Filter by category
response = client.get("/api/v1/plugins?category=testing")
assert response.status_code == 200
data = response.json()
assert data["total_plugins"] == 1
assert data["plugins"][0]["category"] == "testing"
@pytest.mark.integration
def test_list_plugins_with_search():
"""Test listing plugins with search"""
client = TestClient(app)
# Register plugin
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin for testing",
author="Test Author",
category="testing",
tags=["test"],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
client.post("/api/v1/plugins/register", json=plugin.model_dump())
# Search for plugin
response = client.get("/api/v1/plugins?search=test")
assert response.status_code == 200
data = response.json()
assert data["total_plugins"] == 1
@pytest.mark.integration
def test_security_scan_failed():
"""Test security scan that failed"""
client = TestClient(app)
# Register plugin first
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
client.post("/api/v1/plugins/register", json=plugin.model_dump())
# Add version first
version = PluginVersion(
version="1.0.0",
changelog="Initial release",
download_url="https://github.com/test/plugin/archive/v1.0.0.tar.gz",
checksum="abc123",
aitbc_compatibility=["1.0.0"],
release_date=datetime.now(timezone.utc)
)
client.post("/api/v1/plugins/test_plugin/versions", json=version.model_dump(mode='json'))
# Create failed security scan
scan = SecurityScan(
scan_id="scan_123",
plugin_id="test_plugin",
version="1.0.0",
scan_date=datetime.now(timezone.utc),
vulnerabilities=[{"severity": "high", "description": "Critical issue"}],
risk_score="high",
passed=False
)
response = client.post("/api/v1/plugins/test_plugin/security-scan", json=scan.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["passed"] is False
assert data["risk_score"] == "high"

View File

@@ -0,0 +1,422 @@
"""Integration tests for plugin registry service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from main import app, PluginRegistration, PluginVersion, SecurityScan, plugins, plugin_versions, security_scans, analytics, downloads
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
plugins.clear()
plugin_versions.clear()
security_scans.clear()
analytics.clear()
downloads.clear()
yield
plugins.clear()
plugin_versions.clear()
security_scans.clear()
analytics.clear()
downloads.clear()
@pytest.mark.integration
def test_root_endpoint():
"""Test root endpoint"""
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["service"] == "AITBC Plugin Registry"
assert data["status"] == "running"
@pytest.mark.integration
def test_health_check_endpoint():
"""Test health check endpoint"""
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "total_plugins" in data
assert "total_versions" in data
@pytest.mark.integration
def test_register_plugin():
"""Test plugin registration"""
client = TestClient(app)
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=["test", "demo"],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
response = client.post("/api/v1/plugins/register", json=plugin.model_dump())
assert response.status_code == 200
data = response.json()
assert data["plugin_id"] == "test_plugin"
assert data["status"] == "registered"
assert data["name"] == "Test Plugin"
@pytest.mark.integration
def test_register_duplicate_plugin():
"""Test registering duplicate plugin"""
client = TestClient(app)
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
# First registration
client.post("/api/v1/plugins/register", json=plugin.model_dump())
# Second registration should fail
response = client.post("/api/v1/plugins/register", json=plugin.model_dump())
assert response.status_code == 400
@pytest.mark.integration
def test_add_plugin_version():
"""Test adding plugin version"""
client = TestClient(app)
# Register plugin first
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
client.post("/api/v1/plugins/register", json=plugin.model_dump())
# Add version
version = PluginVersion(
version="1.1.0",
changelog="Bug fixes",
download_url="https://github.com/test/plugin/archive/v1.1.0.tar.gz",
checksum="def456",
aitbc_compatibility=["1.0.0"],
release_date=datetime.now(timezone.utc)
)
response = client.post("/api/v1/plugins/test_plugin/versions", json=version.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["version"] == "1.1.0"
assert data["status"] == "added"
@pytest.mark.integration
def test_add_duplicate_version():
"""Test adding duplicate version"""
client = TestClient(app)
# Register plugin first
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
client.post("/api/v1/plugins/register", json=plugin.model_dump())
# Add version
version = PluginVersion(
version="1.1.0",
changelog="Bug fixes",
download_url="https://github.com/test/plugin/archive/v1.1.0.tar.gz",
checksum="def456",
aitbc_compatibility=["1.0.0"],
release_date=datetime.now(timezone.utc)
)
client.post("/api/v1/plugins/test_plugin/versions", json=version.model_dump(mode='json'))
# Add same version again should fail
response = client.post("/api/v1/plugins/test_plugin/versions", json=version.model_dump(mode='json'))
assert response.status_code == 400
@pytest.mark.integration
def test_list_plugins():
"""Test listing plugins"""
client = TestClient(app)
response = client.get("/api/v1/plugins")
assert response.status_code == 200
data = response.json()
assert "plugins" in data
assert "total_plugins" in data
@pytest.mark.integration
def test_get_plugin():
"""Test getting specific plugin"""
client = TestClient(app)
# Register plugin first
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
client.post("/api/v1/plugins/register", json=plugin.model_dump())
# Get plugin
response = client.get("/api/v1/plugins/test_plugin")
assert response.status_code == 200
data = response.json()
assert data["plugin_id"] == "test_plugin"
assert data["name"] == "Test Plugin"
@pytest.mark.integration
def test_get_plugin_not_found():
"""Test getting nonexistent plugin"""
client = TestClient(app)
response = client.get("/api/v1/plugins/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_get_plugin_versions():
"""Test getting plugin versions"""
client = TestClient(app)
# Register plugin first
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
client.post("/api/v1/plugins/register", json=plugin.model_dump())
# Get versions
response = client.get("/api/v1/plugins/test_plugin/versions")
assert response.status_code == 200
data = response.json()
assert data["plugin_id"] == "test_plugin"
assert "versions" in data
@pytest.mark.integration
def test_download_plugin():
"""Test downloading plugin"""
client = TestClient(app)
# Register plugin first
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
client.post("/api/v1/plugins/register", json=plugin.model_dump())
# Add version first
version = PluginVersion(
version="1.0.0",
changelog="Initial release",
download_url="https://github.com/test/plugin/archive/v1.0.0.tar.gz",
checksum="abc123",
aitbc_compatibility=["1.0.0"],
release_date=datetime.now(timezone.utc)
)
client.post("/api/v1/plugins/test_plugin/versions", json=version.model_dump(mode='json'))
# Download plugin
response = client.get("/api/v1/plugins/test_plugin/download/1.0.0")
assert response.status_code == 200
data = response.json()
assert data["plugin_id"] == "test_plugin"
assert data["version"] == "1.0.0"
@pytest.mark.integration
def test_create_security_scan():
"""Test creating security scan"""
client = TestClient(app)
# Register plugin first
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
client.post("/api/v1/plugins/register", json=plugin.model_dump())
# Add version first
version = PluginVersion(
version="1.0.0",
changelog="Initial release",
download_url="https://github.com/test/plugin/archive/v1.0.0.tar.gz",
checksum="abc123",
aitbc_compatibility=["1.0.0"],
release_date=datetime.now(timezone.utc)
)
client.post("/api/v1/plugins/test_plugin/versions", json=version.model_dump(mode='json'))
# Create security scan
scan = SecurityScan(
scan_id="scan_123",
plugin_id="test_plugin",
version="1.0.0",
scan_date=datetime.now(timezone.utc),
vulnerabilities=[],
risk_score="low",
passed=True
)
response = client.post("/api/v1/plugins/test_plugin/security-scan", json=scan.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["scan_id"] == "scan_123"
assert data["passed"] is True
@pytest.mark.integration
def test_get_plugin_security():
"""Test getting plugin security info"""
client = TestClient(app)
# Register plugin first
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
dependencies=[],
aitbc_version="1.0.0",
plugin_type="cli"
)
client.post("/api/v1/plugins/register", json=plugin.model_dump())
# Get security info
response = client.get("/api/v1/plugins/test_plugin/security")
assert response.status_code == 200
data = response.json()
assert data["plugin_id"] == "test_plugin"
assert "security_scans" in data
@pytest.mark.integration
def test_get_categories():
"""Test getting categories"""
client = TestClient(app)
response = client.get("/api/v1/categories")
assert response.status_code == 200
data = response.json()
assert "categories" in data
assert "total_categories" in data
@pytest.mark.integration
def test_get_tags():
"""Test getting tags"""
client = TestClient(app)
response = client.get("/api/v1/tags")
assert response.status_code == 200
data = response.json()
assert "tags" in data
assert "total_tags" in data
@pytest.mark.integration
def test_get_popular_plugins():
"""Test getting popular plugins"""
client = TestClient(app)
response = client.get("/api/v1/analytics/popular")
assert response.status_code == 200
data = response.json()
assert "popular_plugins" in data
@pytest.mark.integration
def test_get_recent_plugins():
"""Test getting recent plugins"""
client = TestClient(app)
response = client.get("/api/v1/analytics/recent")
assert response.status_code == 200
data = response.json()
assert "recent_plugins" in data
@pytest.mark.integration
def test_get_analytics_dashboard():
"""Test getting analytics dashboard"""
client = TestClient(app)
response = client.get("/api/v1/analytics/dashboard")
assert response.status_code == 200
data = response.json()
assert "dashboard" in data

View File

@@ -0,0 +1,101 @@
"""Unit tests for plugin registry service"""
import pytest
import sys
import sys
from pathlib import Path
from datetime import datetime, timezone
from main import app, PluginRegistration, PluginVersion, SecurityScan
@pytest.mark.unit
def test_app_initialization():
"""Test that the FastAPI app initializes correctly"""
assert app is not None
assert app.title == "AITBC Plugin Registry"
assert app.version == "1.0.0"
@pytest.mark.unit
def test_plugin_registration_model():
"""Test PluginRegistration model"""
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=["test", "demo"],
repository_url="https://github.com/test/plugin",
homepage_url="https://test.com",
license="MIT",
dependencies=["dependency1"],
aitbc_version="1.0.0",
plugin_type="cli"
)
assert plugin.name == "Test Plugin"
assert plugin.version == "1.0.0"
assert plugin.author == "Test Author"
assert plugin.category == "testing"
assert plugin.tags == ["test", "demo"]
assert plugin.license == "MIT"
assert plugin.plugin_type == "cli"
@pytest.mark.unit
def test_plugin_registration_defaults():
"""Test PluginRegistration default values"""
plugin = PluginRegistration(
name="Test Plugin",
version="1.0.0",
description="A test plugin",
author="Test Author",
category="testing",
tags=[],
repository_url="https://github.com/test/plugin",
license="MIT",
aitbc_version="1.0.0",
plugin_type="cli"
)
assert plugin.homepage_url is None
assert plugin.dependencies == []
@pytest.mark.unit
def test_plugin_version_model():
"""Test PluginVersion model"""
version = PluginVersion(
version="1.0.0",
changelog="Initial release",
download_url="https://github.com/test/plugin/archive/v1.0.0.tar.gz",
checksum="abc123",
aitbc_compatibility=["1.0.0", "1.1.0"],
release_date=datetime.now(timezone.utc)
)
assert version.version == "1.0.0"
assert version.changelog == "Initial release"
assert version.download_url == "https://github.com/test/plugin/archive/v1.0.0.tar.gz"
assert version.checksum == "abc123"
assert version.aitbc_compatibility == ["1.0.0", "1.1.0"]
@pytest.mark.unit
def test_security_scan_model():
"""Test SecurityScan model"""
scan = SecurityScan(
scan_id="scan_123",
plugin_id="test_plugin",
version="1.0.0",
scan_date=datetime.now(timezone.utc),
vulnerabilities=[{"severity": "low", "description": "Test"}],
risk_score="low",
passed=True
)
assert scan.scan_id == "scan_123"
assert scan.plugin_id == "test_plugin"
assert scan.version == "1.0.0"
assert scan.risk_score == "low"
assert scan.passed is True
assert len(scan.vulnerabilities) == 1

View File

@@ -0,0 +1,659 @@
"""
Plugin Security Validation Service for AITBC
Handles plugin security scanning, vulnerability detection, and validation
"""
import asyncio
import json
import subprocess
import tempfile
import os
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Dict, Any, List, Optional
from fastapi import FastAPI, HTTPException, UploadFile, File
from pydantic import BaseModel
from aitbc import get_logger
logger = get_logger(__name__)
app = FastAPI(
title="AITBC Plugin Security Service",
description="Security validation and vulnerability scanning for AITBC plugins",
version="1.0.0"
)
# Data models
class SecurityScan(BaseModel):
plugin_id: str
version: str
plugin_type: str
scan_type: str # basic, comprehensive, deep
priority: str # low, medium, high, critical
class Vulnerability(BaseModel):
cve_id: Optional[str]
severity: str # low, medium, high, critical
title: str
description: str
affected_file: str
line_number: Optional[int]
recommendation: str
class SecurityReport(BaseModel):
scan_id: str
plugin_id: str
version: str
scan_date: datetime
scan_duration: float
overall_score: str # passed, warning, failed, critical
vulnerabilities: List[Vulnerability]
security_metrics: Dict[str, Any]
recommendations: List[str]
# In-memory storage (in production, use database)
scan_reports: Dict[str, Dict] = {}
security_policies: Dict[str, Dict] = {}
scan_queue: List[Dict] = []
vulnerability_database: Dict[str, Dict] = {}
@app.get("/")
async def root():
return {
"service": "AITBC Plugin Security Service",
"status": "running",
"timestamp": datetime.now(timezone.utc).isoformat(),
"version": "1.0.0"
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"total_scans": len(scan_reports),
"queue_size": len(scan_queue),
"vulnerabilities_db": len(vulnerability_database),
"active_policies": len(security_policies)
}
@app.post("/api/v1/security/scan")
async def initiate_security_scan(scan: SecurityScan):
"""Initiate a security scan for a plugin"""
scan_id = f"scan_{int(datetime.now(timezone.utc).timestamp())}"
# Create scan record
scan_record = {
"scan_id": scan_id,
"plugin_id": scan.plugin_id,
"version": scan.version,
"plugin_type": scan.plugin_type,
"scan_type": scan.scan_type,
"priority": scan.priority,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
"started_at": None,
"completed_at": None,
"duration": None,
"result": None
}
scan_queue.append(scan_record)
# Sort queue by priority
priority_order = {"critical": 0, "high": 1, "medium": 2, "low": 3}
scan_queue.sort(key=lambda x: priority_order.get(x["priority"], 4))
logger.info(f"Security scan queued: {scan_id} for {scan.plugin_id} v{scan.version}")
return {
"scan_id": scan_id,
"status": "queued",
"queue_position": scan_queue.index(scan_record) + 1,
"estimated_time": estimate_scan_time(scan.scan_type)
}
@app.get("/api/v1/security/scan/{scan_id}")
async def get_scan_status(scan_id: str):
"""Get scan status and results"""
if scan_id not in scan_reports and not any(s["scan_id"] == scan_id for s in scan_queue):
raise HTTPException(status_code=404, detail="Scan not found")
# Check if scan is in queue
for scan_record in scan_queue:
if scan_record["scan_id"] == scan_id:
return {
"scan_id": scan_id,
"status": scan_record["status"],
"queue_position": scan_queue.index(scan_record) + 1,
"created_at": scan_record["created_at"]
}
# Return completed scan results
return scan_reports.get(scan_id, {"status": "not_found"})
@app.get("/api/v1/security/reports")
async def list_security_reports(plugin_id: Optional[str] = None,
status: Optional[str] = None,
limit: int = 50):
"""List security scan reports"""
reports = list(scan_reports.values())
# Apply filters
if plugin_id:
reports = [r for r in reports if r.get("plugin_id") == plugin_id]
if status:
reports = [r for r in reports if r.get("status") == status]
# Sort by scan date (most recent first)
reports.sort(key=lambda x: x.get("scan_date", ""), reverse=True)
return {
"reports": reports[:limit],
"total_reports": len(reports),
"filters": {
"plugin_id": plugin_id,
"status": status,
"limit": limit
}
}
@app.get("/api/v1/security/vulnerabilities")
async def list_vulnerabilities(severity: Optional[str] = None,
plugin_id: Optional[str] = None):
"""List known vulnerabilities"""
vulnerabilities = list(vulnerability_database.values())
# Apply filters
if severity:
vulnerabilities = [v for v in vulnerabilities if v["severity"] == severity]
if plugin_id:
vulnerabilities = [v for v in vulnerabilities if v.get("plugin_id") == plugin_id]
return {
"vulnerabilities": vulnerabilities,
"total_vulnerabilities": len(vulnerabilities),
"filters": {
"severity": severity,
"plugin_id": plugin_id
}
}
@app.post("/api/v1/security/policies")
async def create_security_policy(policy: Dict[str, Any]):
"""Create a new security policy"""
policy_id = f"policy_{int(datetime.now(timezone.utc).timestamp())}"
policy_record = {
"policy_id": policy_id,
"name": policy.get("name"),
"description": policy.get("description"),
"rules": policy.get("rules", []),
"severity_thresholds": policy.get("severity_thresholds", {
"critical": 0,
"high": 0,
"medium": 5,
"low": 10
}),
"plugin_types": policy.get("plugin_types", []),
"active": True,
"created_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat()
}
security_policies[policy_id] = policy_record
logger.info(f"Security policy created: {policy_id} - {policy.get('name')}")
return {
"policy_id": policy_id,
"name": policy.get("name"),
"status": "created",
"active": True
}
@app.get("/api/v1/security/policies")
async def list_security_policies():
"""List all security policies"""
return {
"policies": list(security_policies.values()),
"total_policies": len(security_policies),
"active_policies": len([p for p in security_policies.values() if p["active"]])
}
@app.post("/api/v1/security/upload")
async def upload_plugin_for_scan(plugin_id: str, version: str,
file: UploadFile = File(...)):
"""Upload plugin file for security scanning"""
# Validate file
if not file.filename.endswith(('.py', '.zip', '.tar.gz')):
raise HTTPException(status_code=400, detail="Invalid file type")
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as tmp_file:
content = await file.read()
tmp_file.write(content)
tmp_file_path = tmp_file.name
# Initiate scan
scan = SecurityScan(
plugin_id=plugin_id,
version=version,
plugin_type="uploaded",
scan_type="comprehensive",
priority="medium"
)
scan_result = await initiate_security_scan(scan)
# Start async scan process
asyncio.create_task(process_scan_file(scan_result["scan_id"], tmp_file_path, file.filename))
return {
"scan_id": scan_result["scan_id"],
"filename": file.filename,
"file_size": len(content),
"status": "uploaded_and_queued"
}
@app.get("/api/v1/security/dashboard")
async def get_security_dashboard():
"""Get security dashboard data"""
total_scans = len(scan_reports)
recent_scans = [r for r in scan_reports.values()
if datetime.fromisoformat(r["scan_date"]) > datetime.now(timezone.utc) - timedelta(days=7)]
# Calculate statistics
scan_results = list(scan_reports.values())
passed_scans = len([r for r in scan_results if r.get("overall_score") == "passed"])
warning_scans = len([r for r in scan_results if r.get("overall_score") == "warning"])
failed_scans = len([r for r in scan_results if r.get("overall_score") in ["failed", "critical"]])
# Vulnerability statistics
all_vulnerabilities = []
for report in scan_results:
all_vulnerabilities.extend(report.get("vulnerabilities", []))
vuln_by_severity = {"critical": 0, "high": 0, "medium": 0, "low": 0}
for vuln in all_vulnerabilities:
vuln_by_severity[vuln["severity"]] = vuln_by_severity.get(vuln["severity"], 0) + 1
return {
"dashboard": {
"total_scans": total_scans,
"recent_scans": len(recent_scans),
"scan_results": {
"passed": passed_scans,
"warning": warning_scans,
"failed": failed_scans
},
"vulnerabilities": {
"total": len(all_vulnerabilities),
"by_severity": vuln_by_severity
},
"queue_size": len(scan_queue),
"active_policies": len([p for p in security_policies.values() if p["active"]])
},
"generated_at": datetime.now(timezone.utc).isoformat()
}
# Core security scanning functions
async def process_scan_file(scan_id: str, file_path: str, filename: str):
"""Process uploaded file for security scanning"""
try:
# Update scan status
for scan_record in scan_queue:
if scan_record["scan_id"] == scan_id:
scan_record["status"] = "running"
scan_record["started_at"] = datetime.now(timezone.utc).isoformat()
break
start_time = datetime.now(timezone.utc)
# Perform security scan
scan_result = await perform_security_scan(file_path, filename)
end_time = datetime.now(timezone.utc)
duration = (end_time - start_time).total_seconds()
# Create security report
security_report = SecurityReport(
scan_id=scan_id,
plugin_id=scan_record["plugin_id"],
version=scan_record["version"],
scan_date=end_time,
scan_duration=duration,
overall_score=calculate_overall_score(scan_result),
vulnerabilities=scan_result["vulnerabilities"],
security_metrics=scan_result["metrics"],
recommendations=scan_result["recommendations"]
)
# Save report
report_data = {
"scan_id": scan_id,
"plugin_id": scan_record["plugin_id"],
"version": scan_record["version"],
"scan_date": security_report.scan_date.isoformat(),
"scan_duration": security_report.scan_duration,
"overall_score": security_report.overall_score,
"vulnerabilities": [v.dict() for v in security_report.vulnerabilities],
"security_metrics": security_report.security_metrics,
"recommendations": security_report.recommendations,
"status": "completed",
"completed_at": security_report.scan_date.isoformat()
}
scan_reports[scan_id] = report_data
# Remove from queue
scan_queue[:] = [s for s in scan_queue if s["scan_id"] != scan_id]
# Clean up temporary file
os.unlink(file_path)
logger.info(f"Security scan completed: {scan_id} - {security_report.overall_score}")
except Exception as e:
logger.error(f"Error processing scan {scan_id}: {str(e)}")
# Update scan status to failed
for scan_record in scan_queue:
if scan_record["scan_id"] == scan_id:
scan_record["status"] = "failed"
scan_record["completed_at"] = datetime.now(timezone.utc).isoformat()
break
async def perform_security_scan(file_path: str, filename: str) -> Dict[str, Any]:
"""Perform actual security scanning"""
vulnerabilities = []
metrics = {}
recommendations = []
# File analysis
try:
# Basic file checks
file_size = os.path.getsize(file_path)
metrics["file_size"] = file_size
# Check for suspicious patterns (simplified for demo)
if filename.endswith('.py'):
vulnerabilities.extend(scan_python_file(file_path))
elif filename.endswith('.zip'):
vulnerabilities.extend(scan_zip_file(file_path))
# Check common vulnerabilities
vulnerabilities.extend(check_common_vulnerabilities(file_path))
# Generate recommendations
recommendations = generate_recommendations(vulnerabilities)
# Calculate metrics
metrics.update({
"vulnerability_count": len(vulnerabilities),
"severity_distribution": get_severity_distribution(vulnerabilities),
"file_type": filename.split('.')[-1],
"scan_timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as e:
logger.error(f"Error during security scan: {str(e)}")
vulnerabilities.append({
"severity": "medium",
"title": "Scan Error",
"description": f"Error during scanning: {str(e)}",
"affected_file": filename,
"recommendation": "Review file and rescan"
})
return {
"vulnerabilities": vulnerabilities,
"metrics": metrics,
"recommendations": recommendations
}
async def scan_python_file(file_path: str) -> List[Dict]:
"""Scan Python file for security issues"""
vulnerabilities = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
lines = content.split('\n')
# Check for suspicious patterns
suspicious_patterns = {
"eval": "Use of eval() function",
"exec": "Use of exec() function",
"subprocess.call": "Unsafe subprocess usage",
"os.system": "Use of os.system() function",
"pickle.loads": "Unsafe pickle deserialization",
"input(": "Use of input() function"
}
for i, line in enumerate(lines, 1):
for pattern, description in suspicious_patterns.items():
if pattern in line:
vulnerabilities.append({
"severity": "medium",
"title": "Suspicious Code Pattern",
"description": description,
"affected_file": file_path,
"line_number": i,
"recommendation": f"Review usage of {pattern} and consider safer alternatives"
})
# Check for hardcoded credentials
if any('password' in line.lower() or 'secret' in line.lower() or 'key' in line.lower()
for line in lines):
vulnerabilities.append({
"severity": "high",
"title": "Potential Hardcoded Credentials",
"description": "Possible hardcoded sensitive information detected",
"affected_file": file_path,
"recommendation": "Use environment variables or secure configuration management"
})
except Exception as e:
logger.error(f"Error scanning Python file: {str(e)}")
return vulnerabilities
async def scan_zip_file(file_path: str) -> List[Dict]:
"""Scan ZIP file for security issues"""
vulnerabilities = []
try:
import zipfile
with zipfile.ZipFile(file_path, 'r') as zip_file:
# Check for suspicious files
for file_info in zip_file.filelist:
filename = file_info.filename.lower()
# Check for suspicious file types
suspicious_extensions = ['.exe', '.bat', '.cmd', '.scr', '.dll', '.so']
if any(filename.endswith(ext) for ext in suspicious_extensions):
vulnerabilities.append({
"severity": "high",
"title": "Suspicious File Type",
"description": f"Suspicious file found in archive: {filename}",
"affected_file": file_path,
"recommendation": "Review file contents and ensure they are safe"
})
# Check for large files (potential data exfiltration)
if file_info.file_size > 100 * 1024 * 1024: # 100MB
vulnerabilities.append({
"severity": "medium",
"title": "Large File Detected",
"description": f"Large file detected: {filename} ({file_info.file_size} bytes)",
"affected_file": file_path,
"recommendation": "Verify file contents and necessity"
})
except Exception as e:
logger.error(f"Error scanning ZIP file: {str(e)}")
vulnerabilities.append({
"severity": "medium",
"title": "ZIP Scan Error",
"description": f"Error scanning ZIP file: {str(e)}",
"affected_file": file_path,
"recommendation": "Verify ZIP file integrity"
})
return vulnerabilities
async def check_common_vulnerabilities(file_path: str) -> List[Dict]:
"""Check for common security vulnerabilities"""
vulnerabilities = []
# Mock vulnerability database check
known_vulnerabilities = {
"requests": "Check for outdated requests library",
"urllib": "Check for urllib security issues",
"socket": "Check for unsafe socket usage"
}
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
for lib, issue in known_vulnerabilities.items():
if lib in content:
vulnerabilities.append({
"severity": "low",
"title": f"Library Security Check",
"description": issue,
"affected_file": file_path,
"recommendation": f"Update {lib} to latest secure version"
})
except Exception as e:
logger.error(f"Error checking common vulnerabilities: {str(e)}")
return vulnerabilities
def calculate_overall_score(scan_result: Dict[str, Any]) -> str:
"""Calculate overall security score"""
vulnerabilities = scan_result["vulnerabilities"]
if not vulnerabilities:
return "passed"
# Count by severity
critical_count = len([v for v in vulnerabilities if v["severity"] == "critical"])
high_count = len([v for v in vulnerabilities if v["severity"] == "high"])
medium_count = len([v for v in vulnerabilities if v["severity"] == "medium"])
low_count = len([v for v in vulnerabilities if v["severity"] == "low"])
# Determine overall score
if critical_count > 0:
return "critical"
elif high_count > 2:
return "failed"
elif high_count > 0 or medium_count > 5:
return "warning"
else:
return "passed"
def generate_recommendations(vulnerabilities: List[Dict]) -> List[str]:
"""Generate security recommendations"""
recommendations = []
if not vulnerabilities:
recommendations.append("No security issues detected. Plugin appears secure.")
return recommendations
# Generate recommendations based on vulnerabilities
severity_counts = {}
for vuln in vulnerabilities:
severity = vuln["severity"]
severity_counts[severity] = severity_counts.get(severity, 0) + 1
if severity_counts.get("critical", 0) > 0:
recommendations.append("CRITICAL: Address critical security vulnerabilities immediately.")
if severity_counts.get("high", 0) > 0:
recommendations.append("HIGH: Review and fix high-severity security issues.")
if severity_counts.get("medium", 0) > 3:
recommendations.append("MEDIUM: Consider addressing medium-severity issues.")
recommendations.append("Regular security scans recommended for ongoing protection.")
recommendations.append("Keep all dependencies updated to latest secure versions.")
return recommendations
def get_severity_distribution(vulnerabilities: List[Dict]) -> Dict[str, int]:
"""Get vulnerability severity distribution"""
distribution = {"critical": 0, "high": 0, "medium": 0, "low": 0}
for vuln in vulnerabilities:
severity = vuln["severity"]
distribution[severity] = distribution.get(severity, 0) + 1
return distribution
def estimate_scan_time(scan_type: str) -> str:
"""Estimate scan time based on scan type"""
estimates = {
"basic": "1-2 minutes",
"comprehensive": "5-10 minutes",
"deep": "15-30 minutes"
}
return estimates.get(scan_type, "5-10 minutes")
# Background task for processing scan queue
async def process_scan_queue():
"""Background task to process security scan queue"""
while True:
await asyncio.sleep(10) # Check queue every 10 seconds
if scan_queue:
# Get next scan from queue
scan_record = scan_queue[0]
# Process scan (in production, this would be more sophisticated)
logger.info(f"Processing scan from queue: {scan_record['scan_id']}")
# Simulate processing time
await asyncio.sleep(2)
@app.on_event("startup")
async def startup_event():
logger.info("Starting AITBC Plugin Security Service")
# Initialize vulnerability database
initialize_vulnerability_database()
# Start queue processing
asyncio.create_task(process_scan_queue())
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Shutting down AITBC Plugin Security Service")
def initialize_vulnerability_database():
"""Initialize vulnerability database with known issues"""
# Mock data for demo
vulnerabilities = [
{
"vuln_id": "CVE-2023-1234",
"severity": "high",
"title": "Buffer Overflow in Library X",
"description": "Buffer overflow vulnerability in commonly used library",
"affected_plugins": ["plugin1", "plugin2"],
"recommendation": "Update to latest version"
},
{
"vuln_id": "CVE-2023-5678",
"severity": "medium",
"title": "Information Disclosure",
"description": "Potential information disclosure in logging",
"affected_plugins": ["plugin3"],
"recommendation": "Review logging implementation"
}
]
for vuln in vulnerabilities:
vulnerability_database[vuln["vuln_id"]] = vuln
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8015, log_level="info")

View File

@@ -0,0 +1 @@
"""Plugin security service tests"""

View File

@@ -0,0 +1,159 @@
"""Edge case and error handling tests for plugin security service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime
from main import app, SecurityScan, scan_reports, security_policies, scan_queue, vulnerability_database
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
scan_reports.clear()
security_policies.clear()
scan_queue.clear()
vulnerability_database.clear()
yield
scan_reports.clear()
security_policies.clear()
scan_queue.clear()
vulnerability_database.clear()
@pytest.mark.unit
def test_security_scan_empty_fields():
"""Test SecurityScan with empty fields"""
scan = SecurityScan(
plugin_id="",
version="",
plugin_type="",
scan_type="",
priority=""
)
assert scan.plugin_id == ""
assert scan.version == ""
@pytest.mark.unit
def test_vulnerability_empty_description():
"""Test Vulnerability with empty description"""
vuln = {
"severity": "low",
"title": "Test",
"description": "",
"affected_file": "file.py",
"recommendation": "Fix"
}
assert vuln["description"] == ""
@pytest.mark.integration
def test_create_security_policy_minimal():
"""Test creating security policy with minimal fields"""
client = TestClient(app)
policy = {
"name": "Minimal Policy"
}
response = client.post("/api/v1/security/policies", json=policy)
assert response.status_code == 200
data = response.json()
assert data["policy_id"]
assert data["name"] == "Minimal Policy"
@pytest.mark.integration
def test_create_security_policy_empty_name():
"""Test creating security policy with empty name"""
client = TestClient(app)
policy = {}
response = client.post("/api/v1/security/policies", json=policy)
assert response.status_code == 200
@pytest.mark.integration
def test_list_security_reports_with_no_reports():
"""Test listing security reports when no reports exist"""
client = TestClient(app)
response = client.get("/api/v1/security/reports")
assert response.status_code == 200
data = response.json()
assert data["total_reports"] == 0
@pytest.mark.integration
def test_list_vulnerabilities_with_no_vulnerabilities():
"""Test listing vulnerabilities when no vulnerabilities exist"""
client = TestClient(app)
response = client.get("/api/v1/security/vulnerabilities")
assert response.status_code == 200
data = response.json()
assert data["total_vulnerabilities"] == 0
@pytest.mark.integration
def test_list_security_policies_with_no_policies():
"""Test listing security policies when no policies exist"""
client = TestClient(app)
response = client.get("/api/v1/security/policies")
assert response.status_code == 200
data = response.json()
assert data["total_policies"] == 0
@pytest.mark.integration
def test_scan_priority_ordering():
"""Test that scan queue respects priority ordering"""
client = TestClient(app)
# Add scans in random priority order
priorities = ["low", "critical", "medium", "high"]
for priority in priorities:
scan = SecurityScan(
plugin_id=f"plugin_{priority}",
version="1.0.0",
plugin_type="cli",
scan_type="basic",
priority=priority
)
client.post("/api/v1/security/scan", json=scan.model_dump())
# Critical should be first, low should be last
response = client.get("/api/v1/security/scan/nonexistent")
# This will fail, but we can check queue size
assert len(scan_queue) == 4
@pytest.mark.integration
def test_security_dashboard_with_no_data():
"""Test security dashboard with no data"""
client = TestClient(app)
response = client.get("/api/v1/security/dashboard")
assert response.status_code == 200
data = response.json()
assert data["dashboard"]["total_scans"] == 0
assert data["dashboard"]["queue_size"] == 0
@pytest.mark.integration
def test_list_reports_limit_parameter():
"""Test listing reports with limit parameter"""
client = TestClient(app)
response = client.get("/api/v1/security/reports?limit=5")
assert response.status_code == 200
data = response.json()
assert "reports" in data
@pytest.mark.integration
def test_list_vulnerabilities_invalid_filter():
"""Test listing vulnerabilities with invalid filter"""
client = TestClient(app)
response = client.get("/api/v1/security/vulnerabilities?severity=invalid")
assert response.status_code == 200
data = response.json()
assert data["total_vulnerabilities"] == 0

View File

@@ -0,0 +1,217 @@
"""Integration tests for plugin security service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime
from main import app, SecurityScan, scan_reports, security_policies, scan_queue, vulnerability_database
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
scan_reports.clear()
security_policies.clear()
scan_queue.clear()
vulnerability_database.clear()
yield
scan_reports.clear()
security_policies.clear()
scan_queue.clear()
vulnerability_database.clear()
@pytest.mark.integration
def test_root_endpoint():
"""Test root endpoint"""
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["service"] == "AITBC Plugin Security Service"
assert data["status"] == "running"
@pytest.mark.integration
def test_health_check_endpoint():
"""Test health check endpoint"""
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "total_scans" in data
assert "queue_size" in data
@pytest.mark.integration
def test_initiate_security_scan():
"""Test initiating a security scan"""
client = TestClient(app)
scan = SecurityScan(
plugin_id="plugin_123",
version="1.0.0",
plugin_type="cli",
scan_type="comprehensive",
priority="high"
)
response = client.post("/api/v1/security/scan", json=scan.model_dump())
assert response.status_code == 200
data = response.json()
assert data["scan_id"]
assert data["status"] == "queued"
assert "queue_position" in data
@pytest.mark.integration
def test_get_scan_status_queued():
"""Test getting scan status for queued scan"""
client = TestClient(app)
scan = SecurityScan(
plugin_id="plugin_123",
version="1.0.0",
plugin_type="cli",
scan_type="basic",
priority="medium"
)
scan_response = client.post("/api/v1/security/scan", json=scan.model_dump())
scan_id = scan_response.json()["scan_id"]
response = client.get(f"/api/v1/security/scan/{scan_id}")
assert response.status_code == 200
data = response.json()
assert data["scan_id"] == scan_id
assert data["status"] == "queued"
@pytest.mark.integration
def test_get_scan_status_not_found():
"""Test getting scan status for nonexistent scan"""
client = TestClient(app)
response = client.get("/api/v1/security/scan/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_list_security_reports():
"""Test listing security reports"""
client = TestClient(app)
response = client.get("/api/v1/security/reports")
assert response.status_code == 200
data = response.json()
assert "reports" in data
assert "total_reports" in data
@pytest.mark.integration
def test_list_security_reports_with_filters():
"""Test listing security reports with filters"""
client = TestClient(app)
response = client.get("/api/v1/security/reports?plugin_id=plugin_123&status=completed")
assert response.status_code == 200
data = response.json()
assert "reports" in data
@pytest.mark.integration
def test_list_vulnerabilities():
"""Test listing vulnerabilities"""
client = TestClient(app)
response = client.get("/api/v1/security/vulnerabilities")
assert response.status_code == 200
data = response.json()
assert "vulnerabilities" in data
assert "total_vulnerabilities" in data
@pytest.mark.integration
def test_list_vulnerabilities_with_filters():
"""Test listing vulnerabilities with filters"""
client = TestClient(app)
response = client.get("/api/v1/security/vulnerabilities?severity=high&plugin_id=plugin_123")
assert response.status_code == 200
data = response.json()
assert "vulnerabilities" in data
@pytest.mark.integration
def test_create_security_policy():
"""Test creating a security policy"""
client = TestClient(app)
policy = {
"name": "Test Policy",
"description": "A test security policy",
"rules": ["rule1", "rule2"],
"severity_thresholds": {
"critical": 0,
"high": 0,
"medium": 5,
"low": 10
},
"plugin_types": ["cli", "web"]
}
response = client.post("/api/v1/security/policies", json=policy)
assert response.status_code == 200
data = response.json()
assert data["policy_id"]
assert data["name"] == "Test Policy"
assert data["active"] is True
@pytest.mark.integration
def test_list_security_policies():
"""Test listing security policies"""
client = TestClient(app)
response = client.get("/api/v1/security/policies")
assert response.status_code == 200
data = response.json()
assert "policies" in data
assert "total_policies" in data
@pytest.mark.integration
def test_get_security_dashboard():
"""Test getting security dashboard"""
client = TestClient(app)
response = client.get("/api/v1/security/dashboard")
assert response.status_code == 200
data = response.json()
assert "dashboard" in data
assert "total_scans" in data["dashboard"]
assert "vulnerabilities" in data["dashboard"]
@pytest.mark.integration
def test_scan_priority_queueing():
"""Test that scans are queued by priority"""
client = TestClient(app)
# Add low priority scan
scan_low = SecurityScan(
plugin_id="plugin_low",
version="1.0.0",
plugin_type="cli",
scan_type="basic",
priority="low"
)
client.post("/api/v1/security/scan", json=scan_low.model_dump())
# Add critical priority scan
scan_critical = SecurityScan(
plugin_id="plugin_critical",
version="1.0.0",
plugin_type="cli",
scan_type="basic",
priority="critical"
)
response = client.post("/api/v1/security/scan", json=scan_critical.model_dump())
scan_id = response.json()["scan_id"]
# Critical scan should be at position 1
response = client.get(f"/api/v1/security/scan/{scan_id}")
data = response.json()
assert data["queue_position"] == 1

View File

@@ -0,0 +1,205 @@
"""Unit tests for plugin security service"""
import pytest
import sys
import sys
from pathlib import Path
from datetime import datetime, timezone
from main import app, SecurityScan, Vulnerability, SecurityReport, calculate_overall_score, generate_recommendations, get_severity_distribution, estimate_scan_time
@pytest.mark.unit
def test_app_initialization():
"""Test that the FastAPI app initializes correctly"""
assert app is not None
assert app.title == "AITBC Plugin Security Service"
assert app.version == "1.0.0"
@pytest.mark.unit
def test_security_scan_model():
"""Test SecurityScan model"""
scan = SecurityScan(
plugin_id="plugin_123",
version="1.0.0",
plugin_type="cli",
scan_type="comprehensive",
priority="high"
)
assert scan.plugin_id == "plugin_123"
assert scan.version == "1.0.0"
assert scan.plugin_type == "cli"
assert scan.scan_type == "comprehensive"
assert scan.priority == "high"
@pytest.mark.unit
def test_vulnerability_model():
"""Test Vulnerability model"""
vuln = Vulnerability(
cve_id="CVE-2023-1234",
severity="high",
title="Buffer Overflow",
description="Buffer overflow vulnerability",
affected_file="file.py",
line_number=42,
recommendation="Update to latest version"
)
assert vuln.cve_id == "CVE-2023-1234"
assert vuln.severity == "high"
assert vuln.title == "Buffer Overflow"
assert vuln.line_number == 42
@pytest.mark.unit
def test_vulnerability_model_optional_fields():
"""Test Vulnerability model with optional fields"""
vuln = Vulnerability(
cve_id=None,
severity="low",
title="Minor issue",
description="Description",
affected_file="file.py",
line_number=None,
recommendation="Fix it"
)
assert vuln.cve_id is None
assert vuln.line_number is None
@pytest.mark.unit
def test_security_report_model():
"""Test SecurityReport model"""
report = SecurityReport(
scan_id="scan_123",
plugin_id="plugin_123",
version="1.0.0",
scan_date=datetime.now(timezone.utc),
scan_duration=120.5,
overall_score="passed",
vulnerabilities=[],
security_metrics={},
recommendations=[]
)
assert report.scan_id == "scan_123"
assert report.overall_score == "passed"
assert report.scan_duration == 120.5
@pytest.mark.unit
def test_calculate_overall_score_passed():
"""Test calculate overall score with no vulnerabilities"""
scan_result = {"vulnerabilities": []}
score = calculate_overall_score(scan_result)
assert score == "passed"
@pytest.mark.unit
def test_calculate_overall_score_critical():
"""Test calculate overall score with critical vulnerability"""
scan_result = {
"vulnerabilities": [
{"severity": "critical"},
{"severity": "low"}
]
}
score = calculate_overall_score(scan_result)
assert score == "critical"
@pytest.mark.unit
def test_calculate_overall_score_failed():
"""Test calculate overall score with multiple high vulnerabilities"""
scan_result = {
"vulnerabilities": [
{"severity": "high"},
{"severity": "high"},
{"severity": "high"}
]
}
score = calculate_overall_score(scan_result)
assert score == "failed"
@pytest.mark.unit
def test_calculate_overall_score_warning():
"""Test calculate overall score with high and medium vulnerabilities"""
scan_result = {
"vulnerabilities": [
{"severity": "high"},
{"severity": "medium"},
{"severity": "medium"},
{"severity": "medium"},
{"severity": "medium"},
{"severity": "medium"}
]
}
score = calculate_overall_score(scan_result)
assert score == "warning"
@pytest.mark.unit
def test_generate_recommendations_no_vulnerabilities():
"""Test generate recommendations with no vulnerabilities"""
recommendations = generate_recommendations([])
assert len(recommendations) == 1
assert "No security issues detected" in recommendations[0]
@pytest.mark.unit
def test_generate_recommendations_critical():
"""Test generate recommendations with critical vulnerabilities"""
vulnerabilities = [
{"severity": "critical"},
{"severity": "high"}
]
recommendations = generate_recommendations(vulnerabilities)
assert any("CRITICAL" in r for r in recommendations)
assert any("HIGH" in r for r in recommendations)
@pytest.mark.unit
def test_get_severity_distribution():
"""Test get severity distribution"""
vulnerabilities = [
{"severity": "critical"},
{"severity": "high"},
{"severity": "high"},
{"severity": "medium"},
{"severity": "low"}
]
distribution = get_severity_distribution(vulnerabilities)
assert distribution["critical"] == 1
assert distribution["high"] == 2
assert distribution["medium"] == 1
assert distribution["low"] == 1
@pytest.mark.unit
def test_estimate_scan_time_basic():
"""Test estimate scan time for basic scan"""
time = estimate_scan_time("basic")
assert time == "1-2 minutes"
@pytest.mark.unit
def test_estimate_scan_time_comprehensive():
"""Test estimate scan time for comprehensive scan"""
time = estimate_scan_time("comprehensive")
assert time == "5-10 minutes"
@pytest.mark.unit
def test_estimate_scan_time_deep():
"""Test estimate scan time for deep scan"""
time = estimate_scan_time("deep")
assert time == "15-30 minutes"
@pytest.mark.unit
def test_estimate_scan_time_unknown():
"""Test estimate scan time for unknown scan type"""
time = estimate_scan_time("unknown")
assert time == "5-10 minutes"

1292
examples/stubs/plugin-service/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,25 @@
[tool.poetry]
name = "plugin-service"
version = "0.1.0"
description = "AITBC Plugin Service for plugin registration, marketplace, and analytics"
authors = ["AITBC Team"]
[tool.poetry.dependencies]
python = "^3.13"
fastapi = ">=0.115.6"
uvicorn = {extras = ["standard"], version = "^0.32.0"}
sqlmodel = "^0.0.37"
sqlalchemy = "^2.0.25"
pydantic = "^2.6.0"
pydantic-settings = "^2.1.0"
httpx = ">=0.28.1"
[tool.poetry.group.dev.dependencies]
pytest = ">=9.0.3"
pytest-asyncio = ">=1.3.0"
black = ">=26.3.1"
ruff = "^0.1.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1,71 @@
"""Plugin Service for plugin registration, marketplace, and analytics."""
from __future__ import annotations
import os
import logging
from typing import Any
from fastapi import FastAPI
logger = logging.getLogger(__name__)
app = FastAPI(
title="AITBC Plugin Service",
description="Plugin registration, marketplace, and analytics service",
version="1.0.0"
)
@app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "healthy", "service": "plugin-service"}
@app.get("/")
async def root():
"""Root endpoint."""
return {
"service": "AITBC Plugin Service",
"version": "1.0.0",
"status": "operational"
}
@app.post("/register")
async def register_plugin(request: dict[str, Any]) -> dict[str, Any]:
"""Register a new plugin"""
return {
"plugin_id": "plugin_123",
"name": request.get("name", "unknown"),
"version": request.get("version", "1.0.0"),
"status": "registered"
}
@app.get("/marketplace/plugins")
async def list_marketplace_plugins() -> dict[str, Any]:
"""List all plugins in marketplace"""
return {
"plugins": [
{"id": "plugin_1", "name": "GPU Optimizer", "version": "1.0.0", "category": "performance"},
{"id": "plugin_2", "name": "Analytics Dashboard", "version": "2.0.0", "category": "analytics"},
],
"total": 2
}
@app.get("/analytics/plugins")
async def get_plugin_analytics() -> dict[str, Any]:
"""Get plugin analytics data"""
return {
"total_plugins": 2,
"active_installs": 150,
"downloads": 500,
"popular_categories": ["performance", "analytics"]
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8109)

View File

@@ -0,0 +1,238 @@
#!/usr/bin/env python3
"""
Simple AITBC Blockchain Explorer - Demonstrating the issues described in the analysis
"""
import os
import asyncio
import re
from datetime import datetime
from typing import Dict, Any, Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
import uvicorn
from aitbc.network.http_client import AsyncAITBCHTTPClient
from aitbc.aitbc_logging import get_logger
from aitbc.exceptions import NetworkError
app = FastAPI(title="Simple AITBC Explorer", version="0.1.0")
# Initialize logger
logger = get_logger(__name__)
# Configuration
BLOCKCHAIN_RPC_URL = "http://localhost:8025"
# Validation patterns for user inputs to prevent SSRF
TX_HASH_PATTERN = re.compile(r'^[a-fA-F0-9]{64}$') # 64-character hex string for transaction hash
def validate_tx_hash(tx_hash: str) -> bool:
"""Validate transaction hash to prevent SSRF"""
if not tx_hash:
return False
# Check for path traversal or URL manipulation
if any(char in tx_hash for char in ['/', '\\', '..', '\n', '\r', '\t', '?', '&']):
return False
# Validate against hash pattern
return bool(TX_HASH_PATTERN.match(tx_hash))
# HTML Template with the problematic frontend
HTML_TEMPLATE = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Simple AITBC Explorer</title>
<script src="https://cdn.tailwindcss.com"></script>
</head>
<body class="bg-gray-50">
<div class="container mx-auto px-4 py-8">
<h1 class="text-3xl font-bold mb-8">AITBC Blockchain Explorer</h1>
<!-- Search -->
<div class="bg-white rounded-lg shadow p-6 mb-8">
<h2 class="text-xl font-semibold mb-4">Search</h2>
<div class="flex space-x-4">
<input type="text" id="search-input" placeholder="Search by transaction hash (64 chars)"
class="flex-1 px-4 py-2 border rounded-lg">
<button onclick="performSearch()" class="bg-blue-600 text-white px-6 py-2 rounded-lg">
Search
</button>
</div>
</div>
<!-- Results -->
<div id="results" class="hidden bg-white rounded-lg shadow p-6">
<h2 class="text-xl font-semibold mb-4">Transaction Details</h2>
<div id="tx-details"></div>
</div>
<!-- Latest Blocks -->
<div class="bg-white rounded-lg shadow p-6">
<h2 class="text-xl font-semibold mb-4">Latest Blocks</h2>
<div id="blocks-list"></div>
</div>
</div>
<script>
// Problem 1: Frontend calls /api/transactions/{hash} but backend doesn't have it
async function performSearch() {
const query = document.getElementById('search-input').value.trim();
if (!query) return;
if (/^[a-fA-F0-9]{64}$/.test(query)) {
try {
const tx = await fetch(`/api/transactions/${query}`).then(r => {
if (!r.ok) throw new Error('Transaction not found');
return r.json();
});
showTransactionDetails(tx);
} catch (error) {
alert('Transaction not found');
}
} else {
alert('Please enter a valid 64-character hex transaction hash');
}
}
// Problem 2: UI expects tx.hash, tx.from, tx.to, tx.amount, tx.fee
// But RPC returns tx_hash, sender, recipient, payload, created_at
function showTransactionDetails(tx) {
const resultsDiv = document.getElementById('results');
const detailsDiv = document.getElementById('tx-details');
detailsDiv.innerHTML = `
<div class="space-y-4">
<div><strong>Hash:</strong> ${tx.hash || 'N/A'}</div>
<div><strong>From:</strong> ${tx.from || 'N/A'}</div>
<div><strong>To:</strong> ${tx.to || 'N/A'}</div>
<div><strong>Amount:</strong> ${tx.amount || 'N/A'}</div>
<div><strong>Fee:</strong> ${tx.fee || 'N/A'}</div>
<div><strong>Timestamp:</strong> ${formatTimestamp(tx.timestamp)}</div>
</div>
`;
resultsDiv.classList.remove('hidden');
}
// Problem 3: formatTimestamp now handles both numeric and ISO string timestamps
function formatTimestamp(timestamp) {
if (!timestamp) return 'N/A';
// Handle ISO string timestamps
if (typeof timestamp === 'string') {
try {
return new Date(timestamp).toLocaleString();
} catch (e) {
return 'Invalid timestamp';
}
}
// Handle numeric timestamps (Unix seconds)
if (typeof timestamp === 'number') {
try {
return new Date(timestamp * 1000).toLocaleString();
} catch (e) {
return 'Invalid timestamp';
}
}
return 'Invalid timestamp format';
}
// Load latest blocks
async function loadBlocks() {
try {
const head = await fetch('/api/chain/head').then(r => r.json());
const blocksList = document.getElementById('blocks-list');
let html = '<div class="space-y-4">';
for (let i = 0; i < 5 && head.height - i >= 0; i++) {
const block = await fetch(`/api/blocks/${head.height - i}`).then(r => r.json());
html += `
<div class="border rounded p-4">
<div><strong>Height:</strong> ${block.height}</div>
<div><strong>Hash:</strong> ${block.hash ? block.hash.substring(0, 16) + '...' : 'N/A'}</div>
<div><strong>Time:</strong> ${formatTimestamp(block.timestamp)}</div>
</div>
`;
}
html += '</div>';
blocksList.innerHTML = html;
} catch (error) {
console.error('Failed to load blocks:', error);
}
}
// Initialize
document.addEventListener('DOMContentLoaded', () => {
loadBlocks();
});
</script>
</body>
</html>
"""
# Problem 1: Only /api/chain/head and /api/blocks/{height} defined, missing /api/transactions/{hash}
@app.get("/api/chain/head")
async def get_chain_head():
"""Get current chain head"""
try:
client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10)
response = await client.async_get("/rpc/head")
if response:
return response
except NetworkError as e:
logger.error(f"Error getting chain head: {e}")
return {"height": 0, "hash": "", "timestamp": None}
@app.get("/api/blocks/{height}")
async def get_block(height: int):
"""Get block by height"""
# Validate height is non-negative and reasonable
if height < 0 or height > 10000000:
return {"height": height, "hash": "", "timestamp": None, "transactions": []}
try:
client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10)
response = await client.async_get(f"/rpc/blocks/{height}")
if response:
return response
except NetworkError as e:
logger.error(f"Error getting block: {e}")
return {"height": height, "hash": "", "timestamp": None, "transactions": []}
@app.get("/api/transactions/{tx_hash}")
async def get_transaction(tx_hash: str):
"""Get transaction by hash - Problem 1: This endpoint was missing"""
if not validate_tx_hash(tx_hash):
return {"hash": tx_hash, "from": "unknown", "to": "unknown", "amount": 0, "timestamp": None}
try:
client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10)
response = await client.async_get(f"/rpc/tx/{tx_hash}")
if response:
# Problem 2: Map RPC schema to UI schema
return {
"hash": response.get("tx_hash", tx_hash), # tx_hash -> hash
"from": response.get("sender", "unknown"), # sender -> from
"to": response.get("recipient", "unknown"), # recipient -> to
"amount": response.get("payload", {}).get("value", "0"), # payload.value -> amount
"fee": response.get("payload", {}).get("fee", "0"), # payload.fee -> fee
"timestamp": response.get("created_at"), # created_at -> timestamp
"block_height": response.get("block_height", "pending")
}
except NetworkError as e:
logger.error(f"Error getting transaction {tx_hash}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to fetch transaction: {str(e)}")
# Missing: @app.get("/api/transactions/{tx_hash}") - THIS IS THE PROBLEM
@app.get("/", response_class=HTMLResponse)
async def root():
"""Serve the explorer UI"""
return HTML_TEMPLATE
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8017)

View File

@@ -0,0 +1 @@
"""Simple explorer service tests"""

View File

@@ -0,0 +1,221 @@
"""Edge case and error handling tests for simple explorer service"""
import pytest
import sys
import sys
from pathlib import Path
from unittest.mock import Mock, patch, AsyncMock
from fastapi.testclient import TestClient
# Mock httpx before importing
sys.modules['httpx'] = Mock()
from main import app
@pytest.mark.unit
def test_get_transaction_missing_fields():
"""Test transaction mapping with missing fields"""
client = TestClient(app)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"tx_hash": "0x" + "a" * 64,
# Missing sender, recipient, payload
"created_at": "2026-01-01T00:00:00"
}
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.return_value = mock_response
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/transactions/" + "a" * 64)
assert response.status_code == 200
data = response.json()
assert data["from"] == "unknown"
assert data["to"] == "unknown"
assert data["amount"] == "0"
assert data["fee"] == "0"
@pytest.mark.unit
def test_get_transaction_empty_payload():
"""Test transaction mapping with empty payload"""
client = TestClient(app)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"tx_hash": "0x" + "a" * 64,
"sender": "0xsender",
"recipient": "0xrecipient",
"payload": {},
"created_at": "2026-01-01T00:00:00"
}
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.return_value = mock_response
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/transactions/" + "a" * 64)
assert response.status_code == 200
data = response.json()
assert data["amount"] == "0"
assert data["fee"] == "0"
@pytest.mark.unit
def test_get_transaction_missing_created_at():
"""Test transaction mapping with missing created_at"""
client = TestClient(app)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"tx_hash": "0x" + "a" * 64,
"sender": "0xsender",
"recipient": "0xrecipient",
"payload": {"value": "1000", "fee": "10"}
# Missing created_at
}
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.return_value = mock_response
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/transactions/" + "a" * 64)
assert response.status_code == 200
data = response.json()
assert data["timestamp"] is None
@pytest.mark.unit
def test_get_transaction_missing_block_height():
"""Test transaction mapping with missing block_height"""
client = TestClient(app)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"tx_hash": "0x" + "a" * 64,
"sender": "0xsender",
"recipient": "0xrecipient",
"payload": {"value": "1000", "fee": "10"},
"created_at": "2026-01-01T00:00:00"
# Missing block_height
}
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.return_value = mock_response
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/transactions/" + "a" * 64)
assert response.status_code == 200
data = response.json()
assert data["block_height"] == "pending"
@pytest.mark.unit
def test_get_block_negative_height():
"""Test /api/blocks/{height} with negative height"""
client = TestClient(app)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"height": -1,
"hash": "0xblock",
"timestamp": 1234567890,
"transactions": []
}
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.return_value = mock_response
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/blocks/-1")
assert response.status_code == 200
data = response.json()
assert data["height"] == -1
@pytest.mark.unit
def test_get_block_zero_height():
"""Test /api/blocks/{height} with zero height"""
client = TestClient(app)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"height": 0,
"hash": "0xgenesis",
"timestamp": 1234567890,
"transactions": []
}
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.return_value = mock_response
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/blocks/0")
assert response.status_code == 200
data = response.json()
assert data["height"] == 0
@pytest.mark.unit
def test_get_transaction_short_hash():
"""Test /api/transactions/{tx_hash} with short hash"""
client = TestClient(app)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"tx_hash": "0x" + "a" * 64,
"sender": "0xsender",
"recipient": "0xrecipient",
"payload": {"value": "1000", "fee": "10"},
"created_at": "2026-01-01T00:00:00",
"block_height": 100
}
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.return_value = mock_response
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/transactions/abc")
assert response.status_code in [200, 404, 500] # Any valid response
@pytest.mark.unit
def test_get_transaction_invalid_hex_hash():
"""Test /api/transactions/{tx_hash} with invalid hex characters"""
client = TestClient(app)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"tx_hash": "0x" + "a" * 64,
"sender": "0xsender",
"recipient": "0xrecipient",
"payload": {"value": "1000", "fee": "10"},
"created_at": "2026-01-01T00:00:00",
"block_height": 100
}
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.return_value = mock_response
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/transactions/" + "z" * 64)
assert response.status_code in [200, 404, 500]

View File

@@ -0,0 +1,170 @@
"""Integration tests for simple explorer service"""
import pytest
import sys
import sys
from pathlib import Path
from unittest.mock import Mock, patch, AsyncMock
from fastapi.testclient import TestClient
# Mock httpx before importing
sys.modules['httpx'] = Mock()
from main import app
@pytest.mark.integration
def test_root_endpoint():
"""Test root endpoint serves HTML"""
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
assert "text/html" in response.headers["content-type"]
assert "AITBC Blockchain Explorer" in response.text
@pytest.mark.integration
def test_get_chain_head_success():
"""Test /api/chain/head endpoint with successful response"""
client = TestClient(app)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"height": 100, "hash": "0xabc123", "timestamp": 1234567890}
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.return_value = mock_response
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/chain/head")
assert response.status_code == 200
data = response.json()
assert data["height"] == 100
assert data["hash"] == "0xabc123"
@pytest.mark.integration
def test_get_chain_head_error():
"""Test /api/chain/head endpoint with error"""
client = TestClient(app)
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.side_effect = Exception("RPC error")
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/chain/head")
assert response.status_code == 200
data = response.json()
assert data["height"] == 0
assert data["hash"] == ""
@pytest.mark.integration
def test_get_block_success():
"""Test /api/blocks/{height} endpoint with successful response"""
client = TestClient(app)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"height": 50,
"hash": "0xblock50",
"timestamp": 1234567890,
"transactions": []
}
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.return_value = mock_response
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/blocks/50")
assert response.status_code == 200
data = response.json()
assert data["height"] == 50
assert data["hash"] == "0xblock50"
@pytest.mark.integration
def test_get_block_error():
"""Test /api/blocks/{height} endpoint with error"""
client = TestClient(app)
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.side_effect = Exception("RPC error")
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/blocks/50")
assert response.status_code == 200
data = response.json()
assert data["height"] == 50
assert data["hash"] == ""
@pytest.mark.integration
def test_get_transaction_success():
"""Test /api/transactions/{tx_hash} endpoint with successful response"""
client = TestClient(app)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"tx_hash": "0x" + "a" * 64,
"sender": "0xsender",
"recipient": "0xrecipient",
"payload": {
"value": "1000",
"fee": "10"
},
"created_at": "2026-01-01T00:00:00",
"block_height": 100
}
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.return_value = mock_response
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/transactions/" + "a" * 64)
assert response.status_code == 200
data = response.json()
assert data["hash"] == "0x" + "a" * 64
assert data["from"] == "0xsender"
assert data["to"] == "0xrecipient"
assert data["amount"] == "1000"
assert data["fee"] == "10"
@pytest.mark.integration
def test_get_transaction_not_found():
"""Test /api/transactions/{tx_hash} endpoint with 404 response"""
client = TestClient(app)
mock_response = Mock()
mock_response.status_code = 404
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.return_value = mock_response
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/transactions/" + "a" * 64)
assert response.status_code == 404
@pytest.mark.integration
def test_get_transaction_error():
"""Test /api/transactions/{tx_hash} endpoint with error"""
client = TestClient(app)
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.get.side_effect = Exception("RPC error")
with patch('main.httpx.AsyncClient', return_value=mock_client):
response = client.get("/api/transactions/" + "a" * 64)
assert response.status_code == 500

View File

@@ -0,0 +1,70 @@
"""Unit tests for simple explorer service"""
import pytest
import sys
import sys
from pathlib import Path
from unittest.mock import Mock, patch, AsyncMock
from datetime import datetime
# Mock httpx before importing
sys.modules['httpx'] = Mock()
from main import app, BLOCKCHAIN_RPC_URL, HTML_TEMPLATE
@pytest.mark.unit
def test_app_initialization():
"""Test that the FastAPI app initializes correctly"""
assert app is not None
assert app.title == "Simple AITBC Explorer"
assert app.version == "0.1.0"
@pytest.mark.unit
def test_blockchain_rpc_url():
"""Test that the blockchain RPC URL is configured"""
assert BLOCKCHAIN_RPC_URL == "http://localhost:8025"
@pytest.mark.unit
def test_html_template_exists():
"""Test that the HTML template is defined"""
assert HTML_TEMPLATE is not None
assert "<!DOCTYPE html>" in HTML_TEMPLATE
assert "AITBC Blockchain Explorer" in HTML_TEMPLATE
@pytest.mark.unit
def test_html_template_has_search():
"""Test that the HTML template has search functionality"""
assert "search-input" in HTML_TEMPLATE
assert "performSearch()" in HTML_TEMPLATE
@pytest.mark.unit
def test_html_template_has_blocks_section():
"""Test that the HTML template has blocks section"""
assert "Latest Blocks" in HTML_TEMPLATE
assert "blocks-list" in HTML_TEMPLATE
@pytest.mark.unit
def test_html_template_has_results_section():
"""Test that the HTML template has results section"""
assert "Transaction Details" in HTML_TEMPLATE
assert "tx-details" in HTML_TEMPLATE
@pytest.mark.unit
def test_html_template_has_tailwind():
"""Test that the HTML template includes Tailwind CSS"""
assert "tailwindcss" in HTML_TEMPLATE
@pytest.mark.unit
def test_html_template_format_timestamp_function():
"""Test that the HTML template has formatTimestamp function"""
assert "formatTimestamp" in HTML_TEMPLATE
assert "toLocaleString" in HTML_TEMPLATE

View File

@@ -0,0 +1,584 @@
"""
Production Trading Engine for AITBC
Handles order matching, trade execution, and settlement
"""
import os
import asyncio
import json
from collections import defaultdict, deque
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from aitbc import get_logger
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
logger.info("Starting AITBC Trading Engine")
# Start background market simulation
asyncio.create_task(simulate_market_activity())
yield
# Shutdown
logger.info("Shutting down AITBC Trading Engine")
app = FastAPI(
title="AITBC Trading Engine",
description="High-performance order matching and trade execution",
version="1.0.0",
lifespan=lifespan
)
# Data models
class Order(BaseModel):
order_id: str
symbol: str
side: str # buy/sell
type: str # market/limit
quantity: float
price: Optional[float] = None
user_id: str
timestamp: datetime
class Trade(BaseModel):
trade_id: str
symbol: str
buy_order_id: str
sell_order_id: str
quantity: float
price: float
timestamp: datetime
class OrderBookEntry(BaseModel):
price: float
quantity: float
orders_count: int
# In-memory order books (in production, use more sophisticated data structures)
order_books: Dict[str, Dict] = {}
orders: Dict[str, Dict] = {}
trades: Dict[str, Dict] = {}
market_data: Dict[str, Dict] = {}
@app.get("/")
async def root():
return {
"service": "AITBC Trading Engine",
"status": "running",
"timestamp": datetime.now(timezone.utc).isoformat(),
"version": "1.0.0"
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"active_order_books": len(order_books),
"total_orders": len(orders),
"total_trades": len(trades),
"uptime": "running"
}
@app.post("/api/v1/orders/submit")
async def submit_order(order: Order):
"""Submit a new order to the trading engine"""
symbol = order.symbol
# Initialize order book if not exists
if symbol not in order_books:
order_books[symbol] = {
"bids": defaultdict(list), # buy orders
"asks": defaultdict(list), # sell orders
"last_price": None,
"volume_24h": 0.0,
"high_24h": None,
"low_24h": None,
"created_at": datetime.now(timezone.utc).isoformat()
}
# Store order
order_data = {
"order_id": order.order_id,
"symbol": order.symbol,
"side": order.side,
"type": order.type,
"quantity": order.quantity,
"remaining_quantity": order.quantity,
"price": order.price,
"user_id": order.user_id,
"timestamp": order.timestamp.isoformat(),
"status": "open",
"filled_quantity": 0.0,
"average_price": None
}
orders[order.order_id] = order_data
# Process order
trades_executed = await process_order(order_data)
logger.info(f"Order submitted: {order.order_id} - {order.side} {order.quantity} {order.symbol}")
return {
"order_id": order.order_id,
"status": order_data["status"],
"filled_quantity": order_data["filled_quantity"],
"remaining_quantity": order_data["remaining_quantity"],
"trades_executed": len(trades_executed),
"average_price": order_data["average_price"]
}
@app.get("/api/v1/orders/{order_id}")
async def get_order(order_id: str):
"""Get order details"""
if order_id not in orders:
raise HTTPException(status_code=404, detail="Order not found")
return orders[order_id]
@app.get("/api/v1/orders")
async def list_orders():
"""List all orders"""
return {
"orders": list(orders.values()),
"total_orders": len(orders),
"open_orders": len([o for o in orders.values() if o["status"] == "open"]),
"filled_orders": len([o for o in orders.values() if o["status"] == "filled"])
}
@app.get("/api/v1/orderbook/{symbol}")
async def get_order_book(symbol: str, depth: int = 10):
"""Get order book for a trading pair"""
if symbol not in order_books:
raise HTTPException(status_code=404, detail="Order book not found")
book = order_books[symbol]
# Get best bids and asks
bids = sorted(book["bids"].items(), reverse=True)[:depth]
asks = sorted(book["asks"].items())[:depth]
return {
"symbol": symbol,
"bids": [
{
"price": price,
"quantity": sum(order["remaining_quantity"] for order in orders_list),
"orders_count": len(orders_list)
}
for price, orders_list in bids
],
"asks": [
{
"price": price,
"quantity": sum(order["remaining_quantity"] for order in orders_list),
"orders_count": len(orders_list)
}
for price, orders_list in asks
],
"last_price": book["last_price"],
"volume_24h": book["volume_24h"],
"high_24h": book["high_24h"],
"low_24h": book["low_24h"],
"timestamp": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/trades")
async def list_trades(symbol: Optional[str] = None, limit: int = 100):
"""List recent trades"""
all_trades = list(trades.values())
if symbol:
all_trades = [t for t in all_trades if t["symbol"] == symbol]
# Sort by timestamp (most recent first)
all_trades.sort(key=lambda x: x["timestamp"], reverse=True)
return {
"trades": all_trades[:limit],
"total_trades": len(all_trades)
}
@app.get("/api/v1/ticker/{symbol}")
async def get_ticker(symbol: str):
"""Get ticker information for a trading pair"""
if symbol not in order_books:
raise HTTPException(status_code=404, detail="Trading pair not found")
book = order_books[symbol]
# Calculate 24h statistics
trades_24h = [t for t in trades.values()
if t["symbol"] == symbol and
datetime.fromisoformat(t["timestamp"]) >
datetime.now(timezone.utc) - timedelta(hours=24)]
if trades_24h:
prices = [t["price"] for t in trades_24h]
volume = sum(t["quantity"] for t in trades_24h)
ticker = {
"symbol": symbol,
"last_price": book["last_price"],
"bid_price": max(book["bids"].keys()) if book["bids"] else None,
"ask_price": min(book["asks"].keys()) if book["asks"] else None,
"high_24h": max(prices),
"low_24h": min(prices),
"volume_24h": volume,
"change_24h": prices[-1] - prices[0] if len(prices) > 1 else 0,
"change_percent_24h": ((prices[-1] - prices[0]) / prices[0] * 100) if len(prices) > 1 else 0
}
else:
ticker = {
"symbol": symbol,
"last_price": book["last_price"],
"bid_price": None,
"ask_price": None,
"high_24h": None,
"low_24h": None,
"volume_24h": 0.0,
"change_24h": 0.0,
"change_percent_24h": 0.0
}
return ticker
@app.delete("/api/v1/orders/{order_id}")
async def cancel_order(order_id: str):
"""Cancel an order"""
if order_id not in orders:
raise HTTPException(status_code=404, detail="Order not found")
order = orders[order_id]
if order["status"] != "open":
raise HTTPException(status_code=400, detail="Order cannot be cancelled")
# Remove from order book
symbol = order["symbol"]
if symbol in order_books:
book = order_books[symbol]
price_key = str(order["price"])
if order["side"] == "buy" and price_key in book["bids"]:
book["bids"][price_key] = [o for o in book["bids"][price_key] if o["order_id"] != order_id]
if not book["bids"][price_key]:
del book["bids"][price_key]
elif order["side"] == "sell" and price_key in book["asks"]:
book["asks"][price_key] = [o for o in book["asks"][price_key] if o["order_id"] != order_id]
if not book["asks"][price_key]:
del book["asks"][price_key]
# Update order status
order["status"] = "cancelled"
order["cancelled_at"] = datetime.now(timezone.utc).isoformat()
logger.info(f"Order cancelled: {order_id}")
return {
"order_id": order_id,
"status": "cancelled",
"cancelled_at": order["cancelled_at"]
}
@app.get("/api/v1/market-data")
async def get_market_data():
"""Get market data for all symbols"""
market_summary = {}
for symbol, book in order_books.items():
trades_24h = [t for t in trades.values()
if t["symbol"] == symbol and
datetime.fromisoformat(t["timestamp"]) >
datetime.now(timezone.utc) - timedelta(hours=24)]
market_summary[symbol] = {
"last_price": book["last_price"],
"volume_24h": book["volume_24h"],
"high_24h": book["high_24h"],
"low_24h": book["low_24h"],
"trades_count_24h": len(trades_24h),
"bid_price": max(book["bids"].keys()) if book["bids"] else None,
"ask_price": min(book["asks"].keys()) if book["asks"] else None
}
return {
"market_data": market_summary,
"total_symbols": len(market_summary),
"generated_at": datetime.now(timezone.utc).isoformat()
}
@app.get("/api/v1/engine/stats")
async def get_engine_stats():
"""Get trading engine statistics"""
total_orders = len(orders)
total_trades = len(trades)
total_volume = sum(t["quantity"] * t["price"] for t in trades.values())
orders_by_status = defaultdict(int)
for order in orders.values():
orders_by_status[order["status"]] += 1
trades_by_symbol = defaultdict(int)
for trade in trades.values():
trades_by_symbol[trade["symbol"]] += 1
return {
"engine_stats": {
"total_orders": total_orders,
"total_trades": total_trades,
"total_volume": total_volume,
"orders_by_status": dict(orders_by_status),
"trades_by_symbol": dict(trades_by_symbol),
"active_order_books": len(order_books),
"uptime": "running"
},
"generated_at": datetime.now(timezone.utc).isoformat()
}
# Core trading engine logic
async def process_order(order: Dict) -> List[Dict]:
"""Process an order and execute trades"""
symbol = order["symbol"]
book = order_books[symbol]
trades_executed = []
if order["type"] == "market":
trades_executed = await process_market_order(order, book)
else:
trades_executed = await process_limit_order(order, book)
# Update market data
update_market_data(symbol, trades_executed)
return trades_executed
async def process_market_order(order: Dict, book: Dict) -> List[Dict]:
"""Process a market order"""
trades_executed = []
if order["side"] == "buy":
# Match against asks (sell orders)
ask_prices = sorted(book["asks"].keys())
for price in ask_prices:
if order["remaining_quantity"] <= 0:
break
orders_at_price = book["asks"][price][:]
for matching_order in orders_at_price:
if order["remaining_quantity"] <= 0:
break
trade = await execute_trade(order, matching_order, price)
if trade:
trades_executed.append(trade)
else: # sell order
# Match against bids (buy orders)
bid_prices = sorted(book["bids"].keys(), reverse=True)
for price in bid_prices:
if order["remaining_quantity"] <= 0:
break
orders_at_price = book["bids"][price][:]
for matching_order in orders_at_price:
if order["remaining_quantity"] <= 0:
break
trade = await execute_trade(order, matching_order, price)
if trade:
trades_executed.append(trade)
return trades_executed
async def process_limit_order(order: Dict, book: Dict) -> List[Dict]:
"""Process a limit order"""
trades_executed = []
if order["side"] == "buy":
# Match against asks at or below the limit price
ask_prices = sorted([p for p in book["asks"].keys() if float(p) <= order["price"]])
for price in ask_prices:
if order["remaining_quantity"] <= 0:
break
orders_at_price = book["asks"][price][:]
for matching_order in orders_at_price:
if order["remaining_quantity"] <= 0:
break
trade = await execute_trade(order, matching_order, price)
if trade:
trades_executed.append(trade)
# Add remaining quantity to order book
if order["remaining_quantity"] > 0:
price_key = str(order["price"])
book["bids"][price_key].append(order)
else: # sell order
# Match against bids at or above the limit price
bid_prices = sorted([p for p in book["bids"].keys() if float(p) >= order["price"]], reverse=True)
for price in bid_prices:
if order["remaining_quantity"] <= 0:
break
orders_at_price = book["bids"][price][:]
for matching_order in orders_at_price:
if order["remaining_quantity"] <= 0:
break
trade = await execute_trade(order, matching_order, price)
if trade:
trades_executed.append(trade)
# Add remaining quantity to order book
if order["remaining_quantity"] > 0:
price_key = str(order["price"])
book["asks"][price_key].append(order)
return trades_executed
async def execute_trade(order1: Dict, order2: Dict, price: float) -> Optional[Dict]:
"""Execute a trade between two orders"""
# Determine trade quantity
trade_quantity = min(order1["remaining_quantity"], order2["remaining_quantity"])
if trade_quantity <= 0:
return None
# Create trade record
trade_id = f"trade_{int(datetime.now(timezone.utc).timestamp())}_{len(trades)}"
trade = {
"trade_id": trade_id,
"symbol": order1["symbol"],
"buy_order_id": order1["order_id"] if order1["side"] == "buy" else order2["order_id"],
"sell_order_id": order2["order_id"] if order2["side"] == "sell" else order1["order_id"],
"quantity": trade_quantity,
"price": price,
"timestamp": datetime.now(timezone.utc).isoformat()
}
trades[trade_id] = trade
# Update orders
for order in [order1, order2]:
order["filled_quantity"] += trade_quantity
order["remaining_quantity"] -= trade_quantity
if order["remaining_quantity"] <= 0:
order["status"] = "filled"
order["filled_at"] = trade["timestamp"]
else:
order["status"] = "partially_filled"
# Update average price
if order["average_price"] is None:
order["average_price"] = price
else:
total_value = order["average_price"] * (order["filled_quantity"] - trade_quantity) + price * trade_quantity
order["average_price"] = total_value / order["filled_quantity"]
# Remove filled orders from order book
symbol = order1["symbol"]
book = order_books[symbol]
price_key = str(price)
for order in [order1, order2]:
if order["remaining_quantity"] <= 0:
if order["side"] == "buy" and price_key in book["bids"]:
book["bids"][price_key] = [o for o in book["bids"][price_key] if o["order_id"] != order["order_id"]]
if not book["bids"][price_key]:
del book["bids"][price_key]
elif order["side"] == "sell" and price_key in book["asks"]:
book["asks"][price_key] = [o for o in book["asks"][price_key] if o["order_id"] != order["order_id"]]
if not book["asks"][price_key]:
del book["asks"][price_key]
logger.info(f"Trade executed: {trade_id} - {trade_quantity} @ {price}")
return trade
def update_market_data(symbol: str, trades_executed: List[Dict]):
"""Update market data after trades"""
if not trades_executed:
return
book = order_books[symbol]
# Update last price
last_trade = trades_executed[-1]
book["last_price"] = last_trade["price"]
# Update 24h high/low
trades_24h = [t for t in trades.values()
if t["symbol"] == symbol and
datetime.fromisoformat(t["timestamp"]) >
datetime.now(timezone.utc) - timedelta(hours=24)]
if trades_24h:
prices = [t["price"] for t in trades_24h]
book["high_24h"] = max(prices)
book["low_24h"] = min(prices)
book["volume_24h"] = sum(t["quantity"] for t in trades_24h)
# Background task for market data simulation
async def simulate_market_activity():
"""Background task to simulate market activity"""
while True:
await asyncio.sleep(60) # Simulate activity every minute
# Create some random market orders for demo
if len(order_books) > 0:
import random
for symbol in list(order_books.keys())[:3]: # Limit to 3 symbols
if random.random() < 0.3: # 30% chance of market activity
# Create random market order
side = random.choice(["buy", "sell"])
quantity = random.uniform(10, 1000)
order_id = f"sim_order_{int(datetime.now(timezone.utc).timestamp())}"
order = Order(
order_id=order_id,
symbol=symbol,
side=side,
type="market",
quantity=quantity,
user_id="sim_user",
timestamp=datetime.now(timezone.utc)
)
order_data = {
"order_id": order.order_id,
"symbol": order.symbol,
"side": order.side,
"type": order.type,
"quantity": order.quantity,
"remaining_quantity": order.quantity,
"price": order.price,
"user_id": order.user_id,
"timestamp": order.timestamp.isoformat(),
"status": "open",
"filled_quantity": 0.0,
"average_price": None
}
orders[order_id] = order_data
await process_order(order_data)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8012, log_level="info")

View File

@@ -0,0 +1 @@
"""Trading engine service tests"""

View File

@@ -0,0 +1,208 @@
"""Edge case and error handling tests for trading engine service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from main import app, Order, order_books, orders, trades
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
order_books.clear()
orders.clear()
trades.clear()
yield
order_books.clear()
orders.clear()
trades.clear()
@pytest.mark.unit
def test_order_zero_quantity():
"""Test Order with zero quantity"""
order = Order(
order_id="order_123",
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=0.0,
price=0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
assert order.quantity == 0.0
@pytest.mark.unit
def test_order_negative_quantity():
"""Test Order with negative quantity"""
order = Order(
order_id="order_123",
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=-100.0,
price=0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
assert order.quantity == -100.0
@pytest.mark.unit
def test_order_negative_price():
"""Test Order with negative price"""
order = Order(
order_id="order_123",
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=100.0,
price=-0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
assert order.price == -0.00001
@pytest.mark.unit
def test_order_empty_symbol():
"""Test Order with empty symbol"""
order = Order(
order_id="order_123",
symbol="",
side="buy",
type="limit",
quantity=100.0,
price=0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
assert order.symbol == ""
@pytest.mark.integration
def test_cancel_filled_order():
"""Test cancelling a filled order"""
client = TestClient(app)
order = Order(
order_id="order_129",
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
# Manually mark as filled
orders["order_129"]["status"] = "filled"
response = client.delete("/api/v1/orders/order_129")
assert response.status_code == 400
@pytest.mark.integration
def test_submit_order_with_slash_in_symbol():
"""Test submitting order with slash in symbol"""
client = TestClient(app)
order = Order(
order_id="order_130",
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
assert response.status_code == 200
@pytest.mark.integration
def test_submit_order_with_hyphen_in_symbol():
"""Test submitting order with hyphen in symbol"""
client = TestClient(app)
order = Order(
order_id="order_131",
symbol="AITBC-BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
assert response.status_code == 200
@pytest.mark.integration
def test_list_orders_with_no_orders():
"""Test listing orders when no orders exist"""
client = TestClient(app)
response = client.get("/api/v1/orders")
assert response.status_code == 200
data = response.json()
assert data["total_orders"] == 0
@pytest.mark.integration
def test_list_trades_with_no_trades():
"""Test listing trades when no trades exist"""
client = TestClient(app)
response = client.get("/api/v1/trades")
assert response.status_code == 200
data = response.json()
assert data["total_trades"] == 0
@pytest.mark.integration
def test_get_market_data_with_no_symbols():
"""Test getting market data when no symbols exist"""
client = TestClient(app)
response = client.get("/api/v1/market-data")
assert response.status_code == 200
data = response.json()
assert data["total_symbols"] == 0
@pytest.mark.integration
def test_order_book_depth_parameter():
"""Test order book with depth parameter"""
client = TestClient(app)
order = Order(
order_id="order_132",
symbol="AITBC-BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
response = client.get("/api/v1/orderbook/AITBC-BTC?depth=5")
assert response.status_code == 200
data = response.json()
assert data["symbol"] == "AITBC-BTC"
@pytest.mark.integration
def test_list_trades_limit_parameter():
"""Test listing trades with limit parameter"""
client = TestClient(app)
response = client.get("/api/v1/trades?limit=10")
assert response.status_code == 200
data = response.json()
assert "trades" in data

View File

@@ -0,0 +1,264 @@
"""Integration tests for trading engine service"""
import pytest
import sys
import sys
from pathlib import Path
from fastapi.testclient import TestClient
from datetime import datetime, timezone
from main import app, Order, order_books, orders, trades
@pytest.fixture(autouse=True)
def reset_state():
"""Reset global state before each test"""
order_books.clear()
orders.clear()
trades.clear()
yield
order_books.clear()
orders.clear()
trades.clear()
@pytest.mark.integration
def test_root_endpoint():
"""Test root endpoint"""
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["service"] == "AITBC Trading Engine"
assert data["status"] == "running"
@pytest.mark.integration
def test_health_check_endpoint():
"""Test health check endpoint"""
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "active_order_books" in data
assert "total_orders" in data
@pytest.mark.integration
def test_submit_market_order():
"""Test submitting a market order"""
client = TestClient(app)
order = Order(
order_id="order_123",
symbol="AITBC/BTC",
side="buy",
type="market",
quantity=100.0,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["order_id"] == "order_123"
assert "status" in data
@pytest.mark.integration
def test_submit_limit_order():
"""Test submitting a limit order"""
client = TestClient(app)
order = Order(
order_id="order_124",
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
response = client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
assert response.status_code == 200
data = response.json()
assert data["order_id"] == "order_124"
assert "status" in data
@pytest.mark.integration
def test_get_order():
"""Test getting order details"""
client = TestClient(app)
order = Order(
order_id="order_125",
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
response = client.get("/api/v1/orders/order_125")
assert response.status_code == 200
data = response.json()
assert data["order_id"] == "order_125"
@pytest.mark.integration
def test_get_order_not_found():
"""Test getting nonexistent order"""
client = TestClient(app)
response = client.get("/api/v1/orders/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_list_orders():
"""Test listing all orders"""
client = TestClient(app)
response = client.get("/api/v1/orders")
assert response.status_code == 200
data = response.json()
assert "orders" in data
assert "total_orders" in data
@pytest.mark.integration
def test_get_order_book():
"""Test getting order book"""
client = TestClient(app)
# Create some orders first
order1 = Order(
order_id="order_126",
symbol="AITBC-BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
client.post("/api/v1/orders/submit", json=order1.model_dump(mode='json'))
response = client.get("/api/v1/orderbook/AITBC-BTC")
assert response.status_code == 200
data = response.json()
assert data["symbol"] == "AITBC-BTC"
assert "bids" in data
assert "asks" in data
@pytest.mark.integration
def test_get_order_book_not_found():
"""Test getting order book for nonexistent symbol"""
client = TestClient(app)
response = client.get("/api/v1/orderbook/NONEXISTENT")
assert response.status_code == 404
@pytest.mark.integration
def test_list_trades():
"""Test listing trades"""
client = TestClient(app)
response = client.get("/api/v1/trades")
assert response.status_code == 200
data = response.json()
assert "trades" in data
assert "total_trades" in data
@pytest.mark.integration
def test_list_trades_by_symbol():
"""Test listing trades by symbol"""
client = TestClient(app)
response = client.get("/api/v1/trades?symbol=AITBC-BTC")
assert response.status_code == 200
data = response.json()
assert "trades" in data
@pytest.mark.integration
def test_get_ticker():
"""Test getting ticker information"""
client = TestClient(app)
# Create order book first
order = Order(
order_id="order_127",
symbol="AITBC-BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
response = client.get("/api/v1/ticker/AITBC-BTC")
assert response.status_code == 200
data = response.json()
assert data["symbol"] == "AITBC-BTC"
@pytest.mark.integration
def test_get_ticker_not_found():
"""Test getting ticker for nonexistent symbol"""
client = TestClient(app)
response = client.get("/api/v1/ticker/NONEXISTENT")
assert response.status_code == 404
@pytest.mark.integration
def test_cancel_order():
"""Test cancelling an order"""
client = TestClient(app)
order = Order(
order_id="order_128",
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
response = client.delete("/api/v1/orders/order_128")
assert response.status_code == 200
data = response.json()
assert data["status"] == "cancelled"
@pytest.mark.integration
def test_cancel_order_not_found():
"""Test cancelling nonexistent order"""
client = TestClient(app)
response = client.delete("/api/v1/orders/nonexistent")
assert response.status_code == 404
@pytest.mark.integration
def test_get_market_data():
"""Test getting market data"""
client = TestClient(app)
response = client.get("/api/v1/market-data")
assert response.status_code == 200
data = response.json()
assert "market_data" in data
assert "total_symbols" in data
@pytest.mark.integration
def test_get_engine_stats():
"""Test getting engine statistics"""
client = TestClient(app)
response = client.get("/api/v1/engine/stats")
assert response.status_code == 200
data = response.json()
assert "engine_stats" in data

View File

@@ -0,0 +1,89 @@
"""Unit tests for trading engine service"""
import pytest
import sys
import sys
from pathlib import Path
from datetime import datetime, timezone
from main import app, Order, Trade, OrderBookEntry
@pytest.mark.unit
def test_app_initialization():
"""Test that the FastAPI app initializes correctly"""
assert app is not None
assert app.title == "AITBC Trading Engine"
assert app.version == "1.0.0"
@pytest.mark.unit
def test_order_model():
"""Test Order model"""
order = Order(
order_id="order_123",
symbol="AITBC/BTC",
side="buy",
type="limit",
quantity=100.0,
price=0.00001,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
assert order.order_id == "order_123"
assert order.symbol == "AITBC/BTC"
assert order.side == "buy"
assert order.type == "limit"
assert order.quantity == 100.0
assert order.price == 0.00001
assert order.user_id == "user_123"
@pytest.mark.unit
def test_order_model_market_order():
"""Test Order model for market order"""
order = Order(
order_id="order_123",
symbol="AITBC/BTC",
side="sell",
type="market",
quantity=50.0,
user_id="user_123",
timestamp=datetime.now(timezone.utc)
)
assert order.type == "market"
assert order.price is None
@pytest.mark.unit
def test_trade_model():
"""Test Trade model"""
trade = Trade(
trade_id="trade_123",
symbol="AITBC/BTC",
buy_order_id="buy_order_123",
sell_order_id="sell_order_123",
quantity=100.0,
price=0.00001,
timestamp=datetime.now(timezone.utc)
)
assert trade.trade_id == "trade_123"
assert trade.symbol == "AITBC/BTC"
assert trade.buy_order_id == "buy_order_123"
assert trade.sell_order_id == "sell_order_123"
assert trade.quantity == 100.0
assert trade.price == 0.00001
@pytest.mark.unit
def test_order_book_entry_model():
"""Test OrderBookEntry model"""
entry = OrderBookEntry(
price=0.00001,
quantity=1000.0,
orders_count=5
)
assert entry.price == 0.00001
assert entry.quantity == 1000.0
assert entry.orders_count == 5