refactor: improve error handling and remove hardcoded credentials
- Changed bare except clauses to specific exception types in web3_utils.py, testing.py, messages.py, and message_storage.py - Replaced print() calls with logger in testing.py, agent_discovery.py, compliance_agent.py, coordinator.py, trading_agent.py, keys.py, escrow.py, persistent_spending_tracker.py, sync_cli.py, and client.py - Added logger initialization using get_logger(__name__) in compliance_agent.py, coordinator.py, trading_agent.py, keys.py, escrow.py, persistent_spending_tracker.py, and client.py - Removed hardcoded secret
This commit is contained in:
40
examples/stubs/README.md
Normal file
40
examples/stubs/README.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Stub Services
|
||||
|
||||
This directory contains stub and placeholder services that are not yet fully implemented or are minimal implementations.
|
||||
|
||||
## Services in this Directory
|
||||
|
||||
The following services have <10 files and are considered stubs or placeholders:
|
||||
|
||||
- **hermes-service** (4 files) - Hermes agent communication service
|
||||
- **monitor** (7 files) - Monitoring stub
|
||||
- **monitoring-service** (4 files) - Monitoring service stub
|
||||
- **plugin-service** (4 files) - Plugin service stub
|
||||
- **ai-service** (8 files) - AI service stub
|
||||
- **compliance-service** (9 files) - Compliance checking stub
|
||||
- **exchange-integration** (9 files) - Exchange integration stub
|
||||
- **global-ai-agents** (9 files) - Global AI agents stub
|
||||
- **global-infrastructure** (9 files) - Global infrastructure stub
|
||||
- **multi-region-load-balancer** (9 files) - Multi-region load balancer stub
|
||||
- **plugin-analytics** (9 files) - Plugin analytics stub
|
||||
- **plugin-marketplace** (9 files) - Plugin marketplace stub
|
||||
- **plugin-registry** (9 files) - Plugin registry stub
|
||||
- **plugin-security** (9 files) - Plugin security stub
|
||||
- **simple-explorer** (9 files) - Simple blockchain explorer stub
|
||||
- **trading-engine** (9 files) - Trading engine stub
|
||||
|
||||
## Purpose
|
||||
|
||||
These services are placeholders for future functionality. They may be:
|
||||
- Minimal implementations for testing
|
||||
- Skeletons for future development
|
||||
- Experimental features not yet production-ready
|
||||
|
||||
## Active Services
|
||||
|
||||
Active services with full implementations remain in the parent `apps/` directory:
|
||||
- blockchain-node, coordinator-api, exchange, marketplace, wallet, etc.
|
||||
|
||||
## Future Work
|
||||
|
||||
As stub services are fully implemented, they should be moved from this directory to the main `apps/` directory.
|
||||
1718
examples/stubs/ai-service/poetry.lock
generated
Normal file
1718
examples/stubs/ai-service/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
28
examples/stubs/ai-service/pyproject.toml
Normal file
28
examples/stubs/ai-service/pyproject.toml
Normal file
@@ -0,0 +1,28 @@
|
||||
[tool.poetry]
|
||||
name = "ai-service"
|
||||
version = "0.1.0"
|
||||
description = "AITBC AI Service for job operations"
|
||||
authors = ["AITBC Team"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.13"
|
||||
fastapi = ">=0.115.6"
|
||||
uvicorn = {extras = ["standard"], version = "^0.32.0"}
|
||||
sqlmodel = "^0.0.37"
|
||||
sqlalchemy = "^2.0.25"
|
||||
pydantic = "^2.6.0"
|
||||
pydantic-settings = "^2.1.0"
|
||||
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
|
||||
passlib = {extras = ["bcrypt"], version = "^1.7.4"}
|
||||
httpx = ">=0.28.1"
|
||||
asyncpg = ">=0.30.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = ">=9.0.3"
|
||||
pytest-asyncio = ">=1.3.0"
|
||||
black = ">=26.3.1"
|
||||
ruff = "^0.1.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
68
examples/stubs/ai-service/src/ai_service/domain/jobs.py
Normal file
68
examples/stubs/ai-service/src/ai_service/domain/jobs.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Domain models for AI job operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from enum import StrEnum
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import JSON, Column
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class JobState(StrEnum):
|
||||
"""Job execution states."""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELED = "canceled"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class Job(SQLModel, table=True):
|
||||
"""AI job model."""
|
||||
|
||||
__tablename__ = "jobs"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id: str = Field(default_factory=lambda: f"job_{uuid4().hex[:12]}", primary_key=True)
|
||||
client_id: str = Field(index=True)
|
||||
task_type: str = Field(index=True)
|
||||
task_data: dict = Field(default_factory=dict, sa_column=Column(JSON, nullable=False))
|
||||
state: JobState = Field(default=JobState.PENDING, index=True)
|
||||
result: dict | None = Field(default=None, sa_column=Column(JSON, nullable=True))
|
||||
error: str | None = Field(default=None)
|
||||
|
||||
# Payment information
|
||||
payment_id: str | None = Field(default=None, index=True)
|
||||
payment_amount: float = Field(default=0.0)
|
||||
payment_status: str = Field(default="none", index=True)
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), nullable=False, index=True)
|
||||
requested_at: datetime | None = Field(default=None)
|
||||
started_at: datetime | None = Field(default=None)
|
||||
completed_at: datetime | None = Field(default=None)
|
||||
expires_at: datetime | None = Field(default=None)
|
||||
|
||||
# Metadata
|
||||
priority: int = Field(default=0)
|
||||
assigned_miner_id: str | None = Field(default=None, index=True)
|
||||
receipt: dict | None = Field(default=None, sa_column=Column(JSON, nullable=True))
|
||||
receipt_id: str | None = Field(default=None)
|
||||
|
||||
|
||||
class JobReceipt(SQLModel, table=True):
|
||||
"""Job receipts for verification."""
|
||||
|
||||
__tablename__ = "job_receipts"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id: str = Field(default_factory=lambda: f"rcpt_{uuid4().hex[:12]}", primary_key=True)
|
||||
job_id: str = Field(index=True)
|
||||
miner_id: str = Field(index=True)
|
||||
result: dict = Field(default_factory=dict, sa_column=Column(JSON, nullable=False))
|
||||
metrics: dict = Field(default_factory=dict, sa_column=Column(JSON, nullable=False))
|
||||
signature: str = Field(default="")
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), nullable=False, index=True)
|
||||
407
examples/stubs/ai-service/src/ai_service/main.py
Normal file
407
examples/stubs/ai-service/src/ai_service/main.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""AI Service for job operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import FastAPI, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import Field, SQLModel, select
|
||||
|
||||
from .storage import get_session
|
||||
from .domain.jobs import Job, JobState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app = FastAPI(
|
||||
title="AITBC AI Service",
|
||||
description="AI job operations service",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "service": "ai-service"}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
return {
|
||||
"service": "AITBC AI Service",
|
||||
"version": "1.0.0",
|
||||
"status": "operational"
|
||||
}
|
||||
|
||||
|
||||
async def get_session_dep():
|
||||
"""Dependency for database session."""
|
||||
async with get_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# Request/Response models
|
||||
class JobCreate(SQLModel):
|
||||
task_type: str
|
||||
task_data: dict = Field(default_factory=dict)
|
||||
payment_amount: float = 0.0
|
||||
payment_currency: str = "aitbc_token"
|
||||
priority: int = 0
|
||||
|
||||
|
||||
class JobView(SQLModel):
|
||||
id: str
|
||||
client_id: str
|
||||
task_type: str
|
||||
state: str
|
||||
created_at: datetime
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
result: dict | None = None
|
||||
error: str | None = None
|
||||
payment_status: str = "none"
|
||||
|
||||
|
||||
class JobResult(SQLModel):
|
||||
id: str
|
||||
result: dict | None = None
|
||||
error: str | None = None
|
||||
completed_at: datetime | None = None
|
||||
receipt: dict | None = None
|
||||
|
||||
|
||||
@app.post("/jobs", response_model=JobView, status_code=status.HTTP_201_CREATED)
|
||||
async def submit_job(
|
||||
session: Annotated[AsyncSession, Depends(get_session_dep)],
|
||||
req: JobCreate,
|
||||
client_id: str = "default_client",
|
||||
):
|
||||
"""Submit a job for execution."""
|
||||
try:
|
||||
job = Job(
|
||||
client_id=client_id,
|
||||
task_type=req.task_type,
|
||||
task_data=req.task_data,
|
||||
payment_amount=req.payment_amount,
|
||||
priority=req.priority,
|
||||
state=JobState.PENDING,
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
session.add(job)
|
||||
await session.commit()
|
||||
await session.refresh(job)
|
||||
|
||||
return JobView(
|
||||
id=job.id,
|
||||
client_id=job.client_id,
|
||||
task_type=job.task_type,
|
||||
state=job.state,
|
||||
created_at=job.created_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
result=job.result,
|
||||
error=job.error,
|
||||
payment_status=job.payment_status
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Submit job error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/jobs/{job_id}", response_model=JobView)
|
||||
async def get_job(
|
||||
session: Annotated[AsyncSession, Depends(get_session_dep)],
|
||||
job_id: str,
|
||||
client_id: str = "default_client",
|
||||
):
|
||||
"""Get job status."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Job).where(Job.id == job_id, Job.client_id == client_id)
|
||||
)
|
||||
job = result.scalar_one_or_none()
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
return JobView(
|
||||
id=job.id,
|
||||
client_id=job.client_id,
|
||||
task_type=job.task_type,
|
||||
state=job.state,
|
||||
created_at=job.created_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
result=job.result,
|
||||
error=job.error,
|
||||
payment_status=job.payment_status
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Get job error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/jobs/{job_id}/result", response_model=JobResult)
|
||||
async def get_job_result(
|
||||
session: Annotated[AsyncSession, Depends(get_session_dep)],
|
||||
job_id: str,
|
||||
client_id: str = "default_client",
|
||||
):
|
||||
"""Get job result."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Job).where(Job.id == job_id, Job.client_id == client_id)
|
||||
)
|
||||
job = result.scalar_one_or_none()
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.state not in {JobState.COMPLETED, JobState.FAILED, JobState.CANCELED, JobState.EXPIRED}:
|
||||
raise HTTPException(status_code=425, detail="Job not ready")
|
||||
|
||||
return JobResult(
|
||||
id=job.id,
|
||||
result=job.result,
|
||||
error=job.error,
|
||||
completed_at=job.completed_at,
|
||||
receipt=job.receipt
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Get job result error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/jobs/{job_id}/cancel", response_model=JobView)
|
||||
async def cancel_job(
|
||||
session: Annotated[AsyncSession, Depends(get_session_dep)],
|
||||
job_id: str,
|
||||
client_id: str = "default_client",
|
||||
):
|
||||
"""Cancel a job."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Job).where(Job.id == job_id, Job.client_id == client_id)
|
||||
)
|
||||
job = result.scalar_one_or_none()
|
||||
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
|
||||
if job.state in {JobState.COMPLETED, JobState.FAILED, JobState.CANCELED, JobState.EXPIRED}:
|
||||
raise HTTPException(status_code=400, detail="Job already completed")
|
||||
|
||||
job.state = JobState.CANCELED
|
||||
job.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(job)
|
||||
|
||||
return JobView(
|
||||
id=job.id,
|
||||
client_id=job.client_id,
|
||||
task_type=job.task_type,
|
||||
state=job.state,
|
||||
created_at=job.created_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
result=job.result,
|
||||
error=job.error,
|
||||
payment_status=job.payment_status
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Cancel job error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/jobs")
|
||||
async def list_jobs(
|
||||
session: Annotated[AsyncSession, Depends(get_session_dep)],
|
||||
client_id: str = "default_client",
|
||||
limit: int = 10,
|
||||
state: str | None = None,
|
||||
):
|
||||
"""List jobs with filtering."""
|
||||
try:
|
||||
query = select(Job).where(Job.client_id == client_id)
|
||||
|
||||
if state:
|
||||
query = query.where(Job.state == state)
|
||||
|
||||
query = query.order_by(Job.created_at.desc()).limit(limit)
|
||||
|
||||
result = await session.execute(query)
|
||||
jobs = result.scalars().all()
|
||||
|
||||
return {
|
||||
"jobs": [
|
||||
JobView(
|
||||
id=job.id,
|
||||
client_id=job.client_id,
|
||||
task_type=job.task_type,
|
||||
state=job.state,
|
||||
created_at=job.created_at,
|
||||
started_at=job.started_at,
|
||||
completed_at=job.completed_at,
|
||||
result=job.result,
|
||||
error=job.error,
|
||||
payment_status=job.payment_status
|
||||
)
|
||||
for job in jobs
|
||||
],
|
||||
"limit": limit,
|
||||
"total": len(jobs)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"List jobs error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/multimodal/process")
|
||||
async def process_multimodal(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Process multimodal AI requests (text, image, audio, video)"""
|
||||
return {
|
||||
"task_id": "multimodal_123",
|
||||
"modality": request.get("modality", "text"),
|
||||
"status": "processing",
|
||||
"result": "multimodal processing initiated"
|
||||
}
|
||||
|
||||
|
||||
@app.post("/multimodal/benchmark")
|
||||
async def benchmark_multimodal(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Benchmark multimodal AI performance"""
|
||||
return {
|
||||
"benchmark_id": "bench_456",
|
||||
"modality": request.get("modality", "text"),
|
||||
"performance_score": 95.5,
|
||||
"latency_ms": 150,
|
||||
"throughput": "high"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/multimodal/agents")
|
||||
async def list_multimodal_agents() -> dict[str, Any]:
|
||||
"""List available multimodal AI agents"""
|
||||
return {
|
||||
"agents": [
|
||||
{"id": "agent_1", "name": "Text-Image Agent", "capabilities": ["text", "image"]},
|
||||
{"id": "agent_2", "name": "Audio-Video Agent", "capabilities": ["audio", "video"]},
|
||||
],
|
||||
"total": 2
|
||||
}
|
||||
|
||||
|
||||
@app.get("/multimodal/health")
|
||||
async def multimodal_health() -> dict[str, Any]:
|
||||
"""Multi-Modal Agent Service Health"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "multimodal-agent",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"capabilities": {
|
||||
"text_processing": True,
|
||||
"image_processing": True,
|
||||
"audio_processing": True,
|
||||
"video_processing": True,
|
||||
"tabular_processing": True,
|
||||
"graph_processing": True,
|
||||
},
|
||||
"performance": {
|
||||
"text_processing_time": "0.02s",
|
||||
"image_processing_time": "0.15s",
|
||||
"audio_processing_time": "0.22s",
|
||||
"video_processing_time": "0.35s",
|
||||
"tabular_processing_time": "0.05s",
|
||||
"graph_processing_time": "0.08s",
|
||||
"average_accuracy": "94%",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/multimodal/health/deep")
|
||||
async def multimodal_deep_health() -> dict[str, Any]:
|
||||
"""Deep Multi-Modal Service Health with modality tests"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "multimodal-agent",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"modality_tests": {
|
||||
"text": {"status": "pass", "processing_time": "0.02s", "accuracy": "92%"},
|
||||
"image": {"status": "pass", "processing_time": "0.15s", "accuracy": "87%"},
|
||||
"audio": {"status": "pass", "processing_time": "0.22s", "accuracy": "89%"},
|
||||
"video": {"status": "pass", "processing_time": "0.35s", "accuracy": "85%"},
|
||||
},
|
||||
"overall_health": "pass"
|
||||
}
|
||||
|
||||
|
||||
@app.post("/optimization/tune")
|
||||
async def tune_optimization(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Tune AI model optimization parameters"""
|
||||
return {
|
||||
"tuning_id": "tune_789",
|
||||
"model": request.get("model", "default"),
|
||||
"parameters": {"learning_rate": 0.001, "batch_size": 32},
|
||||
"status": "tuned"
|
||||
}
|
||||
|
||||
|
||||
@app.post("/optimization/predict")
|
||||
async def predict_optimization(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Predict optimal model performance"""
|
||||
return {
|
||||
"prediction_id": "pred_101",
|
||||
"model": request.get("model", "default"),
|
||||
"expected_performance": "high",
|
||||
"estimated_accuracy": 95.5
|
||||
}
|
||||
|
||||
|
||||
@app.get("/optimization/agents")
|
||||
async def list_optimization_agents() -> dict[str, Any]:
|
||||
"""List available optimization agents"""
|
||||
return {
|
||||
"agents": [
|
||||
{"id": "opt_1", "name": "Gradient Descent Optimizer", "type": "gradient"},
|
||||
{"id": "opt_2", "name": "Genetic Algorithm", "type": "evolutionary"},
|
||||
],
|
||||
"total": 2
|
||||
}
|
||||
|
||||
|
||||
@app.get("/optimization/health")
|
||||
async def optimization_health() -> dict[str, Any]:
|
||||
"""Optimization Service Health"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "modality-optimization",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"capabilities": {
|
||||
"text_optimization": True,
|
||||
"image_optimization": True,
|
||||
"audio_optimization": True,
|
||||
"video_optimization": True,
|
||||
},
|
||||
"performance": {
|
||||
"optimization_speedup": "150x average",
|
||||
"memory_reduction": "60% average",
|
||||
"accuracy_retention": "95% average",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8106)
|
||||
26
examples/stubs/ai-service/src/ai_service/storage.py
Normal file
26
examples/stubs/ai-service/src/ai_service/storage.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Database storage configuration for AI Service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
DATABASE_URL = os.getenv(
|
||||
"AI_SERVICE_DATABASE_URL",
|
||||
"postgresql+asyncpg://aitbc_ai:password@localhost:5432/aitbc_ai"
|
||||
)
|
||||
|
||||
engine = create_async_engine(DATABASE_URL, echo=False)
|
||||
AsyncSessionLocal = sessionmaker(
|
||||
engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session():
|
||||
"""Async context manager for database sessions."""
|
||||
async with AsyncSessionLocal() as session:
|
||||
yield session
|
||||
434
examples/stubs/compliance-service/main.py
Executable file
434
examples/stubs/compliance-service/main.py
Executable file
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
Production Compliance Service for AITBC
|
||||
Handles KYC/AML, regulatory compliance, and monitoring
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
logger.info("Starting AITBC Compliance Service")
|
||||
# Start background compliance checks
|
||||
asyncio.create_task(periodic_compliance_checks())
|
||||
yield
|
||||
# Shutdown
|
||||
logger.info("Shutting down AITBC Compliance Service")
|
||||
|
||||
app = FastAPI(
|
||||
title="AITBC Compliance Service",
|
||||
description="Regulatory compliance and monitoring for AITBC operations",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Data models
|
||||
class KYCRequest(BaseModel):
|
||||
user_id: str
|
||||
name: str
|
||||
email: str
|
||||
document_type: str
|
||||
document_number: str
|
||||
address: Dict[str, str]
|
||||
|
||||
class ComplianceReport(BaseModel):
|
||||
report_type: str
|
||||
description: str
|
||||
severity: str # low, medium, high, critical
|
||||
details: Dict[str, Any]
|
||||
|
||||
class TransactionMonitoring(BaseModel):
|
||||
transaction_id: str
|
||||
user_id: str
|
||||
amount: float
|
||||
currency: str
|
||||
counterparty: str
|
||||
timestamp: datetime
|
||||
|
||||
# In-memory storage (in production, use database)
|
||||
kyc_records: Dict[str, Dict] = {}
|
||||
compliance_reports: Dict[str, Dict] = {}
|
||||
suspicious_transactions: Dict[str, Dict] = {}
|
||||
compliance_rules: Dict[str, Dict] = {}
|
||||
risk_scores: Dict[str, Dict] = {}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "AITBC Compliance Service",
|
||||
"status": "running",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"kyc_records": len(kyc_records),
|
||||
"compliance_reports": len(compliance_reports),
|
||||
"suspicious_transactions": len(suspicious_transactions),
|
||||
"active_rules": len(compliance_rules)
|
||||
}
|
||||
|
||||
@app.post("/api/v1/kyc/submit")
|
||||
async def submit_kyc(kyc_request: KYCRequest):
|
||||
"""Submit KYC verification request"""
|
||||
if kyc_request.user_id in kyc_records:
|
||||
raise HTTPException(status_code=400, detail="KYC already submitted for this user")
|
||||
|
||||
# Create KYC record
|
||||
kyc_record = {
|
||||
"user_id": kyc_request.user_id,
|
||||
"name": kyc_request.name,
|
||||
"email": kyc_request.email,
|
||||
"document_type": kyc_request.document_type,
|
||||
"document_number": kyc_request.document_number,
|
||||
"address": kyc_request.address,
|
||||
"status": "pending",
|
||||
"submitted_at": datetime.now(timezone.utc).isoformat(),
|
||||
"reviewed_at": None,
|
||||
"approved_at": None,
|
||||
"risk_score": "medium",
|
||||
"notes": []
|
||||
}
|
||||
|
||||
kyc_records[kyc_request.user_id] = kyc_record
|
||||
|
||||
# Simulate KYC verification process
|
||||
await asyncio.sleep(2) # Simulate verification delay
|
||||
|
||||
# Auto-approve for demo (in production, this would involve actual verification)
|
||||
kyc_record["status"] = "approved"
|
||||
kyc_record["reviewed_at"] = datetime.now(timezone.utc).isoformat()
|
||||
kyc_record["approved_at"] = datetime.now(timezone.utc).isoformat()
|
||||
kyc_record["risk_score"] = "low"
|
||||
|
||||
logger.info(f"KYC approved for user: {kyc_request.user_id}")
|
||||
|
||||
return {
|
||||
"user_id": kyc_request.user_id,
|
||||
"status": kyc_record["status"],
|
||||
"risk_score": kyc_record["risk_score"],
|
||||
"approved_at": kyc_record["approved_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/kyc/{user_id}")
|
||||
async def get_kyc_status(user_id: str):
|
||||
"""Get KYC status for a user"""
|
||||
if user_id not in kyc_records:
|
||||
raise HTTPException(status_code=404, detail="KYC record not found")
|
||||
|
||||
return kyc_records[user_id]
|
||||
|
||||
@app.get("/api/v1/kyc")
|
||||
async def list_kyc_records():
|
||||
"""List all KYC records"""
|
||||
return {
|
||||
"kyc_records": list(kyc_records.values()),
|
||||
"total_records": len(kyc_records),
|
||||
"approved": len([r for r in kyc_records.values() if r["status"] == "approved"]),
|
||||
"pending": len([r for r in kyc_records.values() if r["status"] == "pending"]),
|
||||
"rejected": len([r for r in kyc_records.values() if r["status"] == "rejected"])
|
||||
}
|
||||
|
||||
@app.post("/api/v1/compliance/report")
|
||||
async def create_compliance_report(report: ComplianceReport):
|
||||
"""Create a compliance report"""
|
||||
report_id = f"report_{int(datetime.now(timezone.utc).timestamp())}"
|
||||
|
||||
compliance_record = {
|
||||
"report_id": report_id,
|
||||
"report_type": report.report_type,
|
||||
"description": report.description,
|
||||
"severity": report.severity,
|
||||
"details": report.details,
|
||||
"status": "open",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"assigned_to": None,
|
||||
"resolved_at": None,
|
||||
"resolution": None
|
||||
}
|
||||
|
||||
compliance_reports[report_id] = compliance_record
|
||||
|
||||
logger.info(f"Compliance report created: {report_id} - {report.report_type}")
|
||||
|
||||
return {
|
||||
"report_id": report_id,
|
||||
"status": "created",
|
||||
"severity": report.severity,
|
||||
"created_at": compliance_record["created_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/compliance/reports")
|
||||
async def list_compliance_reports():
|
||||
"""List all compliance reports"""
|
||||
return {
|
||||
"reports": list(compliance_reports.values()),
|
||||
"total_reports": len(compliance_reports),
|
||||
"open": len([r for r in compliance_reports.values() if r["status"] == "open"]),
|
||||
"resolved": len([r for r in compliance_reports.values() if r["status"] == "resolved"])
|
||||
}
|
||||
|
||||
@app.post("/api/v1/monitoring/transaction")
|
||||
async def monitor_transaction(transaction: TransactionMonitoring):
|
||||
"""Monitor transaction for compliance"""
|
||||
transaction_id = transaction.transaction_id
|
||||
|
||||
# Create transaction monitoring record
|
||||
monitoring_record = {
|
||||
"transaction_id": transaction_id,
|
||||
"user_id": transaction.user_id,
|
||||
"amount": transaction.amount,
|
||||
"currency": transaction.currency,
|
||||
"counterparty": transaction.counterparty,
|
||||
"timestamp": transaction.timestamp.isoformat(),
|
||||
"monitored_at": datetime.now(timezone.utc).isoformat(),
|
||||
"risk_score": calculate_transaction_risk(transaction),
|
||||
"flags": [],
|
||||
"status": "monitored"
|
||||
}
|
||||
|
||||
suspicious_transactions[transaction_id] = monitoring_record
|
||||
|
||||
# Check for suspicious patterns
|
||||
flags = check_suspicious_patterns(transaction)
|
||||
if flags:
|
||||
monitoring_record["flags"] = flags
|
||||
monitoring_record["status"] = "flagged"
|
||||
|
||||
# Create compliance report for suspicious transaction
|
||||
await create_suspicious_transaction_report(transaction, flags)
|
||||
|
||||
return {
|
||||
"transaction_id": transaction_id,
|
||||
"risk_score": monitoring_record["risk_score"],
|
||||
"flags": flags,
|
||||
"status": monitoring_record["status"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/monitoring/transactions")
|
||||
async def list_monitored_transactions():
|
||||
"""List all monitored transactions"""
|
||||
return {
|
||||
"transactions": list(suspicious_transactions.values()),
|
||||
"total_transactions": len(suspicious_transactions),
|
||||
"flagged": len([t for t in suspicious_transactions.values() if t["status"] == "flagged"]),
|
||||
"suspicious": len([t for t in suspicious_transactions.values() if t["risk_score"] == "high"])
|
||||
}
|
||||
|
||||
@app.post("/api/v1/rules/create")
|
||||
async def create_compliance_rule(rule_data: Dict[str, Any]):
|
||||
"""Create a new compliance rule"""
|
||||
rule_id = f"rule_{int(datetime.now(timezone.utc).timestamp())}"
|
||||
|
||||
rule = {
|
||||
"rule_id": rule_id,
|
||||
"name": rule_data.get("name"),
|
||||
"description": rule_data.get("description"),
|
||||
"type": rule_data.get("type"),
|
||||
"conditions": rule_data.get("conditions", {}),
|
||||
"actions": rule_data.get("actions", []),
|
||||
"severity": rule_data.get("severity", "medium"),
|
||||
"active": True,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"trigger_count": 0
|
||||
}
|
||||
|
||||
compliance_rules[rule_id] = rule
|
||||
|
||||
logger.info(f"Compliance rule created: {rule_id} - {rule['name']}")
|
||||
|
||||
return {
|
||||
"rule_id": rule_id,
|
||||
"name": rule["name"],
|
||||
"status": "created",
|
||||
"active": rule["active"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/rules")
|
||||
async def list_compliance_rules():
|
||||
"""List all compliance rules"""
|
||||
return {
|
||||
"rules": list(compliance_rules.values()),
|
||||
"total_rules": len(compliance_rules),
|
||||
"active": len([r for r in compliance_rules.values() if r["active"]])
|
||||
}
|
||||
|
||||
@app.get("/api/v1/dashboard")
|
||||
async def compliance_dashboard():
|
||||
"""Get compliance dashboard data"""
|
||||
total_users = len(kyc_records)
|
||||
approved_users = len([r for r in kyc_records.values() if r["status"] == "approved"])
|
||||
pending_reviews = len([r for r in kyc_records.values() if r["status"] == "pending"])
|
||||
|
||||
total_reports = len(compliance_reports)
|
||||
open_reports = len([r for r in compliance_reports.values() if r["status"] == "open"])
|
||||
|
||||
total_transactions = len(suspicious_transactions)
|
||||
flagged_transactions = len([t for t in suspicious_transactions.values() if t["status"] == "flagged"])
|
||||
|
||||
return {
|
||||
"summary": {
|
||||
"total_users": total_users,
|
||||
"approved_users": approved_users,
|
||||
"pending_reviews": pending_reviews,
|
||||
"approval_rate": (approved_users / total_users * 100) if total_users > 0 else 0,
|
||||
"total_reports": total_reports,
|
||||
"open_reports": open_reports,
|
||||
"total_transactions": total_transactions,
|
||||
"flagged_transactions": flagged_transactions,
|
||||
"flag_rate": (flagged_transactions / total_transactions * 100) if total_transactions > 0 else 0
|
||||
},
|
||||
"risk_distribution": get_risk_distribution(),
|
||||
"recent_activity": get_recent_activity(),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Helper functions
|
||||
def calculate_transaction_risk(transaction: TransactionMonitoring) -> str:
|
||||
"""Calculate risk score for a transaction"""
|
||||
risk_score = 0
|
||||
|
||||
# Amount-based risk
|
||||
if transaction.amount > 10000:
|
||||
risk_score += 3
|
||||
elif transaction.amount > 1000:
|
||||
risk_score += 2
|
||||
elif transaction.amount > 100:
|
||||
risk_score += 1
|
||||
|
||||
# Time-based risk (transactions outside business hours)
|
||||
hour = transaction.timestamp.hour
|
||||
if hour < 9 or hour > 17:
|
||||
risk_score += 1
|
||||
|
||||
# Convert to risk level
|
||||
if risk_score >= 4:
|
||||
return "high"
|
||||
elif risk_score >= 2:
|
||||
return "medium"
|
||||
else:
|
||||
return "low"
|
||||
|
||||
def check_suspicious_patterns(transaction: TransactionMonitoring) -> List[str]:
|
||||
"""Check for suspicious transaction patterns"""
|
||||
flags = []
|
||||
|
||||
# High value transaction
|
||||
if transaction.amount > 50000:
|
||||
flags.append("high_value_transaction")
|
||||
|
||||
# Rapid transactions (check if user has multiple transactions in short time)
|
||||
user_transactions = [t for t in suspicious_transactions.values()
|
||||
if t["user_id"] == transaction.user_id]
|
||||
|
||||
recent_transactions = [t for t in user_transactions
|
||||
if datetime.fromisoformat(t["monitored_at"]) >
|
||||
datetime.now(timezone.utc) - timedelta(hours=1)]
|
||||
|
||||
if len(recent_transactions) > 5:
|
||||
flags.append("rapid_transactions")
|
||||
|
||||
# Unusual counterparty
|
||||
if transaction.counterparty in ["high_risk_entity_1", "high_risk_entity_2"]:
|
||||
flags.append("high_risk_counterparty")
|
||||
|
||||
return flags
|
||||
|
||||
async def create_suspicious_transaction_report(transaction: TransactionMonitoring, flags: List[str]):
|
||||
"""Create compliance report for suspicious transaction"""
|
||||
report_data = ComplianceReport(
|
||||
report_type="suspicious_transaction",
|
||||
description=f"Suspicious transaction detected: {transaction.transaction_id}",
|
||||
severity="high",
|
||||
details={
|
||||
"transaction_id": transaction.transaction_id,
|
||||
"user_id": transaction.user_id,
|
||||
"amount": transaction.amount,
|
||||
"flags": flags,
|
||||
"timestamp": transaction.timestamp.isoformat()
|
||||
}
|
||||
)
|
||||
|
||||
await create_compliance_report(report_data)
|
||||
|
||||
def get_risk_distribution() -> Dict[str, int]:
|
||||
"""Get distribution of risk scores"""
|
||||
distribution = {"low": 0, "medium": 0, "high": 0}
|
||||
|
||||
for record in kyc_records.values():
|
||||
distribution[record["risk_score"]] = distribution.get(record["risk_score"], 0) + 1
|
||||
|
||||
for transaction in suspicious_transactions.values():
|
||||
distribution[transaction["risk_score"]] = distribution.get(transaction["risk_score"], 0) + 1
|
||||
|
||||
return distribution
|
||||
|
||||
def get_recent_activity() -> List[Dict]:
|
||||
"""Get recent compliance activity"""
|
||||
activities = []
|
||||
|
||||
# Recent KYC approvals
|
||||
recent_kyc = [r for r in kyc_records.values()
|
||||
if r.get("approved_at") and
|
||||
datetime.fromisoformat(r["approved_at"]) >
|
||||
datetime.now(timezone.utc) - timedelta(hours=24)]
|
||||
|
||||
for kyc in recent_kyc[:5]:
|
||||
activities.append({
|
||||
"type": "kyc_approved",
|
||||
"description": f"KYC approved for {kyc['name']}",
|
||||
"timestamp": kyc["approved_at"]
|
||||
})
|
||||
|
||||
# Recent compliance reports
|
||||
recent_reports = [r for r in compliance_reports.values()
|
||||
if datetime.fromisoformat(r["created_at"]) >
|
||||
datetime.now(timezone.utc) - timedelta(hours=24)]
|
||||
|
||||
for report in recent_reports[:5]:
|
||||
activities.append({
|
||||
"type": "compliance_report",
|
||||
"description": f"Report: {report['description']}",
|
||||
"timestamp": report["created_at"]
|
||||
})
|
||||
|
||||
# Sort by timestamp
|
||||
activities.sort(key=lambda x: x["timestamp"], reverse=True)
|
||||
|
||||
return activities[:10]
|
||||
|
||||
# Background task for periodic compliance checks
|
||||
async def periodic_compliance_checks():
|
||||
"""Background task for periodic compliance monitoring"""
|
||||
while True:
|
||||
await asyncio.sleep(300) # Check every 5 minutes
|
||||
|
||||
# Check for expired KYC records
|
||||
current_time = datetime.now(timezone.utc)
|
||||
for user_id, kyc_record in kyc_records.items():
|
||||
if kyc_record["status"] == "approved":
|
||||
approved_time = datetime.fromisoformat(kyc_record["approved_at"])
|
||||
if current_time - approved_time > timedelta(days=365):
|
||||
# Flag for re-verification
|
||||
kyc_record["status"] = "reverification_required"
|
||||
logger.info(f"KYC re-verification required for user: {user_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8011, log_level="info")
|
||||
1
examples/stubs/compliance-service/tests/__init__.py
Normal file
1
examples/stubs/compliance-service/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Compliance service tests"""
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Edge case and error handling tests for compliance service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, KYCRequest, ComplianceReport, TransactionMonitoring, kyc_records, compliance_reports, suspicious_transactions, compliance_rules
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
kyc_records.clear()
|
||||
compliance_reports.clear()
|
||||
suspicious_transactions.clear()
|
||||
compliance_rules.clear()
|
||||
yield
|
||||
kyc_records.clear()
|
||||
compliance_reports.clear()
|
||||
suspicious_transactions.clear()
|
||||
compliance_rules.clear()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_kyc_request_empty_fields():
|
||||
"""Test KYCRequest with empty fields"""
|
||||
kyc = KYCRequest(
|
||||
user_id="",
|
||||
name="",
|
||||
email="",
|
||||
document_type="",
|
||||
document_number="",
|
||||
address={}
|
||||
)
|
||||
assert kyc.user_id == ""
|
||||
assert kyc.name == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_compliance_report_invalid_severity():
|
||||
"""Test ComplianceReport with invalid severity"""
|
||||
report = ComplianceReport(
|
||||
report_type="test",
|
||||
description="test",
|
||||
severity="invalid", # Not in low/medium/high/critical
|
||||
details={}
|
||||
)
|
||||
assert report.severity == "invalid"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_transaction_monitoring_zero_amount():
|
||||
"""Test TransactionMonitoring with zero amount"""
|
||||
tx = TransactionMonitoring(
|
||||
transaction_id="tx123",
|
||||
user_id="user123",
|
||||
amount=0.0,
|
||||
currency="BTC",
|
||||
counterparty="counterparty1",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert tx.amount == 0.0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_transaction_monitoring_negative_amount():
|
||||
"""Test TransactionMonitoring with negative amount"""
|
||||
tx = TransactionMonitoring(
|
||||
transaction_id="tx123",
|
||||
user_id="user123",
|
||||
amount=-1000.0,
|
||||
currency="BTC",
|
||||
counterparty="counterparty1",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert tx.amount == -1000.0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_kyc_with_missing_address_fields():
|
||||
"""Test KYC submission with missing address fields"""
|
||||
client = TestClient(app)
|
||||
kyc = KYCRequest(
|
||||
user_id="user123",
|
||||
name="John Doe",
|
||||
email="john@example.com",
|
||||
document_type="passport",
|
||||
document_number="ABC123",
|
||||
address={"city": "New York"} # Missing other fields
|
||||
)
|
||||
response = client.post("/api/v1/kyc/submit", json=kyc.model_dump())
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_compliance_report_empty_details():
|
||||
"""Test compliance report with empty details"""
|
||||
client = TestClient(app)
|
||||
report = ComplianceReport(
|
||||
report_type="test",
|
||||
description="test",
|
||||
severity="low",
|
||||
details={}
|
||||
)
|
||||
response = client.post("/api/v1/compliance/report", json=report.model_dump())
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_compliance_rule_missing_fields():
|
||||
"""Test compliance rule with missing fields"""
|
||||
client = TestClient(app)
|
||||
rule_data = {
|
||||
"name": "Test Rule"
|
||||
# Missing description, type, etc.
|
||||
}
|
||||
response = client.post("/api/v1/rules/create", json=rule_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Rule"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_dashboard_with_no_data():
|
||||
"""Test dashboard with no data"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["summary"]["total_users"] == 0
|
||||
assert data["summary"]["total_reports"] == 0
|
||||
assert data["summary"]["total_transactions"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_monitor_transaction_with_future_timestamp():
|
||||
"""Test monitoring transaction with future timestamp"""
|
||||
client = TestClient(app)
|
||||
tx = TransactionMonitoring(
|
||||
transaction_id="tx123",
|
||||
user_id="user123",
|
||||
amount=1000.0,
|
||||
currency="BTC",
|
||||
counterparty="counterparty1",
|
||||
timestamp=datetime(2030, 1, 1) # Future timestamp
|
||||
)
|
||||
response = client.post("/api/v1/monitoring/transaction", json=tx.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_monitor_transaction_with_past_timestamp():
|
||||
"""Test monitoring transaction with past timestamp"""
|
||||
client = TestClient(app)
|
||||
tx = TransactionMonitoring(
|
||||
transaction_id="tx123",
|
||||
user_id="user123",
|
||||
amount=1000.0,
|
||||
currency="BTC",
|
||||
counterparty="counterparty1",
|
||||
timestamp=datetime(2020, 1, 1) # Past timestamp
|
||||
)
|
||||
response = client.post("/api/v1/monitoring/transaction", json=tx.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_kyc_list_with_multiple_records():
|
||||
"""Test listing KYC with multiple records"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Create multiple KYC records
|
||||
for i in range(5):
|
||||
kyc = KYCRequest(
|
||||
user_id=f"user{i}",
|
||||
name=f"User {i}",
|
||||
email=f"user{i}@example.com",
|
||||
document_type="passport",
|
||||
document_number=f"ABC{i}",
|
||||
address={"city": "New York"}
|
||||
)
|
||||
client.post("/api/v1/kyc/submit", json=kyc.model_dump())
|
||||
|
||||
response = client.get("/api/v1/kyc")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_records"] == 5
|
||||
assert data["approved"] == 5
|
||||
@@ -0,0 +1,252 @@
|
||||
"""Integration tests for compliance service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, KYCRequest, ComplianceReport, TransactionMonitoring, kyc_records, compliance_reports, suspicious_transactions, compliance_rules
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
kyc_records.clear()
|
||||
compliance_reports.clear()
|
||||
suspicious_transactions.clear()
|
||||
compliance_rules.clear()
|
||||
yield
|
||||
kyc_records.clear()
|
||||
compliance_reports.clear()
|
||||
suspicious_transactions.clear()
|
||||
compliance_rules.clear()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_root_endpoint():
|
||||
"""Test root endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["service"] == "AITBC Compliance Service"
|
||||
assert data["status"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_health_check_endpoint():
|
||||
"""Test health check endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "kyc_records" in data
|
||||
assert "compliance_reports" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_submit_kyc():
|
||||
"""Test KYC submission"""
|
||||
client = TestClient(app)
|
||||
kyc = KYCRequest(
|
||||
user_id="user123",
|
||||
name="John Doe",
|
||||
email="john@example.com",
|
||||
document_type="passport",
|
||||
document_number="ABC123",
|
||||
address={"street": "123 Main St", "city": "New York", "country": "USA"}
|
||||
)
|
||||
response = client.post("/api/v1/kyc/submit", json=kyc.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == "user123"
|
||||
assert data["status"] == "approved"
|
||||
assert data["risk_score"] == "low"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_submit_duplicate_kyc():
|
||||
"""Test submitting duplicate KYC"""
|
||||
client = TestClient(app)
|
||||
kyc = KYCRequest(
|
||||
user_id="user123",
|
||||
name="John Doe",
|
||||
email="john@example.com",
|
||||
document_type="passport",
|
||||
document_number="ABC123",
|
||||
address={"street": "123 Main St", "city": "New York", "country": "USA"}
|
||||
)
|
||||
|
||||
# First submission
|
||||
client.post("/api/v1/kyc/submit", json=kyc.model_dump())
|
||||
|
||||
# Second submission should fail
|
||||
response = client.post("/api/v1/kyc/submit", json=kyc.model_dump())
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_kyc_status():
|
||||
"""Test getting KYC status"""
|
||||
client = TestClient(app)
|
||||
kyc = KYCRequest(
|
||||
user_id="user123",
|
||||
name="John Doe",
|
||||
email="john@example.com",
|
||||
document_type="passport",
|
||||
document_number="ABC123",
|
||||
address={"street": "123 Main St", "city": "New York", "country": "USA"}
|
||||
)
|
||||
|
||||
# Submit KYC first
|
||||
client.post("/api/v1/kyc/submit", json=kyc.model_dump())
|
||||
|
||||
# Get KYC status
|
||||
response = client.get("/api/v1/kyc/user123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == "user123"
|
||||
assert data["status"] == "approved"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_kyc_status_not_found():
|
||||
"""Test getting KYC status for nonexistent user"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/kyc/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_kyc_records():
|
||||
"""Test listing KYC records"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/kyc")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "kyc_records" in data
|
||||
assert "total_records" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_compliance_report():
|
||||
"""Test creating compliance report"""
|
||||
client = TestClient(app)
|
||||
report = ComplianceReport(
|
||||
report_type="suspicious_activity",
|
||||
description="Suspicious transaction detected",
|
||||
severity="high",
|
||||
details={"transaction_id": "tx123"}
|
||||
)
|
||||
response = client.post("/api/v1/compliance/report", json=report.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["severity"] == "high"
|
||||
assert data["status"] == "created"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_compliance_reports():
|
||||
"""Test listing compliance reports"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/compliance/reports")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "reports" in data
|
||||
assert "total_reports" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_monitor_transaction():
|
||||
"""Test transaction monitoring"""
|
||||
client = TestClient(app)
|
||||
tx = TransactionMonitoring(
|
||||
transaction_id="tx123",
|
||||
user_id="user123",
|
||||
amount=1000.0,
|
||||
currency="BTC",
|
||||
counterparty="counterparty1",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/monitoring/transaction", json=tx.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["transaction_id"] == "tx123"
|
||||
assert "risk_score" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_monitor_suspicious_transaction():
|
||||
"""Test monitoring suspicious transaction"""
|
||||
client = TestClient(app)
|
||||
tx = TransactionMonitoring(
|
||||
transaction_id="tx123",
|
||||
user_id="user123",
|
||||
amount=100000.0,
|
||||
currency="BTC",
|
||||
counterparty="high_risk_entity_1",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/monitoring/transaction", json=tx.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "flagged"
|
||||
assert len(data["flags"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_monitored_transactions():
|
||||
"""Test listing monitored transactions"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/monitoring/transactions")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "transactions" in data
|
||||
assert "total_transactions" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_compliance_rule():
|
||||
"""Test creating compliance rule"""
|
||||
client = TestClient(app)
|
||||
rule_data = {
|
||||
"name": "High Value Transaction Rule",
|
||||
"description": "Flag transactions over $50,000",
|
||||
"type": "transaction_monitoring",
|
||||
"conditions": {"min_amount": 50000},
|
||||
"actions": ["flag", "report"],
|
||||
"severity": "high"
|
||||
}
|
||||
response = client.post("/api/v1/rules/create", json=rule_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "High Value Transaction Rule"
|
||||
assert data["active"] is True
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_compliance_rules():
|
||||
"""Test listing compliance rules"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/rules")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "rules" in data
|
||||
assert "total_rules" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_compliance_dashboard():
|
||||
"""Test compliance dashboard"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "summary" in data
|
||||
assert "risk_distribution" in data
|
||||
assert "recent_activity" in data
|
||||
@@ -0,0 +1,161 @@
|
||||
"""Unit tests for compliance service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, KYCRequest, ComplianceReport, TransactionMonitoring, calculate_transaction_risk, check_suspicious_patterns
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_app_initialization():
|
||||
"""Test that the FastAPI app initializes correctly"""
|
||||
assert app is not None
|
||||
assert app.title == "AITBC Compliance Service"
|
||||
assert app.version == "1.0.0"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_kyc_request_model():
|
||||
"""Test KYCRequest model"""
|
||||
kyc = KYCRequest(
|
||||
user_id="user123",
|
||||
name="John Doe",
|
||||
email="john@example.com",
|
||||
document_type="passport",
|
||||
document_number="ABC123",
|
||||
address={"street": "123 Main St", "city": "New York", "country": "USA"}
|
||||
)
|
||||
assert kyc.user_id == "user123"
|
||||
assert kyc.name == "John Doe"
|
||||
assert kyc.email == "john@example.com"
|
||||
assert kyc.document_type == "passport"
|
||||
assert kyc.document_number == "ABC123"
|
||||
assert kyc.address["city"] == "New York"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_compliance_report_model():
|
||||
"""Test ComplianceReport model"""
|
||||
report = ComplianceReport(
|
||||
report_type="suspicious_activity",
|
||||
description="Suspicious transaction detected",
|
||||
severity="high",
|
||||
details={"transaction_id": "tx123"}
|
||||
)
|
||||
assert report.report_type == "suspicious_activity"
|
||||
assert report.description == "Suspicious transaction detected"
|
||||
assert report.severity == "high"
|
||||
assert report.details["transaction_id"] == "tx123"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_transaction_monitoring_model():
|
||||
"""Test TransactionMonitoring model"""
|
||||
tx = TransactionMonitoring(
|
||||
transaction_id="tx123",
|
||||
user_id="user123",
|
||||
amount=1000.0,
|
||||
currency="BTC",
|
||||
counterparty="counterparty1",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert tx.transaction_id == "tx123"
|
||||
assert tx.user_id == "user123"
|
||||
assert tx.amount == 1000.0
|
||||
assert tx.currency == "BTC"
|
||||
assert tx.counterparty == "counterparty1"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_transaction_risk_low():
|
||||
"""Test risk calculation for low risk transaction"""
|
||||
tx = TransactionMonitoring(
|
||||
transaction_id="tx123",
|
||||
user_id="user123",
|
||||
amount=50.0,
|
||||
currency="BTC",
|
||||
counterparty="counterparty1",
|
||||
timestamp=datetime(2026, 1, 1, 10, 0, 0) # Business hours
|
||||
)
|
||||
risk = calculate_transaction_risk(tx)
|
||||
assert risk == "low"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_transaction_risk_medium():
|
||||
"""Test risk calculation for medium risk transaction"""
|
||||
tx = TransactionMonitoring(
|
||||
transaction_id="tx123",
|
||||
user_id="user123",
|
||||
amount=5000.0,
|
||||
currency="BTC",
|
||||
counterparty="counterparty1",
|
||||
timestamp=datetime(2026, 1, 1, 10, 0, 0)
|
||||
)
|
||||
risk = calculate_transaction_risk(tx)
|
||||
assert risk == "medium"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_transaction_risk_high():
|
||||
"""Test risk calculation for high risk transaction"""
|
||||
tx = TransactionMonitoring(
|
||||
transaction_id="tx123",
|
||||
user_id="user123",
|
||||
amount=20000.0,
|
||||
currency="BTC",
|
||||
counterparty="counterparty1",
|
||||
timestamp=datetime(2026, 1, 1, 8, 0, 0) # Outside business hours
|
||||
)
|
||||
risk = calculate_transaction_risk(tx)
|
||||
assert risk == "high"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_check_suspicious_patterns_high_value():
|
||||
"""Test suspicious pattern detection for high value"""
|
||||
tx = TransactionMonitoring(
|
||||
transaction_id="tx123",
|
||||
user_id="user123",
|
||||
amount=100000.0,
|
||||
currency="BTC",
|
||||
counterparty="counterparty1",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
flags = check_suspicious_patterns(tx)
|
||||
assert "high_value_transaction" in flags
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_check_suspicious_patterns_high_risk_counterparty():
|
||||
"""Test suspicious pattern detection for high risk counterparty"""
|
||||
tx = TransactionMonitoring(
|
||||
transaction_id="tx123",
|
||||
user_id="user123",
|
||||
amount=1000.0,
|
||||
currency="BTC",
|
||||
counterparty="high_risk_entity_1",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
flags = check_suspicious_patterns(tx)
|
||||
assert "high_risk_counterparty" in flags
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_check_suspicious_patterns_none():
|
||||
"""Test suspicious pattern detection with no flags"""
|
||||
tx = TransactionMonitoring(
|
||||
transaction_id="tx123",
|
||||
user_id="user123",
|
||||
amount=1000.0,
|
||||
currency="BTC",
|
||||
counterparty="safe_counterparty",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
flags = check_suspicious_patterns(tx)
|
||||
assert len(flags) == 0
|
||||
324
examples/stubs/exchange-integration/main.py
Executable file
324
examples/stubs/exchange-integration/main.py
Executable file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
Production Exchange API Integration Service
|
||||
Handles real exchange connections and trading operations
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
import aiohttp
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="AITBC Exchange Integration Service",
|
||||
description="Production exchange API integration for AITBC trading",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Data models
|
||||
class ExchangeRegistration(BaseModel):
|
||||
name: str
|
||||
api_key: str
|
||||
sandbox: bool = True
|
||||
description: Optional[str] = None
|
||||
|
||||
class TradingPair(BaseModel):
|
||||
symbol: str
|
||||
base_asset: str
|
||||
quote_asset: str
|
||||
min_order_size: float
|
||||
price_precision: int
|
||||
quantity_precision: int
|
||||
|
||||
class OrderRequest(BaseModel):
|
||||
symbol: str
|
||||
side: str # buy/sell
|
||||
type: str # market/limit
|
||||
quantity: float
|
||||
price: Optional[float] = None
|
||||
|
||||
# In-memory storage (in production, use database)
|
||||
exchanges: Dict[str, Dict] = {}
|
||||
trading_pairs: Dict[str, Dict] = {}
|
||||
orders: Dict[str, Dict] = {}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "AITBC Exchange Integration",
|
||||
"status": "running",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"exchanges_connected": len([e for e in exchanges.values() if e.get("connected")]),
|
||||
"active_pairs": len(trading_pairs),
|
||||
"total_orders": len(orders)
|
||||
}
|
||||
|
||||
@app.post("/api/v1/exchanges/register")
|
||||
async def register_exchange(registration: ExchangeRegistration):
|
||||
"""Register a new exchange connection"""
|
||||
exchange_id = registration.name.lower()
|
||||
|
||||
if exchange_id in exchanges:
|
||||
raise HTTPException(status_code=400, detail="Exchange already registered")
|
||||
|
||||
# Create exchange configuration
|
||||
exchange_config = {
|
||||
"exchange_id": exchange_id,
|
||||
"name": registration.name,
|
||||
"api_key": registration.api_key,
|
||||
"sandbox": registration.sandbox,
|
||||
"description": registration.description,
|
||||
"connected": False,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"last_sync": None,
|
||||
"trading_pairs": []
|
||||
}
|
||||
|
||||
exchanges[exchange_id] = exchange_config
|
||||
|
||||
logger.info(f"Exchange registered: {registration.name}")
|
||||
|
||||
return {
|
||||
"exchange_id": exchange_id,
|
||||
"status": "registered",
|
||||
"name": registration.name,
|
||||
"sandbox": registration.sandbox,
|
||||
"created_at": exchange_config["created_at"]
|
||||
}
|
||||
|
||||
@app.post("/api/v1/exchanges/{exchange_id}/connect")
|
||||
async def connect_exchange(exchange_id: str):
|
||||
"""Connect to a registered exchange"""
|
||||
if exchange_id not in exchanges:
|
||||
raise HTTPException(status_code=404, detail="Exchange not found")
|
||||
|
||||
exchange = exchanges[exchange_id]
|
||||
|
||||
if exchange["connected"]:
|
||||
return {"status": "already_connected", "exchange_id": exchange_id}
|
||||
|
||||
# Simulate exchange connection
|
||||
# In production, this would make actual API calls to the exchange
|
||||
await asyncio.sleep(1) # Simulate connection delay
|
||||
|
||||
exchange["connected"] = True
|
||||
exchange["last_sync"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
logger.info(f"Exchange connected: {exchange_id}")
|
||||
|
||||
return {
|
||||
"exchange_id": exchange_id,
|
||||
"status": "connected",
|
||||
"connected_at": exchange["last_sync"]
|
||||
}
|
||||
|
||||
@app.post("/api/v1/pairs/create")
|
||||
async def create_trading_pair(pair: TradingPair):
|
||||
"""Create a new trading pair"""
|
||||
pair_id = f"{pair.symbol.lower()}"
|
||||
|
||||
if pair_id in trading_pairs:
|
||||
raise HTTPException(status_code=400, detail="Trading pair already exists")
|
||||
|
||||
# Create trading pair configuration
|
||||
pair_config = {
|
||||
"pair_id": pair_id,
|
||||
"symbol": pair.symbol,
|
||||
"base_asset": pair.base_asset,
|
||||
"quote_asset": pair.quote_asset,
|
||||
"min_order_size": pair.min_order_size,
|
||||
"price_precision": pair.price_precision,
|
||||
"quantity_precision": pair.quantity_precision,
|
||||
"status": "active",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"current_price": None,
|
||||
"volume_24h": 0.0,
|
||||
"orders": []
|
||||
}
|
||||
|
||||
trading_pairs[pair_id] = pair_config
|
||||
|
||||
logger.info(f"Trading pair created: {pair.symbol}")
|
||||
|
||||
return {
|
||||
"pair_id": pair_id,
|
||||
"symbol": pair.symbol,
|
||||
"status": "created",
|
||||
"created_at": pair_config["created_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/pairs")
|
||||
async def list_trading_pairs():
|
||||
"""List all trading pairs"""
|
||||
return {
|
||||
"pairs": list(trading_pairs.values()),
|
||||
"total_pairs": len(trading_pairs)
|
||||
}
|
||||
|
||||
@app.get("/api/v1/pairs/{pair_id}")
|
||||
async def get_trading_pair(pair_id: str):
|
||||
"""Get specific trading pair information"""
|
||||
if pair_id not in trading_pairs:
|
||||
raise HTTPException(status_code=404, detail="Trading pair not found")
|
||||
|
||||
return trading_pairs[pair_id]
|
||||
|
||||
@app.post("/api/v1/orders")
|
||||
async def create_order(order: OrderRequest):
|
||||
"""Create a new trading order"""
|
||||
pair_id = order.symbol.lower()
|
||||
|
||||
if pair_id not in trading_pairs:
|
||||
raise HTTPException(status_code=404, detail="Trading pair not found")
|
||||
|
||||
# Generate order ID
|
||||
order_id = f"order_{int(datetime.now(timezone.utc).timestamp())}"
|
||||
|
||||
# Create order
|
||||
order_data = {
|
||||
"order_id": order_id,
|
||||
"symbol": order.symbol,
|
||||
"side": order.side,
|
||||
"type": order.type,
|
||||
"quantity": order.quantity,
|
||||
"price": order.price,
|
||||
"status": "submitted",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"filled_quantity": 0.0,
|
||||
"remaining_quantity": order.quantity,
|
||||
"average_price": None
|
||||
}
|
||||
|
||||
orders[order_id] = order_data
|
||||
|
||||
# Add to trading pair
|
||||
trading_pairs[pair_id]["orders"].append(order_id)
|
||||
|
||||
# Simulate order processing
|
||||
await asyncio.sleep(0.5) # Simulate processing delay
|
||||
|
||||
# Mark as filled (for demo)
|
||||
order_data["status"] = "filled"
|
||||
order_data["filled_quantity"] = order.quantity
|
||||
order_data["remaining_quantity"] = 0.0
|
||||
order_data["average_price"] = order.price or 0.00001 # Default price for demo
|
||||
order_data["filled_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
logger.info(f"Order created and filled: {order_id}")
|
||||
|
||||
return order_data
|
||||
|
||||
@app.get("/api/v1/orders")
|
||||
async def list_orders():
|
||||
"""List all orders"""
|
||||
return {
|
||||
"orders": list(orders.values()),
|
||||
"total_orders": len(orders)
|
||||
}
|
||||
|
||||
@app.get("/api/v1/orders/{order_id}")
|
||||
async def get_order(order_id: str):
|
||||
"""Get specific order information"""
|
||||
if order_id not in orders:
|
||||
raise HTTPException(status_code=404, detail="Order not found")
|
||||
|
||||
return orders[order_id]
|
||||
|
||||
@app.get("/api/v1/exchanges")
|
||||
async def list_exchanges():
|
||||
"""List all registered exchanges"""
|
||||
return {
|
||||
"exchanges": list(exchanges.values()),
|
||||
"total_exchanges": len(exchanges)
|
||||
}
|
||||
|
||||
@app.get("/api/v1/exchanges/{exchange_id}")
|
||||
async def get_exchange(exchange_id: str):
|
||||
"""Get specific exchange information"""
|
||||
if exchange_id not in exchanges:
|
||||
raise HTTPException(status_code=404, detail="Exchange not found")
|
||||
|
||||
return exchanges[exchange_id]
|
||||
|
||||
@app.post("/api/v1/market-data/{pair_id}/price")
|
||||
async def update_market_price(pair_id: str, price_data: Dict[str, Any]):
|
||||
"""Update market price for a trading pair"""
|
||||
if pair_id not in trading_pairs:
|
||||
raise HTTPException(status_code=404, detail="Trading pair not found")
|
||||
|
||||
pair = trading_pairs[pair_id]
|
||||
pair["current_price"] = price_data.get("price")
|
||||
pair["volume_24h"] = price_data.get("volume", pair["volume_24h"])
|
||||
pair["last_price_update"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
return {
|
||||
"pair_id": pair_id,
|
||||
"current_price": pair["current_price"],
|
||||
"updated_at": pair["last_price_update"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/market-data")
|
||||
async def get_market_data():
|
||||
"""Get market data for all pairs"""
|
||||
market_data = {}
|
||||
for pair_id, pair in trading_pairs.items():
|
||||
market_data[pair_id] = {
|
||||
"symbol": pair["symbol"],
|
||||
"current_price": pair.get("current_price"),
|
||||
"volume_24h": pair.get("volume_24h"),
|
||||
"last_update": pair.get("last_price_update")
|
||||
}
|
||||
|
||||
return {
|
||||
"market_data": market_data,
|
||||
"total_pairs": len(market_data),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Background task for simulating market data
|
||||
async def simulate_market_data():
|
||||
"""Background task to simulate market data updates"""
|
||||
while True:
|
||||
await asyncio.sleep(30) # Update every 30 seconds
|
||||
|
||||
for pair_id, pair in trading_pairs.items():
|
||||
if pair["status"] == "active":
|
||||
# Simulate price changes
|
||||
import random
|
||||
base_price = 0.00001 # Base price for AITBC
|
||||
variation = random.uniform(-0.02, 0.02) # ±2% variation
|
||||
new_price = round(base_price * (1 + variation), 8)
|
||||
|
||||
pair["current_price"] = new_price
|
||||
pair["volume_24h"] += random.uniform(100, 1000)
|
||||
pair["last_price_update"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# Start background task on startup
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("Starting AITBC Exchange Integration Service")
|
||||
# Start background market data simulation
|
||||
asyncio.create_task(simulate_market_data())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logger.info("Shutting down AITBC Exchange Integration Service")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8010, log_level="info")
|
||||
1
examples/stubs/exchange-integration/tests/__init__.py
Normal file
1
examples/stubs/exchange-integration/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Exchange integration service tests"""
|
||||
@@ -0,0 +1,256 @@
|
||||
"""Edge case and error handling tests for exchange integration service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# Mock aiohttp before importing
|
||||
sys.modules['aiohttp'] = Mock()
|
||||
|
||||
from main import app, ExchangeRegistration, TradingPair, OrderRequest, exchanges, trading_pairs, orders
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
exchanges.clear()
|
||||
trading_pairs.clear()
|
||||
orders.clear()
|
||||
yield
|
||||
exchanges.clear()
|
||||
trading_pairs.clear()
|
||||
orders.clear()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exchange_registration_empty_name():
|
||||
"""Test ExchangeRegistration with empty name"""
|
||||
registration = ExchangeRegistration(
|
||||
name="",
|
||||
api_key="test_key_123"
|
||||
)
|
||||
assert registration.name == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exchange_registration_empty_api_key():
|
||||
"""Test ExchangeRegistration with empty API key"""
|
||||
registration = ExchangeRegistration(
|
||||
name="TestExchange",
|
||||
api_key=""
|
||||
)
|
||||
assert registration.api_key == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_trading_pair_zero_min_order_size():
|
||||
"""Test TradingPair with zero min order size"""
|
||||
pair = TradingPair(
|
||||
symbol="AITBC/BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=0.0,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
assert pair.min_order_size == 0.0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_trading_pair_negative_min_order_size():
|
||||
"""Test TradingPair with negative min order size"""
|
||||
pair = TradingPair(
|
||||
symbol="AITBC/BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=-0.001,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
assert pair.min_order_size == -0.001
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_order_request_zero_quantity():
|
||||
"""Test OrderRequest with zero quantity"""
|
||||
order = OrderRequest(
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=0.0,
|
||||
price=0.00001
|
||||
)
|
||||
assert order.quantity == 0.0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_order_request_negative_quantity():
|
||||
"""Test OrderRequest with negative quantity"""
|
||||
order = OrderRequest(
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=-100.0,
|
||||
price=0.00001
|
||||
)
|
||||
assert order.quantity == -100.0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_order_request_invalid_side():
|
||||
"""Test OrderRequest with invalid side"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Create trading pair first
|
||||
pair = TradingPair(
|
||||
symbol="AITBC/BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=0.001,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
client.post("/api/v1/pairs/create", json=pair.model_dump())
|
||||
|
||||
# Create order with invalid side (API doesn't validate, but test the behavior)
|
||||
order = OrderRequest(
|
||||
symbol="AITBC/BTC",
|
||||
side="invalid",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001
|
||||
)
|
||||
# This will be accepted by the API as it doesn't validate the side
|
||||
response = client.post("/api/v1/orders", json=order.model_dump())
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_order_request_invalid_type():
|
||||
"""Test OrderRequest with invalid type"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Create trading pair first
|
||||
pair = TradingPair(
|
||||
symbol="AITBC/BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=0.001,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
client.post("/api/v1/pairs/create", json=pair.model_dump())
|
||||
|
||||
# Create order with invalid type (API doesn't validate, but test the behavior)
|
||||
order = OrderRequest(
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="invalid",
|
||||
quantity=100.0,
|
||||
price=0.00001
|
||||
)
|
||||
# This will be accepted by the API as it doesn't validate the type
|
||||
response = client.post("/api/v1/orders", json=order.model_dump())
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_connect_already_connected_exchange():
|
||||
"""Test connecting to already connected exchange"""
|
||||
client = TestClient(app)
|
||||
registration = ExchangeRegistration(
|
||||
name="TestExchange",
|
||||
api_key="test_key_123"
|
||||
)
|
||||
|
||||
# Register exchange
|
||||
client.post("/api/v1/exchanges/register", json=registration.model_dump())
|
||||
|
||||
# Connect first time
|
||||
client.post("/api/v1/exchanges/testexchange/connect")
|
||||
|
||||
# Connect second time should return already_connected
|
||||
response = client.post("/api/v1/exchanges/testexchange/connect")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "already_connected"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_update_market_price_missing_fields():
|
||||
"""Test updating market price with missing fields"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Create trading pair first
|
||||
pair = TradingPair(
|
||||
symbol="AITBC-BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=0.001,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
create_response = client.post("/api/v1/pairs/create", json=pair.model_dump())
|
||||
assert create_response.status_code == 200
|
||||
|
||||
# Update with missing price
|
||||
price_data = {"volume": 50000.0}
|
||||
response = client.post("/api/v1/market-data/aitbc-btc/price", json=price_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should use None for missing price
|
||||
assert data["current_price"] is None
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_update_market_price_zero_price():
|
||||
"""Test updating market price with zero price"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Create trading pair first
|
||||
pair = TradingPair(
|
||||
symbol="AITBC-BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=0.001,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
create_response = client.post("/api/v1/pairs/create", json=pair.model_dump())
|
||||
assert create_response.status_code == 200
|
||||
|
||||
# Update with zero price
|
||||
price_data = {"price": 0.0}
|
||||
response = client.post("/api/v1/market-data/aitbc-btc/price", json=price_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["current_price"] == 0.0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_update_market_price_negative_price():
|
||||
"""Test updating market price with negative price"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Create trading pair first
|
||||
pair = TradingPair(
|
||||
symbol="AITBC-BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=0.001,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
create_response = client.post("/api/v1/pairs/create", json=pair.model_dump())
|
||||
assert create_response.status_code == 200
|
||||
|
||||
# Update with negative price
|
||||
price_data = {"price": -0.00001}
|
||||
response = client.post("/api/v1/market-data/aitbc-btc/price", json=price_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["current_price"] == -0.00001
|
||||
@@ -0,0 +1,378 @@
|
||||
"""Integration tests for exchange integration service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# Mock aiohttp before importing
|
||||
sys.modules['aiohttp'] = Mock()
|
||||
|
||||
from main import app, ExchangeRegistration, TradingPair, OrderRequest, exchanges, trading_pairs, orders
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
exchanges.clear()
|
||||
trading_pairs.clear()
|
||||
orders.clear()
|
||||
yield
|
||||
exchanges.clear()
|
||||
trading_pairs.clear()
|
||||
orders.clear()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_root_endpoint():
|
||||
"""Test root endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["service"] == "AITBC Exchange Integration"
|
||||
assert data["status"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_health_check_endpoint():
|
||||
"""Test health check endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "exchanges_connected" in data
|
||||
assert "active_pairs" in data
|
||||
assert "total_orders" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_register_exchange():
|
||||
"""Test exchange registration"""
|
||||
client = TestClient(app)
|
||||
registration = ExchangeRegistration(
|
||||
name="TestExchange",
|
||||
api_key="test_key_123",
|
||||
sandbox=True
|
||||
)
|
||||
response = client.post("/api/v1/exchanges/register", json=registration.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["exchange_id"] == "testexchange"
|
||||
assert data["status"] == "registered"
|
||||
assert data["name"] == "TestExchange"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_register_duplicate_exchange():
|
||||
"""Test registering duplicate exchange"""
|
||||
client = TestClient(app)
|
||||
registration = ExchangeRegistration(
|
||||
name="TestExchange",
|
||||
api_key="test_key_123"
|
||||
)
|
||||
|
||||
# First registration
|
||||
client.post("/api/v1/exchanges/register", json=registration.model_dump())
|
||||
|
||||
# Second registration should fail
|
||||
response = client.post("/api/v1/exchanges/register", json=registration.model_dump())
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_connect_exchange():
|
||||
"""Test connecting to exchange"""
|
||||
client = TestClient(app)
|
||||
registration = ExchangeRegistration(
|
||||
name="TestExchange",
|
||||
api_key="test_key_123"
|
||||
)
|
||||
|
||||
# Register exchange first
|
||||
client.post("/api/v1/exchanges/register", json=registration.model_dump())
|
||||
|
||||
# Connect to exchange
|
||||
response = client.post("/api/v1/exchanges/testexchange/connect")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["exchange_id"] == "testexchange"
|
||||
assert data["status"] == "connected"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_connect_nonexistent_exchange():
|
||||
"""Test connecting to nonexistent exchange"""
|
||||
client = TestClient(app)
|
||||
response = client.post("/api/v1/exchanges/nonexistent/connect")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_trading_pair():
|
||||
"""Test creating trading pair"""
|
||||
client = TestClient(app)
|
||||
pair = TradingPair(
|
||||
symbol="AITBC/BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=0.001,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
response = client.post("/api/v1/pairs/create", json=pair.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["pair_id"] == "aitbc/btc"
|
||||
assert data["symbol"] == "AITBC/BTC"
|
||||
assert data["status"] == "created"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_duplicate_trading_pair():
|
||||
"""Test creating duplicate trading pair"""
|
||||
client = TestClient(app)
|
||||
pair = TradingPair(
|
||||
symbol="AITBC/BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=0.001,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
|
||||
# First creation
|
||||
client.post("/api/v1/pairs/create", json=pair.model_dump())
|
||||
|
||||
# Second creation should fail
|
||||
response = client.post("/api/v1/pairs/create", json=pair.model_dump())
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_trading_pairs():
|
||||
"""Test listing trading pairs"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/pairs")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "pairs" in data
|
||||
assert "total_pairs" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_trading_pair():
|
||||
"""Test getting specific trading pair"""
|
||||
client = TestClient(app)
|
||||
pair = TradingPair(
|
||||
symbol="AITBC-BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=0.001,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
|
||||
# Create pair first
|
||||
client.post("/api/v1/pairs/create", json=pair.model_dump())
|
||||
|
||||
# Get pair with lowercase symbol as pair_id
|
||||
response = client.get("/api/v1/pairs/aitbc-btc")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["symbol"] == "AITBC-BTC"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_nonexistent_trading_pair():
|
||||
"""Test getting nonexistent trading pair"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/pairs/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_order():
|
||||
"""Test creating order"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Create trading pair first
|
||||
pair = TradingPair(
|
||||
symbol="AITBC/BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=0.001,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
client.post("/api/v1/pairs/create", json=pair.model_dump())
|
||||
|
||||
# Create order
|
||||
order = OrderRequest(
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001
|
||||
)
|
||||
response = client.post("/api/v1/orders", json=order.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["symbol"] == "AITBC/BTC"
|
||||
assert data["side"] == "buy"
|
||||
assert data["status"] == "filled"
|
||||
assert data["filled_quantity"] == 100.0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_order_nonexistent_pair():
|
||||
"""Test creating order for nonexistent pair"""
|
||||
client = TestClient(app)
|
||||
order = OrderRequest(
|
||||
symbol="NONEXISTENT/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001
|
||||
)
|
||||
response = client.post("/api/v1/orders", json=order.model_dump())
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_orders():
|
||||
"""Test listing orders"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/orders")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "orders" in data
|
||||
assert "total_orders" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_order():
|
||||
"""Test getting specific order"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Create trading pair first
|
||||
pair = TradingPair(
|
||||
symbol="AITBC/BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=0.001,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
client.post("/api/v1/pairs/create", json=pair.model_dump())
|
||||
|
||||
# Create order
|
||||
order = OrderRequest(
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001
|
||||
)
|
||||
create_response = client.post("/api/v1/orders", json=order.model_dump())
|
||||
order_id = create_response.json()["order_id"]
|
||||
|
||||
# Get order
|
||||
response = client.get(f"/api/v1/orders/{order_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["order_id"] == order_id
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_nonexistent_order():
|
||||
"""Test getting nonexistent order"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/orders/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_exchanges():
|
||||
"""Test listing exchanges"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/exchanges")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "exchanges" in data
|
||||
assert "total_exchanges" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_exchange():
|
||||
"""Test getting specific exchange"""
|
||||
client = TestClient(app)
|
||||
registration = ExchangeRegistration(
|
||||
name="TestExchange",
|
||||
api_key="test_key_123"
|
||||
)
|
||||
|
||||
# Register exchange first
|
||||
client.post("/api/v1/exchanges/register", json=registration.model_dump())
|
||||
|
||||
# Get exchange
|
||||
response = client.get("/api/v1/exchanges/testexchange")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["exchange_id"] == "testexchange"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_nonexistent_exchange():
|
||||
"""Test getting nonexistent exchange"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/exchanges/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_update_market_price():
|
||||
"""Test updating market price"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Create trading pair first
|
||||
pair = TradingPair(
|
||||
symbol="AITBC-BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=0.001,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
client.post("/api/v1/pairs/create", json=pair.model_dump())
|
||||
|
||||
# Update price
|
||||
price_data = {"price": 0.000015, "volume": 50000.0}
|
||||
response = client.post("/api/v1/market-data/aitbc-btc/price", json=price_data)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["current_price"] == 0.000015
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_update_price_nonexistent_pair():
|
||||
"""Test updating price for nonexistent pair"""
|
||||
client = TestClient(app)
|
||||
price_data = {"price": 0.000015}
|
||||
response = client.post("/api/v1/market-data/nonexistent/price", json=price_data)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_market_data():
|
||||
"""Test getting market data"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/market-data")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "market_data" in data
|
||||
assert "total_pairs" in data
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Unit tests for exchange integration service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
||||
# Mock aiohttp before importing
|
||||
sys.modules['aiohttp'] = Mock()
|
||||
|
||||
from main import app, ExchangeRegistration, TradingPair, OrderRequest
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_app_initialization():
|
||||
"""Test that the FastAPI app initializes correctly"""
|
||||
assert app is not None
|
||||
assert app.title == "AITBC Exchange Integration Service"
|
||||
assert app.version == "1.0.0"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exchange_registration_model():
|
||||
"""Test ExchangeRegistration model"""
|
||||
registration = ExchangeRegistration(
|
||||
name="TestExchange",
|
||||
api_key="test_key_123",
|
||||
sandbox=True,
|
||||
description="Test exchange"
|
||||
)
|
||||
assert registration.name == "TestExchange"
|
||||
assert registration.api_key == "test_key_123"
|
||||
assert registration.sandbox is True
|
||||
assert registration.description == "Test exchange"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exchange_registration_defaults():
|
||||
"""Test ExchangeRegistration default values"""
|
||||
registration = ExchangeRegistration(
|
||||
name="TestExchange",
|
||||
api_key="test_key_123"
|
||||
)
|
||||
assert registration.name == "TestExchange"
|
||||
assert registration.api_key == "test_key_123"
|
||||
assert registration.sandbox is True
|
||||
assert registration.description is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_trading_pair_model():
|
||||
"""Test TradingPair model"""
|
||||
pair = TradingPair(
|
||||
symbol="AITBC/BTC",
|
||||
base_asset="AITBC",
|
||||
quote_asset="BTC",
|
||||
min_order_size=0.001,
|
||||
price_precision=8,
|
||||
quantity_precision=6
|
||||
)
|
||||
assert pair.symbol == "AITBC/BTC"
|
||||
assert pair.base_asset == "AITBC"
|
||||
assert pair.quote_asset == "BTC"
|
||||
assert pair.min_order_size == 0.001
|
||||
assert pair.price_precision == 8
|
||||
assert pair.quantity_precision == 6
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_order_request_model():
|
||||
"""Test OrderRequest model"""
|
||||
order = OrderRequest(
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001
|
||||
)
|
||||
assert order.symbol == "AITBC/BTC"
|
||||
assert order.side == "buy"
|
||||
assert order.type == "limit"
|
||||
assert order.quantity == 100.0
|
||||
assert order.price == 0.00001
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_order_request_market_order():
|
||||
"""Test OrderRequest for market order"""
|
||||
order = OrderRequest(
|
||||
symbol="AITBC/BTC",
|
||||
side="sell",
|
||||
type="market",
|
||||
quantity=50.0
|
||||
)
|
||||
assert order.symbol == "AITBC/BTC"
|
||||
assert order.side == "sell"
|
||||
assert order.type == "market"
|
||||
assert order.quantity == 50.0
|
||||
assert order.price is None
|
||||
662
examples/stubs/global-ai-agents/main.py
Executable file
662
examples/stubs/global-ai-agents/main.py
Executable file
@@ -0,0 +1,662 @@
|
||||
"""
|
||||
Global AI Agent Communication Service for AITBC
|
||||
Handles cross-chain and cross-region AI agent communication with global optimization
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="AITBC Global AI Agent Communication Service",
|
||||
description="Global AI agent communication and collaboration platform",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Data models
|
||||
class Agent(BaseModel):
|
||||
agent_id: str
|
||||
name: str
|
||||
type: str # ai, blockchain, oracle, market_maker, etc.
|
||||
region: str
|
||||
capabilities: List[str]
|
||||
status: str # active, inactive, busy
|
||||
languages: List[str] # Languages the agent can communicate in
|
||||
specialization: str
|
||||
performance_score: float
|
||||
|
||||
class AgentMessage(BaseModel):
|
||||
message_id: str
|
||||
sender_id: str
|
||||
recipient_id: Optional[str] # None for broadcast
|
||||
message_type: str # request, response, collaboration, data_share
|
||||
content: Dict[str, Any]
|
||||
priority: str # low, medium, high, critical
|
||||
language: str
|
||||
timestamp: datetime
|
||||
encryption_key: Optional[str] = None
|
||||
|
||||
class CollaborationSession(BaseModel):
|
||||
session_id: str
|
||||
participants: List[str]
|
||||
session_type: str # task_force, research, trading, governance
|
||||
objective: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
status: str # active, completed, expired
|
||||
|
||||
class AgentPerformance(BaseModel):
|
||||
agent_id: str
|
||||
timestamp: datetime
|
||||
tasks_completed: int
|
||||
response_time_ms: float
|
||||
accuracy_score: float
|
||||
collaboration_score: float
|
||||
resource_usage: Dict[str, float]
|
||||
|
||||
# In-memory storage (in production, use database)
|
||||
global_agents: Dict[str, Dict] = {}
|
||||
agent_messages: Dict[str, List[Dict]] = {}
|
||||
collaboration_sessions: Dict[str, Dict] = {}
|
||||
agent_performance: Dict[str, List[Dict]] = {}
|
||||
global_network_stats: Dict[str, Any] = {}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "AITBC Global AI Agent Communication Service",
|
||||
"status": "running",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"total_agents": len(global_agents),
|
||||
"active_agents": len([a for a in global_agents.values() if a["status"] == "active"]),
|
||||
"active_sessions": len([s for s in collaboration_sessions.values() if s["status"] == "active"]),
|
||||
"total_messages": sum(len(messages) for messages in agent_messages.values())
|
||||
}
|
||||
|
||||
@app.post("/api/v1/agents/register")
|
||||
async def register_agent(agent: Agent):
|
||||
"""Register a new AI agent in the global network"""
|
||||
if agent.agent_id in global_agents:
|
||||
raise HTTPException(status_code=400, detail="Agent already registered")
|
||||
|
||||
# Create agent record
|
||||
agent_record = {
|
||||
"agent_id": agent.agent_id,
|
||||
"name": agent.name,
|
||||
"type": agent.type,
|
||||
"region": agent.region,
|
||||
"capabilities": agent.capabilities,
|
||||
"status": agent.status,
|
||||
"languages": agent.languages,
|
||||
"specialization": agent.specialization,
|
||||
"performance_score": agent.performance_score,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"last_active": datetime.now(timezone.utc).isoformat(),
|
||||
"total_messages_sent": 0,
|
||||
"total_messages_received": 0,
|
||||
"collaborations_participated": 0,
|
||||
"tasks_completed": 0,
|
||||
"reputation_score": 5.0,
|
||||
"network_connections": []
|
||||
}
|
||||
|
||||
global_agents[agent.agent_id] = agent_record
|
||||
agent_messages[agent.agent_id] = []
|
||||
|
||||
logger.info(f"Agent registered: {agent.name} ({agent.agent_id}) in {agent.region}")
|
||||
|
||||
return {
|
||||
"agent_id": agent.agent_id,
|
||||
"status": "registered",
|
||||
"name": agent.name,
|
||||
"region": agent.region,
|
||||
"created_at": agent_record["created_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/agents")
|
||||
async def list_agents(region: Optional[str] = None,
|
||||
agent_type: Optional[str] = None,
|
||||
status: Optional[str] = None):
|
||||
"""List all agents with filtering"""
|
||||
agents = list(global_agents.values())
|
||||
|
||||
# Apply filters
|
||||
if region:
|
||||
agents = [a for a in agents if a["region"] == region]
|
||||
if agent_type:
|
||||
agents = [a for a in agents if a["type"] == agent_type]
|
||||
if status:
|
||||
agents = [a for a in agents if a["status"] == status]
|
||||
|
||||
return {
|
||||
"agents": agents,
|
||||
"total_agents": len(agents),
|
||||
"filters": {
|
||||
"region": region,
|
||||
"agent_type": agent_type,
|
||||
"status": status
|
||||
}
|
||||
}
|
||||
|
||||
@app.get("/api/v1/agents/{agent_id}")
|
||||
async def get_agent(agent_id: str):
|
||||
"""Get detailed agent information"""
|
||||
if agent_id not in global_agents:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
|
||||
agent = global_agents[agent_id].copy()
|
||||
|
||||
# Add recent messages
|
||||
agent["recent_messages"] = agent_messages.get(agent_id, [])[-10:]
|
||||
|
||||
# Add performance metrics
|
||||
agent["performance_metrics"] = agent_performance.get(agent_id, [])
|
||||
|
||||
return agent
|
||||
|
||||
@app.post("/api/v1/messages/send")
|
||||
async def send_message(message: AgentMessage):
|
||||
"""Send a message from one agent to another or broadcast"""
|
||||
# Validate sender
|
||||
if message.sender_id not in global_agents:
|
||||
raise HTTPException(status_code=400, detail="Sender agent not found")
|
||||
|
||||
# Create message record
|
||||
message_record = {
|
||||
"message_id": message.message_id,
|
||||
"sender_id": message.sender_id,
|
||||
"recipient_id": message.recipient_id,
|
||||
"message_type": message.message_type,
|
||||
"content": message.content,
|
||||
"priority": message.priority,
|
||||
"language": message.language,
|
||||
"timestamp": message.timestamp.isoformat(),
|
||||
"encryption_key": message.encryption_key,
|
||||
"status": "delivered",
|
||||
"delivered_at": datetime.now(timezone.utc).isoformat(),
|
||||
"read_at": None
|
||||
}
|
||||
|
||||
# Handle broadcast
|
||||
if message.recipient_id is None:
|
||||
# Broadcast to all active agents
|
||||
for agent_id in global_agents:
|
||||
if agent_id != message.sender_id and global_agents[agent_id]["status"] == "active":
|
||||
if agent_id not in agent_messages:
|
||||
agent_messages[agent_id] = []
|
||||
agent_messages[agent_id].append(message_record.copy())
|
||||
|
||||
# Update sender stats
|
||||
global_agents[message.sender_id]["total_messages_sent"] += len(global_agents) - 1
|
||||
|
||||
logger.info(f"Broadcast message sent from {message.sender_id} to all agents")
|
||||
|
||||
else:
|
||||
# Direct message
|
||||
if message.recipient_id not in global_agents:
|
||||
raise HTTPException(status_code=400, detail="Recipient agent not found")
|
||||
|
||||
if message.recipient_id not in agent_messages:
|
||||
agent_messages[message.recipient_id] = []
|
||||
|
||||
agent_messages[message.recipient_id].append(message_record)
|
||||
|
||||
# Update stats
|
||||
global_agents[message.sender_id]["total_messages_sent"] += 1
|
||||
global_agents[message.recipient_id]["total_messages_received"] += 1
|
||||
|
||||
logger.info(f"Message sent from {message.sender_id} to {message.recipient_id}")
|
||||
|
||||
return {
|
||||
"message_id": message.message_id,
|
||||
"status": "delivered",
|
||||
"delivered_at": message_record["delivered_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/messages/{agent_id}")
|
||||
async def get_agent_messages(agent_id: str, limit: int = 50):
|
||||
"""Get messages for an agent"""
|
||||
if agent_id not in global_agents:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
|
||||
messages = agent_messages.get(agent_id, [])
|
||||
|
||||
# Sort by timestamp (most recent first)
|
||||
messages.sort(key=lambda x: x["timestamp"], reverse=True)
|
||||
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"messages": messages[:limit],
|
||||
"total_messages": len(messages),
|
||||
"unread_count": len([m for m in messages if m.get("read_at") is None])
|
||||
}
|
||||
|
||||
@app.post("/api/v1/collaborations/create")
|
||||
async def create_collaboration(session: CollaborationSession):
|
||||
"""Create a new collaboration session"""
|
||||
# Validate participants
|
||||
for participant_id in session.participants:
|
||||
if participant_id not in global_agents:
|
||||
raise HTTPException(status_code=400, detail=f"Participant {participant_id} not found")
|
||||
|
||||
# Create collaboration session
|
||||
session_record = {
|
||||
"session_id": session.session_id,
|
||||
"participants": session.participants,
|
||||
"session_type": session.session_type,
|
||||
"objective": session.objective,
|
||||
"created_at": session.created_at.isoformat(),
|
||||
"expires_at": session.expires_at.isoformat(),
|
||||
"status": session.status,
|
||||
"messages": [],
|
||||
"shared_resources": {},
|
||||
"task_progress": {},
|
||||
"outcome": None
|
||||
}
|
||||
|
||||
collaboration_sessions[session.session_id] = session_record
|
||||
|
||||
# Update participant stats
|
||||
for participant_id in session.participants:
|
||||
global_agents[participant_id]["collaborations_participated"] += 1
|
||||
|
||||
# Notify participants
|
||||
notification = {
|
||||
"type": "collaboration_invite",
|
||||
"session_id": session.session_id,
|
||||
"objective": session.objective,
|
||||
"participants": session.participants
|
||||
}
|
||||
|
||||
for participant_id in session.participants:
|
||||
message_record = {
|
||||
"message_id": f"collab_{int(datetime.now(timezone.utc).timestamp())}",
|
||||
"sender_id": "system",
|
||||
"recipient_id": participant_id,
|
||||
"message_type": "notification",
|
||||
"content": notification,
|
||||
"priority": "medium",
|
||||
"language": "english",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"status": "delivered",
|
||||
"delivered_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
if participant_id not in agent_messages:
|
||||
agent_messages[participant_id] = []
|
||||
agent_messages[participant_id].append(message_record)
|
||||
|
||||
logger.info(f"Collaboration session created: {session.session_id} with {len(session.participants)} participants")
|
||||
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
"status": "created",
|
||||
"participants": session.participants,
|
||||
"objective": session.objective,
|
||||
"created_at": session_record["created_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/collaborations/{session_id}")
|
||||
async def get_collaboration(session_id: str):
|
||||
"""Get collaboration session details"""
|
||||
if session_id not in collaboration_sessions:
|
||||
raise HTTPException(status_code=404, detail="Collaboration session not found")
|
||||
|
||||
return collaboration_sessions[session_id]
|
||||
|
||||
@app.post("/api/v1/collaborations/{session_id}/message")
|
||||
async def send_collaboration_message(session_id: str, sender_id: str, content: Dict[str, Any]):
|
||||
"""Send a message within a collaboration session"""
|
||||
if session_id not in collaboration_sessions:
|
||||
raise HTTPException(status_code=404, detail="Collaboration session not found")
|
||||
|
||||
if sender_id not in collaboration_sessions[session_id]["participants"]:
|
||||
raise HTTPException(status_code=400, detail="Sender not a participant in this session")
|
||||
|
||||
# Create collaboration message
|
||||
message_record = {
|
||||
"message_id": f"collab_msg_{int(datetime.now(timezone.utc).timestamp())}",
|
||||
"sender_id": sender_id,
|
||||
"session_id": session_id,
|
||||
"content": content,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"type": "collaboration_message"
|
||||
}
|
||||
|
||||
collaboration_sessions[session_id]["messages"].append(message_record)
|
||||
|
||||
# Notify all participants
|
||||
for participant_id in collaboration_sessions[session_id]["participants"]:
|
||||
if participant_id != sender_id:
|
||||
notification = {
|
||||
"type": "collaboration_message",
|
||||
"session_id": session_id,
|
||||
"sender_id": sender_id,
|
||||
"content": content
|
||||
}
|
||||
|
||||
msg_record = {
|
||||
"message_id": f"notif_{int(datetime.now(timezone.utc).timestamp())}",
|
||||
"sender_id": "system",
|
||||
"recipient_id": participant_id,
|
||||
"message_type": "notification",
|
||||
"content": notification,
|
||||
"priority": "medium",
|
||||
"language": "english",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"status": "delivered",
|
||||
"delivered_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
if participant_id not in agent_messages:
|
||||
agent_messages[participant_id] = []
|
||||
agent_messages[participant_id].append(msg_record)
|
||||
|
||||
return {
|
||||
"message_id": message_record["message_id"],
|
||||
"status": "delivered",
|
||||
"timestamp": message_record["timestamp"]
|
||||
}
|
||||
|
||||
@app.post("/api/v1/performance/record")
|
||||
async def record_agent_performance(performance: AgentPerformance):
|
||||
"""Record performance metrics for an agent"""
|
||||
if performance.agent_id not in global_agents:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
|
||||
# Create performance record
|
||||
performance_record = {
|
||||
"performance_id": f"perf_{int(datetime.now(timezone.utc).timestamp())}",
|
||||
"agent_id": performance.agent_id,
|
||||
"timestamp": performance.timestamp.isoformat(),
|
||||
"tasks_completed": performance.tasks_completed,
|
||||
"response_time_ms": performance.response_time_ms,
|
||||
"accuracy_score": performance.accuracy_score,
|
||||
"collaboration_score": performance.collaboration_score,
|
||||
"resource_usage": performance.resource_usage
|
||||
}
|
||||
|
||||
if performance.agent_id not in agent_performance:
|
||||
agent_performance[performance.agent_id] = []
|
||||
|
||||
agent_performance[performance.agent_id].append(performance_record)
|
||||
|
||||
# Update agent's performance score
|
||||
recent_performances = agent_performance[performance.agent_id][-10:] # Last 10 records
|
||||
if recent_performances:
|
||||
avg_accuracy = sum(p["accuracy_score"] for p in recent_performances) / len(recent_performances)
|
||||
avg_collaboration = sum(p["collaboration_score"] for p in recent_performances) / len(recent_performances)
|
||||
|
||||
# Update overall performance score
|
||||
new_score = (avg_accuracy * 0.6 + avg_collaboration * 0.4)
|
||||
global_agents[performance.agent_id]["performance_score"] = round(new_score, 2)
|
||||
|
||||
# Update tasks completed
|
||||
global_agents[performance.agent_id]["tasks_completed"] += performance.tasks_completed
|
||||
|
||||
return {
|
||||
"performance_id": performance_record["performance_id"],
|
||||
"status": "recorded",
|
||||
"updated_performance_score": global_agents[performance.agent_id]["performance_score"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/performance/{agent_id}")
|
||||
async def get_agent_performance(agent_id: str, hours: int = 24):
|
||||
"""Get performance metrics for an agent"""
|
||||
if agent_id not in global_agents:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
performance_records = agent_performance.get(agent_id, [])
|
||||
recent_performance = [
|
||||
p for p in performance_records
|
||||
if datetime.fromisoformat(p["timestamp"]) > cutoff_time
|
||||
]
|
||||
|
||||
# Calculate statistics
|
||||
if recent_performance:
|
||||
avg_response_time = sum(p["response_time_ms"] for p in recent_performance) / len(recent_performance)
|
||||
avg_accuracy = sum(p["accuracy_score"] for p in recent_performance) / len(recent_performance)
|
||||
avg_collaboration = sum(p["collaboration_score"] for p in recent_performance) / len(recent_performance)
|
||||
total_tasks = sum(p["tasks_completed"] for p in recent_performance)
|
||||
else:
|
||||
avg_response_time = avg_accuracy = avg_collaboration = total_tasks = 0.0
|
||||
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"period_hours": hours,
|
||||
"performance_records": recent_performance,
|
||||
"statistics": {
|
||||
"average_response_time_ms": round(avg_response_time, 2),
|
||||
"average_accuracy_score": round(avg_accuracy, 3),
|
||||
"average_collaboration_score": round(avg_collaboration, 3),
|
||||
"total_tasks_completed": int(total_tasks),
|
||||
"total_records": len(recent_performance)
|
||||
},
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/network/dashboard")
|
||||
async def get_network_dashboard():
|
||||
"""Get global AI agent network dashboard"""
|
||||
# Calculate network statistics
|
||||
total_agents = len(global_agents)
|
||||
active_agents = len([a for a in global_agents.values() if a["status"] == "active"])
|
||||
|
||||
# Agent type distribution
|
||||
type_distribution = {}
|
||||
for agent in global_agents.values():
|
||||
agent_type = agent["type"]
|
||||
type_distribution[agent_type] = type_distribution.get(agent_type, 0) + 1
|
||||
|
||||
# Regional distribution
|
||||
region_distribution = {}
|
||||
for agent in global_agents.values():
|
||||
region = agent["region"]
|
||||
region_distribution[region] = region_distribution.get(region, 0) + 1
|
||||
|
||||
# Performance summary
|
||||
performance_scores = [a["performance_score"] for a in global_agents.values()]
|
||||
avg_performance = sum(performance_scores) / len(performance_scores) if performance_scores else 0.0
|
||||
|
||||
# Recent activity
|
||||
recent_messages = 0
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
for messages in agent_messages.values():
|
||||
recent_messages += len([m for m in messages if datetime.fromisoformat(m["timestamp"]) > cutoff_time])
|
||||
|
||||
return {
|
||||
"dashboard": {
|
||||
"network_overview": {
|
||||
"total_agents": total_agents,
|
||||
"active_agents": active_agents,
|
||||
"agent_utilization": round((active_agents / total_agents * 100) if total_agents > 0 else 0, 2),
|
||||
"average_performance_score": round(avg_performance, 3)
|
||||
},
|
||||
"agent_distribution": {
|
||||
"by_type": type_distribution,
|
||||
"by_region": region_distribution
|
||||
},
|
||||
"collaborations": {
|
||||
"total_sessions": len(collaboration_sessions),
|
||||
"active_sessions": len([s for s in collaboration_sessions.values() if s["status"] == "active"]),
|
||||
"total_participants": sum(len(s["participants"]) for s in collaboration_sessions.values())
|
||||
},
|
||||
"activity": {
|
||||
"recent_messages_hour": recent_messages,
|
||||
"total_messages_sent": sum(a["total_messages_sent"] for a in global_agents.values()),
|
||||
"total_tasks_completed": sum(a["tasks_completed"] for a in global_agents.values())
|
||||
}
|
||||
},
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/network/optimize")
|
||||
async def optimize_network():
|
||||
"""Optimize global agent network performance"""
|
||||
optimization_results = {
|
||||
"recommendations": [],
|
||||
"actions_taken": [],
|
||||
"performance_improvements": {}
|
||||
}
|
||||
|
||||
# Find underperforming agents
|
||||
for agent_id, agent in global_agents.items():
|
||||
if agent["performance_score"] < 3.0 and agent["status"] == "active":
|
||||
optimization_results["recommendations"].append({
|
||||
"type": "agent_performance",
|
||||
"agent_id": agent_id,
|
||||
"issue": "Low performance score",
|
||||
"recommendation": "Consider agent retraining or resource allocation"
|
||||
})
|
||||
|
||||
# Find overloaded regions
|
||||
region_load = {}
|
||||
for agent in global_agents.values():
|
||||
if agent["status"] == "active":
|
||||
region = agent["region"]
|
||||
region_load[region] = region_load.get(region, 0) + 1
|
||||
|
||||
total_capacity = len(global_agents)
|
||||
for region, load in region_load.items():
|
||||
if load > total_capacity * 0.4: # More than 40% of agents in one region
|
||||
optimization_results["recommendations"].append({
|
||||
"type": "regional_balance",
|
||||
"region": region,
|
||||
"issue": "Agent concentration imbalance",
|
||||
"recommendation": "Redistribute agents to other regions"
|
||||
})
|
||||
|
||||
# Find inactive agents with good performance
|
||||
for agent_id, agent in global_agents.items():
|
||||
if agent["status"] == "inactive" and agent["performance_score"] > 4.0:
|
||||
optimization_results["actions_taken"].append({
|
||||
"type": "agent_activation",
|
||||
"agent_id": agent_id,
|
||||
"action": "Activated high-performing inactive agent"
|
||||
})
|
||||
agent["status"] = "active"
|
||||
|
||||
return {
|
||||
"optimization_results": optimization_results,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Background task for network monitoring
|
||||
async def network_monitoring_task():
|
||||
"""Background task for global network monitoring"""
|
||||
while True:
|
||||
await asyncio.sleep(300) # Monitor every 5 minutes
|
||||
|
||||
# Update network statistics
|
||||
global_network_stats["last_update"] = datetime.now(timezone.utc).isoformat()
|
||||
global_network_stats["total_agents"] = len(global_agents)
|
||||
global_network_stats["active_agents"] = len([a for a in global_agents.values() if a["status"] == "active"])
|
||||
|
||||
# Check for expired collaboration sessions
|
||||
current_time = datetime.now(timezone.utc)
|
||||
for session_id, session in collaboration_sessions.items():
|
||||
if datetime.fromisoformat(session["expires_at"]) < current_time and session["status"] == "active":
|
||||
session["status"] = "expired"
|
||||
logger.info(f"Collaboration session expired: {session_id}")
|
||||
|
||||
# Clean up old messages (older than 7 days)
|
||||
cutoff_time = current_time - timedelta(days=7)
|
||||
for agent_id in agent_messages:
|
||||
agent_messages[agent_id] = [
|
||||
m for m in agent_messages[agent_id]
|
||||
if datetime.fromisoformat(m["timestamp"]) > cutoff_time
|
||||
]
|
||||
|
||||
# Initialize with some default AI agents
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("Starting AITBC Global AI Agent Communication Service")
|
||||
|
||||
# Initialize default AI agents
|
||||
default_agents = [
|
||||
{
|
||||
"agent_id": "ai-trader-001",
|
||||
"name": "AlphaTrader",
|
||||
"type": "trading",
|
||||
"region": "us-east-1",
|
||||
"capabilities": ["market_analysis", "trading", "risk_management"],
|
||||
"status": "active",
|
||||
"languages": ["english", "chinese", "japanese", "spanish"],
|
||||
"specialization": "cryptocurrency_trading",
|
||||
"performance_score": 4.7
|
||||
},
|
||||
{
|
||||
"agent_id": "ai-oracle-001",
|
||||
"name": "OraclePro",
|
||||
"type": "oracle",
|
||||
"region": "eu-west-1",
|
||||
"capabilities": ["price_feeds", "data_analysis", "prediction"],
|
||||
"status": "active",
|
||||
"languages": ["english", "german", "french"],
|
||||
"specialization": "price_discovery",
|
||||
"performance_score": 4.9
|
||||
},
|
||||
{
|
||||
"agent_id": "ai-research-001",
|
||||
"name": "ResearchNova",
|
||||
"type": "research",
|
||||
"region": "ap-southeast-1",
|
||||
"capabilities": ["data_analysis", "pattern_recognition", "reporting"],
|
||||
"status": "active",
|
||||
"languages": ["english", "chinese", "korean"],
|
||||
"specialization": "blockchain_research",
|
||||
"performance_score": 4.5
|
||||
}
|
||||
]
|
||||
|
||||
for agent_data in default_agents:
|
||||
agent = Agent(**agent_data)
|
||||
agent_record = {
|
||||
"agent_id": agent.agent_id,
|
||||
"name": agent.name,
|
||||
"type": agent.type,
|
||||
"region": agent.region,
|
||||
"capabilities": agent.capabilities,
|
||||
"status": agent.status,
|
||||
"languages": agent.languages,
|
||||
"specialization": agent.specialization,
|
||||
"performance_score": agent.performance_score,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"last_active": datetime.now(timezone.utc).isoformat(),
|
||||
"total_messages_sent": 0,
|
||||
"total_messages_received": 0,
|
||||
"collaborations_participated": 0,
|
||||
"tasks_completed": 0,
|
||||
"reputation_score": 5.0,
|
||||
"network_connections": []
|
||||
}
|
||||
global_agents[agent.agent_id] = agent_record
|
||||
agent_messages[agent.agent_id] = []
|
||||
|
||||
# Start network monitoring
|
||||
asyncio.create_task(network_monitoring_task())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logger.info("Shutting down AITBC Global AI Agent Communication Service")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8018, log_level="info")
|
||||
1
examples/stubs/global-ai-agents/tests/__init__.py
Normal file
1
examples/stubs/global-ai-agents/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Global AI agents service tests"""
|
||||
@@ -0,0 +1,186 @@
|
||||
"""Edge case and error handling tests for global AI agents service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
|
||||
from main import app, Agent, AgentMessage, CollaborationSession, AgentPerformance, global_agents, agent_messages, collaboration_sessions, agent_performance
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
global_agents.clear()
|
||||
agent_messages.clear()
|
||||
collaboration_sessions.clear()
|
||||
agent_performance.clear()
|
||||
yield
|
||||
global_agents.clear()
|
||||
agent_messages.clear()
|
||||
collaboration_sessions.clear()
|
||||
agent_performance.clear()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_empty_name():
|
||||
"""Test Agent with empty name"""
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
assert agent.name == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_negative_performance_score():
|
||||
"""Test Agent with negative performance score"""
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Test Agent",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=-4.5
|
||||
)
|
||||
assert agent.performance_score == -4.5
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_performance_out_of_range_score():
|
||||
"""Test AgentPerformance with out of range scores"""
|
||||
performance = AgentPerformance(
|
||||
agent_id="agent_123",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
tasks_completed=10,
|
||||
response_time_ms=50.5,
|
||||
accuracy_score=2.0,
|
||||
collaboration_score=2.0,
|
||||
resource_usage={}
|
||||
)
|
||||
assert performance.accuracy_score == 2.0
|
||||
assert performance.collaboration_score == 2.0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_message_empty_content():
|
||||
"""Test AgentMessage with empty content"""
|
||||
message = AgentMessage(
|
||||
message_id="msg_123",
|
||||
sender_id="agent_123",
|
||||
recipient_id="agent_456",
|
||||
message_type="request",
|
||||
content={},
|
||||
priority="high",
|
||||
language="english",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert message.content == {}
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_agents_with_no_agents():
|
||||
"""Test listing agents when no agents exist"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/agents")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_agents"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_agent_messages_agent_not_found():
|
||||
"""Test getting messages for nonexistent agent"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/messages/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_collaboration_not_found():
|
||||
"""Test getting nonexistent collaboration session"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/collaborations/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_send_collaboration_message_session_not_found():
|
||||
"""Test sending message to nonexistent collaboration session"""
|
||||
client = TestClient(app)
|
||||
response = client.post("/api/v1/collaborations/nonexistent/message", params={"sender_id": "agent_123"}, json={"content": "test"})
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_send_collaboration_message_sender_not_participant():
|
||||
"""Test sending message from non-participant"""
|
||||
client = TestClient(app)
|
||||
# Register agent and create collaboration
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Agent 1",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
|
||||
session = CollaborationSession(
|
||||
session_id="session_123",
|
||||
participants=["agent_123"],
|
||||
session_type="research",
|
||||
objective="Research task",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
status="active"
|
||||
)
|
||||
client.post("/api/v1/collaborations/create", json=session.model_dump(mode='json'))
|
||||
|
||||
response = client.post("/api/v1/collaborations/session_123/message", params={"sender_id": "nonexistent"}, json={"content": "test"})
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_agent_performance_agent_not_found():
|
||||
"""Test getting performance for nonexistent agent"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/performance/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_dashboard_with_no_data():
|
||||
"""Test dashboard with no data"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/network/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["dashboard"]["network_overview"]["total_agents"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_optimize_network_with_no_agents():
|
||||
"""Test network optimization with no agents"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/network/optimize")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "optimization_results" in data
|
||||
@@ -0,0 +1,590 @@
|
||||
"""Integration tests for global AI agents service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
|
||||
from main import app, Agent, AgentMessage, CollaborationSession, AgentPerformance, global_agents, agent_messages, collaboration_sessions, agent_performance
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
global_agents.clear()
|
||||
agent_messages.clear()
|
||||
collaboration_sessions.clear()
|
||||
agent_performance.clear()
|
||||
yield
|
||||
global_agents.clear()
|
||||
agent_messages.clear()
|
||||
collaboration_sessions.clear()
|
||||
agent_performance.clear()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_root_endpoint():
|
||||
"""Test root endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["service"] == "AITBC Global AI Agent Communication Service"
|
||||
assert data["status"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_health_check_endpoint():
|
||||
"""Test health check endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "total_agents" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_register_agent():
|
||||
"""Test registering a new agent"""
|
||||
client = TestClient(app)
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Test Agent",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
response = client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["agent_id"] == "agent_123"
|
||||
assert data["status"] == "registered"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_register_duplicate_agent():
|
||||
"""Test registering duplicate agent"""
|
||||
client = TestClient(app)
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Test Agent",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
|
||||
response = client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_agents():
|
||||
"""Test listing all agents"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/agents")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "agents" in data
|
||||
assert "total_agents" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_agents_with_filters():
|
||||
"""Test listing agents with filters"""
|
||||
client = TestClient(app)
|
||||
# Register an agent first
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Test Agent",
|
||||
type="trading",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
|
||||
response = client.get("/api/v1/agents?region=us-east-1&type=trading&status=active")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "filters" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_agent():
|
||||
"""Test getting specific agent"""
|
||||
client = TestClient(app)
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Test Agent",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
|
||||
response = client.get("/api/v1/agents/agent_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["agent_id"] == "agent_123"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_agent_not_found():
|
||||
"""Test getting nonexistent agent"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/agents/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_send_direct_message():
|
||||
"""Test sending direct message"""
|
||||
client = TestClient(app)
|
||||
# Register two agents
|
||||
agent1 = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Agent 1",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
agent2 = Agent(
|
||||
agent_id="agent_456",
|
||||
name="Agent 2",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent1.model_dump())
|
||||
client.post("/api/v1/agents/register", json=agent2.model_dump())
|
||||
|
||||
message = AgentMessage(
|
||||
message_id="msg_123",
|
||||
sender_id="agent_123",
|
||||
recipient_id="agent_456",
|
||||
message_type="request",
|
||||
content={"data": "test"},
|
||||
priority="high",
|
||||
language="english",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/messages/send", json=message.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["message_id"] == "msg_123"
|
||||
assert data["status"] == "delivered"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_send_broadcast_message():
|
||||
"""Test sending broadcast message"""
|
||||
client = TestClient(app)
|
||||
# Register two agents
|
||||
agent1 = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Agent 1",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
agent2 = Agent(
|
||||
agent_id="agent_456",
|
||||
name="Agent 2",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent1.model_dump())
|
||||
client.post("/api/v1/agents/register", json=agent2.model_dump())
|
||||
|
||||
message = AgentMessage(
|
||||
message_id="msg_123",
|
||||
sender_id="agent_123",
|
||||
recipient_id=None,
|
||||
message_type="broadcast",
|
||||
content={"data": "test"},
|
||||
priority="medium",
|
||||
language="english",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/messages/send", json=message.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["message_id"] == "msg_123"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_send_message_sender_not_found():
|
||||
"""Test sending message with nonexistent sender"""
|
||||
client = TestClient(app)
|
||||
message = AgentMessage(
|
||||
message_id="msg_123",
|
||||
sender_id="nonexistent",
|
||||
recipient_id="agent_456",
|
||||
message_type="request",
|
||||
content={"data": "test"},
|
||||
priority="high",
|
||||
language="english",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/messages/send", json=message.model_dump(mode='json'))
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_send_message_recipient_not_found():
|
||||
"""Test sending message with nonexistent recipient"""
|
||||
client = TestClient(app)
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Agent 1",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
|
||||
message = AgentMessage(
|
||||
message_id="msg_123",
|
||||
sender_id="agent_123",
|
||||
recipient_id="nonexistent",
|
||||
message_type="request",
|
||||
content={"data": "test"},
|
||||
priority="high",
|
||||
language="english",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/messages/send", json=message.model_dump(mode='json'))
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_agent_messages():
|
||||
"""Test getting agent messages"""
|
||||
client = TestClient(app)
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Agent 1",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
|
||||
response = client.get("/api/v1/messages/agent_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["agent_id"] == "agent_123"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_agent_messages_with_limit():
|
||||
"""Test getting agent messages with limit parameter"""
|
||||
client = TestClient(app)
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Agent 1",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
|
||||
response = client.get("/api/v1/messages/agent_123?limit=10")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_collaboration():
|
||||
"""Test creating collaboration session"""
|
||||
client = TestClient(app)
|
||||
# Register two agents
|
||||
agent1 = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Agent 1",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
agent2 = Agent(
|
||||
agent_id="agent_456",
|
||||
name="Agent 2",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent1.model_dump())
|
||||
client.post("/api/v1/agents/register", json=agent2.model_dump())
|
||||
|
||||
session = CollaborationSession(
|
||||
session_id="session_123",
|
||||
participants=["agent_123", "agent_456"],
|
||||
session_type="task_force",
|
||||
objective="Complete task",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
status="active"
|
||||
)
|
||||
response = client.post("/api/v1/collaborations/create", json=session.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["session_id"] == "session_123"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_collaboration_participant_not_found():
|
||||
"""Test creating collaboration with nonexistent participant"""
|
||||
client = TestClient(app)
|
||||
session = CollaborationSession(
|
||||
session_id="session_123",
|
||||
participants=["nonexistent"],
|
||||
session_type="task_force",
|
||||
objective="Complete task",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
status="active"
|
||||
)
|
||||
response = client.post("/api/v1/collaborations/create", json=session.model_dump(mode='json'))
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_collaboration():
|
||||
"""Test getting collaboration session"""
|
||||
client = TestClient(app)
|
||||
# Register agents and create collaboration
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Agent 1",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
|
||||
session = CollaborationSession(
|
||||
session_id="session_123",
|
||||
participants=["agent_123"],
|
||||
session_type="research",
|
||||
objective="Research task",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
status="active"
|
||||
)
|
||||
client.post("/api/v1/collaborations/create", json=session.model_dump(mode='json'))
|
||||
|
||||
response = client.get("/api/v1/collaborations/session_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["session_id"] == "session_123"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_send_collaboration_message():
|
||||
"""Test sending message within collaboration session"""
|
||||
client = TestClient(app)
|
||||
# Register agent and create collaboration
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Agent 1",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
|
||||
session = CollaborationSession(
|
||||
session_id="session_123",
|
||||
participants=["agent_123"],
|
||||
session_type="research",
|
||||
objective="Research task",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
status="active"
|
||||
)
|
||||
client.post("/api/v1/collaborations/create", json=session.model_dump(mode='json'))
|
||||
|
||||
response = client.post("/api/v1/collaborations/session_123/message", params={"sender_id": "agent_123"}, json={"content": "test message"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "delivered"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_record_agent_performance():
|
||||
"""Test recording agent performance"""
|
||||
client = TestClient(app)
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Agent 1",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
|
||||
performance = AgentPerformance(
|
||||
agent_id="agent_123",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
tasks_completed=10,
|
||||
response_time_ms=50.5,
|
||||
accuracy_score=0.95,
|
||||
collaboration_score=0.9,
|
||||
resource_usage={"cpu": 50.0}
|
||||
)
|
||||
response = client.post("/api/v1/performance/record", json=performance.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["performance_id"]
|
||||
assert data["status"] == "recorded"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_record_performance_agent_not_found():
|
||||
"""Test recording performance for nonexistent agent"""
|
||||
client = TestClient(app)
|
||||
performance = AgentPerformance(
|
||||
agent_id="nonexistent",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
tasks_completed=10,
|
||||
response_time_ms=50.5,
|
||||
accuracy_score=0.95,
|
||||
collaboration_score=0.9,
|
||||
resource_usage={}
|
||||
)
|
||||
response = client.post("/api/v1/performance/record", json=performance.model_dump(mode='json'))
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_agent_performance():
|
||||
"""Test getting agent performance"""
|
||||
client = TestClient(app)
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Agent 1",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
|
||||
response = client.get("/api/v1/performance/agent_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["agent_id"] == "agent_123"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_agent_performance_hours_parameter():
|
||||
"""Test getting agent performance with custom hours parameter"""
|
||||
client = TestClient(app)
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Agent 1",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading"],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
client.post("/api/v1/agents/register", json=agent.model_dump())
|
||||
|
||||
response = client.get("/api/v1/performance/agent_123?hours=12")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["period_hours"] == 12
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_network_dashboard():
|
||||
"""Test getting network dashboard"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/network/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "dashboard" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_optimize_network():
|
||||
"""Test network optimization"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/network/optimize")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "optimization_results" in data
|
||||
@@ -0,0 +1,158 @@
|
||||
"""Unit tests for global AI agents service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, Agent, AgentMessage, CollaborationSession, AgentPerformance
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_app_initialization():
|
||||
"""Test that the FastAPI app initializes correctly"""
|
||||
assert app is not None
|
||||
assert app.title == "AITBC Global AI Agent Communication Service"
|
||||
assert app.version == "1.0.0"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_model():
|
||||
"""Test Agent model"""
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Test Agent",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=["trading", "analysis"],
|
||||
status="active",
|
||||
languages=["english", "chinese"],
|
||||
specialization="trading",
|
||||
performance_score=4.5
|
||||
)
|
||||
assert agent.agent_id == "agent_123"
|
||||
assert agent.name == "Test Agent"
|
||||
assert agent.type == "ai"
|
||||
assert agent.status == "active"
|
||||
assert agent.performance_score == 4.5
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_empty_capabilities():
|
||||
"""Test Agent with empty capabilities"""
|
||||
agent = Agent(
|
||||
agent_id="agent_123",
|
||||
name="Test Agent",
|
||||
type="ai",
|
||||
region="us-east-1",
|
||||
capabilities=[],
|
||||
status="active",
|
||||
languages=["english"],
|
||||
specialization="general",
|
||||
performance_score=4.5
|
||||
)
|
||||
assert agent.capabilities == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_message_model():
|
||||
"""Test AgentMessage model"""
|
||||
message = AgentMessage(
|
||||
message_id="msg_123",
|
||||
sender_id="agent_123",
|
||||
recipient_id="agent_456",
|
||||
message_type="request",
|
||||
content={"data": "test"},
|
||||
priority="high",
|
||||
language="english",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert message.message_id == "msg_123"
|
||||
assert message.sender_id == "agent_123"
|
||||
assert message.recipient_id == "agent_456"
|
||||
assert message.message_type == "request"
|
||||
assert message.priority == "high"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_message_broadcast():
|
||||
"""Test AgentMessage with None recipient (broadcast)"""
|
||||
message = AgentMessage(
|
||||
message_id="msg_123",
|
||||
sender_id="agent_123",
|
||||
recipient_id=None,
|
||||
message_type="broadcast",
|
||||
content={"data": "test"},
|
||||
priority="medium",
|
||||
language="english",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert message.recipient_id is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_collaboration_session_model():
|
||||
"""Test CollaborationSession model"""
|
||||
session = CollaborationSession(
|
||||
session_id="session_123",
|
||||
participants=["agent_123", "agent_456"],
|
||||
session_type="task_force",
|
||||
objective="Complete trading task",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc),
|
||||
status="active"
|
||||
)
|
||||
assert session.session_id == "session_123"
|
||||
assert session.participants == ["agent_123", "agent_456"]
|
||||
assert session.session_type == "task_force"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_collaboration_session_empty_participants():
|
||||
"""Test CollaborationSession with empty participants"""
|
||||
session = CollaborationSession(
|
||||
session_id="session_123",
|
||||
participants=[],
|
||||
session_type="research",
|
||||
objective="Research task",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc),
|
||||
status="active"
|
||||
)
|
||||
assert session.participants == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_performance_model():
|
||||
"""Test AgentPerformance model"""
|
||||
performance = AgentPerformance(
|
||||
agent_id="agent_123",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
tasks_completed=10,
|
||||
response_time_ms=50.5,
|
||||
accuracy_score=0.95,
|
||||
collaboration_score=0.9,
|
||||
resource_usage={"cpu": 50.0, "memory": 60.0}
|
||||
)
|
||||
assert performance.agent_id == "agent_123"
|
||||
assert performance.tasks_completed == 10
|
||||
assert performance.response_time_ms == 50.5
|
||||
assert performance.accuracy_score == 0.95
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_performance_negative_values():
|
||||
"""Test AgentPerformance with negative values"""
|
||||
performance = AgentPerformance(
|
||||
agent_id="agent_123",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
tasks_completed=-10,
|
||||
response_time_ms=-50.5,
|
||||
accuracy_score=-0.95,
|
||||
collaboration_score=-0.9,
|
||||
resource_usage={}
|
||||
)
|
||||
assert performance.tasks_completed == -10
|
||||
assert performance.response_time_ms == -50.5
|
||||
602
examples/stubs/global-infrastructure/main.py
Executable file
602
examples/stubs/global-infrastructure/main.py
Executable file
@@ -0,0 +1,602 @@
|
||||
"""
|
||||
Global Infrastructure Deployment Service for AITBC
|
||||
Handles multi-region deployment, load balancing, and global optimization
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="AITBC Global Infrastructure Service",
|
||||
description="Global infrastructure deployment and multi-region optimization",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Data models
|
||||
class Region(BaseModel):
|
||||
region_id: str
|
||||
name: str
|
||||
location: str
|
||||
endpoint: str
|
||||
status: str # active, inactive, maintenance
|
||||
capacity: int
|
||||
current_load: float
|
||||
latency_ms: float
|
||||
compliance_level: str
|
||||
|
||||
class GlobalDeployment(BaseModel):
|
||||
deployment_id: str
|
||||
service_name: str
|
||||
target_regions: List[str]
|
||||
configuration: Dict[str, Any]
|
||||
deployment_strategy: str # blue_green, canary, rolling
|
||||
health_checks: List[str]
|
||||
|
||||
class LoadBalancer(BaseModel):
|
||||
balancer_id: str
|
||||
name: str
|
||||
algorithm: str # round_robin, weighted, least_connections
|
||||
target_regions: List[str]
|
||||
health_check_interval: int
|
||||
failover_threshold: int
|
||||
|
||||
class PerformanceMetrics(BaseModel):
|
||||
region_id: str
|
||||
timestamp: datetime
|
||||
cpu_usage: float
|
||||
memory_usage: float
|
||||
network_io: float
|
||||
disk_io: float
|
||||
active_connections: int
|
||||
response_time_ms: float
|
||||
|
||||
# In-memory storage (in production, use database)
|
||||
global_regions: Dict[str, Dict] = {}
|
||||
deployments: Dict[str, Dict] = {}
|
||||
load_balancers: Dict[str, Dict] = {}
|
||||
performance_metrics: Dict[str, List[Dict]] = {}
|
||||
compliance_data: Dict[str, Dict] = {}
|
||||
global_monitoring: Dict[str, Dict] = {}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "AITBC Global Infrastructure Service",
|
||||
"status": "running",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"total_regions": len(global_regions),
|
||||
"active_regions": len([r for r in global_regions.values() if r["status"] == "active"]),
|
||||
"total_deployments": len(deployments),
|
||||
"active_load_balancers": len([lb for lb in load_balancers.values() if lb["status"] == "active"])
|
||||
}
|
||||
|
||||
@app.post("/api/v1/regions/register")
|
||||
async def register_region(region: Region):
|
||||
"""Register a new global region"""
|
||||
if region.region_id in global_regions:
|
||||
raise HTTPException(status_code=400, detail="Region already registered")
|
||||
|
||||
# Create region record
|
||||
region_record = {
|
||||
"region_id": region.region_id,
|
||||
"name": region.name,
|
||||
"location": region.location,
|
||||
"endpoint": region.endpoint,
|
||||
"status": region.status,
|
||||
"capacity": region.capacity,
|
||||
"current_load": region.current_load,
|
||||
"latency_ms": region.latency_ms,
|
||||
"compliance_level": region.compliance_level,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"last_health_check": None,
|
||||
"services_deployed": [],
|
||||
"performance_history": []
|
||||
}
|
||||
|
||||
global_regions[region.region_id] = region_record
|
||||
|
||||
logger.info(f"Region registered: {region.name} ({region.region_id})")
|
||||
|
||||
return {
|
||||
"region_id": region.region_id,
|
||||
"status": "registered",
|
||||
"name": region.name,
|
||||
"created_at": region_record["created_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/regions")
|
||||
async def list_regions():
|
||||
"""List all registered regions"""
|
||||
return {
|
||||
"regions": list(global_regions.values()),
|
||||
"total_regions": len(global_regions),
|
||||
"active_regions": len([r for r in global_regions.values() if r["status"] == "active"])
|
||||
}
|
||||
|
||||
@app.get("/api/v1/regions/{region_id}")
|
||||
async def get_region(region_id: str):
|
||||
"""Get detailed region information"""
|
||||
if region_id not in global_regions:
|
||||
raise HTTPException(status_code=404, detail="Region not found")
|
||||
|
||||
region = global_regions[region_id].copy()
|
||||
|
||||
# Add performance metrics
|
||||
region["performance_metrics"] = performance_metrics.get(region_id, [])
|
||||
|
||||
# Add compliance data
|
||||
region["compliance_data"] = compliance_data.get(region_id, {})
|
||||
|
||||
return region
|
||||
|
||||
@app.post("/api/v1/deployments/create")
|
||||
async def create_deployment(deployment: GlobalDeployment):
|
||||
"""Create a new global deployment"""
|
||||
deployment_id = f"deploy_{int(datetime.now(timezone.utc).timestamp())}"
|
||||
|
||||
# Validate target regions
|
||||
for region_id in deployment.target_regions:
|
||||
if region_id not in global_regions:
|
||||
raise HTTPException(status_code=400, detail=f"Region {region_id} not found")
|
||||
|
||||
# Create deployment record
|
||||
deployment_record = {
|
||||
"deployment_id": deployment_id,
|
||||
"service_name": deployment.service_name,
|
||||
"target_regions": deployment.target_regions,
|
||||
"configuration": deployment.configuration,
|
||||
"deployment_strategy": deployment.deployment_strategy,
|
||||
"health_checks": deployment.health_checks,
|
||||
"status": "pending",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"started_at": None,
|
||||
"completed_at": None,
|
||||
"deployment_progress": {},
|
||||
"rollback_available": False
|
||||
}
|
||||
|
||||
deployments[deployment_id] = deployment_record
|
||||
|
||||
# Start async deployment
|
||||
asyncio.create_task(execute_deployment(deployment_id))
|
||||
|
||||
logger.info(f"Deployment created: {deployment_id} for {deployment.service_name}")
|
||||
|
||||
return {
|
||||
"deployment_id": deployment_id,
|
||||
"status": "pending",
|
||||
"service_name": deployment.service_name,
|
||||
"target_regions": deployment.target_regions,
|
||||
"created_at": deployment_record["created_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/deployments/{deployment_id}")
|
||||
async def get_deployment(deployment_id: str):
|
||||
"""Get deployment status and details"""
|
||||
if deployment_id not in deployments:
|
||||
raise HTTPException(status_code=404, detail="Deployment not found")
|
||||
|
||||
return deployments[deployment_id]
|
||||
|
||||
@app.get("/api/v1/deployments")
|
||||
async def list_deployments(status: Optional[str] = None):
|
||||
"""List all deployments"""
|
||||
deployment_list = list(deployments.values())
|
||||
|
||||
if status:
|
||||
deployment_list = [d for d in deployment_list if d["status"] == status]
|
||||
|
||||
# Sort by creation date (most recent first)
|
||||
deployment_list.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
return {
|
||||
"deployments": deployment_list,
|
||||
"total_deployments": len(deployment_list),
|
||||
"status_filter": status
|
||||
}
|
||||
|
||||
@app.post("/api/v1/load-balancers/create")
|
||||
async def create_load_balancer(balancer: LoadBalancer):
|
||||
"""Create a new load balancer"""
|
||||
balancer_id = f"lb_{int(datetime.now(timezone.utc).timestamp())}"
|
||||
|
||||
# Validate target regions
|
||||
for region_id in balancer.target_regions:
|
||||
if region_id not in global_regions:
|
||||
raise HTTPException(status_code=400, detail=f"Region {region_id} not found")
|
||||
|
||||
# Create load balancer record
|
||||
balancer_record = {
|
||||
"balancer_id": balancer_id,
|
||||
"name": balancer.name,
|
||||
"algorithm": balancer.algorithm,
|
||||
"target_regions": balancer.target_regions,
|
||||
"health_check_interval": balancer.health_check_interval,
|
||||
"failover_threshold": balancer.failover_threshold,
|
||||
"status": "active",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"current_weights": {region_id: 1.0 for region_id in balancer.target_regions},
|
||||
"health_status": {region_id: "healthy" for region_id in balancer.target_regions},
|
||||
"total_requests": 0,
|
||||
"failed_requests": 0
|
||||
}
|
||||
|
||||
load_balancers[balancer_id] = balancer_record
|
||||
|
||||
# Start health checking
|
||||
asyncio.create_task(start_health_monitoring(balancer_id))
|
||||
|
||||
logger.info(f"Load balancer created: {balancer_id} - {balancer.name}")
|
||||
|
||||
return {
|
||||
"balancer_id": balancer_id,
|
||||
"status": "active",
|
||||
"name": balancer.name,
|
||||
"algorithm": balancer.algorithm,
|
||||
"created_at": balancer_record["created_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/load-balancers")
|
||||
async def list_load_balancers():
|
||||
"""List all load balancers"""
|
||||
return {
|
||||
"load_balancers": list(load_balancers.values()),
|
||||
"total_balancers": len(load_balancers),
|
||||
"active_balancers": len([lb for lb in load_balancers.values() if lb["status"] == "active"])
|
||||
}
|
||||
|
||||
@app.post("/api/v1/performance/metrics")
|
||||
async def record_performance_metrics(metrics: PerformanceMetrics):
|
||||
"""Record performance metrics for a region"""
|
||||
metrics_record = {
|
||||
"metrics_id": f"metrics_{int(datetime.now(timezone.utc).timestamp())}",
|
||||
"region_id": metrics.region_id,
|
||||
"timestamp": metrics.timestamp.isoformat(),
|
||||
"cpu_usage": metrics.cpu_usage,
|
||||
"memory_usage": metrics.memory_usage,
|
||||
"network_io": metrics.network_io,
|
||||
"disk_io": metrics.disk_io,
|
||||
"active_connections": metrics.active_connections,
|
||||
"response_time_ms": metrics.response_time_ms
|
||||
}
|
||||
|
||||
if metrics.region_id not in performance_metrics:
|
||||
performance_metrics[metrics.region_id] = []
|
||||
|
||||
performance_metrics[metrics.region_id].append(metrics_record)
|
||||
|
||||
# Keep only last 1000 records per region
|
||||
if len(performance_metrics[metrics.region_id]) > 1000:
|
||||
performance_metrics[metrics.region_id] = performance_metrics[metrics.region_id][-1000:]
|
||||
|
||||
# Update region performance history
|
||||
if metrics.region_id in global_regions:
|
||||
global_regions[metrics.region_id]["performance_history"].append({
|
||||
"timestamp": metrics.timestamp.isoformat(),
|
||||
"cpu_usage": metrics.cpu_usage,
|
||||
"memory_usage": metrics.memory_usage,
|
||||
"response_time_ms": metrics.response_time_ms
|
||||
})
|
||||
|
||||
# Keep only last 100 records
|
||||
if len(global_regions[metrics.region_id]["performance_history"]) > 100:
|
||||
global_regions[metrics.region_id]["performance_history"] = global_regions[metrics.region_id]["performance_history"][-100:]
|
||||
|
||||
return {
|
||||
"metrics_id": metrics_record["metrics_id"],
|
||||
"status": "recorded",
|
||||
"timestamp": metrics_record["timestamp"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/performance/{region_id}")
|
||||
async def get_region_performance(region_id: str, hours: int = 24):
|
||||
"""Get performance metrics for a region"""
|
||||
if region_id not in performance_metrics:
|
||||
raise HTTPException(status_code=404, detail="No performance data for region")
|
||||
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
recent_metrics = [
|
||||
m for m in performance_metrics[region_id]
|
||||
if datetime.fromisoformat(m["timestamp"]) > cutoff_time
|
||||
]
|
||||
|
||||
# Calculate statistics
|
||||
if recent_metrics:
|
||||
avg_cpu = sum(m["cpu_usage"] for m in recent_metrics) / len(recent_metrics)
|
||||
avg_memory = sum(m["memory_usage"] for m in recent_metrics) / len(recent_metrics)
|
||||
avg_response_time = sum(m["response_time_ms"] for m in recent_metrics) / len(recent_metrics)
|
||||
else:
|
||||
avg_cpu = avg_memory = avg_response_time = 0.0
|
||||
|
||||
return {
|
||||
"region_id": region_id,
|
||||
"period_hours": hours,
|
||||
"metrics": recent_metrics,
|
||||
"statistics": {
|
||||
"average_cpu_usage": round(avg_cpu, 2),
|
||||
"average_memory_usage": round(avg_memory, 2),
|
||||
"average_response_time_ms": round(avg_response_time, 2),
|
||||
"total_samples": len(recent_metrics)
|
||||
},
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/compliance/{region_id}")
|
||||
async def get_region_compliance(region_id: str):
|
||||
"""Get compliance information for a region"""
|
||||
if region_id not in global_regions:
|
||||
raise HTTPException(status_code=404, detail="Region not found")
|
||||
|
||||
# Mock compliance data (in production, this would be from actual compliance systems)
|
||||
compliance_info = {
|
||||
"region_id": region_id,
|
||||
"region_name": global_regions[region_id]["name"],
|
||||
"compliance_level": global_regions[region_id]["compliance_level"],
|
||||
"certifications": ["SOC2", "ISO27001", "GDPR"],
|
||||
"data_residency": "compliant",
|
||||
"last_audit": (datetime.now(timezone.utc) - timedelta(days=90)).isoformat(),
|
||||
"next_audit": (datetime.now(timezone.utc) + timedelta(days=275)).isoformat(),
|
||||
"regulations": ["GDPR", "CCPA", "PDPA"],
|
||||
"data_protection": "end-to-end-encryption",
|
||||
"access_controls": "role-based-access",
|
||||
"audit_logging": "enabled"
|
||||
}
|
||||
|
||||
return compliance_info
|
||||
|
||||
@app.get("/api/v1/global/dashboard")
|
||||
async def get_global_dashboard():
|
||||
"""Get global infrastructure dashboard"""
|
||||
# Calculate global statistics
|
||||
total_capacity = sum(r["capacity"] for r in global_regions.values())
|
||||
total_load = sum(r["current_load"] for r in global_regions.values())
|
||||
avg_latency = sum(r["latency_ms"] for r in global_regions.values()) / len(global_regions) if global_regions else 0
|
||||
|
||||
# Deployment statistics
|
||||
deployment_stats = {
|
||||
"total": len(deployments),
|
||||
"pending": len([d for d in deployments.values() if d["status"] == "pending"]),
|
||||
"in_progress": len([d for d in deployments.values() if d["status"] == "in_progress"]),
|
||||
"completed": len([d for d in deployments.values() if d["status"] == "completed"]),
|
||||
"failed": len([d for d in deployments.values() if d["status"] == "failed"])
|
||||
}
|
||||
|
||||
# Performance summary
|
||||
performance_summary = {}
|
||||
for region_id, metrics_list in performance_metrics.items():
|
||||
if metrics_list:
|
||||
latest_metrics = metrics_list[-1]
|
||||
performance_summary[region_id] = {
|
||||
"cpu_usage": latest_metrics["cpu_usage"],
|
||||
"memory_usage": latest_metrics["memory_usage"],
|
||||
"response_time_ms": latest_metrics["response_time_ms"],
|
||||
"active_connections": latest_metrics["active_connections"]
|
||||
}
|
||||
|
||||
return {
|
||||
"dashboard": {
|
||||
"infrastructure": {
|
||||
"total_regions": len(global_regions),
|
||||
"active_regions": len([r for r in global_regions.values() if r["status"] == "active"]),
|
||||
"total_capacity": total_capacity,
|
||||
"current_load": total_load,
|
||||
"utilization_percentage": round((total_load / total_capacity * 100) if total_capacity > 0 else 0, 2),
|
||||
"average_latency_ms": round(avg_latency, 2)
|
||||
},
|
||||
"deployments": deployment_stats,
|
||||
"load_balancers": {
|
||||
"total": len(load_balancers),
|
||||
"active": len([lb for lb in load_balancers.values() if lb["status"] == "active"])
|
||||
},
|
||||
"performance": performance_summary,
|
||||
"compliance": {
|
||||
"compliant_regions": len([r for r in global_regions.values() if r["compliance_level"] == "full"]),
|
||||
"partial_compliance": len([r for r in global_regions.values() if r["compliance_level"] == "partial"])
|
||||
}
|
||||
},
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Core deployment and load balancing functions
|
||||
async def execute_deployment(deployment_id: str):
|
||||
"""Execute a global deployment"""
|
||||
deployment = deployments[deployment_id]
|
||||
deployment["status"] = "in_progress"
|
||||
deployment["started_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
try:
|
||||
for region_id in deployment["target_regions"]:
|
||||
deployment["deployment_progress"][region_id] = {
|
||||
"status": "deploying",
|
||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
||||
"progress": 0
|
||||
}
|
||||
|
||||
# Simulate deployment process
|
||||
await simulate_deployment_step(region_id, deployment_id)
|
||||
|
||||
deployment["deployment_progress"][region_id].update({
|
||||
"status": "completed",
|
||||
"completed_at": datetime.now(timezone.utc).isoformat(),
|
||||
"progress": 100
|
||||
})
|
||||
|
||||
# Update region services
|
||||
if region_id in global_regions:
|
||||
if deployment["service_name"] not in global_regions[region_id]["services_deployed"]:
|
||||
global_regions[region_id]["services_deployed"].append(deployment["service_name"])
|
||||
|
||||
deployment["status"] = "completed"
|
||||
deployment["completed_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
logger.info(f"Deployment completed: {deployment_id}")
|
||||
|
||||
except Exception as e:
|
||||
deployment["status"] = "failed"
|
||||
deployment["error"] = str(e)
|
||||
logger.error(f"Deployment failed: {deployment_id} - {str(e)}")
|
||||
|
||||
async def simulate_deployment_step(region_id: str, deployment_id: str):
|
||||
"""Simulate deployment step for demo"""
|
||||
deployment = deployments[deployment_id]
|
||||
|
||||
# Simulate deployment progress
|
||||
for progress in range(0, 101, 10):
|
||||
if region_id in deployment["deployment_progress"]:
|
||||
deployment["deployment_progress"][region_id]["progress"] = progress
|
||||
await asyncio.sleep(0.1) # Simulate work
|
||||
|
||||
async def start_health_monitoring(balancer_id: str):
|
||||
"""Start health monitoring for a load balancer"""
|
||||
balancer = load_balancers[balancer_id]
|
||||
|
||||
while balancer["status"] == "active":
|
||||
try:
|
||||
# Check health of target regions
|
||||
for region_id in balancer["target_regions"]:
|
||||
if region_id in global_regions:
|
||||
region = global_regions[region_id]
|
||||
|
||||
# Simulate health check (in production, this would be actual health checks)
|
||||
is_healthy = region["status"] == "active" and region["current_load"] < region["capacity"] * 0.9
|
||||
|
||||
balancer["health_status"][region_id] = "healthy" if is_healthy else "unhealthy"
|
||||
|
||||
# Update load balancer weights based on performance
|
||||
update_load_balancer_weights(balancer_id)
|
||||
|
||||
await asyncio.sleep(balancer["health_check_interval"])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health monitoring error for {balancer_id}: {str(e)}")
|
||||
await asyncio.sleep(10)
|
||||
|
||||
def update_load_balancer_weights(balancer_id: str):
|
||||
"""Update load balancer weights based on region performance"""
|
||||
balancer = load_balancers[balancer_id]
|
||||
|
||||
if balancer["algorithm"] == "weighted":
|
||||
# Calculate weights based on capacity and current load
|
||||
for region_id in balancer["target_regions"]:
|
||||
if region_id in global_regions:
|
||||
region = global_regions[region_id]
|
||||
|
||||
# Weight based on available capacity
|
||||
available_capacity = region["capacity"] - region["current_load"]
|
||||
total_available = sum(
|
||||
global_regions[r]["capacity"] - global_regions[r]["current_load"]
|
||||
for r in balancer["target_regions"]
|
||||
if r in global_regions
|
||||
)
|
||||
|
||||
if total_available > 0:
|
||||
weight = available_capacity / total_available
|
||||
balancer["current_weights"][region_id] = round(weight, 3)
|
||||
|
||||
# Background task for global monitoring
|
||||
async def global_monitoring_task():
|
||||
"""Background task for global infrastructure monitoring"""
|
||||
while True:
|
||||
await asyncio.sleep(60) # Monitor every minute
|
||||
|
||||
# Update global monitoring data
|
||||
global_monitoring["last_update"] = datetime.now(timezone.utc).isoformat()
|
||||
global_monitoring["total_requests"] = sum(lb.get("total_requests", 0) for lb in load_balancers.values())
|
||||
global_monitoring["failed_requests"] = sum(lb.get("failed_requests", 0) for lb in load_balancers.values())
|
||||
|
||||
# Check for regions that need attention
|
||||
for region_id, region in global_regions.items():
|
||||
if region["current_load"] > region["capacity"] * 0.8:
|
||||
logger.warning(f"High load detected in region {region_id}: {region['current_load']}/{region['capacity']}")
|
||||
|
||||
if region["latency_ms"] > 500:
|
||||
logger.warning(f"High latency detected in region {region_id}: {region['latency_ms']}ms")
|
||||
|
||||
# Initialize with some default regions
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("Starting AITBC Global Infrastructure Service")
|
||||
|
||||
# Initialize default regions
|
||||
default_regions = [
|
||||
{
|
||||
"region_id": "us-east-1",
|
||||
"name": "US East (N. Virginia)",
|
||||
"location": "North America",
|
||||
"endpoint": "https://us-east-1.api.aitbc.dev",
|
||||
"status": "active",
|
||||
"capacity": 10000,
|
||||
"current_load": 3500,
|
||||
"latency_ms": 45,
|
||||
"compliance_level": "full"
|
||||
},
|
||||
{
|
||||
"region_id": "eu-west-1",
|
||||
"name": "EU West (Ireland)",
|
||||
"location": "Europe",
|
||||
"endpoint": "https://eu-west-1.api.aitbc.dev",
|
||||
"status": "active",
|
||||
"capacity": 8000,
|
||||
"current_load": 2800,
|
||||
"latency_ms": 38,
|
||||
"compliance_level": "full"
|
||||
},
|
||||
{
|
||||
"region_id": "ap-southeast-1",
|
||||
"name": "AP Southeast (Singapore)",
|
||||
"location": "Asia Pacific",
|
||||
"endpoint": "https://ap-southeast-1.api.aitbc.dev",
|
||||
"status": "active",
|
||||
"capacity": 6000,
|
||||
"current_load": 2200,
|
||||
"latency_ms": 62,
|
||||
"compliance_level": "partial"
|
||||
}
|
||||
]
|
||||
|
||||
for region_data in default_regions:
|
||||
region = Region(**region_data)
|
||||
region_record = {
|
||||
"region_id": region.region_id,
|
||||
"name": region.name,
|
||||
"location": region.location,
|
||||
"endpoint": region.endpoint,
|
||||
"status": region.status,
|
||||
"capacity": region.capacity,
|
||||
"current_load": region.current_load,
|
||||
"latency_ms": region.latency_ms,
|
||||
"compliance_level": region.compliance_level,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"last_health_check": None,
|
||||
"services_deployed": [],
|
||||
"performance_history": []
|
||||
}
|
||||
global_regions[region.region_id] = region_record
|
||||
|
||||
# Start global monitoring
|
||||
asyncio.create_task(global_monitoring_task())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logger.info("Shutting down AITBC Global Infrastructure Service")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8017, log_level="info")
|
||||
1
examples/stubs/global-infrastructure/tests/__init__.py
Normal file
1
examples/stubs/global-infrastructure/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Global infrastructure service tests"""
|
||||
@@ -0,0 +1,195 @@
|
||||
"""Edge case and error handling tests for global infrastructure service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, Region, GlobalDeployment, LoadBalancer, PerformanceMetrics, global_regions, deployments, load_balancers, performance_metrics
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
global_regions.clear()
|
||||
deployments.clear()
|
||||
load_balancers.clear()
|
||||
performance_metrics.clear()
|
||||
yield
|
||||
global_regions.clear()
|
||||
deployments.clear()
|
||||
load_balancers.clear()
|
||||
performance_metrics.clear()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_region_negative_capacity():
|
||||
"""Test Region with negative capacity"""
|
||||
region = Region(
|
||||
region_id="us-west-1",
|
||||
name="US West",
|
||||
location="North America",
|
||||
endpoint="https://us-west-1.api.aitbc.dev",
|
||||
status="active",
|
||||
capacity=-1000,
|
||||
current_load=-500,
|
||||
latency_ms=-50,
|
||||
compliance_level="full"
|
||||
)
|
||||
assert region.capacity == -1000
|
||||
assert region.current_load == -500
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_region_empty_name():
|
||||
"""Test Region with empty name"""
|
||||
region = Region(
|
||||
region_id="us-west-1",
|
||||
name="",
|
||||
location="North America",
|
||||
endpoint="https://us-west-1.api.aitbc.dev",
|
||||
status="active",
|
||||
capacity=8000,
|
||||
current_load=2000,
|
||||
latency_ms=50,
|
||||
compliance_level="full"
|
||||
)
|
||||
assert region.name == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deployment_empty_target_regions():
|
||||
"""Test GlobalDeployment with empty target regions"""
|
||||
deployment = GlobalDeployment(
|
||||
deployment_id="deploy_123",
|
||||
service_name="test-service",
|
||||
target_regions=[],
|
||||
configuration={},
|
||||
deployment_strategy="blue_green",
|
||||
health_checks=[]
|
||||
)
|
||||
assert deployment.target_regions == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_balancer_negative_health_check_interval():
|
||||
"""Test LoadBalancer with negative health check interval"""
|
||||
balancer = LoadBalancer(
|
||||
balancer_id="lb_123",
|
||||
name="Main LB",
|
||||
algorithm="round_robin",
|
||||
target_regions=["us-east-1"],
|
||||
health_check_interval=-30,
|
||||
failover_threshold=3
|
||||
)
|
||||
assert balancer.health_check_interval == -30
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_performance_metrics_negative_values():
|
||||
"""Test PerformanceMetrics with negative values"""
|
||||
metrics = PerformanceMetrics(
|
||||
region_id="us-east-1",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
cpu_usage=-50.5,
|
||||
memory_usage=-60.2,
|
||||
network_io=-1000.5,
|
||||
disk_io=-500.3,
|
||||
active_connections=-100,
|
||||
response_time_ms=-45.2
|
||||
)
|
||||
assert metrics.cpu_usage == -50.5
|
||||
assert metrics.active_connections == -100
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_regions_with_no_regions():
|
||||
"""Test listing regions when no regions exist"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/regions")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_regions"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_deployments_with_no_deployments():
|
||||
"""Test listing deployments when no deployments exist"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/deployments")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_deployments"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_load_balancers_with_no_balancers():
|
||||
"""Test listing load balancers when no balancers exist"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/load-balancers")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_balancers"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_deployment_not_found():
|
||||
"""Test getting nonexistent deployment"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/deployments/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_region_performance_no_data():
|
||||
"""Test getting region performance when no data exists"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/performance/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_region_compliance_nonexistent():
|
||||
"""Test getting compliance for nonexistent region"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/compliance/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_load_balancer_nonexistent_region():
|
||||
"""Test creating load balancer with nonexistent region"""
|
||||
client = TestClient(app)
|
||||
balancer = LoadBalancer(
|
||||
balancer_id="lb_123",
|
||||
name="Main LB",
|
||||
algorithm="round_robin",
|
||||
target_regions=["nonexistent"],
|
||||
health_check_interval=30,
|
||||
failover_threshold=3
|
||||
)
|
||||
response = client.post("/api/v1/load-balancers/create", json=balancer.model_dump())
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_deployments_with_status_filter():
|
||||
"""Test listing deployments with status filter"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/deployments?status=pending")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "status_filter" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_global_dashboard_with_no_data():
|
||||
"""Test global dashboard with no data"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/global/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["dashboard"]["infrastructure"]["total_regions"] == 0
|
||||
@@ -0,0 +1,353 @@
|
||||
"""Integration tests for global infrastructure service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, Region, GlobalDeployment, LoadBalancer, PerformanceMetrics, global_regions, deployments, load_balancers, performance_metrics
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
global_regions.clear()
|
||||
deployments.clear()
|
||||
load_balancers.clear()
|
||||
performance_metrics.clear()
|
||||
yield
|
||||
global_regions.clear()
|
||||
deployments.clear()
|
||||
load_balancers.clear()
|
||||
performance_metrics.clear()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_root_endpoint():
|
||||
"""Test root endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["service"] == "AITBC Global Infrastructure Service"
|
||||
assert data["status"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_health_check_endpoint():
|
||||
"""Test health check endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "total_regions" in data
|
||||
assert "active_regions" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_register_region():
|
||||
"""Test registering a new region"""
|
||||
client = TestClient(app)
|
||||
region = Region(
|
||||
region_id="us-west-1",
|
||||
name="US West",
|
||||
location="North America",
|
||||
endpoint="https://us-west-1.api.aitbc.dev",
|
||||
status="active",
|
||||
capacity=8000,
|
||||
current_load=2000,
|
||||
latency_ms=50,
|
||||
compliance_level="full"
|
||||
)
|
||||
response = client.post("/api/v1/regions/register", json=region.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["region_id"] == "us-west-1"
|
||||
assert data["status"] == "registered"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_register_duplicate_region():
|
||||
"""Test registering duplicate region"""
|
||||
client = TestClient(app)
|
||||
region = Region(
|
||||
region_id="us-west-1",
|
||||
name="US West",
|
||||
location="North America",
|
||||
endpoint="https://us-west-1.api.aitbc.dev",
|
||||
status="active",
|
||||
capacity=8000,
|
||||
current_load=2000,
|
||||
latency_ms=50,
|
||||
compliance_level="full"
|
||||
)
|
||||
client.post("/api/v1/regions/register", json=region.model_dump())
|
||||
|
||||
response = client.post("/api/v1/regions/register", json=region.model_dump())
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_regions():
|
||||
"""Test listing all regions"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/regions")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "regions" in data
|
||||
assert "total_regions" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_region():
|
||||
"""Test getting specific region"""
|
||||
client = TestClient(app)
|
||||
region = Region(
|
||||
region_id="us-west-1",
|
||||
name="US West",
|
||||
location="North America",
|
||||
endpoint="https://us-west-1.api.aitbc.dev",
|
||||
status="active",
|
||||
capacity=8000,
|
||||
current_load=2000,
|
||||
latency_ms=50,
|
||||
compliance_level="full"
|
||||
)
|
||||
client.post("/api/v1/regions/register", json=region.model_dump())
|
||||
|
||||
response = client.get("/api/v1/regions/us-west-1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["region_id"] == "us-west-1"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_region_not_found():
|
||||
"""Test getting nonexistent region"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/regions/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_deployment():
|
||||
"""Test creating a deployment"""
|
||||
client = TestClient(app)
|
||||
# Register region first
|
||||
region = Region(
|
||||
region_id="us-west-1",
|
||||
name="US West",
|
||||
location="North America",
|
||||
endpoint="https://us-west-1.api.aitbc.dev",
|
||||
status="active",
|
||||
capacity=8000,
|
||||
current_load=2000,
|
||||
latency_ms=50,
|
||||
compliance_level="full"
|
||||
)
|
||||
client.post("/api/v1/regions/register", json=region.model_dump())
|
||||
|
||||
deployment = GlobalDeployment(
|
||||
deployment_id="deploy_123",
|
||||
service_name="test-service",
|
||||
target_regions=["us-west-1"],
|
||||
configuration={"replicas": 3},
|
||||
deployment_strategy="blue_green",
|
||||
health_checks=["/health"]
|
||||
)
|
||||
response = client.post("/api/v1/deployments/create", json=deployment.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["deployment_id"]
|
||||
assert data["status"] == "pending"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_deployment_nonexistent_region():
|
||||
"""Test creating deployment with nonexistent region"""
|
||||
client = TestClient(app)
|
||||
deployment = GlobalDeployment(
|
||||
deployment_id="deploy_123",
|
||||
service_name="test-service",
|
||||
target_regions=["nonexistent"],
|
||||
configuration={"replicas": 3},
|
||||
deployment_strategy="blue_green",
|
||||
health_checks=["/health"]
|
||||
)
|
||||
response = client.post("/api/v1/deployments/create", json=deployment.model_dump())
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_deployment():
|
||||
"""Test getting deployment details"""
|
||||
client = TestClient(app)
|
||||
# Register region first
|
||||
region = Region(
|
||||
region_id="us-west-1",
|
||||
name="US West",
|
||||
location="North America",
|
||||
endpoint="https://us-west-1.api.aitbc.dev",
|
||||
status="active",
|
||||
capacity=8000,
|
||||
current_load=2000,
|
||||
latency_ms=50,
|
||||
compliance_level="full"
|
||||
)
|
||||
client.post("/api/v1/regions/register", json=region.model_dump())
|
||||
|
||||
deployment = GlobalDeployment(
|
||||
deployment_id="deploy_123",
|
||||
service_name="test-service",
|
||||
target_regions=["us-west-1"],
|
||||
configuration={"replicas": 3},
|
||||
deployment_strategy="blue_green",
|
||||
health_checks=["/health"]
|
||||
)
|
||||
create_response = client.post("/api/v1/deployments/create", json=deployment.model_dump())
|
||||
deployment_id = create_response.json()["deployment_id"]
|
||||
|
||||
response = client.get(f"/api/v1/deployments/{deployment_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["deployment_id"] == deployment_id
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_deployments():
|
||||
"""Test listing all deployments"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/deployments")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "deployments" in data
|
||||
assert "total_deployments" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_load_balancer():
|
||||
"""Test creating a load balancer"""
|
||||
client = TestClient(app)
|
||||
# Register region first
|
||||
region = Region(
|
||||
region_id="us-west-1",
|
||||
name="US West",
|
||||
location="North America",
|
||||
endpoint="https://us-west-1.api.aitbc.dev",
|
||||
status="active",
|
||||
capacity=8000,
|
||||
current_load=2000,
|
||||
latency_ms=50,
|
||||
compliance_level="full"
|
||||
)
|
||||
client.post("/api/v1/regions/register", json=region.model_dump())
|
||||
|
||||
balancer = LoadBalancer(
|
||||
balancer_id="lb_123",
|
||||
name="Main LB",
|
||||
algorithm="round_robin",
|
||||
target_regions=["us-west-1"],
|
||||
health_check_interval=30,
|
||||
failover_threshold=3
|
||||
)
|
||||
response = client.post("/api/v1/load-balancers/create", json=balancer.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["balancer_id"]
|
||||
assert data["status"] == "active"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_load_balancers():
|
||||
"""Test listing all load balancers"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/load-balancers")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "load_balancers" in data
|
||||
assert "total_balancers" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_record_performance_metrics():
|
||||
"""Test recording performance metrics"""
|
||||
client = TestClient(app)
|
||||
metrics = PerformanceMetrics(
|
||||
region_id="us-west-1",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
cpu_usage=50.5,
|
||||
memory_usage=60.2,
|
||||
network_io=1000.5,
|
||||
disk_io=500.3,
|
||||
active_connections=100,
|
||||
response_time_ms=45.2
|
||||
)
|
||||
response = client.post("/api/v1/performance/metrics", json=metrics.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["metrics_id"]
|
||||
assert data["status"] == "recorded"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_region_performance():
|
||||
"""Test getting region performance metrics"""
|
||||
client = TestClient(app)
|
||||
# Record metrics first
|
||||
metrics = PerformanceMetrics(
|
||||
region_id="us-west-1",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
cpu_usage=50.5,
|
||||
memory_usage=60.2,
|
||||
network_io=1000.5,
|
||||
disk_io=500.3,
|
||||
active_connections=100,
|
||||
response_time_ms=45.2
|
||||
)
|
||||
client.post("/api/v1/performance/metrics", json=metrics.model_dump(mode='json'))
|
||||
|
||||
response = client.get("/api/v1/performance/us-west-1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["region_id"] == "us-west-1"
|
||||
assert "statistics" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_region_compliance():
|
||||
"""Test getting region compliance information"""
|
||||
client = TestClient(app)
|
||||
# Register region first
|
||||
region = Region(
|
||||
region_id="us-west-1",
|
||||
name="US West",
|
||||
location="North America",
|
||||
endpoint="https://us-west-1.api.aitbc.dev",
|
||||
status="active",
|
||||
capacity=8000,
|
||||
current_load=2000,
|
||||
latency_ms=50,
|
||||
compliance_level="full"
|
||||
)
|
||||
client.post("/api/v1/regions/register", json=region.model_dump())
|
||||
|
||||
response = client.get("/api/v1/compliance/us-west-1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["region_id"] == "us-west-1"
|
||||
assert "compliance_level" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_global_dashboard():
|
||||
"""Test getting global dashboard"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/global/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "dashboard" in data
|
||||
assert "infrastructure" in data["dashboard"]
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Unit tests for global infrastructure service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, Region, GlobalDeployment, LoadBalancer, PerformanceMetrics
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_app_initialization():
|
||||
"""Test that the FastAPI app initializes correctly"""
|
||||
assert app is not None
|
||||
assert app.title == "AITBC Global Infrastructure Service"
|
||||
assert app.version == "1.0.0"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_region_model():
|
||||
"""Test Region model"""
|
||||
region = Region(
|
||||
region_id="us-east-1",
|
||||
name="US East",
|
||||
location="North America",
|
||||
endpoint="https://us-east-1.api.aitbc.dev",
|
||||
status="active",
|
||||
capacity=10000,
|
||||
current_load=3500,
|
||||
latency_ms=45,
|
||||
compliance_level="full"
|
||||
)
|
||||
assert region.region_id == "us-east-1"
|
||||
assert region.name == "US East"
|
||||
assert region.status == "active"
|
||||
assert region.capacity == 10000
|
||||
assert region.compliance_level == "full"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_global_deployment_model():
|
||||
"""Test GlobalDeployment model"""
|
||||
deployment = GlobalDeployment(
|
||||
deployment_id="deploy_123",
|
||||
service_name="test-service",
|
||||
target_regions=["us-east-1", "eu-west-1"],
|
||||
configuration={"replicas": 3},
|
||||
deployment_strategy="blue_green",
|
||||
health_checks=["/health", "/ready"]
|
||||
)
|
||||
assert deployment.deployment_id == "deploy_123"
|
||||
assert deployment.service_name == "test-service"
|
||||
assert deployment.target_regions == ["us-east-1", "eu-west-1"]
|
||||
assert deployment.deployment_strategy == "blue_green"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_balancer_model():
|
||||
"""Test LoadBalancer model"""
|
||||
balancer = LoadBalancer(
|
||||
balancer_id="lb_123",
|
||||
name="Main LB",
|
||||
algorithm="round_robin",
|
||||
target_regions=["us-east-1", "eu-west-1"],
|
||||
health_check_interval=30,
|
||||
failover_threshold=3
|
||||
)
|
||||
assert balancer.balancer_id == "lb_123"
|
||||
assert balancer.name == "Main LB"
|
||||
assert balancer.algorithm == "round_robin"
|
||||
assert balancer.health_check_interval == 30
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_performance_metrics_model():
|
||||
"""Test PerformanceMetrics model"""
|
||||
metrics = PerformanceMetrics(
|
||||
region_id="us-east-1",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
cpu_usage=50.5,
|
||||
memory_usage=60.2,
|
||||
network_io=1000.5,
|
||||
disk_io=500.3,
|
||||
active_connections=100,
|
||||
response_time_ms=45.2
|
||||
)
|
||||
assert metrics.region_id == "us-east-1"
|
||||
assert metrics.cpu_usage == 50.5
|
||||
assert metrics.memory_usage == 60.2
|
||||
assert metrics.active_connections == 100
|
||||
assert metrics.response_time_ms == 45.2
|
||||
1292
examples/stubs/hermes-service/poetry.lock
generated
Normal file
1292
examples/stubs/hermes-service/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
25
examples/stubs/hermes-service/pyproject.toml
Normal file
25
examples/stubs/hermes-service/pyproject.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[tool.poetry]
|
||||
name = "hermes-service"
|
||||
version = "0.1.0"
|
||||
description = "AITBC hermes Service for agent orchestration and edge computing"
|
||||
authors = ["AITBC Team"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.13"
|
||||
fastapi = ">=0.115.6"
|
||||
uvicorn = {extras = ["standard"], version = "^0.32.0"}
|
||||
sqlmodel = "^0.0.37"
|
||||
sqlalchemy = "^2.0.25"
|
||||
pydantic = "^2.6.0"
|
||||
pydantic-settings = "^2.1.0"
|
||||
httpx = ">=0.28.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = ">=9.0.3"
|
||||
pytest-asyncio = ">=1.3.0"
|
||||
black = ">=26.3.1"
|
||||
ruff = "^0.1.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
122
examples/stubs/hermes-service/src/hermes_service/main.py
Normal file
122
examples/stubs/hermes-service/src/hermes_service/main.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Hermes Service for agent orchestration and edge computing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app = FastAPI(
|
||||
title="AITBC Hermes Service",
|
||||
description="Agent orchestration and edge computing service",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "service": "hermes-service"}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
return {
|
||||
"service": "AITBC Hermes Service",
|
||||
"version": "1.0.0",
|
||||
"status": "operational"
|
||||
}
|
||||
|
||||
|
||||
@app.post("/routing/skill")
|
||||
async def route_agent_skill(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sophisticated agent skill routing"""
|
||||
return {
|
||||
"selected_agent": "agent_default",
|
||||
"routing_strategy": "performance_based",
|
||||
"expected_performance": "high",
|
||||
"estimated_cost": 0.5
|
||||
}
|
||||
|
||||
|
||||
@app.post("/offloading/intelligent")
|
||||
async def intelligent_job_offloading(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Intelligent job offloading strategies"""
|
||||
return {
|
||||
"should_offload": True,
|
||||
"job_size": "medium",
|
||||
"cost_analysis": {"estimated_cost": 10.0, "savings": 5.0},
|
||||
"performance_prediction": "improved",
|
||||
"fallback_mechanism": "local_execution"
|
||||
}
|
||||
|
||||
|
||||
@app.post("/collaboration/coordinate")
|
||||
async def coordinate_agent_collaboration(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Agent collaboration and coordination"""
|
||||
return {
|
||||
"coordination_method": "consensus",
|
||||
"selected_coordinator": "agent_coordinator_1",
|
||||
"consensus_reached": True,
|
||||
"task_distribution": {"agent_1": 0.5, "agent_2": 0.5},
|
||||
"estimated_completion_time": 300
|
||||
}
|
||||
|
||||
|
||||
@app.post("/execution/hybrid-optimize")
|
||||
async def optimize_hybrid_execution(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Hybrid execution optimization"""
|
||||
return {
|
||||
"execution_mode": "hybrid",
|
||||
"strategy": "cost_performance_balance",
|
||||
"resource_allocation": {"cpu": 50, "memory": 40, "gpu": 10},
|
||||
"performance_tuning": "optimized",
|
||||
"expected_improvement": "30%"
|
||||
}
|
||||
|
||||
|
||||
@app.post("/edge/deploy")
|
||||
async def deploy_to_edge(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Deploy agent to edge computing infrastructure"""
|
||||
return {
|
||||
"deployment_id": "deployment_123",
|
||||
"agent_id": request.get("agent_id", "unknown"),
|
||||
"edge_locations": request.get("edge_locations", []),
|
||||
"deployment_results": {"success": True, "deployed_count": 3},
|
||||
"status": "completed"
|
||||
}
|
||||
|
||||
|
||||
@app.post("/edge/coordinate")
|
||||
async def coordinate_edge_to_cloud(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coordinate edge-to-cloud agent operations"""
|
||||
return {
|
||||
"coordination_id": "coord_456",
|
||||
"edge_deployment_id": request.get("edge_deployment_id", "unknown"),
|
||||
"synchronization": "enabled",
|
||||
"load_balancing": "round_robin",
|
||||
"failover": "active",
|
||||
"status": "active"
|
||||
}
|
||||
|
||||
|
||||
@app.post("/ecosystem/develop")
|
||||
async def develop_hermes_ecosystem(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Build comprehensive Hermes ecosystem"""
|
||||
return {
|
||||
"ecosystem_id": "eco_789",
|
||||
"developer_tools": ["cli", "sdk", "dashboard"],
|
||||
"marketplace": "enabled",
|
||||
"community": "growing",
|
||||
"partnerships": ["partner_1", "partner_2"],
|
||||
"status": "developing"
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8108)
|
||||
48
examples/stubs/monitor/monitor.py
Normal file
48
examples/stubs/monitor/monitor.py
Normal file
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AITBC Monitor Service
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
import psutil
|
||||
|
||||
from aitbc import get_logger, DATA_DIR
|
||||
|
||||
def main():
|
||||
logger = get_logger('aitbc-monitor')
|
||||
|
||||
while True:
|
||||
try:
|
||||
# System stats
|
||||
cpu_percent = psutil.cpu_percent()
|
||||
memory_percent = psutil.virtual_memory().percent
|
||||
logger.info(f'System: CPU {cpu_percent}%, Memory {memory_percent}%')
|
||||
|
||||
# Blockchain stats
|
||||
blockchain_file = DATA_DIR / 'data/blockchain/aitbc/blockchain.json'
|
||||
if blockchain_file.exists():
|
||||
with open(blockchain_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
logger.info(f'Blockchain: {len(data.get("blocks", []))} blocks')
|
||||
|
||||
# Marketplace stats
|
||||
marketplace_dir = DATA_DIR / 'data/marketplace'
|
||||
if marketplace_dir.exists():
|
||||
listings_file = marketplace_dir / 'gpu_listings.json'
|
||||
if listings_file.exists():
|
||||
with open(listings_file, 'r') as f:
|
||||
listings = json.load(f)
|
||||
logger.info(f'Marketplace: {len(listings)} GPU listings')
|
||||
|
||||
time.sleep(30)
|
||||
except (json.JSONDecodeError, FileNotFoundError, PermissionError, IOError) as e:
|
||||
logger.error(f'Monitoring error: {type(e).__name__}: {e}')
|
||||
time.sleep(60)
|
||||
except psutil.Error as e:
|
||||
logger.error(f'System monitoring error: {type(e).__name__}: {e}')
|
||||
time.sleep(60)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
examples/stubs/monitor/tests/__init__.py
Normal file
1
examples/stubs/monitor/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Monitor service tests"""
|
||||
216
examples/stubs/monitor/tests/test_edge_cases_monitor.py
Normal file
216
examples/stubs/monitor/tests/test_edge_cases_monitor.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""Edge case and error handling tests for monitor service"""
|
||||
|
||||
import sys
|
||||
import pytest
|
||||
import sys
|
||||
from unittest.mock import Mock, patch, MagicMock, mock_open
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
# Create a proper psutil mock with Error exception class
|
||||
class PsutilError(Exception):
|
||||
pass
|
||||
|
||||
mock_psutil = MagicMock()
|
||||
mock_psutil.cpu_percent = Mock(return_value=45.5)
|
||||
mock_psutil.virtual_memory = Mock(return_value=MagicMock(percent=60.2))
|
||||
mock_psutil.Error = PsutilError
|
||||
sys.modules['psutil'] = mock_psutil
|
||||
|
||||
import monitor
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_json_decode_error_handling():
|
||||
"""Test JSON decode error is handled correctly"""
|
||||
with patch('monitor.logging') as mock_logging, \
|
||||
patch('monitor.time.sleep', side_effect=[None, KeyboardInterrupt]), \
|
||||
patch('monitor.Path') as mock_path, \
|
||||
patch('builtins.open', mock_open(read_data='invalid json{')):
|
||||
|
||||
# Mock blockchain file exists
|
||||
blockchain_path = Mock()
|
||||
blockchain_path.exists.return_value = True
|
||||
marketplace_path = Mock()
|
||||
marketplace_path.exists.return_value = False
|
||||
|
||||
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
|
||||
|
||||
logger = mock_logging.getLogger.return_value
|
||||
mock_logging.basicConfig.return_value = None
|
||||
|
||||
try:
|
||||
monitor.main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# Verify error was logged
|
||||
error_calls = [call for call in logger.error.call_args_list if 'JSONDecodeError' in str(call)]
|
||||
assert len(error_calls) > 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_file_not_found_error_handling():
|
||||
"""Test FileNotFoundError is handled correctly"""
|
||||
with patch('monitor.logging') as mock_logging, \
|
||||
patch('monitor.time.sleep', side_effect=[None, KeyboardInterrupt]), \
|
||||
patch('monitor.Path') as mock_path, \
|
||||
patch('builtins.open', side_effect=FileNotFoundError("File not found")):
|
||||
|
||||
# Mock blockchain file exists
|
||||
blockchain_path = Mock()
|
||||
blockchain_path.exists.return_value = True
|
||||
marketplace_path = Mock()
|
||||
marketplace_path.exists.return_value = False
|
||||
|
||||
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
|
||||
|
||||
logger = mock_logging.getLogger.return_value
|
||||
mock_logging.basicConfig.return_value = None
|
||||
|
||||
try:
|
||||
monitor.main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# Verify error was logged
|
||||
error_calls = [call for call in logger.error.call_args_list if 'FileNotFoundError' in str(call)]
|
||||
assert len(error_calls) > 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_permission_error_handling():
|
||||
"""Test PermissionError is handled correctly"""
|
||||
with patch('monitor.logging') as mock_logging, \
|
||||
patch('monitor.time.sleep', side_effect=[None, KeyboardInterrupt]), \
|
||||
patch('monitor.Path') as mock_path, \
|
||||
patch('builtins.open', side_effect=PermissionError("Permission denied")):
|
||||
|
||||
# Mock blockchain file exists
|
||||
blockchain_path = Mock()
|
||||
blockchain_path.exists.return_value = True
|
||||
marketplace_path = Mock()
|
||||
marketplace_path.exists.return_value = False
|
||||
|
||||
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
|
||||
|
||||
logger = mock_logging.getLogger.return_value
|
||||
mock_logging.basicConfig.return_value = None
|
||||
|
||||
try:
|
||||
monitor.main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# Verify error was logged
|
||||
error_calls = [call for call in logger.error.call_args_list if 'PermissionError' in str(call)]
|
||||
assert len(error_calls) > 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_io_error_handling():
|
||||
"""Test IOError is handled correctly"""
|
||||
with patch('monitor.logging') as mock_logging, \
|
||||
patch('monitor.time.sleep', side_effect=[None, KeyboardInterrupt]), \
|
||||
patch('monitor.Path') as mock_path, \
|
||||
patch('builtins.open', side_effect=IOError("I/O error")):
|
||||
|
||||
# Mock blockchain file exists
|
||||
blockchain_path = Mock()
|
||||
blockchain_path.exists.return_value = True
|
||||
marketplace_path = Mock()
|
||||
marketplace_path.exists.return_value = False
|
||||
|
||||
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
|
||||
|
||||
logger = mock_logging.getLogger.return_value
|
||||
mock_logging.basicConfig.return_value = None
|
||||
|
||||
try:
|
||||
monitor.main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# Verify error was logged
|
||||
error_calls = [call for call in logger.error.call_args_list if 'IOError' in str(call) or 'OSError' in str(call)]
|
||||
assert len(error_calls) > 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_psutil_error_handling():
|
||||
"""Test psutil.Error is handled correctly"""
|
||||
with patch('monitor.logging') as mock_logging, \
|
||||
patch('monitor.time.sleep', side_effect=[None, KeyboardInterrupt]), \
|
||||
patch('monitor.psutil.cpu_percent', side_effect=PsutilError("psutil error")):
|
||||
|
||||
logger = mock_logging.getLogger.return_value
|
||||
mock_logging.basicConfig.return_value = None
|
||||
|
||||
try:
|
||||
monitor.main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# Verify error was logged
|
||||
error_calls = [call for call in logger.error.call_args_list if 'psutil error' in str(call)]
|
||||
assert len(error_calls) > 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_blocks_array():
|
||||
"""Test handling of empty blocks array in blockchain data"""
|
||||
with patch('monitor.logging') as mock_logging, \
|
||||
patch('monitor.time.sleep', side_effect=KeyboardInterrupt), \
|
||||
patch('monitor.Path') as mock_path, \
|
||||
patch('builtins.open', mock_open(read_data='{"blocks": []}')):
|
||||
|
||||
# Mock blockchain file exists
|
||||
blockchain_path = Mock()
|
||||
blockchain_path.exists.return_value = True
|
||||
marketplace_path = Mock()
|
||||
marketplace_path.exists.return_value = False
|
||||
|
||||
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
|
||||
|
||||
logger = mock_logging.getLogger.return_value
|
||||
mock_logging.basicConfig.return_value = None
|
||||
|
||||
try:
|
||||
monitor.main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# Verify blockchain stats were logged with 0 blocks
|
||||
blockchain_calls = [call for call in logger.info.call_args_list if 'Blockchain' in str(call)]
|
||||
assert len(blockchain_calls) > 0
|
||||
assert '0 blocks' in str(blockchain_calls[0])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_blocks_key():
|
||||
"""Test handling of missing blocks key in blockchain data"""
|
||||
with patch('monitor.logging') as mock_logging, \
|
||||
patch('monitor.time.sleep', side_effect=KeyboardInterrupt), \
|
||||
patch('monitor.Path') as mock_path, \
|
||||
patch('builtins.open', mock_open(read_data='{"height": 100}')):
|
||||
|
||||
# Mock blockchain file exists
|
||||
blockchain_path = Mock()
|
||||
blockchain_path.exists.return_value = True
|
||||
marketplace_path = Mock()
|
||||
marketplace_path.exists.return_value = False
|
||||
|
||||
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
|
||||
|
||||
logger = mock_logging.getLogger.return_value
|
||||
mock_logging.basicConfig.return_value = None
|
||||
|
||||
try:
|
||||
monitor.main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# Verify blockchain stats were logged with 0 blocks (default)
|
||||
blockchain_calls = [call for call in logger.info.call_args_list if 'Blockchain' in str(call)]
|
||||
assert len(blockchain_calls) > 0
|
||||
assert '0 blocks' in str(blockchain_calls[0])
|
||||
108
examples/stubs/monitor/tests/test_unit_monitor.py
Normal file
108
examples/stubs/monitor/tests/test_unit_monitor.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Unit tests for monitor service"""
|
||||
|
||||
import sys
|
||||
import pytest
|
||||
import sys
|
||||
from unittest.mock import Mock, patch, MagicMock, mock_open
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
# Create a proper psutil mock with Error exception class
|
||||
class PsutilError(Exception):
|
||||
pass
|
||||
|
||||
mock_psutil = MagicMock()
|
||||
mock_psutil.cpu_percent = Mock(return_value=45.5)
|
||||
mock_psutil.virtual_memory = Mock(return_value=MagicMock(percent=60.2))
|
||||
mock_psutil.Error = PsutilError
|
||||
sys.modules['psutil'] = mock_psutil
|
||||
|
||||
import monitor
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_main_system_stats_logging():
|
||||
"""Test that system stats are logged correctly"""
|
||||
with patch('monitor.logging') as mock_logging, \
|
||||
patch('monitor.time.sleep', side_effect=KeyboardInterrupt), \
|
||||
patch('monitor.Path') as mock_path:
|
||||
|
||||
mock_path.return_value.exists.return_value = False
|
||||
|
||||
logger = mock_logging.getLogger.return_value
|
||||
mock_logging.basicConfig.return_value = None
|
||||
|
||||
try:
|
||||
monitor.main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# Verify system stats were logged
|
||||
assert logger.info.call_count >= 1
|
||||
system_call = logger.info.call_args_list[0]
|
||||
assert 'CPU 45.5%' in str(system_call)
|
||||
assert 'Memory 60.2%' in str(system_call)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_main_blockchain_stats_logging():
|
||||
"""Test that blockchain stats are logged when file exists"""
|
||||
with patch('monitor.logging') as mock_logging, \
|
||||
patch('monitor.time.sleep', side_effect=KeyboardInterrupt), \
|
||||
patch('monitor.Path') as mock_path, \
|
||||
patch('builtins.open', mock_open(read_data='{"blocks": [{"height": 1}, {"height": 2}]}')):
|
||||
|
||||
# Mock blockchain file exists
|
||||
blockchain_path = Mock()
|
||||
blockchain_path.exists.return_value = True
|
||||
marketplace_path = Mock()
|
||||
marketplace_path.exists.return_value = False
|
||||
|
||||
mock_path.side_effect = lambda x: blockchain_path if 'blockchain' in str(x) else marketplace_path
|
||||
|
||||
logger = mock_logging.getLogger.return_value
|
||||
mock_logging.basicConfig.return_value = None
|
||||
|
||||
try:
|
||||
monitor.main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# Verify blockchain stats were logged
|
||||
blockchain_calls = [call for call in logger.info.call_args_list if 'Blockchain' in str(call)]
|
||||
assert len(blockchain_calls) > 0
|
||||
assert '2 blocks' in str(blockchain_calls[0])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_main_marketplace_stats_logging():
|
||||
"""Test that marketplace stats are logged when file exists"""
|
||||
with patch('monitor.logging') as mock_logging, \
|
||||
patch('monitor.time.sleep', side_effect=KeyboardInterrupt), \
|
||||
patch('monitor.Path') as mock_path, \
|
||||
patch('builtins.open', mock_open(read_data='[{"id": 1, "gpu": "rtx3080"}, {"id": 2, "gpu": "rtx3090"}]')):
|
||||
|
||||
# Mock blockchain file doesn't exist, marketplace does
|
||||
blockchain_path = Mock()
|
||||
blockchain_path.exists.return_value = False
|
||||
marketplace_path = Mock()
|
||||
marketplace_path.exists.return_value = True
|
||||
listings_file = Mock()
|
||||
listings_file.exists.return_value = True
|
||||
listings_file.__truediv__ = Mock(return_value=listings_file)
|
||||
marketplace_path.__truediv__ = Mock(return_value=listings_file)
|
||||
|
||||
mock_path.side_effect = lambda x: listings_file if 'gpu_listings' in str(x) else (marketplace_path if 'marketplace' in str(x) else blockchain_path)
|
||||
|
||||
logger = mock_logging.getLogger.return_value
|
||||
mock_logging.basicConfig.return_value = None
|
||||
|
||||
try:
|
||||
monitor.main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# Verify marketplace stats were logged
|
||||
marketplace_calls = [call for call in logger.info.call_args_list if 'Marketplace' in str(call)]
|
||||
assert len(marketplace_calls) > 0
|
||||
assert '2 GPU listings' in str(marketplace_calls[0])
|
||||
1327
examples/stubs/monitoring-service/poetry.lock
generated
Normal file
1327
examples/stubs/monitoring-service/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
26
examples/stubs/monitoring-service/pyproject.toml
Normal file
26
examples/stubs/monitoring-service/pyproject.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[tool.poetry]
|
||||
name = "monitoring-service"
|
||||
version = "0.1.0"
|
||||
description = "AITBC Monitoring Service for system health and metrics"
|
||||
authors = ["AITBC Team"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.13"
|
||||
fastapi = ">=0.115.6"
|
||||
uvicorn = {extras = ["standard"], version = "^0.32.0"}
|
||||
sqlmodel = "^0.0.37"
|
||||
sqlalchemy = "^2.0.25"
|
||||
pydantic = "^2.6.0"
|
||||
pydantic-settings = "^2.1.0"
|
||||
httpx = ">=0.28.1"
|
||||
psutil = "^7.2.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = ">=9.0.3"
|
||||
pytest-asyncio = ">=1.3.0"
|
||||
black = ">=26.3.1"
|
||||
ruff = "^0.1.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
266
examples/stubs/monitoring-service/src/monitoring_service/main.py
Normal file
266
examples/stubs/monitoring-service/src/monitoring_service/main.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Monitoring Service for system health and metrics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app = FastAPI(
|
||||
title="AITBC Monitoring Service",
|
||||
description="System health and metrics monitoring service",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Service endpoints configuration
|
||||
SERVICES = {
|
||||
"gpu": {
|
||||
"name": "GPU Service",
|
||||
"port": 8101,
|
||||
"url": "http://localhost:8101",
|
||||
"description": "GPU marketplace and miner operations",
|
||||
},
|
||||
"marketplace": {
|
||||
"name": "Marketplace Service",
|
||||
"port": 8102,
|
||||
"url": "http://localhost:8102",
|
||||
"description": "Marketplace transactions",
|
||||
},
|
||||
"trading": {
|
||||
"name": "Trading Service",
|
||||
"port": 8104,
|
||||
"url": "http://localhost:8104",
|
||||
"description": "Trading and explorer operations",
|
||||
},
|
||||
"governance": {
|
||||
"name": "Governance Service",
|
||||
"port": 8105,
|
||||
"url": "http://localhost:8105",
|
||||
"description": "Governance transactions",
|
||||
},
|
||||
"ai": {
|
||||
"name": "AI Service",
|
||||
"port": 8106,
|
||||
"url": "http://localhost:8106",
|
||||
"description": "AI job operations",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "service": "monitoring-service"}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
return {
|
||||
"service": "AITBC Monitoring Service",
|
||||
"version": "1.0.0",
|
||||
"status": "operational"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/dashboard")
|
||||
async def monitoring_dashboard() -> dict[str, Any]:
|
||||
"""
|
||||
Unified monitoring dashboard for all services
|
||||
"""
|
||||
try:
|
||||
# Collect health data from all services
|
||||
health_data = await collect_all_health_data()
|
||||
|
||||
# Calculate overall metrics
|
||||
overall_metrics = calculate_overall_metrics(health_data)
|
||||
|
||||
dashboard_data = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"overall_status": overall_metrics["overall_status"],
|
||||
"services": health_data,
|
||||
"metrics": overall_metrics,
|
||||
"summary": {
|
||||
"total_services": len(SERVICES),
|
||||
"healthy_services": len([s for s in health_data.values() if s.get("status") == "healthy"]),
|
||||
"degraded_services": len([s for s in health_data.values() if s.get("status") == "degraded"]),
|
||||
"unhealthy_services": len([s for s in health_data.values() if s.get("status") == "unhealthy"]),
|
||||
"last_updated": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC"),
|
||||
},
|
||||
}
|
||||
|
||||
logger.info("Monitoring dashboard data collected successfully")
|
||||
return dashboard_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate monitoring dashboard: {e}")
|
||||
return {
|
||||
"error": "Failed to generate dashboard",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"services": SERVICES,
|
||||
"overall_status": "error",
|
||||
"summary": {
|
||||
"total_services": len(SERVICES),
|
||||
"healthy_services": 0,
|
||||
"degraded_services": 0,
|
||||
"unhealthy_services": len(SERVICES),
|
||||
"last_updated": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/dashboard/summary")
|
||||
async def services_summary() -> dict[str, Any]:
|
||||
"""
|
||||
Quick summary of all services status
|
||||
"""
|
||||
try:
|
||||
health_data = await collect_all_health_data()
|
||||
|
||||
summary = {"timestamp": datetime.now(timezone.utc).isoformat(), "services": {}}
|
||||
|
||||
for service_id, service_info in SERVICES.items():
|
||||
health = health_data.get(service_id, {})
|
||||
summary["services"][service_id] = {
|
||||
"name": service_info["name"],
|
||||
"port": service_info["port"],
|
||||
"status": health.get("status", "unknown"),
|
||||
"description": service_info["description"],
|
||||
"last_check": health.get("timestamp"),
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate services summary: {e}")
|
||||
return {"error": "Failed to generate summary", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
@app.get("/dashboard/metrics")
|
||||
async def system_metrics() -> dict[str, Any]:
|
||||
"""
|
||||
System-wide performance metrics
|
||||
"""
|
||||
try:
|
||||
import psutil
|
||||
|
||||
# System metrics
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
memory = psutil.virtual_memory()
|
||||
disk = psutil.disk_usage("/")
|
||||
|
||||
# Network metrics
|
||||
network = psutil.net_io_counters()
|
||||
|
||||
metrics = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"system": {
|
||||
"cpu_percent": cpu_percent,
|
||||
"cpu_count": psutil.cpu_count(),
|
||||
"memory_percent": memory.percent,
|
||||
"memory_total_gb": round(memory.total / (1024**3), 2),
|
||||
"memory_available_gb": round(memory.available / (1024**3), 2),
|
||||
"disk_percent": disk.percent,
|
||||
"disk_total_gb": round(disk.total / (1024**3), 2),
|
||||
"disk_free_gb": round(disk.free / (1024**3), 2),
|
||||
},
|
||||
"network": {
|
||||
"bytes_sent": network.bytes_sent,
|
||||
"bytes_recv": network.bytes_recv,
|
||||
"packets_sent": network.packets_sent,
|
||||
"packets_recv": network.packets_recv,
|
||||
},
|
||||
"services": {
|
||||
"total_services": len(SERVICES),
|
||||
"service_names": list(SERVICES.keys()),
|
||||
},
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to collect system metrics: {e}")
|
||||
return {"error": "Failed to collect metrics", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
async def collect_all_health_data() -> dict[str, Any]:
|
||||
"""Collect health data from all services"""
|
||||
health_data = {}
|
||||
|
||||
tasks = []
|
||||
for service_id, service_info in SERVICES.items():
|
||||
task = check_service_health(service_id, service_info)
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for i, (service_id, service_info) in enumerate(SERVICES.items()):
|
||||
result = results[i]
|
||||
if isinstance(result, Exception):
|
||||
health_data[service_id] = {
|
||||
"status": "unhealthy",
|
||||
"error": str(result),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
else:
|
||||
health_data[service_id] = result
|
||||
|
||||
return health_data
|
||||
|
||||
|
||||
async def check_service_health(service_name: str, service_config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Check health status of a specific service
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
health_url = f"{service_config['url']}/health"
|
||||
response = await client.get(health_url)
|
||||
return {
|
||||
"status": "healthy",
|
||||
"response_time": 0.1,
|
||||
"last_check": datetime.now(timezone.utc).isoformat(),
|
||||
"details": response.json(),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Service {service_name} health check failed: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"last_check": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
def calculate_overall_metrics(health_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Calculate overall system metrics from health data"""
|
||||
|
||||
status_counts = {"healthy": 0, "degraded": 0, "unhealthy": 0, "unknown": 0}
|
||||
|
||||
for service_health in health_data.values():
|
||||
status = service_health.get("status", "unknown")
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
|
||||
# Determine overall status
|
||||
if status_counts["unhealthy"] > 0:
|
||||
overall_status = "unhealthy"
|
||||
elif status_counts["degraded"] > 0:
|
||||
overall_status = "degraded"
|
||||
else:
|
||||
overall_status = "healthy"
|
||||
|
||||
return {
|
||||
"overall_status": overall_status,
|
||||
"status_counts": status_counts,
|
||||
"health_percentage": (status_counts["healthy"] / len(health_data)) * 100 if health_data else 0,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8107)
|
||||
697
examples/stubs/multi-region-load-balancer/main.py
Executable file
697
examples/stubs/multi-region-load-balancer/main.py
Executable file
@@ -0,0 +1,697 @@
|
||||
"""
|
||||
Multi-Region Load Balancing Service for AITBC
|
||||
Handles intelligent load distribution across global regions
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="AITBC Multi-Region Load Balancer",
|
||||
description="Intelligent load balancing across global regions",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Data models
|
||||
class LoadBalancingRule(BaseModel):
|
||||
rule_id: str
|
||||
name: str
|
||||
algorithm: str # weighted_round_robin, least_connections, geographic, performance_based
|
||||
target_regions: List[str]
|
||||
weights: Dict[str, float] # Region weights
|
||||
health_check_path: str
|
||||
failover_enabled: bool
|
||||
session_affinity: bool
|
||||
|
||||
class RegionHealth(BaseModel):
|
||||
region_id: str
|
||||
status: str # healthy, unhealthy, degraded
|
||||
response_time_ms: float
|
||||
success_rate: float
|
||||
active_connections: int
|
||||
last_check: datetime
|
||||
|
||||
class LoadBalancingMetrics(BaseModel):
|
||||
balancer_id: str
|
||||
timestamp: datetime
|
||||
total_requests: int
|
||||
requests_per_region: Dict[str, int]
|
||||
average_response_time: float
|
||||
error_rate: float
|
||||
throughput: float
|
||||
|
||||
class GeographicRule(BaseModel):
|
||||
rule_id: str
|
||||
source_regions: List[str]
|
||||
target_regions: List[str]
|
||||
priority: int # Lower number = higher priority
|
||||
latency_threshold_ms: float
|
||||
|
||||
# In-memory storage (in production, use database)
|
||||
load_balancing_rules: Dict[str, Dict] = {}
|
||||
region_health_status: Dict[str, RegionHealth] = {}
|
||||
balancing_metrics: Dict[str, List[Dict]] = {}
|
||||
geographic_rules: Dict[str, Dict] = {}
|
||||
session_affinity_data: Dict[str, Dict] = {}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "AITBC Multi-Region Load Balancer",
|
||||
"status": "running",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"total_rules": len(load_balancing_rules),
|
||||
"active_rules": len([r for r in load_balancing_rules.values() if r["status"] == "active"]),
|
||||
"monitored_regions": len(region_health_status),
|
||||
"healthy_regions": len([r for r in region_health_status.values() if r.status == "healthy"])
|
||||
}
|
||||
|
||||
@app.post("/api/v1/rules/create")
|
||||
async def create_load_balancing_rule(rule: LoadBalancingRule):
|
||||
"""Create a new load balancing rule"""
|
||||
if rule.rule_id in load_balancing_rules:
|
||||
raise HTTPException(status_code=400, detail="Load balancing rule already exists")
|
||||
|
||||
# Create rule record
|
||||
rule_record = {
|
||||
"rule_id": rule.rule_id,
|
||||
"name": rule.name,
|
||||
"algorithm": rule.algorithm,
|
||||
"target_regions": rule.target_regions,
|
||||
"weights": rule.weights,
|
||||
"health_check_path": rule.health_check_path,
|
||||
"failover_enabled": rule.failover_enabled,
|
||||
"session_affinity": rule.session_affinity,
|
||||
"status": "active",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"total_requests": 0,
|
||||
"failed_requests": 0,
|
||||
"last_updated": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
load_balancing_rules[rule.rule_id] = rule_record
|
||||
|
||||
# Start health monitoring for target regions
|
||||
asyncio.create_task(start_health_monitoring(rule.rule_id))
|
||||
|
||||
logger.info(f"Load balancing rule created: {rule.name} ({rule.rule_id})")
|
||||
|
||||
return {
|
||||
"rule_id": rule.rule_id,
|
||||
"status": "created",
|
||||
"name": rule.name,
|
||||
"algorithm": rule.algorithm,
|
||||
"created_at": rule_record["created_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/rules")
|
||||
async def list_load_balancing_rules():
|
||||
"""List all load balancing rules"""
|
||||
return {
|
||||
"rules": list(load_balancing_rules.values()),
|
||||
"total_rules": len(load_balancing_rules),
|
||||
"active_rules": len([r for r in load_balancing_rules.values() if r["status"] == "active"])
|
||||
}
|
||||
|
||||
@app.get("/api/v1/rules/{rule_id}")
|
||||
async def get_load_balancing_rule(rule_id: str):
|
||||
"""Get detailed load balancing rule information"""
|
||||
if rule_id not in load_balancing_rules:
|
||||
raise HTTPException(status_code=404, detail="Load balancing rule not found")
|
||||
|
||||
rule = load_balancing_rules[rule_id].copy()
|
||||
|
||||
# Add region health status
|
||||
rule["region_health"] = {
|
||||
region_id: region_health_status.get(region_id)
|
||||
for region_id in rule["target_regions"]
|
||||
if region_id in region_health_status
|
||||
}
|
||||
|
||||
# Add performance metrics
|
||||
rule["performance_metrics"] = balancing_metrics.get(rule_id, [])
|
||||
|
||||
return rule
|
||||
|
||||
@app.post("/api/v1/rules/{rule_id}/update-weights")
|
||||
async def update_rule_weights(rule_id: str, weights: Dict[str, float]):
|
||||
"""Update weights for a load balancing rule"""
|
||||
if rule_id not in load_balancing_rules:
|
||||
raise HTTPException(status_code=404, detail="Load balancing rule not found")
|
||||
|
||||
rule = load_balancing_rules[rule_id]
|
||||
|
||||
# Validate weights
|
||||
total_weight = sum(weights.values())
|
||||
if total_weight == 0:
|
||||
raise HTTPException(status_code=400, detail="Total weight cannot be zero")
|
||||
|
||||
# Normalize weights
|
||||
normalized_weights = {k: v / total_weight for k, v in weights.items()}
|
||||
|
||||
# Update rule weights
|
||||
rule["weights"] = normalized_weights
|
||||
rule["last_updated"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
logger.info(f"Weights updated for rule {rule_id}: {normalized_weights}")
|
||||
|
||||
return {
|
||||
"rule_id": rule_id,
|
||||
"new_weights": normalized_weights,
|
||||
"updated_at": rule["last_updated"]
|
||||
}
|
||||
|
||||
@app.post("/api/v1/health/register")
|
||||
async def register_region_health(health: RegionHealth):
|
||||
"""Register or update health status for a region"""
|
||||
region_health_status[health.region_id] = health
|
||||
|
||||
# Update load balancing rules that use this region
|
||||
for rule_id, rule in load_balancing_rules.items():
|
||||
if health.region_id in rule["target_regions"]:
|
||||
# Update rule based on health status
|
||||
if health.status == "unhealthy" and rule["failover_enabled"]:
|
||||
logger.warning(f"Region {health.region_id} unhealthy, enabling failover for rule {rule_id}")
|
||||
enable_failover(rule_id, health.region_id)
|
||||
|
||||
return {
|
||||
"region_id": health.region_id,
|
||||
"status": health.status,
|
||||
"registered_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/health")
|
||||
async def get_all_region_health():
|
||||
"""Get health status for all monitored regions"""
|
||||
return {
|
||||
"region_health": {
|
||||
region_id: health.dict()
|
||||
for region_id, health in region_health_status.items()
|
||||
},
|
||||
"total_regions": len(region_health_status),
|
||||
"healthy_regions": len([r for r in region_health_status.values() if r.status == "healthy"]),
|
||||
"unhealthy_regions": len([r for r in region_health_status.values() if r.status == "unhealthy"]),
|
||||
"degraded_regions": len([r for r in region_health_status.values() if r.status == "degraded"])
|
||||
}
|
||||
|
||||
@app.post("/api/v1/geographic-rules/create")
|
||||
async def create_geographic_rule(rule: GeographicRule):
|
||||
"""Create a geographic routing rule"""
|
||||
if rule.rule_id in geographic_rules:
|
||||
raise HTTPException(status_code=400, detail="Geographic rule already exists")
|
||||
|
||||
# Create geographic rule record
|
||||
rule_record = {
|
||||
"rule_id": rule.rule_id,
|
||||
"source_regions": rule.source_regions,
|
||||
"target_regions": rule.target_regions,
|
||||
"priority": rule.priority,
|
||||
"latency_threshold_ms": rule.latency_threshold_ms,
|
||||
"status": "active",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"usage_count": 0
|
||||
}
|
||||
|
||||
geographic_rules[rule.rule_id] = rule_record
|
||||
|
||||
logger.info(f"Geographic rule created: {rule.rule_id}")
|
||||
|
||||
return {
|
||||
"rule_id": rule.rule_id,
|
||||
"status": "created",
|
||||
"priority": rule.priority,
|
||||
"created_at": rule_record["created_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/route/{client_region}")
|
||||
async def get_optimal_region(client_region: str, rule_id: Optional[str] = None):
|
||||
"""Get optimal target region for a client region"""
|
||||
if rule_id and rule_id not in load_balancing_rules:
|
||||
raise HTTPException(status_code=404, detail="Load balancing rule not found")
|
||||
|
||||
# Find optimal region based on rules
|
||||
if rule_id:
|
||||
optimal_region = select_region_by_algorithm(rule_id, client_region)
|
||||
else:
|
||||
optimal_region = select_region_geographically(client_region)
|
||||
|
||||
return {
|
||||
"client_region": client_region,
|
||||
"optimal_region": optimal_region,
|
||||
"rule_id": rule_id,
|
||||
"selection_reason": get_selection_reason(optimal_region, client_region, rule_id),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.post("/api/v1/metrics/record")
|
||||
async def record_balancing_metrics(metrics: LoadBalancingMetrics):
|
||||
"""Record load balancing performance metrics"""
|
||||
metrics_record = {
|
||||
"metrics_id": f"metrics_{int(datetime.now(timezone.utc).timestamp())}",
|
||||
"balancer_id": metrics.balancer_id,
|
||||
"timestamp": metrics.timestamp.isoformat(),
|
||||
"total_requests": metrics.total_requests,
|
||||
"requests_per_region": metrics.requests_per_region,
|
||||
"average_response_time": metrics.average_response_time,
|
||||
"error_rate": metrics.error_rate,
|
||||
"throughput": metrics.throughput
|
||||
}
|
||||
|
||||
if metrics.balancer_id not in balancing_metrics:
|
||||
balancing_metrics[metrics.balancer_id] = []
|
||||
|
||||
balancing_metrics[metrics.balancer_id].append(metrics_record)
|
||||
|
||||
# Keep only last 1000 records per balancer
|
||||
if len(balancing_metrics[metrics.balancer_id]) > 1000:
|
||||
balancing_metrics[metrics.balancer_id] = balancing_metrics[metrics.balancer_id][-1000:]
|
||||
|
||||
return {
|
||||
"metrics_id": metrics_record["metrics_id"],
|
||||
"status": "recorded",
|
||||
"timestamp": metrics_record["timestamp"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/metrics/{rule_id}")
|
||||
async def get_balancing_metrics(rule_id: str, hours: int = 24):
|
||||
"""Get performance metrics for a load balancing rule"""
|
||||
if rule_id not in load_balancing_rules:
|
||||
raise HTTPException(status_code=404, detail="Load balancing rule not found")
|
||||
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
recent_metrics = [
|
||||
m for m in balancing_metrics.get(rule_id, [])
|
||||
if datetime.fromisoformat(m["timestamp"]) > cutoff_time
|
||||
]
|
||||
|
||||
# Calculate statistics
|
||||
if recent_metrics:
|
||||
avg_response_time = sum(m["average_response_time"] for m in recent_metrics) / len(recent_metrics)
|
||||
avg_error_rate = sum(m["error_rate"] for m in recent_metrics) / len(recent_metrics)
|
||||
avg_throughput = sum(m["throughput"] for m in recent_metrics) / len(recent_metrics)
|
||||
total_requests = sum(m["total_requests"] for m in recent_metrics)
|
||||
else:
|
||||
avg_response_time = avg_error_rate = avg_throughput = total_requests = 0.0
|
||||
|
||||
return {
|
||||
"rule_id": rule_id,
|
||||
"period_hours": hours,
|
||||
"metrics": recent_metrics,
|
||||
"statistics": {
|
||||
"average_response_time_ms": round(avg_response_time, 3),
|
||||
"average_error_rate": round(avg_error_rate, 4),
|
||||
"average_throughput": round(avg_throughput, 2),
|
||||
"total_requests": int(total_requests),
|
||||
"total_samples": len(recent_metrics)
|
||||
},
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/dashboard")
|
||||
async def get_load_balancing_dashboard():
|
||||
"""Get comprehensive load balancing dashboard"""
|
||||
# Calculate overall statistics
|
||||
total_rules = len(load_balancing_rules)
|
||||
active_rules = len([r for r in load_balancing_rules.values() if r["status"] == "active"])
|
||||
|
||||
# Region health summary
|
||||
health_summary = {
|
||||
"total_regions": len(region_health_status),
|
||||
"healthy": len([r for r in region_health_status.values() if r.status == "healthy"]),
|
||||
"unhealthy": len([r for r in region_health_status.values() if r.status == "unhealthy"]),
|
||||
"degraded": len([r for r in region_health_status.values() if r.status == "degraded"])
|
||||
}
|
||||
|
||||
# Performance summary
|
||||
performance_summary = {}
|
||||
for rule_id, metrics_list in balancing_metrics.items():
|
||||
if metrics_list:
|
||||
latest_metrics = metrics_list[-1]
|
||||
performance_summary[rule_id] = {
|
||||
"total_requests": latest_metrics["total_requests"],
|
||||
"average_response_time": latest_metrics["average_response_time"],
|
||||
"error_rate": latest_metrics["error_rate"],
|
||||
"throughput": latest_metrics["throughput"]
|
||||
}
|
||||
|
||||
# Algorithm distribution
|
||||
algorithm_distribution = {}
|
||||
for rule in load_balancing_rules.values():
|
||||
algorithm = rule["algorithm"]
|
||||
algorithm_distribution[algorithm] = algorithm_distribution.get(algorithm, 0) + 1
|
||||
|
||||
return {
|
||||
"dashboard": {
|
||||
"overview": {
|
||||
"total_rules": total_rules,
|
||||
"active_rules": active_rules,
|
||||
"geographic_rules": len(geographic_rules),
|
||||
"algorithm_distribution": algorithm_distribution
|
||||
},
|
||||
"region_health": health_summary,
|
||||
"performance": performance_summary,
|
||||
"recent_activity": get_recent_activity()
|
||||
},
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Core load balancing functions
|
||||
def select_region_by_algorithm(rule_id: str, client_region: str) -> Optional[str]:
|
||||
"""Select optimal region based on load balancing algorithm"""
|
||||
if rule_id not in load_balancing_rules:
|
||||
return None
|
||||
|
||||
rule = load_balancing_rules[rule_id]
|
||||
algorithm = rule["algorithm"]
|
||||
target_regions = rule["target_regions"]
|
||||
|
||||
# Filter healthy regions
|
||||
healthy_regions = [
|
||||
region for region in target_regions
|
||||
if region in region_health_status and region_health_status[region].status == "healthy"
|
||||
]
|
||||
|
||||
if not healthy_regions:
|
||||
# Fallback to any region if no healthy ones
|
||||
healthy_regions = target_regions
|
||||
|
||||
if algorithm == "weighted_round_robin":
|
||||
return select_weighted_round_robin(rule_id, healthy_regions)
|
||||
elif algorithm == "least_connections":
|
||||
return select_least_connections(healthy_regions)
|
||||
elif algorithm == "geographic":
|
||||
return select_geographic_optimal(client_region, healthy_regions)
|
||||
elif algorithm == "performance_based":
|
||||
return select_performance_optimal(healthy_regions)
|
||||
else:
|
||||
return healthy_regions[0] if healthy_regions else None
|
||||
|
||||
def select_weighted_round_robin(rule_id: str, regions: List[str]) -> str:
|
||||
"""Select region using weighted round robin"""
|
||||
rule = load_balancing_rules[rule_id]
|
||||
weights = rule["weights"]
|
||||
|
||||
# Filter weights for available regions
|
||||
available_weights = {r: weights.get(r, 1.0) for r in regions if r in weights}
|
||||
|
||||
if not available_weights:
|
||||
return regions[0]
|
||||
|
||||
# Simple weighted selection (in production, use proper round robin state)
|
||||
import random
|
||||
total_weight = sum(available_weights.values())
|
||||
rand_val = random.uniform(0, total_weight)
|
||||
|
||||
current_weight = 0
|
||||
for region, weight in available_weights.items():
|
||||
current_weight += weight
|
||||
if rand_val <= current_weight:
|
||||
return region
|
||||
|
||||
return list(available_weights.keys())[-1]
|
||||
|
||||
def select_least_connections(regions: List[str]) -> str:
|
||||
"""Select region with least connections"""
|
||||
min_connections = float('inf')
|
||||
optimal_region = None
|
||||
|
||||
for region in regions:
|
||||
if region in region_health_status:
|
||||
connections = region_health_status[region].active_connections
|
||||
if connections < min_connections:
|
||||
min_connections = connections
|
||||
optimal_region = region
|
||||
|
||||
return optimal_region or regions[0]
|
||||
|
||||
def select_geographic_optimal(client_region: str, target_regions: List[str]) -> str:
|
||||
"""Select region based on geographic proximity"""
|
||||
# Simplified geographic mapping (in production, use actual geographic data)
|
||||
geographic_proximity = {
|
||||
"us-east": ["us-east-1", "us-west-1"],
|
||||
"us-west": ["us-west-1", "us-east-1"],
|
||||
"europe": ["eu-west-1", "eu-central-1"],
|
||||
"asia": ["ap-southeast-1", "ap-northeast-1"]
|
||||
}
|
||||
|
||||
# Find closest regions
|
||||
for geo_area, close_regions in geographic_proximity.items():
|
||||
if client_region.lower() in geo_area.lower():
|
||||
for close_region in close_regions:
|
||||
if close_region in target_regions:
|
||||
return close_region
|
||||
|
||||
# Fallback to first healthy region
|
||||
return target_regions[0]
|
||||
|
||||
def select_performance_optimal(regions: List[str]) -> str:
|
||||
"""Select region with best performance"""
|
||||
best_region = None
|
||||
best_score = float('inf')
|
||||
|
||||
for region in regions:
|
||||
if region in region_health_status:
|
||||
health = region_health_status[region]
|
||||
# Calculate performance score (lower is better)
|
||||
score = health.response_time_ms * (1 - health.success_rate)
|
||||
if score < best_score:
|
||||
best_score = score
|
||||
best_region = region
|
||||
|
||||
return best_region or regions[0]
|
||||
|
||||
def select_region_geographically(client_region: str) -> Optional[str]:
|
||||
"""Select region based on geographic rules"""
|
||||
# Apply geographic rules
|
||||
applicable_rules = [
|
||||
rule for rule in geographic_rules.values()
|
||||
if client_region in rule["source_regions"] and rule["status"] == "active"
|
||||
]
|
||||
|
||||
# Sort by priority (lower number = higher priority)
|
||||
applicable_rules.sort(key=lambda x: x["priority"])
|
||||
|
||||
for rule in applicable_rules:
|
||||
# Find best target region based on latency
|
||||
best_target = None
|
||||
best_latency = float('inf')
|
||||
|
||||
for target_region in rule["target_regions"]:
|
||||
if target_region in region_health_status:
|
||||
latency = region_health_status[target_region].response_time_ms
|
||||
if latency < best_latency and latency < rule["latency_threshold_ms"]:
|
||||
best_latency = latency
|
||||
best_target = target_region
|
||||
|
||||
if best_target:
|
||||
rule["usage_count"] += 1
|
||||
return best_target
|
||||
|
||||
# Fallback to any healthy region
|
||||
healthy_regions = [
|
||||
region for region, health in region_health_status.items()
|
||||
if health.status == "healthy"
|
||||
]
|
||||
|
||||
return healthy_regions[0] if healthy_regions else None
|
||||
|
||||
def get_selection_reason(region: str, client_region: str, rule_id: Optional[str]) -> str:
|
||||
"""Get reason for region selection"""
|
||||
if rule_id and rule_id in load_balancing_rules:
|
||||
rule = load_balancing_rules[rule_id]
|
||||
return f"Selected by {rule['algorithm']} algorithm using rule {rule['name']}"
|
||||
else:
|
||||
return f"Selected based on geographic proximity from {client_region}"
|
||||
|
||||
def enable_failover(rule_id: str, unhealthy_region: str):
|
||||
"""Enable failover for unhealthy region"""
|
||||
rule = load_balancing_rules[rule_id]
|
||||
|
||||
# Remove unhealthy region from rotation temporarily
|
||||
if unhealthy_region in rule["target_regions"]:
|
||||
rule["target_regions"].remove(unhealthy_region)
|
||||
logger.warning(f"Region {unhealthy_region} removed from load balancing rule {rule_id}")
|
||||
|
||||
def get_recent_activity() -> List[Dict]:
|
||||
"""Get recent load balancing activity"""
|
||||
activity = []
|
||||
|
||||
# Recent health changes
|
||||
for region_id, health in region_health_status.items():
|
||||
if (datetime.now(timezone.utc) - health.last_check).total_seconds() < 3600: # Last hour
|
||||
activity.append({
|
||||
"type": "health_check",
|
||||
"region": region_id,
|
||||
"status": health.status,
|
||||
"timestamp": health.last_check.isoformat()
|
||||
})
|
||||
|
||||
# Recent rule updates
|
||||
for rule_id, rule in load_balancing_rules.items():
|
||||
if (datetime.now(timezone.utc) - datetime.fromisoformat(rule["last_updated"])).total_seconds() < 3600:
|
||||
activity.append({
|
||||
"type": "rule_update",
|
||||
"rule_id": rule_id,
|
||||
"name": rule["name"],
|
||||
"timestamp": rule["last_updated"]
|
||||
})
|
||||
|
||||
# Sort by timestamp (most recent first)
|
||||
activity.sort(key=lambda x: x["timestamp"], reverse=True)
|
||||
|
||||
return activity[:20]
|
||||
|
||||
# Background task for health monitoring
|
||||
async def start_health_monitoring(rule_id: str):
|
||||
"""Start health monitoring for a load balancing rule"""
|
||||
rule = load_balancing_rules[rule_id]
|
||||
|
||||
while rule["status"] == "active":
|
||||
try:
|
||||
# Check health of all target regions
|
||||
for region_id in rule["target_regions"]:
|
||||
await check_region_health(region_id)
|
||||
|
||||
await asyncio.sleep(30) # Check every 30 seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health monitoring error for rule {rule_id}: {str(e)}")
|
||||
await asyncio.sleep(10)
|
||||
|
||||
async def check_region_health(region_id: str):
|
||||
"""Check health of a specific region"""
|
||||
# Simulate health check (in production, this would be actual health checks)
|
||||
import random
|
||||
|
||||
# Simulate health metrics
|
||||
response_time = random.uniform(20, 200)
|
||||
success_rate = random.uniform(0.95, 1.0)
|
||||
active_connections = random.randint(100, 1000)
|
||||
|
||||
# Determine health status
|
||||
if response_time < 100 and success_rate > 0.99:
|
||||
status = "healthy"
|
||||
elif response_time < 200 and success_rate > 0.95:
|
||||
status = "degraded"
|
||||
else:
|
||||
status = "unhealthy"
|
||||
|
||||
health = RegionHealth(
|
||||
region_id=region_id,
|
||||
status=status,
|
||||
response_time_ms=response_time,
|
||||
success_rate=success_rate,
|
||||
active_connections=active_connections,
|
||||
last_check=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
region_health_status[region_id] = health
|
||||
|
||||
# Initialize with some default rules
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("Starting AITBC Multi-Region Load Balancer")
|
||||
|
||||
# Initialize default load balancing rules
|
||||
default_rules = [
|
||||
{
|
||||
"rule_id": "global-web-rule",
|
||||
"name": "Global Web Load Balancer",
|
||||
"algorithm": "weighted_round_robin",
|
||||
"target_regions": ["us-east-1", "eu-west-1", "ap-southeast-1"],
|
||||
"weights": {"us-east-1": 0.4, "eu-west-1": 0.35, "ap-southeast-1": 0.25},
|
||||
"health_check_path": "/health",
|
||||
"failover_enabled": True,
|
||||
"session_affinity": False
|
||||
},
|
||||
{
|
||||
"rule_id": "api-performance-rule",
|
||||
"name": "API Performance Optimizer",
|
||||
"algorithm": "performance_based",
|
||||
"target_regions": ["us-east-1", "eu-west-1"],
|
||||
"weights": {"us-east-1": 0.5, "eu-west-1": 0.5},
|
||||
"health_check_path": "/api/health",
|
||||
"failover_enabled": True,
|
||||
"session_affinity": True
|
||||
}
|
||||
]
|
||||
|
||||
for rule_data in default_rules:
|
||||
rule = LoadBalancingRule(**rule_data)
|
||||
rule_record = {
|
||||
"rule_id": rule.rule_id,
|
||||
"name": rule.name,
|
||||
"algorithm": rule.algorithm,
|
||||
"target_regions": rule.target_regions,
|
||||
"weights": rule.weights,
|
||||
"health_check_path": rule.health_check_path,
|
||||
"failover_enabled": rule.failover_enabled,
|
||||
"session_affinity": rule.session_affinity,
|
||||
"status": "active",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"total_requests": 0,
|
||||
"failed_requests": 0,
|
||||
"last_updated": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
load_balancing_rules[rule.rule_id] = rule_record
|
||||
|
||||
# Start health monitoring
|
||||
asyncio.create_task(start_health_monitoring(rule.rule_id))
|
||||
|
||||
# Initialize default geographic rules
|
||||
default_geo_rules = [
|
||||
{
|
||||
"rule_id": "us-to-us",
|
||||
"source_regions": ["us-east", "us-west", "north-america"],
|
||||
"target_regions": ["us-east-1", "us-west-1"],
|
||||
"priority": 1,
|
||||
"latency_threshold_ms": 50
|
||||
},
|
||||
{
|
||||
"rule_id": "eu-to-eu",
|
||||
"source_regions": ["europe", "eu-west", "eu-central"],
|
||||
"target_regions": ["eu-west-1", "eu-central-1"],
|
||||
"priority": 1,
|
||||
"latency_threshold_ms": 30
|
||||
}
|
||||
]
|
||||
|
||||
for geo_rule_data in default_geo_rules:
|
||||
geo_rule = GeographicRule(**geo_rule_data)
|
||||
geo_rule_record = {
|
||||
"rule_id": geo_rule.rule_id,
|
||||
"source_regions": geo_rule.source_regions,
|
||||
"target_regions": geo_rule.target_regions,
|
||||
"priority": geo_rule.priority,
|
||||
"latency_threshold_ms": geo_rule.latency_threshold_ms,
|
||||
"status": "active",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"usage_count": 0
|
||||
}
|
||||
geographic_rules[geo_rule.rule_id] = geo_rule_record
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logger.info("Shutting down AITBC Multi-Region Load Balancer")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
import os
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8019, log_level="info")
|
||||
@@ -0,0 +1 @@
|
||||
"""Multi-region load balancer service tests"""
|
||||
@@ -0,0 +1,199 @@
|
||||
"""Edge case and error handling tests for multi-region load balancer service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, LoadBalancingRule, RegionHealth, LoadBalancingMetrics, GeographicRule, load_balancing_rules, region_health_status, balancing_metrics, geographic_rules
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
load_balancing_rules.clear()
|
||||
region_health_status.clear()
|
||||
balancing_metrics.clear()
|
||||
geographic_rules.clear()
|
||||
yield
|
||||
load_balancing_rules.clear()
|
||||
region_health_status.clear()
|
||||
balancing_metrics.clear()
|
||||
geographic_rules.clear()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_balancing_rule_empty_target_regions():
|
||||
"""Test LoadBalancingRule with empty target regions"""
|
||||
rule = LoadBalancingRule(
|
||||
rule_id="rule_123",
|
||||
name="Test Rule",
|
||||
algorithm="round_robin",
|
||||
target_regions=[],
|
||||
weights={},
|
||||
health_check_path="/health",
|
||||
failover_enabled=False,
|
||||
session_affinity=False
|
||||
)
|
||||
assert rule.target_regions == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_region_health_negative_success_rate():
|
||||
"""Test RegionHealth with negative success rate"""
|
||||
health = RegionHealth(
|
||||
region_id="us-east-1",
|
||||
status="healthy",
|
||||
response_time_ms=45.5,
|
||||
success_rate=-0.5,
|
||||
active_connections=100,
|
||||
last_check=datetime.now(timezone.utc)
|
||||
)
|
||||
assert health.success_rate == -0.5
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_region_health_negative_connections():
|
||||
"""Test RegionHealth with negative connections"""
|
||||
health = RegionHealth(
|
||||
region_id="us-east-1",
|
||||
status="healthy",
|
||||
response_time_ms=45.5,
|
||||
success_rate=0.99,
|
||||
active_connections=-100,
|
||||
last_check=datetime.now(timezone.utc)
|
||||
)
|
||||
assert health.active_connections == -100
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_balancing_metrics_negative_requests():
|
||||
"""Test LoadBalancingMetrics with negative requests"""
|
||||
metrics = LoadBalancingMetrics(
|
||||
balancer_id="lb_123",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
total_requests=-1000,
|
||||
requests_per_region={},
|
||||
average_response_time=50.5,
|
||||
error_rate=0.001,
|
||||
throughput=100.0
|
||||
)
|
||||
assert metrics.total_requests == -1000
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_balancing_metrics_negative_response_time():
|
||||
"""Test LoadBalancingMetrics with negative response time"""
|
||||
metrics = LoadBalancingMetrics(
|
||||
balancer_id="lb_123",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
total_requests=1000,
|
||||
requests_per_region={},
|
||||
average_response_time=-50.5,
|
||||
error_rate=0.001,
|
||||
throughput=100.0
|
||||
)
|
||||
assert metrics.average_response_time == -50.5
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_geographic_rule_empty_source_regions():
|
||||
"""Test GeographicRule with empty source regions"""
|
||||
rule = GeographicRule(
|
||||
rule_id="geo_123",
|
||||
source_regions=[],
|
||||
target_regions=["us-east-1"],
|
||||
priority=1,
|
||||
latency_threshold_ms=50.0
|
||||
)
|
||||
assert rule.source_regions == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_geographic_rule_negative_priority():
|
||||
"""Test GeographicRule with negative priority"""
|
||||
rule = GeographicRule(
|
||||
rule_id="geo_123",
|
||||
source_regions=["us-east"],
|
||||
target_regions=["us-east-1"],
|
||||
priority=-5,
|
||||
latency_threshold_ms=50.0
|
||||
)
|
||||
assert rule.priority == -5
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_geographic_rule_negative_latency_threshold():
|
||||
"""Test GeographicRule with negative latency threshold"""
|
||||
rule = GeographicRule(
|
||||
rule_id="geo_123",
|
||||
source_regions=["us-east"],
|
||||
target_regions=["us-east-1"],
|
||||
priority=1,
|
||||
latency_threshold_ms=-50.0
|
||||
)
|
||||
assert rule.latency_threshold_ms == -50.0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_rules_with_no_rules():
|
||||
"""Test listing rules when no rules exist"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/rules")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_rules"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_region_health_with_no_regions():
|
||||
"""Test getting region health when no regions exist"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_regions"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_balancing_metrics_hours_parameter():
|
||||
"""Test getting balancing metrics with custom hours parameter"""
|
||||
client = TestClient(app)
|
||||
# Create a rule first
|
||||
rule = LoadBalancingRule(
|
||||
rule_id="rule_123",
|
||||
name="Test Rule",
|
||||
algorithm="weighted_round_robin",
|
||||
target_regions=["us-east-1"],
|
||||
weights={"us-east-1": 1.0},
|
||||
health_check_path="/health",
|
||||
failover_enabled=True,
|
||||
session_affinity=False
|
||||
)
|
||||
client.post("/api/v1/rules/create", json=rule.model_dump())
|
||||
|
||||
response = client.get("/api/v1/metrics/rule_123?hours=12")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["period_hours"] == 12
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_optimal_region_nonexistent_rule():
|
||||
"""Test getting optimal region with nonexistent rule"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/route/us-east?rule_id=nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_dashboard_with_no_data():
|
||||
"""Test dashboard with no data"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["dashboard"]["overview"]["total_rules"] == 0
|
||||
@@ -0,0 +1,341 @@
|
||||
"""Integration tests for multi-region load balancer service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, LoadBalancingRule, RegionHealth, LoadBalancingMetrics, GeographicRule, load_balancing_rules, region_health_status, balancing_metrics, geographic_rules
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
load_balancing_rules.clear()
|
||||
region_health_status.clear()
|
||||
balancing_metrics.clear()
|
||||
geographic_rules.clear()
|
||||
yield
|
||||
load_balancing_rules.clear()
|
||||
region_health_status.clear()
|
||||
balancing_metrics.clear()
|
||||
geographic_rules.clear()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_root_endpoint():
|
||||
"""Test root endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["service"] == "AITBC Multi-Region Load Balancer"
|
||||
assert data["status"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_health_check_endpoint():
|
||||
"""Test health check endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "total_rules" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_load_balancing_rule():
|
||||
"""Test creating a load balancing rule"""
|
||||
client = TestClient(app)
|
||||
rule = LoadBalancingRule(
|
||||
rule_id="rule_123",
|
||||
name="Test Rule",
|
||||
algorithm="weighted_round_robin",
|
||||
target_regions=["us-east-1"],
|
||||
weights={"us-east-1": 1.0},
|
||||
health_check_path="/health",
|
||||
failover_enabled=True,
|
||||
session_affinity=False
|
||||
)
|
||||
response = client.post("/api/v1/rules/create", json=rule.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["rule_id"] == "rule_123"
|
||||
assert data["status"] == "created"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_duplicate_rule():
|
||||
"""Test creating duplicate load balancing rule"""
|
||||
client = TestClient(app)
|
||||
rule = LoadBalancingRule(
|
||||
rule_id="rule_123",
|
||||
name="Test Rule",
|
||||
algorithm="weighted_round_robin",
|
||||
target_regions=["us-east-1"],
|
||||
weights={"us-east-1": 1.0},
|
||||
health_check_path="/health",
|
||||
failover_enabled=True,
|
||||
session_affinity=False
|
||||
)
|
||||
client.post("/api/v1/rules/create", json=rule.model_dump())
|
||||
|
||||
response = client.post("/api/v1/rules/create", json=rule.model_dump())
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_load_balancing_rules():
|
||||
"""Test listing load balancing rules"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/rules")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "rules" in data
|
||||
assert "total_rules" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_load_balancing_rule():
|
||||
"""Test getting specific load balancing rule"""
|
||||
client = TestClient(app)
|
||||
rule = LoadBalancingRule(
|
||||
rule_id="rule_123",
|
||||
name="Test Rule",
|
||||
algorithm="weighted_round_robin",
|
||||
target_regions=["us-east-1"],
|
||||
weights={"us-east-1": 1.0},
|
||||
health_check_path="/health",
|
||||
failover_enabled=True,
|
||||
session_affinity=False
|
||||
)
|
||||
client.post("/api/v1/rules/create", json=rule.model_dump())
|
||||
|
||||
response = client.get("/api/v1/rules/rule_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["rule_id"] == "rule_123"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_load_balancing_rule_not_found():
|
||||
"""Test getting nonexistent load balancing rule"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/rules/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_update_rule_weights():
|
||||
"""Test updating rule weights"""
|
||||
client = TestClient(app)
|
||||
rule = LoadBalancingRule(
|
||||
rule_id="rule_123",
|
||||
name="Test Rule",
|
||||
algorithm="weighted_round_robin",
|
||||
target_regions=["us-east-1", "eu-west-1"],
|
||||
weights={"us-east-1": 0.5, "eu-west-1": 0.5},
|
||||
health_check_path="/health",
|
||||
failover_enabled=True,
|
||||
session_affinity=False
|
||||
)
|
||||
client.post("/api/v1/rules/create", json=rule.model_dump())
|
||||
|
||||
new_weights = {"us-east-1": 0.7, "eu-west-1": 0.3}
|
||||
response = client.post("/api/v1/rules/rule_123/update-weights", json=new_weights)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["rule_id"] == "rule_123"
|
||||
assert "new_weights" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_update_rule_weights_not_found():
|
||||
"""Test updating weights for nonexistent rule"""
|
||||
client = TestClient(app)
|
||||
new_weights = {"us-east-1": 1.0}
|
||||
response = client.post("/api/v1/rules/nonexistent/update-weights", json=new_weights)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_update_rule_weights_zero_total():
|
||||
"""Test updating weights with zero total"""
|
||||
client = TestClient(app)
|
||||
rule = LoadBalancingRule(
|
||||
rule_id="rule_123",
|
||||
name="Test Rule",
|
||||
algorithm="weighted_round_robin",
|
||||
target_regions=["us-east-1"],
|
||||
weights={"us-east-1": 1.0},
|
||||
health_check_path="/health",
|
||||
failover_enabled=True,
|
||||
session_affinity=False
|
||||
)
|
||||
client.post("/api/v1/rules/create", json=rule.model_dump())
|
||||
|
||||
new_weights = {"us-east-1": 0.0}
|
||||
response = client.post("/api/v1/rules/rule_123/update-weights", json=new_weights)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_register_region_health():
|
||||
"""Test registering region health"""
|
||||
client = TestClient(app)
|
||||
health = RegionHealth(
|
||||
region_id="us-east-1",
|
||||
status="healthy",
|
||||
response_time_ms=45.5,
|
||||
success_rate=0.99,
|
||||
active_connections=100,
|
||||
last_check=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/health/register", json=health.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["region_id"] == "us-east-1"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_all_region_health():
|
||||
"""Test getting all region health"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "region_health" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_geographic_rule():
|
||||
"""Test creating geographic rule"""
|
||||
client = TestClient(app)
|
||||
rule = GeographicRule(
|
||||
rule_id="geo_123",
|
||||
source_regions=["us-east"],
|
||||
target_regions=["us-east-1"],
|
||||
priority=1,
|
||||
latency_threshold_ms=50.0
|
||||
)
|
||||
response = client.post("/api/v1/geographic-rules/create", json=rule.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["rule_id"] == "geo_123"
|
||||
assert data["status"] == "created"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_duplicate_geographic_rule():
|
||||
"""Test creating duplicate geographic rule"""
|
||||
client = TestClient(app)
|
||||
rule = GeographicRule(
|
||||
rule_id="geo_123",
|
||||
source_regions=["us-east"],
|
||||
target_regions=["us-east-1"],
|
||||
priority=1,
|
||||
latency_threshold_ms=50.0
|
||||
)
|
||||
client.post("/api/v1/geographic-rules/create", json=rule.model_dump())
|
||||
|
||||
response = client.post("/api/v1/geographic-rules/create", json=rule.model_dump())
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_optimal_region():
|
||||
"""Test getting optimal region"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/route/us-east")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "client_region" in data
|
||||
assert "optimal_region" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_optimal_region_with_rule():
|
||||
"""Test getting optimal region with specific rule"""
|
||||
client = TestClient(app)
|
||||
# Create a rule first
|
||||
rule = LoadBalancingRule(
|
||||
rule_id="rule_123",
|
||||
name="Test Rule",
|
||||
algorithm="weighted_round_robin",
|
||||
target_regions=["us-east-1"],
|
||||
weights={"us-east-1": 1.0},
|
||||
health_check_path="/health",
|
||||
failover_enabled=True,
|
||||
session_affinity=False
|
||||
)
|
||||
client.post("/api/v1/rules/create", json=rule.model_dump())
|
||||
|
||||
response = client.get("/api/v1/route/us-east?rule_id=rule_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["rule_id"] == "rule_123"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_record_balancing_metrics():
|
||||
"""Test recording balancing metrics"""
|
||||
client = TestClient(app)
|
||||
metrics = LoadBalancingMetrics(
|
||||
balancer_id="lb_123",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
total_requests=1000,
|
||||
requests_per_region={"us-east-1": 500},
|
||||
average_response_time=50.5,
|
||||
error_rate=0.001,
|
||||
throughput=100.0
|
||||
)
|
||||
response = client.post("/api/v1/metrics/record", json=metrics.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["metrics_id"]
|
||||
assert data["status"] == "recorded"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_balancing_metrics():
|
||||
"""Test getting balancing metrics"""
|
||||
client = TestClient(app)
|
||||
# Create a rule first
|
||||
rule = LoadBalancingRule(
|
||||
rule_id="rule_123",
|
||||
name="Test Rule",
|
||||
algorithm="weighted_round_robin",
|
||||
target_regions=["us-east-1"],
|
||||
weights={"us-east-1": 1.0},
|
||||
health_check_path="/health",
|
||||
failover_enabled=True,
|
||||
session_affinity=False
|
||||
)
|
||||
client.post("/api/v1/rules/create", json=rule.model_dump())
|
||||
|
||||
response = client.get("/api/v1/metrics/rule_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["rule_id"] == "rule_123"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_balancing_metrics_not_found():
|
||||
"""Test getting metrics for nonexistent rule"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/metrics/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_load_balancing_dashboard():
|
||||
"""Test getting load balancing dashboard"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "dashboard" in data
|
||||
@@ -0,0 +1,120 @@
|
||||
"""Unit tests for multi-region load balancer service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, LoadBalancingRule, RegionHealth, LoadBalancingMetrics, GeographicRule
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_app_initialization():
|
||||
"""Test that the FastAPI app initializes correctly"""
|
||||
assert app is not None
|
||||
assert app.title == "AITBC Multi-Region Load Balancer"
|
||||
assert app.version == "1.0.0"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_balancing_rule_model():
|
||||
"""Test LoadBalancingRule model"""
|
||||
rule = LoadBalancingRule(
|
||||
rule_id="rule_123",
|
||||
name="Test Rule",
|
||||
algorithm="weighted_round_robin",
|
||||
target_regions=["us-east-1", "eu-west-1"],
|
||||
weights={"us-east-1": 0.5, "eu-west-1": 0.5},
|
||||
health_check_path="/health",
|
||||
failover_enabled=True,
|
||||
session_affinity=False
|
||||
)
|
||||
assert rule.rule_id == "rule_123"
|
||||
assert rule.name == "Test Rule"
|
||||
assert rule.algorithm == "weighted_round_robin"
|
||||
assert rule.failover_enabled is True
|
||||
assert rule.session_affinity is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_region_health_model():
|
||||
"""Test RegionHealth model"""
|
||||
health = RegionHealth(
|
||||
region_id="us-east-1",
|
||||
status="healthy",
|
||||
response_time_ms=45.5,
|
||||
success_rate=0.99,
|
||||
active_connections=100,
|
||||
last_check=datetime.now(timezone.utc)
|
||||
)
|
||||
assert health.region_id == "us-east-1"
|
||||
assert health.status == "healthy"
|
||||
assert health.response_time_ms == 45.5
|
||||
assert health.success_rate == 0.99
|
||||
assert health.active_connections == 100
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_balancing_metrics_model():
|
||||
"""Test LoadBalancingMetrics model"""
|
||||
metrics = LoadBalancingMetrics(
|
||||
balancer_id="lb_123",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
total_requests=1000,
|
||||
requests_per_region={"us-east-1": 500, "eu-west-1": 500},
|
||||
average_response_time=50.5,
|
||||
error_rate=0.001,
|
||||
throughput=100.0
|
||||
)
|
||||
assert metrics.balancer_id == "lb_123"
|
||||
assert metrics.total_requests == 1000
|
||||
assert metrics.average_response_time == 50.5
|
||||
assert metrics.error_rate == 0.001
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_geographic_rule_model():
|
||||
"""Test GeographicRule model"""
|
||||
rule = GeographicRule(
|
||||
rule_id="geo_123",
|
||||
source_regions=["us-east", "us-west"],
|
||||
target_regions=["us-east-1", "us-west-1"],
|
||||
priority=1,
|
||||
latency_threshold_ms=50.0
|
||||
)
|
||||
assert rule.rule_id == "geo_123"
|
||||
assert rule.source_regions == ["us-east", "us-west"]
|
||||
assert rule.priority == 1
|
||||
assert rule.latency_threshold_ms == 50.0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_balancing_rule_empty_weights():
|
||||
"""Test LoadBalancingRule with empty weights"""
|
||||
rule = LoadBalancingRule(
|
||||
rule_id="rule_123",
|
||||
name="Test Rule",
|
||||
algorithm="round_robin",
|
||||
target_regions=["us-east-1"],
|
||||
weights={},
|
||||
health_check_path="/health",
|
||||
failover_enabled=False,
|
||||
session_affinity=False
|
||||
)
|
||||
assert rule.weights == {}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_region_health_negative_response_time():
|
||||
"""Test RegionHealth with negative response time"""
|
||||
health = RegionHealth(
|
||||
region_id="us-east-1",
|
||||
status="healthy",
|
||||
response_time_ms=-45.5,
|
||||
success_rate=0.99,
|
||||
active_connections=100,
|
||||
last_check=datetime.now(timezone.utc)
|
||||
)
|
||||
assert health.response_time_ms == -45.5
|
||||
655
examples/stubs/plugin-analytics/main.py
Executable file
655
examples/stubs/plugin-analytics/main.py
Executable file
@@ -0,0 +1,655 @@
|
||||
"""
|
||||
Plugin Analytics and Usage Tracking Service for AITBC
|
||||
Handles plugin analytics, usage tracking, and performance monitoring
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="AITBC Plugin Analytics Service",
|
||||
description="Plugin analytics, usage tracking, and performance monitoring",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Data models
|
||||
class PluginUsage(BaseModel):
|
||||
plugin_id: str
|
||||
user_id: str
|
||||
action: str # install, uninstall, enable, disable, use
|
||||
timestamp: datetime
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
class PluginPerformance(BaseModel):
|
||||
plugin_id: str
|
||||
version: str
|
||||
cpu_usage: float
|
||||
memory_usage: float
|
||||
response_time: float
|
||||
error_rate: float
|
||||
uptime: float
|
||||
timestamp: datetime
|
||||
|
||||
class PluginRating(BaseModel):
|
||||
plugin_id: str
|
||||
user_id: str
|
||||
rating: int # 1-5
|
||||
review: Optional[str] = None
|
||||
timestamp: datetime
|
||||
|
||||
class PluginEvent(BaseModel):
|
||||
event_type: str
|
||||
plugin_id: str
|
||||
user_id: Optional[str] = None
|
||||
data: Dict[str, Any] = {}
|
||||
timestamp: datetime
|
||||
|
||||
# In-memory storage (in production, use database)
|
||||
plugin_usage_data: Dict[str, List[Dict]] = {}
|
||||
plugin_performance_data: Dict[str, List[Dict]] = {}
|
||||
plugin_ratings: Dict[str, List[Dict]] = {}
|
||||
plugin_events: Dict[str, List[Dict]] = {}
|
||||
analytics_cache: Dict[str, Dict] = {}
|
||||
usage_trends: Dict[str, Dict] = {}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "AITBC Plugin Analytics Service",
|
||||
"status": "running",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"total_usage_records": sum(len(data) for data in plugin_usage_data.values()),
|
||||
"total_performance_records": sum(len(data) for data in plugin_performance_data.values()),
|
||||
"total_ratings": sum(len(data) for data in plugin_ratings.values()),
|
||||
"total_events": sum(len(data) for data in plugin_events.values()),
|
||||
"cache_size": len(analytics_cache)
|
||||
}
|
||||
|
||||
@app.post("/api/v1/analytics/usage")
|
||||
async def record_plugin_usage(usage: PluginUsage):
|
||||
"""Record plugin usage event"""
|
||||
usage_record = {
|
||||
"usage_id": f"usage_{int(datetime.now(timezone.utc).timestamp())}",
|
||||
"plugin_id": usage.plugin_id,
|
||||
"user_id": usage.user_id,
|
||||
"action": usage.action,
|
||||
"timestamp": usage.timestamp.isoformat(),
|
||||
"metadata": usage.metadata
|
||||
}
|
||||
|
||||
if usage.plugin_id not in plugin_usage_data:
|
||||
plugin_usage_data[usage.plugin_id] = []
|
||||
|
||||
plugin_usage_data[usage.plugin_id].append(usage_record)
|
||||
|
||||
# Update usage trends
|
||||
update_usage_trends(usage.plugin_id, usage.action, usage.timestamp)
|
||||
|
||||
logger.info(f"Usage recorded: {usage.plugin_id} - {usage.action} by {usage.user_id}")
|
||||
|
||||
return {
|
||||
"usage_id": usage_record["usage_id"],
|
||||
"status": "recorded",
|
||||
"timestamp": usage_record["timestamp"]
|
||||
}
|
||||
|
||||
@app.post("/api/v1/analytics/performance")
|
||||
async def record_plugin_performance(performance: PluginPerformance):
|
||||
"""Record plugin performance metrics"""
|
||||
performance_record = {
|
||||
"performance_id": f"perf_{int(datetime.now(timezone.utc).timestamp())}",
|
||||
"plugin_id": performance.plugin_id,
|
||||
"version": performance.version,
|
||||
"cpu_usage": performance.cpu_usage,
|
||||
"memory_usage": performance.memory_usage,
|
||||
"response_time": performance.response_time,
|
||||
"error_rate": performance.error_rate,
|
||||
"uptime": performance.uptime,
|
||||
"timestamp": performance.timestamp.isoformat()
|
||||
}
|
||||
|
||||
if performance.plugin_id not in plugin_performance_data:
|
||||
plugin_performance_data[performance.plugin_id] = []
|
||||
|
||||
plugin_performance_data[performance.plugin_id].append(performance_record)
|
||||
|
||||
logger.info(f"Performance recorded: {performance.plugin_id} - CPU: {performance.cpu_usage}%, Memory: {performance.memory_usage}%")
|
||||
|
||||
return {
|
||||
"performance_id": performance_record["performance_id"],
|
||||
"status": "recorded",
|
||||
"timestamp": performance_record["timestamp"]
|
||||
}
|
||||
|
||||
@app.post("/api/v1/analytics/rating")
|
||||
async def record_plugin_rating(rating: PluginRating):
|
||||
"""Record plugin rating and review"""
|
||||
rating_record = {
|
||||
"rating_id": f"rating_{int(datetime.now(timezone.utc).timestamp())}",
|
||||
"plugin_id": rating.plugin_id,
|
||||
"user_id": rating.user_id,
|
||||
"rating": rating.rating,
|
||||
"review": rating.review,
|
||||
"timestamp": rating.timestamp.isoformat()
|
||||
}
|
||||
|
||||
if rating.plugin_id not in plugin_ratings:
|
||||
plugin_ratings[rating.plugin_id] = []
|
||||
|
||||
plugin_ratings[rating.plugin_id].append(rating_record)
|
||||
|
||||
logger.info(f"Rating recorded: {rating.plugin_id} - {rating.rating} stars by {rating.user_id}")
|
||||
|
||||
return {
|
||||
"rating_id": rating_record["rating_id"],
|
||||
"status": "recorded",
|
||||
"timestamp": rating_record["timestamp"]
|
||||
}
|
||||
|
||||
@app.post("/api/v1/analytics/event")
|
||||
async def record_plugin_event(event: PluginEvent):
|
||||
"""Record generic plugin event"""
|
||||
event_record = {
|
||||
"event_id": f"event_{int(datetime.now(timezone.utc).timestamp())}",
|
||||
"event_type": event.event_type,
|
||||
"plugin_id": event.plugin_id,
|
||||
"user_id": event.user_id,
|
||||
"data": event.data,
|
||||
"timestamp": event.timestamp.isoformat()
|
||||
}
|
||||
|
||||
if event.plugin_id not in plugin_events:
|
||||
plugin_events[event.plugin_id] = []
|
||||
|
||||
plugin_events[event.plugin_id].append(event_record)
|
||||
|
||||
logger.info(f"Event recorded: {event.event_type} for {event.plugin_id}")
|
||||
|
||||
return {
|
||||
"event_id": event_record["event_id"],
|
||||
"status": "recorded",
|
||||
"timestamp": event_record["timestamp"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/analytics/usage/{plugin_id}")
|
||||
async def get_plugin_usage(plugin_id: str, days: int = 30):
|
||||
"""Get usage analytics for a specific plugin"""
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
|
||||
usage_records = plugin_usage_data.get(plugin_id, [])
|
||||
recent_usage = [r for r in usage_records
|
||||
if datetime.fromisoformat(r["timestamp"]) > cutoff_date]
|
||||
|
||||
# Calculate usage statistics
|
||||
usage_stats = calculate_usage_statistics(recent_usage)
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"period_days": days,
|
||||
"usage_statistics": usage_stats,
|
||||
"total_records": len(recent_usage),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/analytics/performance/{plugin_id}")
|
||||
async def get_plugin_performance(plugin_id: str, hours: int = 24):
|
||||
"""Get performance analytics for a specific plugin"""
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
|
||||
performance_records = plugin_performance_data.get(plugin_id, [])
|
||||
recent_performance = [r for r in performance_records
|
||||
if datetime.fromisoformat(r["timestamp"]) > cutoff_time]
|
||||
|
||||
# Calculate performance statistics
|
||||
performance_stats = calculate_performance_statistics(recent_performance)
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"period_hours": hours,
|
||||
"performance_statistics": performance_stats,
|
||||
"total_records": len(recent_performance),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/analytics/ratings/{plugin_id}")
|
||||
async def get_plugin_ratings(plugin_id: str):
|
||||
"""Get ratings and reviews for a specific plugin"""
|
||||
rating_records = plugin_ratings.get(plugin_id, [])
|
||||
|
||||
# Calculate rating statistics
|
||||
rating_stats = calculate_rating_statistics(rating_records)
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"rating_statistics": rating_stats,
|
||||
"total_ratings": len(rating_records),
|
||||
"recent_ratings": rating_records[-10:], # Last 10 ratings
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/analytics/dashboard")
|
||||
async def get_analytics_dashboard():
|
||||
"""Get comprehensive analytics dashboard"""
|
||||
dashboard_data = {
|
||||
"overview": get_overview_statistics(),
|
||||
"trending_plugins": get_trending_plugins(),
|
||||
"usage_trends": get_global_usage_trends(),
|
||||
"performance_summary": get_performance_summary(),
|
||||
"rating_summary": get_rating_summary(),
|
||||
"recent_events": get_recent_events()
|
||||
}
|
||||
|
||||
return {
|
||||
"dashboard": dashboard_data,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/analytics/trends")
|
||||
async def get_usage_trends(plugin_id: Optional[str] = None, days: int = 30):
|
||||
"""Get usage trends data"""
|
||||
if plugin_id:
|
||||
return get_plugin_trends(plugin_id, days)
|
||||
else:
|
||||
return get_global_usage_trends(days)
|
||||
|
||||
@app.get("/api/v1/analytics/reports")
|
||||
async def generate_analytics_report(report_type: str, plugin_id: Optional[str] = None):
|
||||
"""Generate various analytics reports"""
|
||||
if report_type == "usage":
|
||||
return generate_usage_report(plugin_id)
|
||||
elif report_type == "performance":
|
||||
return generate_performance_report(plugin_id)
|
||||
elif report_type == "ratings":
|
||||
return generate_ratings_report(plugin_id)
|
||||
elif report_type == "summary":
|
||||
return generate_summary_report(plugin_id)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid report type")
|
||||
|
||||
# Analytics calculation functions
|
||||
def calculate_usage_statistics(usage_records: List[Dict]) -> Dict[str, Any]:
|
||||
"""Calculate usage statistics from usage records"""
|
||||
if not usage_records:
|
||||
return {
|
||||
"total_actions": 0,
|
||||
"unique_users": 0,
|
||||
"action_distribution": {},
|
||||
"daily_usage": {}
|
||||
}
|
||||
|
||||
# Basic statistics
|
||||
total_actions = len(usage_records)
|
||||
unique_users = len(set(r["user_id"] for r in usage_records))
|
||||
|
||||
# Action distribution
|
||||
action_counts = {}
|
||||
for record in usage_records:
|
||||
action = record["action"]
|
||||
action_counts[action] = action_counts.get(action, 0) + 1
|
||||
|
||||
# Daily usage
|
||||
daily_usage = {}
|
||||
for record in usage_records:
|
||||
date = datetime.fromisoformat(record["timestamp"]).date().isoformat()
|
||||
daily_usage[date] = daily_usage.get(date, 0) + 1
|
||||
|
||||
return {
|
||||
"total_actions": total_actions,
|
||||
"unique_users": unique_users,
|
||||
"action_distribution": action_counts,
|
||||
"daily_usage": daily_usage,
|
||||
"most_common_action": max(action_counts.items(), key=lambda x: x[1])[0] if action_counts else None
|
||||
}
|
||||
|
||||
def calculate_performance_statistics(performance_records: List[Dict]) -> Dict[str, Any]:
|
||||
"""Calculate performance statistics from performance records"""
|
||||
if not performance_records:
|
||||
return {
|
||||
"avg_cpu_usage": 0.0,
|
||||
"avg_memory_usage": 0.0,
|
||||
"avg_response_time": 0.0,
|
||||
"avg_error_rate": 0.0,
|
||||
"avg_uptime": 0.0
|
||||
}
|
||||
|
||||
# Calculate averages
|
||||
cpu_usage = sum(r["cpu_usage"] for r in performance_records) / len(performance_records)
|
||||
memory_usage = sum(r["memory_usage"] for r in performance_records) / len(performance_records)
|
||||
response_time = sum(r["response_time"] for r in performance_records) / len(performance_records)
|
||||
error_rate = sum(r["error_rate"] for r in performance_records) / len(performance_records)
|
||||
uptime = sum(r["uptime"] for r in performance_records) / len(performance_records)
|
||||
|
||||
# Calculate min/max
|
||||
min_cpu = min(r["cpu_usage"] for r in performance_records)
|
||||
max_cpu = max(r["cpu_usage"] for r in performance_records)
|
||||
|
||||
return {
|
||||
"avg_cpu_usage": round(cpu_usage, 2),
|
||||
"avg_memory_usage": round(memory_usage, 2),
|
||||
"avg_response_time": round(response_time, 3),
|
||||
"avg_error_rate": round(error_rate, 4),
|
||||
"avg_uptime": round(uptime, 2),
|
||||
"min_cpu_usage": round(min_cpu, 2),
|
||||
"max_cpu_usage": round(max_cpu, 2),
|
||||
"total_samples": len(performance_records)
|
||||
}
|
||||
|
||||
def calculate_rating_statistics(rating_records: List[Dict]) -> Dict[str, Any]:
|
||||
"""Calculate rating statistics from rating records"""
|
||||
if not rating_records:
|
||||
return {
|
||||
"average_rating": 0.0,
|
||||
"total_ratings": 0,
|
||||
"rating_distribution": {1: 0, 2: 0, 3: 0, 4: 0, 5: 0}
|
||||
}
|
||||
|
||||
# Calculate average rating
|
||||
total_rating = sum(r["rating"] for r in rating_records)
|
||||
average_rating = total_rating / len(rating_records)
|
||||
|
||||
# Rating distribution
|
||||
rating_distribution = {1: 0, 2: 0, 3: 0, 4: 0, 5: 0}
|
||||
for record in rating_records:
|
||||
rating_distribution[record["rating"]] += 1
|
||||
|
||||
return {
|
||||
"average_rating": round(average_rating, 2),
|
||||
"total_ratings": len(rating_records),
|
||||
"rating_distribution": rating_distribution,
|
||||
"latest_rating": rating_records[-1]["rating"] if rating_records else 0
|
||||
}
|
||||
|
||||
def update_usage_trends(plugin_id: str, action: str, timestamp: datetime):
|
||||
"""Update usage trends data"""
|
||||
if plugin_id not in usage_trends:
|
||||
usage_trends[plugin_id] = {
|
||||
"daily": {},
|
||||
"weekly": {},
|
||||
"monthly": {}
|
||||
}
|
||||
|
||||
# Update daily trends
|
||||
date_key = timestamp.date().isoformat()
|
||||
if date_key not in usage_trends[plugin_id]["daily"]:
|
||||
usage_trends[plugin_id]["daily"][date_key] = {}
|
||||
|
||||
usage_trends[plugin_id]["daily"][date_key][action] = usage_trends[plugin_id]["daily"][date_key].get(action, 0) + 1
|
||||
|
||||
def get_overview_statistics() -> Dict[str, Any]:
|
||||
"""Get overview statistics for all plugins"""
|
||||
total_plugins = len(set(plugin_usage_data.keys()) | set(plugin_performance_data.keys()) | set(plugin_ratings.keys()))
|
||||
total_usage = sum(len(data) for data in plugin_usage_data.values())
|
||||
total_ratings = sum(len(data) for data in plugin_ratings.values())
|
||||
|
||||
# Calculate active plugins (plugins with usage in last 7 days)
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=7)
|
||||
active_plugins = 0
|
||||
|
||||
for plugin_id, usage_records in plugin_usage_data.items():
|
||||
recent_usage = [r for r in usage_records
|
||||
if datetime.fromisoformat(r["timestamp"]) > cutoff_date]
|
||||
if recent_usage:
|
||||
active_plugins += 1
|
||||
|
||||
return {
|
||||
"total_plugins": total_plugins,
|
||||
"active_plugins": active_plugins,
|
||||
"total_usage_events": total_usage,
|
||||
"total_ratings": total_ratings,
|
||||
"average_ratings_per_plugin": round(total_ratings / total_plugins, 2) if total_plugins > 0 else 0
|
||||
}
|
||||
|
||||
def get_trending_plugins(limit: int = 10) -> List[Dict]:
|
||||
"""Get trending plugins based on recent usage"""
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=7)
|
||||
|
||||
plugin_scores = []
|
||||
|
||||
for plugin_id, usage_records in plugin_usage_data.items():
|
||||
recent_usage = [r for r in usage_records
|
||||
if datetime.fromisoformat(r["timestamp"]) > cutoff_date]
|
||||
|
||||
if recent_usage:
|
||||
# Calculate trend score (simplified)
|
||||
score = len(recent_usage) + len(set(r["user_id"] for r in recent_usage))
|
||||
|
||||
plugin_scores.append({
|
||||
"plugin_id": plugin_id,
|
||||
"trend_score": score,
|
||||
"recent_usage": len(recent_usage),
|
||||
"unique_users": len(set(r["user_id"] for r in recent_usage))
|
||||
})
|
||||
|
||||
# Sort by trend score
|
||||
plugin_scores.sort(key=lambda x: x["trend_score"], reverse=True)
|
||||
|
||||
return plugin_scores[:limit]
|
||||
|
||||
def get_global_usage_trends(days: int = 30) -> Dict[str, Any]:
|
||||
"""Get global usage trends"""
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
global_trends = {}
|
||||
|
||||
for plugin_id, usage_records in plugin_usage_data.items():
|
||||
recent_usage = [r for r in usage_records
|
||||
if datetime.fromisoformat(r["timestamp"]) > cutoff_date]
|
||||
|
||||
if recent_usage:
|
||||
daily_counts = {}
|
||||
for record in recent_usage:
|
||||
date = datetime.fromisoformat(record["timestamp"]).date().isoformat()
|
||||
daily_counts[date] = daily_counts.get(date, 0) + 1
|
||||
|
||||
global_trends[plugin_id] = daily_counts
|
||||
|
||||
return {
|
||||
"trends": global_trends,
|
||||
"period_days": days,
|
||||
"total_plugins": len(global_trends)
|
||||
}
|
||||
|
||||
def get_performance_summary() -> Dict[str, Any]:
|
||||
"""Get performance summary for all plugins"""
|
||||
all_performance = []
|
||||
|
||||
for plugin_id, performance_records in plugin_performance_data.items():
|
||||
if performance_records:
|
||||
latest_record = performance_records[-1]
|
||||
all_performance.append({
|
||||
"plugin_id": plugin_id,
|
||||
"cpu_usage": latest_record["cpu_usage"],
|
||||
"memory_usage": latest_record["memory_usage"],
|
||||
"response_time": latest_record["response_time"],
|
||||
"error_rate": latest_record["error_rate"]
|
||||
})
|
||||
|
||||
# Calculate averages
|
||||
if all_performance:
|
||||
avg_cpu = sum(p["cpu_usage"] for p in all_performance) / len(all_performance)
|
||||
avg_memory = sum(p["memory_usage"] for p in all_performance) / len(all_performance)
|
||||
avg_response = sum(p["response_time"] for p in all_performance) / len(all_performance)
|
||||
avg_error = sum(p["error_rate"] for p in all_performance) / len(all_performance)
|
||||
else:
|
||||
avg_cpu = avg_memory = avg_response = avg_error = 0.0
|
||||
|
||||
return {
|
||||
"total_plugins": len(all_performance),
|
||||
"average_cpu_usage": round(avg_cpu, 2),
|
||||
"average_memory_usage": round(avg_memory, 2),
|
||||
"average_response_time": round(avg_response, 3),
|
||||
"average_error_rate": round(avg_error, 4),
|
||||
"top_cpu_users": sorted(all_performance, key=lambda x: x["cpu_usage"], reverse=True)[:5]
|
||||
}
|
||||
|
||||
def get_rating_summary() -> Dict[str, Any]:
|
||||
"""Get rating summary for all plugins"""
|
||||
all_ratings = []
|
||||
|
||||
for plugin_id, rating_records in plugin_ratings.items():
|
||||
if rating_records:
|
||||
avg_rating = sum(r["rating"] for r in rating_records) / len(rating_records)
|
||||
all_ratings.append({
|
||||
"plugin_id": plugin_id,
|
||||
"average_rating": round(avg_rating, 2),
|
||||
"total_ratings": len(rating_records)
|
||||
})
|
||||
|
||||
# Sort by rating
|
||||
all_ratings.sort(key=lambda x: x["average_rating"], reverse=True)
|
||||
|
||||
return {
|
||||
"total_plugins": len(all_ratings),
|
||||
"top_rated": all_ratings[:10],
|
||||
"average_rating_all": round(sum(r["average_rating"] for r in all_ratings) / len(all_ratings), 2) if all_ratings else 0.0
|
||||
}
|
||||
|
||||
def get_recent_events(limit: int = 20) -> List[Dict]:
|
||||
"""Get recent plugin events"""
|
||||
all_events = []
|
||||
|
||||
for plugin_id, events in plugin_events.items():
|
||||
for event in events:
|
||||
all_events.append({
|
||||
"plugin_id": plugin_id,
|
||||
"event_type": event["event_type"],
|
||||
"timestamp": event["timestamp"],
|
||||
"user_id": event.get("user_id")
|
||||
})
|
||||
|
||||
# Sort by timestamp (most recent first)
|
||||
all_events.sort(key=lambda x: x["timestamp"], reverse=True)
|
||||
|
||||
return all_events[:limit]
|
||||
|
||||
def get_plugin_trends(plugin_id: str, days: int) -> Dict[str, Any]:
|
||||
"""Get trends for a specific plugin"""
|
||||
plugin_trends = usage_trends.get(plugin_id, {})
|
||||
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
date_key = cutoff_date.date().isoformat()
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"trends": plugin_trends,
|
||||
"period_days": days,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Report generation functions
|
||||
def generate_usage_report(plugin_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Generate usage report"""
|
||||
if plugin_id:
|
||||
return get_plugin_usage(plugin_id, days=30)
|
||||
else:
|
||||
return get_global_usage_trends(days=30)
|
||||
|
||||
def generate_performance_report(plugin_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Generate performance report"""
|
||||
if plugin_id:
|
||||
return get_plugin_performance(plugin_id, hours=24)
|
||||
else:
|
||||
return get_performance_summary()
|
||||
|
||||
def generate_ratings_report(plugin_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Generate ratings report"""
|
||||
if plugin_id:
|
||||
return get_plugin_ratings(plugin_id)
|
||||
else:
|
||||
return get_rating_summary()
|
||||
|
||||
def generate_summary_report(plugin_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Generate comprehensive summary report"""
|
||||
if plugin_id:
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"usage": get_plugin_usage(plugin_id, days=30),
|
||||
"performance": get_plugin_performance(plugin_id, hours=24),
|
||||
"ratings": get_plugin_ratings(plugin_id),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
else:
|
||||
return get_analytics_dashboard()
|
||||
|
||||
# Background task for analytics processing
|
||||
async def process_analytics():
|
||||
"""Background task to process analytics data"""
|
||||
while True:
|
||||
await asyncio.sleep(3600) # Process every hour
|
||||
|
||||
# Update analytics cache
|
||||
update_analytics_cache()
|
||||
|
||||
# Clean old data (older than 90 days)
|
||||
cleanup_old_data()
|
||||
|
||||
logger.info("Analytics processing completed")
|
||||
|
||||
def update_analytics_cache():
|
||||
"""Update analytics cache with frequently accessed data"""
|
||||
# Cache trending plugins
|
||||
analytics_cache["trending_plugins"] = get_trending_plugins()
|
||||
|
||||
# Cache overview statistics
|
||||
analytics_cache["overview"] = get_overview_statistics()
|
||||
|
||||
# Cache performance summary
|
||||
analytics_cache["performance_summary"] = get_performance_summary()
|
||||
|
||||
def cleanup_old_data():
|
||||
"""Clean up old analytics data"""
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
cutoff_iso = cutoff_date.isoformat()
|
||||
|
||||
# Clean usage data
|
||||
for plugin_id in plugin_usage_data:
|
||||
plugin_usage_data[plugin_id] = [
|
||||
r for r in plugin_usage_data[plugin_id]
|
||||
if r["timestamp"] > cutoff_iso
|
||||
]
|
||||
|
||||
# Clean performance data
|
||||
for plugin_id in plugin_performance_data:
|
||||
plugin_performance_data[plugin_id] = [
|
||||
r for r in plugin_performance_data[plugin_id]
|
||||
if r["timestamp"] > cutoff_iso
|
||||
]
|
||||
|
||||
# Clean events data
|
||||
for plugin_id in plugin_events:
|
||||
plugin_events[plugin_id] = [
|
||||
r for r in plugin_events[plugin_id]
|
||||
if r["timestamp"] > cutoff_iso
|
||||
]
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("Starting AITBC Plugin Analytics Service")
|
||||
# Initialize analytics cache
|
||||
update_analytics_cache()
|
||||
# Start analytics processing
|
||||
asyncio.create_task(process_analytics())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logger.info("Shutting down AITBC Plugin Analytics Service")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8016, log_level="info")
|
||||
1
examples/stubs/plugin-analytics/tests/__init__.py
Normal file
1
examples/stubs/plugin-analytics/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Plugin analytics service tests"""
|
||||
@@ -0,0 +1,168 @@
|
||||
"""Edge case and error handling tests for plugin analytics service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, PluginUsage, PluginPerformance, PluginRating, PluginEvent, plugin_usage_data, plugin_performance_data, plugin_ratings, plugin_events
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
plugin_usage_data.clear()
|
||||
plugin_performance_data.clear()
|
||||
plugin_ratings.clear()
|
||||
plugin_events.clear()
|
||||
yield
|
||||
plugin_usage_data.clear()
|
||||
plugin_performance_data.clear()
|
||||
plugin_ratings.clear()
|
||||
plugin_events.clear()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_usage_empty_plugin_id():
|
||||
"""Test PluginUsage with empty plugin_id"""
|
||||
usage = PluginUsage(
|
||||
plugin_id="",
|
||||
user_id="user_123",
|
||||
action="install",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert usage.plugin_id == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_performance_negative_values():
|
||||
"""Test PluginPerformance with negative values"""
|
||||
perf = PluginPerformance(
|
||||
plugin_id="plugin_123",
|
||||
version="1.0.0",
|
||||
cpu_usage=-10.0,
|
||||
memory_usage=-5.0,
|
||||
response_time=-0.1,
|
||||
error_rate=-0.01,
|
||||
uptime=-50.0,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert perf.cpu_usage == -10.0
|
||||
assert perf.memory_usage == -5.0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_rating_out_of_range():
|
||||
"""Test PluginRating with out of range rating"""
|
||||
rating = PluginRating(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
rating=10,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert rating.rating == 10
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_rating_zero():
|
||||
"""Test PluginRating with zero rating"""
|
||||
rating = PluginRating(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
rating=0,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert rating.rating == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin_usage_no_data():
|
||||
"""Test getting plugin usage when no data exists"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/usage/nonexistent")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_records"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin_performance_no_data():
|
||||
"""Test getting plugin performance when no data exists"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/performance/nonexistent")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_records"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin_ratings_no_data():
|
||||
"""Test getting plugin ratings when no data exists"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/ratings/nonexistent")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_ratings"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_dashboard_with_no_data():
|
||||
"""Test dashboard with no data"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["dashboard"]["overview"]["total_plugins"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_record_multiple_usage_events():
|
||||
"""Test recording multiple usage events for same plugin"""
|
||||
client = TestClient(app)
|
||||
|
||||
for i in range(5):
|
||||
usage = PluginUsage(
|
||||
plugin_id="plugin_123",
|
||||
user_id=f"user_{i}",
|
||||
action="use",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/analytics/usage", json=usage.model_dump(mode='json'))
|
||||
|
||||
response = client.get("/api/v1/analytics/usage/plugin_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_records"] == 5
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_usage_trends_days_parameter():
|
||||
"""Test usage trends with custom days parameter"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/trends?days=7")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "trends" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin_usage_days_parameter():
|
||||
"""Test getting plugin usage with custom days parameter"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/usage/plugin_123?days=7")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["period_days"] == 7
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin_performance_hours_parameter():
|
||||
"""Test getting plugin performance with custom hours parameter"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/performance/plugin_123?hours=12")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["period_hours"] == 12
|
||||
@@ -0,0 +1,253 @@
|
||||
"""Integration tests for plugin analytics service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, PluginUsage, PluginPerformance, PluginRating, PluginEvent, plugin_usage_data, plugin_performance_data, plugin_ratings, plugin_events
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
plugin_usage_data.clear()
|
||||
plugin_performance_data.clear()
|
||||
plugin_ratings.clear()
|
||||
plugin_events.clear()
|
||||
yield
|
||||
plugin_usage_data.clear()
|
||||
plugin_performance_data.clear()
|
||||
plugin_ratings.clear()
|
||||
plugin_events.clear()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_root_endpoint():
|
||||
"""Test root endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["service"] == "AITBC Plugin Analytics Service"
|
||||
assert data["status"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_health_check_endpoint():
|
||||
"""Test health check endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "total_usage_records" in data
|
||||
assert "total_performance_records" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_record_plugin_usage():
|
||||
"""Test recording plugin usage"""
|
||||
client = TestClient(app)
|
||||
usage = PluginUsage(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
action="install",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/analytics/usage", json=usage.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["usage_id"]
|
||||
assert data["status"] == "recorded"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_record_plugin_performance():
|
||||
"""Test recording plugin performance"""
|
||||
client = TestClient(app)
|
||||
perf = PluginPerformance(
|
||||
plugin_id="plugin_123",
|
||||
version="1.0.0",
|
||||
cpu_usage=50.5,
|
||||
memory_usage=30.2,
|
||||
response_time=0.123,
|
||||
error_rate=0.001,
|
||||
uptime=99.9,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/analytics/performance", json=perf.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["performance_id"]
|
||||
assert data["status"] == "recorded"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_record_plugin_rating():
|
||||
"""Test recording plugin rating"""
|
||||
client = TestClient(app)
|
||||
rating = PluginRating(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
rating=5,
|
||||
review="Great plugin!",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/analytics/rating", json=rating.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["rating_id"]
|
||||
assert data["status"] == "recorded"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_record_plugin_event():
|
||||
"""Test recording plugin event"""
|
||||
client = TestClient(app)
|
||||
event = PluginEvent(
|
||||
event_type="error",
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
data={"error": "timeout"},
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/analytics/event", json=event.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["event_id"]
|
||||
assert data["status"] == "recorded"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin_usage():
|
||||
"""Test getting plugin usage analytics"""
|
||||
client = TestClient(app)
|
||||
# Record usage first
|
||||
usage = PluginUsage(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
action="install",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/analytics/usage", json=usage.model_dump(mode='json'))
|
||||
|
||||
response = client.get("/api/v1/analytics/usage/plugin_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["plugin_id"] == "plugin_123"
|
||||
assert "usage_statistics" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin_performance():
|
||||
"""Test getting plugin performance analytics"""
|
||||
client = TestClient(app)
|
||||
# Record performance first
|
||||
perf = PluginPerformance(
|
||||
plugin_id="plugin_123",
|
||||
version="1.0.0",
|
||||
cpu_usage=50.5,
|
||||
memory_usage=30.2,
|
||||
response_time=0.123,
|
||||
error_rate=0.001,
|
||||
uptime=99.9,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/analytics/performance", json=perf.model_dump(mode='json'))
|
||||
|
||||
response = client.get("/api/v1/analytics/performance/plugin_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["plugin_id"] == "plugin_123"
|
||||
assert "performance_statistics" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin_ratings():
|
||||
"""Test getting plugin ratings"""
|
||||
client = TestClient(app)
|
||||
# Record rating first
|
||||
rating = PluginRating(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
rating=5,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/analytics/rating", json=rating.model_dump(mode='json'))
|
||||
|
||||
response = client.get("/api/v1/analytics/ratings/plugin_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["plugin_id"] == "plugin_123"
|
||||
assert "rating_statistics" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_analytics_dashboard():
|
||||
"""Test getting analytics dashboard"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "dashboard" in data
|
||||
assert "overview" in data["dashboard"]
|
||||
assert "trending_plugins" in data["dashboard"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_usage_trends():
|
||||
"""Test getting usage trends"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/trends")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "trends" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_usage_trends_plugin_specific():
|
||||
"""Test getting usage trends for specific plugin"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/trends?plugin_id=plugin_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "plugin_id" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_generate_analytics_report_usage():
|
||||
"""Test generating usage report"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/reports?report_type=usage")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_generate_analytics_report_performance():
|
||||
"""Test generating performance report"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/reports?report_type=performance")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_generate_analytics_report_ratings():
|
||||
"""Test generating ratings report"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/reports?report_type=ratings")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_generate_analytics_report_invalid():
|
||||
"""Test generating analytics report with invalid type"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/reports?report_type=invalid")
|
||||
assert response.status_code == 400
|
||||
@@ -0,0 +1,123 @@
|
||||
"""Unit tests for plugin analytics service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, PluginUsage, PluginPerformance, PluginRating, PluginEvent
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_app_initialization():
|
||||
"""Test that the FastAPI app initializes correctly"""
|
||||
assert app is not None
|
||||
assert app.title == "AITBC Plugin Analytics Service"
|
||||
assert app.version == "1.0.0"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_usage_model():
|
||||
"""Test PluginUsage model"""
|
||||
usage = PluginUsage(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
action="install",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
metadata={"source": "marketplace"}
|
||||
)
|
||||
assert usage.plugin_id == "plugin_123"
|
||||
assert usage.user_id == "user_123"
|
||||
assert usage.action == "install"
|
||||
assert usage.metadata == {"source": "marketplace"}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_usage_defaults():
|
||||
"""Test PluginUsage with default metadata"""
|
||||
usage = PluginUsage(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
action="use",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert usage.metadata == {}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_performance_model():
|
||||
"""Test PluginPerformance model"""
|
||||
perf = PluginPerformance(
|
||||
plugin_id="plugin_123",
|
||||
version="1.0.0",
|
||||
cpu_usage=50.5,
|
||||
memory_usage=30.2,
|
||||
response_time=0.123,
|
||||
error_rate=0.001,
|
||||
uptime=99.9,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert perf.plugin_id == "plugin_123"
|
||||
assert perf.version == "1.0.0"
|
||||
assert perf.cpu_usage == 50.5
|
||||
assert perf.memory_usage == 30.2
|
||||
assert perf.response_time == 0.123
|
||||
assert perf.error_rate == 0.001
|
||||
assert perf.uptime == 99.9
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_rating_model():
|
||||
"""Test PluginRating model"""
|
||||
rating = PluginRating(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
rating=5,
|
||||
review="Great plugin!",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert rating.plugin_id == "plugin_123"
|
||||
assert rating.rating == 5
|
||||
assert rating.review == "Great plugin!"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_rating_defaults():
|
||||
"""Test PluginRating with default review"""
|
||||
rating = PluginRating(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
rating=4,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert rating.review is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_event_model():
|
||||
"""Test PluginEvent model"""
|
||||
event = PluginEvent(
|
||||
event_type="error",
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
data={"error": "timeout"},
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert event.event_type == "error"
|
||||
assert event.plugin_id == "plugin_123"
|
||||
assert event.user_id == "user_123"
|
||||
assert event.data == {"error": "timeout"}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_event_defaults():
|
||||
"""Test PluginEvent with default values"""
|
||||
event = PluginEvent(
|
||||
event_type="info",
|
||||
plugin_id="plugin_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert event.user_id is None
|
||||
assert event.data == {}
|
||||
604
examples/stubs/plugin-marketplace/main.py
Executable file
604
examples/stubs/plugin-marketplace/main.py
Executable file
@@ -0,0 +1,604 @@
|
||||
"""
|
||||
Plugin Marketplace Frontend Service for AITBC
|
||||
Provides web interface and marketplace functionality for plugins
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="AITBC Plugin Marketplace",
|
||||
description="Plugin marketplace frontend and community features",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Data models
|
||||
class MarketplaceReview(BaseModel):
|
||||
plugin_id: str
|
||||
user_id: str
|
||||
rating: int # 1-5 stars
|
||||
title: str
|
||||
content: str
|
||||
pros: List[str] = []
|
||||
cons: List[str] = []
|
||||
|
||||
class PluginPurchase(BaseModel):
|
||||
plugin_id: str
|
||||
user_id: str
|
||||
price: float
|
||||
payment_method: str
|
||||
|
||||
class DeveloperApplication(BaseModel):
|
||||
developer_name: str
|
||||
email: str
|
||||
company: Optional[str] = None
|
||||
experience: str
|
||||
portfolio_url: Optional[str] = None
|
||||
github_username: Optional[str] = None
|
||||
description: str
|
||||
|
||||
# In-memory storage (in production, use database)
|
||||
marketplace_data: Dict[str, Dict] = {}
|
||||
reviews: Dict[str, List[Dict]] = {}
|
||||
purchases: Dict[str, List[Dict]] = {}
|
||||
developer_applications: Dict[str, Dict] = {}
|
||||
verified_developers: Dict[str, Dict] = {}
|
||||
revenue_sharing: Dict[str, Dict] = {}
|
||||
|
||||
# Static files and templates
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
|
||||
@app.get("/")
|
||||
async def marketplace_home(request: Request):
|
||||
"""Marketplace homepage"""
|
||||
return templates.TemplateResponse("index.html", {
|
||||
"request": request,
|
||||
"featured_plugins": get_featured_plugins(),
|
||||
"popular_plugins": get_popular_plugins(),
|
||||
"recent_plugins": get_recent_plugins(),
|
||||
"categories": get_categories(),
|
||||
"stats": get_marketplace_stats()
|
||||
})
|
||||
|
||||
@app.get("/plugins")
|
||||
async def plugins_page(request: Request):
|
||||
"""Plugin listing page"""
|
||||
return templates.TemplateResponse("plugins.html", {
|
||||
"request": request,
|
||||
"plugins": get_all_plugins(),
|
||||
"categories": get_categories(),
|
||||
"tags": get_all_tags()
|
||||
})
|
||||
|
||||
@app.get("/plugins/{plugin_id}")
|
||||
async def plugin_detail(request: Request, plugin_id: str):
|
||||
"""Individual plugin detail page"""
|
||||
plugin = get_plugin_details(plugin_id)
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
|
||||
return templates.TemplateResponse("plugin_detail.html", {
|
||||
"request": request,
|
||||
"plugin": plugin,
|
||||
"reviews": get_plugin_reviews(plugin_id),
|
||||
"related_plugins": get_related_plugins(plugin_id)
|
||||
})
|
||||
|
||||
@app.get("/developers")
|
||||
async def developers_page(request: Request):
|
||||
"""Developer portal page"""
|
||||
return templates.TemplateResponse("developers.html", {
|
||||
"request": request,
|
||||
"verified_developers": get_verified_developers(),
|
||||
"developer_stats": get_developer_stats()
|
||||
})
|
||||
|
||||
@app.get("/submit")
|
||||
async def submit_plugin_page(request: Request):
|
||||
"""Plugin submission page"""
|
||||
return templates.TemplateResponse("submit.html", {
|
||||
"request": request,
|
||||
"categories": get_categories(),
|
||||
"guidelines": get_submission_guidelines()
|
||||
})
|
||||
|
||||
# API endpoints
|
||||
@app.get("/api/v1/marketplace/featured")
|
||||
async def get_featured_plugins_api():
|
||||
"""Get featured plugins for marketplace"""
|
||||
return {
|
||||
"featured_plugins": get_featured_plugins(),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/marketplace/popular")
|
||||
async def get_popular_plugins_api(limit: int = 12):
|
||||
"""Get popular plugins"""
|
||||
return {
|
||||
"popular_plugins": get_popular_plugins(limit),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/marketplace/recent")
|
||||
async def get_recent_plugins_api(limit: int = 12):
|
||||
"""Get recently added plugins"""
|
||||
return {
|
||||
"recent_plugins": get_recent_plugins(limit),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/marketplace/stats")
|
||||
async def get_marketplace_stats_api():
|
||||
"""Get marketplace statistics"""
|
||||
return {
|
||||
"stats": get_marketplace_stats(),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.post("/api/v1/reviews")
|
||||
async def create_review(review: MarketplaceReview):
|
||||
"""Create a plugin review"""
|
||||
review_id = f"review_{int(datetime.now(timezone.utc).timestamp())}"
|
||||
|
||||
review_record = {
|
||||
"review_id": review_id,
|
||||
"plugin_id": review.plugin_id,
|
||||
"user_id": review.user_id,
|
||||
"rating": review.rating,
|
||||
"title": review.title,
|
||||
"content": review.content,
|
||||
"pros": review.pros,
|
||||
"cons": review.cons,
|
||||
"helpful_votes": 0,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"verified_purchase": False
|
||||
}
|
||||
|
||||
if review.plugin_id not in reviews:
|
||||
reviews[review.plugin_id] = []
|
||||
|
||||
reviews[review.plugin_id].append(review_record)
|
||||
|
||||
logger.info(f"Review created for plugin {review.plugin_id}: {review.rating} stars")
|
||||
|
||||
return {
|
||||
"review_id": review_id,
|
||||
"status": "created",
|
||||
"rating": review.rating,
|
||||
"created_at": review_record["created_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/reviews/{plugin_id}")
|
||||
async def get_plugin_reviews_api(plugin_id: str):
|
||||
"""Get all reviews for a plugin"""
|
||||
plugin_reviews = reviews.get(plugin_id, [])
|
||||
|
||||
# Calculate average rating
|
||||
if plugin_reviews:
|
||||
avg_rating = sum(r["rating"] for r in plugin_reviews) / len(plugin_reviews)
|
||||
else:
|
||||
avg_rating = 0.0
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"reviews": plugin_reviews,
|
||||
"total_reviews": len(plugin_reviews),
|
||||
"average_rating": avg_rating,
|
||||
"rating_distribution": get_rating_distribution(plugin_reviews)
|
||||
}
|
||||
|
||||
@app.post("/api/v1/purchases")
|
||||
async def create_purchase(purchase: PluginPurchase):
|
||||
"""Create a plugin purchase"""
|
||||
purchase_id = f"purchase_{int(datetime.now(timezone.utc).timestamp())}"
|
||||
|
||||
purchase_record = {
|
||||
"purchase_id": purchase_id,
|
||||
"plugin_id": purchase.plugin_id,
|
||||
"user_id": purchase.user_id,
|
||||
"price": purchase.price,
|
||||
"payment_method": purchase.payment_method,
|
||||
"status": "completed",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"refund_deadline": (datetime.now(timezone.utc) + timedelta(days=30)).isoformat()
|
||||
}
|
||||
|
||||
if purchase.plugin_id not in purchases:
|
||||
purchases[purchase.plugin_id] = []
|
||||
|
||||
purchases[purchase.plugin_id].append(purchase_record)
|
||||
|
||||
# Update revenue sharing
|
||||
update_revenue_sharing(purchase.plugin_id, purchase.price)
|
||||
|
||||
logger.info(f"Purchase created for plugin {purchase.plugin_id}: ${purchase.price}")
|
||||
|
||||
return {
|
||||
"purchase_id": purchase_id,
|
||||
"status": "completed",
|
||||
"price": purchase.price,
|
||||
"created_at": purchase_record["created_at"]
|
||||
}
|
||||
|
||||
@app.post("/api/v1/developers/apply")
|
||||
async def apply_developer(application: DeveloperApplication):
|
||||
"""Apply to become a verified developer"""
|
||||
application_id = f"dev_app_{int(datetime.now(timezone.utc).timestamp())}"
|
||||
|
||||
application_record = {
|
||||
"application_id": application_id,
|
||||
"developer_name": application.developer_name,
|
||||
"email": application.email,
|
||||
"company": application.company,
|
||||
"experience": application.experience,
|
||||
"portfolio_url": application.portfolio_url,
|
||||
"github_username": application.github_username,
|
||||
"description": application.description,
|
||||
"status": "pending",
|
||||
"submitted_at": datetime.now(timezone.utc).isoformat(),
|
||||
"reviewed_at": None,
|
||||
"reviewer_notes": None
|
||||
}
|
||||
|
||||
developer_applications[application_id] = application_record
|
||||
|
||||
logger.info(f"Developer application submitted: {application.developer_name}")
|
||||
|
||||
return {
|
||||
"application_id": application_id,
|
||||
"status": "pending",
|
||||
"submitted_at": application_record["submitted_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/developers/verified")
|
||||
async def get_verified_developers_api():
|
||||
"""Get list of verified developers"""
|
||||
return {
|
||||
"verified_developers": get_verified_developers(),
|
||||
"total_developers": len(verified_developers),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/revenue/{developer_id}")
|
||||
async def get_developer_revenue(developer_id: str):
|
||||
"""Get revenue information for a developer"""
|
||||
developer_revenue = revenue_sharing.get(developer_id, {
|
||||
"total_revenue": 0.0,
|
||||
"plugin_revenue": {},
|
||||
"monthly_revenue": {},
|
||||
"last_updated": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
return developer_revenue
|
||||
|
||||
# Helper functions
|
||||
def get_featured_plugins() -> List[Dict]:
|
||||
"""Get featured plugins"""
|
||||
# In production, this would be based on editorial selection or algorithm
|
||||
featured_plugins = []
|
||||
|
||||
# Mock data for demo
|
||||
featured_plugins = [
|
||||
{
|
||||
"plugin_id": "ai_trading_bot",
|
||||
"name": "AI Trading Bot",
|
||||
"description": "Advanced AI-powered trading automation",
|
||||
"author": "AITBC Labs",
|
||||
"category": "ai",
|
||||
"rating": 4.8,
|
||||
"downloads": 15420,
|
||||
"price": 99.99,
|
||||
"featured": True
|
||||
},
|
||||
{
|
||||
"plugin_id": "blockchain_analyzer",
|
||||
"name": "Blockchain Analyzer",
|
||||
"description": "Comprehensive blockchain analytics and monitoring",
|
||||
"author": "CryptoTools",
|
||||
"category": "blockchain",
|
||||
"rating": 4.6,
|
||||
"downloads": 12350,
|
||||
"price": 149.99,
|
||||
"featured": True
|
||||
}
|
||||
]
|
||||
|
||||
return featured_plugins
|
||||
|
||||
def get_popular_plugins(limit: int = 12) -> List[Dict]:
|
||||
"""Get popular plugins"""
|
||||
# Mock data for demo
|
||||
popular_plugins = [
|
||||
{
|
||||
"plugin_id": "cli_enhancer",
|
||||
"name": "CLI Enhancer",
|
||||
"description": "Enhanced CLI commands and shortcuts",
|
||||
"author": "DevTools",
|
||||
"category": "cli",
|
||||
"rating": 4.7,
|
||||
"downloads": 8920,
|
||||
"price": 29.99
|
||||
},
|
||||
{
|
||||
"plugin_id": "web_dashboard",
|
||||
"name": "Web Dashboard",
|
||||
"description": "Beautiful web dashboard for AITBC",
|
||||
"author": "WebCraft",
|
||||
"category": "web",
|
||||
"rating": 4.5,
|
||||
"downloads": 7650,
|
||||
"price": 79.99
|
||||
}
|
||||
]
|
||||
|
||||
return popular_plugins[:limit]
|
||||
|
||||
def get_recent_plugins(limit: int = 12) -> List[Dict]:
|
||||
"""Get recently added plugins"""
|
||||
# Mock data for demo
|
||||
recent_plugins = [
|
||||
{
|
||||
"plugin_id": "security_scanner",
|
||||
"name": "Security Scanner",
|
||||
"description": "Advanced security vulnerability scanner",
|
||||
"author": "SecureDev",
|
||||
"category": "security",
|
||||
"rating": 4.9,
|
||||
"downloads": 2340,
|
||||
"price": 199.99,
|
||||
"created_at": (datetime.now(timezone.utc) - timedelta(days=3)).isoformat()
|
||||
},
|
||||
{
|
||||
"plugin_id": "performance_monitor",
|
||||
"name": "Performance Monitor",
|
||||
"description": "Real-time performance monitoring and alerts",
|
||||
"author": "PerfTools",
|
||||
"category": "monitoring",
|
||||
"rating": 4.4,
|
||||
"downloads": 1890,
|
||||
"price": 59.99,
|
||||
"created_at": (datetime.now(timezone.utc) - timedelta(days=7)).isoformat()
|
||||
}
|
||||
]
|
||||
|
||||
return recent_plugins[:limit]
|
||||
|
||||
def get_categories() -> List[Dict]:
|
||||
"""Get plugin categories"""
|
||||
categories = [
|
||||
{"name": "ai", "display_name": "AI & Machine Learning", "count": 45},
|
||||
{"name": "blockchain", "display_name": "Blockchain", "count": 32},
|
||||
{"name": "cli", "display_name": "CLI Tools", "count": 28},
|
||||
{"name": "web", "display_name": "Web & UI", "count": 24},
|
||||
{"name": "security", "display_name": "Security", "count": 18},
|
||||
{"name": "monitoring", "display_name": "Monitoring", "count": 15}
|
||||
]
|
||||
|
||||
return categories
|
||||
|
||||
def get_all_plugins() -> List[Dict]:
|
||||
"""Get all plugins"""
|
||||
# Mock data for demo
|
||||
all_plugins = get_featured_plugins() + get_popular_plugins() + get_recent_plugins()
|
||||
return all_plugins
|
||||
|
||||
def get_all_tags() -> List[str]:
|
||||
"""Get all plugin tags"""
|
||||
tags = ["automation", "trading", "analytics", "security", "monitoring", "dashboard", "cli", "ai", "blockchain", "web"]
|
||||
return tags
|
||||
|
||||
def get_plugin_details(plugin_id: str) -> Optional[Dict]:
|
||||
"""Get detailed plugin information"""
|
||||
# Mock data for demo
|
||||
plugins = {
|
||||
"ai_trading_bot": {
|
||||
"plugin_id": "ai_trading_bot",
|
||||
"name": "AI Trading Bot",
|
||||
"description": "Advanced AI-powered trading automation with machine learning algorithms for optimal trading strategies",
|
||||
"author": "AITBC Labs",
|
||||
"category": "ai",
|
||||
"tags": ["automation", "trading", "ai", "machine-learning"],
|
||||
"rating": 4.8,
|
||||
"downloads": 15420,
|
||||
"price": 99.99,
|
||||
"version": "2.1.0",
|
||||
"last_updated": (datetime.now(timezone.utc) - timedelta(days=15)).isoformat(),
|
||||
"repository_url": "https://github.com/aitbc-labs/ai-trading-bot",
|
||||
"homepage_url": "https://aitbc-trading-bot.com",
|
||||
"license": "MIT",
|
||||
"screenshots": [
|
||||
"/static/screenshots/trading-bot-1.png",
|
||||
"/static/screenshots/trading-bot-2.png"
|
||||
],
|
||||
"changelog": "Added new ML models, improved performance, bug fixes",
|
||||
"compatibility": ["v1.0.0+", "v2.0.0+"]
|
||||
}
|
||||
}
|
||||
|
||||
return plugins.get(plugin_id)
|
||||
|
||||
def get_plugin_reviews(plugin_id: str) -> List[Dict]:
|
||||
"""Get reviews for a plugin"""
|
||||
# Mock data for demo
|
||||
mock_reviews = [
|
||||
{
|
||||
"review_id": "review_1",
|
||||
"user_id": "user123",
|
||||
"rating": 5,
|
||||
"title": "Excellent Trading Bot",
|
||||
"content": "This plugin has transformed my trading strategy. Highly recommended!",
|
||||
"pros": ["Easy to use", "Great performance", "Good documentation"],
|
||||
"cons": ["Initial setup complexity"],
|
||||
"helpful_votes": 23,
|
||||
"created_at": (datetime.now(timezone.utc) - timedelta(days=10)).isoformat()
|
||||
},
|
||||
{
|
||||
"review_id": "review_2",
|
||||
"user_id": "user456",
|
||||
"rating": 4,
|
||||
"title": "Good but needs improvements",
|
||||
"content": "Solid plugin with room for improvement in the UI.",
|
||||
"pros": ["Powerful features", "Good support"],
|
||||
"cons": ["UI could be better", "Learning curve"],
|
||||
"helpful_votes": 15,
|
||||
"created_at": (datetime.now(timezone.utc) - timedelta(days=25)).isoformat()
|
||||
}
|
||||
]
|
||||
|
||||
return mock_reviews
|
||||
|
||||
def get_related_plugins(plugin_id: str) -> List[Dict]:
|
||||
"""Get related plugins"""
|
||||
# Mock data for demo
|
||||
related_plugins = [
|
||||
{
|
||||
"plugin_id": "market_analyzer",
|
||||
"name": "Market Analyzer",
|
||||
"description": "Advanced market analysis tools",
|
||||
"rating": 4.6,
|
||||
"price": 79.99
|
||||
},
|
||||
{
|
||||
"plugin_id": "risk_manager",
|
||||
"name": "Risk Manager",
|
||||
"description": "Comprehensive risk management system",
|
||||
"rating": 4.5,
|
||||
"price": 89.99
|
||||
}
|
||||
]
|
||||
|
||||
return related_plugins
|
||||
|
||||
def get_verified_developers() -> List[Dict]:
|
||||
"""Get verified developers"""
|
||||
# Mock data for demo
|
||||
verified_devs = [
|
||||
{
|
||||
"developer_id": "aitbc_labs",
|
||||
"name": "AITBC Labs",
|
||||
"description": "Official AITBC development team",
|
||||
"plugins_count": 12,
|
||||
"total_downloads": 45680,
|
||||
"verified_since": "2025-01-15",
|
||||
"avatar": "/static/avatars/aitbc-labs.png"
|
||||
},
|
||||
{
|
||||
"developer_id": "crypto_tools",
|
||||
"name": "CryptoTools",
|
||||
"description": "Professional blockchain tools provider",
|
||||
"plugins_count": 8,
|
||||
"total_downloads": 23450,
|
||||
"verified_since": "2025-03-01",
|
||||
"avatar": "/static/avatars/crypto-tools.png"
|
||||
}
|
||||
]
|
||||
|
||||
return verified_devs
|
||||
|
||||
def get_developer_stats() -> Dict:
|
||||
"""Get developer statistics"""
|
||||
return {
|
||||
"total_developers": 156,
|
||||
"verified_developers": 23,
|
||||
"total_revenue_paid": 1250000.00,
|
||||
"active_developers": 89
|
||||
}
|
||||
|
||||
def get_submission_guidelines() -> Dict:
|
||||
"""Get plugin submission guidelines"""
|
||||
return {
|
||||
"requirements": [
|
||||
"Plugin must be compatible with AITBC v2.0+",
|
||||
"Code must be open source with appropriate license",
|
||||
"Comprehensive documentation required",
|
||||
"Security scan must pass",
|
||||
"Unit tests with 80%+ coverage"
|
||||
],
|
||||
"process": [
|
||||
"Submit plugin for review",
|
||||
"Security and quality assessment",
|
||||
"Community review period",
|
||||
"Final approval and publication"
|
||||
],
|
||||
"benefits": [
|
||||
"Revenue sharing (70% to developer)",
|
||||
"Featured placement opportunities",
|
||||
"Developer support and resources",
|
||||
"Community recognition"
|
||||
]
|
||||
}
|
||||
|
||||
def get_marketplace_stats() -> Dict:
|
||||
"""Get marketplace statistics"""
|
||||
return {
|
||||
"total_plugins": 234,
|
||||
"total_developers": 156,
|
||||
"total_downloads": 1256780,
|
||||
"total_revenue": 2345678.90,
|
||||
"active_users": 45678,
|
||||
"featured_plugins": 12,
|
||||
"categories": 8
|
||||
}
|
||||
|
||||
def get_rating_distribution(reviews: List[Dict]) -> Dict:
|
||||
"""Get rating distribution"""
|
||||
distribution = {1: 0, 2: 0, 3: 0, 4: 0, 5: 0}
|
||||
for review in reviews:
|
||||
distribution[review["rating"]] += 1
|
||||
return distribution
|
||||
|
||||
def update_revenue_sharing(plugin_id: str, price: float):
|
||||
"""Update revenue sharing records"""
|
||||
# Mock implementation - in production, this would calculate actual revenue sharing
|
||||
developer_share = price * 0.7 # 70% to developer
|
||||
platform_share = price * 0.3 # 30% to platform
|
||||
|
||||
# Update records (simplified for demo)
|
||||
if "revenue_sharing" not in revenue_sharing:
|
||||
revenue_sharing["revenue_sharing"] = {
|
||||
"total_revenue": 0.0,
|
||||
"developer_revenue": 0.0,
|
||||
"platform_revenue": 0.0
|
||||
}
|
||||
|
||||
revenue_sharing["revenue_sharing"]["total_revenue"] += price
|
||||
revenue_sharing["revenue_sharing"]["developer_revenue"] += developer_share
|
||||
revenue_sharing["revenue_sharing"]["platform_revenue"] += platform_share
|
||||
|
||||
# Background task for marketplace analytics
|
||||
async def update_marketplace_analytics():
|
||||
"""Background task to update marketplace analytics"""
|
||||
while True:
|
||||
await asyncio.sleep(3600) # Update every hour
|
||||
|
||||
# Update trending plugins
|
||||
# Update revenue calculations
|
||||
# Update user engagement metrics
|
||||
logger.info("Marketplace analytics updated")
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("Starting AITBC Plugin Marketplace")
|
||||
# Start analytics processing
|
||||
asyncio.create_task(update_marketplace_analytics())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logger.info("Shutting down AITBC Plugin Marketplace")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8014, log_level="info")
|
||||
1
examples/stubs/plugin-marketplace/tests/__init__.py
Normal file
1
examples/stubs/plugin-marketplace/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Plugin marketplace service tests"""
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Edge case and error handling tests for plugin marketplace service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
from main import app, MarketplaceReview, PluginPurchase, DeveloperApplication, reviews, purchases, developer_applications
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
reviews.clear()
|
||||
purchases.clear()
|
||||
developer_applications.clear()
|
||||
yield
|
||||
reviews.clear()
|
||||
purchases.clear()
|
||||
developer_applications.clear()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_marketplace_review_out_of_range_rating():
|
||||
"""Test MarketplaceReview with out of range rating"""
|
||||
review = MarketplaceReview(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
rating=10,
|
||||
title="Great plugin",
|
||||
content="Excellent"
|
||||
)
|
||||
assert review.rating == 10
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_marketplace_review_zero_rating():
|
||||
"""Test MarketplaceReview with zero rating"""
|
||||
review = MarketplaceReview(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
rating=0,
|
||||
title="Bad plugin",
|
||||
content="Poor"
|
||||
)
|
||||
assert review.rating == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_marketplace_review_negative_rating():
|
||||
"""Test MarketplaceReview with negative rating"""
|
||||
review = MarketplaceReview(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
rating=-5,
|
||||
title="Terrible",
|
||||
content="Worst"
|
||||
)
|
||||
assert review.rating == -5
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_marketplace_review_empty_fields():
|
||||
"""Test MarketplaceReview with empty fields"""
|
||||
review = MarketplaceReview(
|
||||
plugin_id="",
|
||||
user_id="",
|
||||
rating=3,
|
||||
title="",
|
||||
content=""
|
||||
)
|
||||
assert review.plugin_id == ""
|
||||
assert review.title == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_purchase_zero_price():
|
||||
"""Test PluginPurchase with zero price"""
|
||||
purchase = PluginPurchase(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
price=0.0,
|
||||
payment_method="free"
|
||||
)
|
||||
assert purchase.price == 0.0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_developer_application_empty_fields():
|
||||
"""Test DeveloperApplication with empty fields"""
|
||||
application = DeveloperApplication(
|
||||
developer_name="",
|
||||
email="",
|
||||
experience="",
|
||||
description=""
|
||||
)
|
||||
assert application.developer_name == ""
|
||||
assert application.email == ""
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_popular_plugins_with_limit():
|
||||
"""Test getting popular plugins with limit parameter"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/marketplace/popular?limit=5")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "popular_plugins" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_recent_plugins_with_limit():
|
||||
"""Test getting recent plugins with limit parameter"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/marketplace/recent?limit=5")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "recent_plugins" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_multiple_reviews():
|
||||
"""Test creating multiple reviews for same plugin"""
|
||||
client = TestClient(app)
|
||||
|
||||
for i in range(3):
|
||||
review = MarketplaceReview(
|
||||
plugin_id="plugin_123",
|
||||
user_id=f"user_{i}",
|
||||
rating=5,
|
||||
title="Great",
|
||||
content="Excellent"
|
||||
)
|
||||
client.post("/api/v1/reviews", json=review.model_dump())
|
||||
|
||||
response = client.get("/api/v1/reviews/plugin_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_reviews"] == 3
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_multiple_purchases():
|
||||
"""Test creating multiple purchases for same plugin"""
|
||||
client = TestClient(app)
|
||||
|
||||
for i in range(3):
|
||||
purchase = PluginPurchase(
|
||||
plugin_id="plugin_123",
|
||||
user_id=f"user_{i}",
|
||||
price=99.99,
|
||||
payment_method="credit_card"
|
||||
)
|
||||
client.post("/api/v1/purchases", json=purchase.model_dump())
|
||||
|
||||
response = client.get("/api/v1/revenue/revenue_sharing")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_developer_application_with_company():
|
||||
"""Test developer application with company"""
|
||||
client = TestClient(app)
|
||||
application = DeveloperApplication(
|
||||
developer_name="Dev Name",
|
||||
email="dev@example.com",
|
||||
company="Dev Corp",
|
||||
experience="5 years",
|
||||
description="Experienced"
|
||||
)
|
||||
response = client.post("/api/v1/developers/apply", json=application.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["application_id"]
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Integration tests for plugin marketplace service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
from main import app, MarketplaceReview, PluginPurchase, DeveloperApplication, reviews, purchases, developer_applications
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
reviews.clear()
|
||||
purchases.clear()
|
||||
developer_applications.clear()
|
||||
yield
|
||||
reviews.clear()
|
||||
purchases.clear()
|
||||
developer_applications.clear()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_featured_plugins_api():
|
||||
"""Test getting featured plugins API"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/marketplace/featured")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "featured_plugins" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_popular_plugins_api():
|
||||
"""Test getting popular plugins API"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/marketplace/popular")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "popular_plugins" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_recent_plugins_api():
|
||||
"""Test getting recent plugins API"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/marketplace/recent")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "recent_plugins" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_marketplace_stats_api():
|
||||
"""Test getting marketplace stats API"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/marketplace/stats")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "stats" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_review():
|
||||
"""Test creating a review"""
|
||||
client = TestClient(app)
|
||||
review = MarketplaceReview(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
rating=5,
|
||||
title="Great plugin",
|
||||
content="Excellent functionality"
|
||||
)
|
||||
response = client.post("/api/v1/reviews", json=review.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["review_id"]
|
||||
assert data["status"] == "created"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin_reviews_api():
|
||||
"""Test getting plugin reviews API"""
|
||||
client = TestClient(app)
|
||||
# Create a review first
|
||||
review = MarketplaceReview(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
rating=5,
|
||||
title="Great plugin",
|
||||
content="Excellent functionality"
|
||||
)
|
||||
client.post("/api/v1/reviews", json=review.model_dump())
|
||||
|
||||
response = client.get("/api/v1/reviews/plugin_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["plugin_id"] == "plugin_123"
|
||||
assert "reviews" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin_reviews_no_reviews():
|
||||
"""Test getting plugin reviews when no reviews exist"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/reviews/nonexistent")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_reviews"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_purchase():
|
||||
"""Test creating a purchase"""
|
||||
client = TestClient(app)
|
||||
purchase = PluginPurchase(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
price=99.99,
|
||||
payment_method="credit_card"
|
||||
)
|
||||
response = client.post("/api/v1/purchases", json=purchase.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["purchase_id"]
|
||||
assert data["status"] == "completed"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_apply_developer():
|
||||
"""Test applying to become a developer"""
|
||||
client = TestClient(app)
|
||||
application = DeveloperApplication(
|
||||
developer_name="Dev Name",
|
||||
email="dev@example.com",
|
||||
experience="5 years",
|
||||
description="Experienced developer"
|
||||
)
|
||||
response = client.post("/api/v1/developers/apply", json=application.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["application_id"]
|
||||
assert data["status"] == "pending"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_verified_developers_api():
|
||||
"""Test getting verified developers API"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/developers/verified")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "verified_developers" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_developer_revenue():
|
||||
"""Test getting developer revenue"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/revenue/dev_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total_revenue" in data
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Unit tests for plugin marketplace service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
from main import app, MarketplaceReview, PluginPurchase, DeveloperApplication
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_app_initialization():
|
||||
"""Test that the FastAPI app initializes correctly"""
|
||||
assert app is not None
|
||||
assert app.title == "AITBC Plugin Marketplace"
|
||||
assert app.version == "1.0.0"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_marketplace_review_model():
|
||||
"""Test MarketplaceReview model"""
|
||||
review = MarketplaceReview(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
rating=5,
|
||||
title="Great plugin",
|
||||
content="Excellent functionality",
|
||||
pros=["Easy to use", "Fast"],
|
||||
cons=["Learning curve"]
|
||||
)
|
||||
assert review.plugin_id == "plugin_123"
|
||||
assert review.rating == 5
|
||||
assert review.title == "Great plugin"
|
||||
assert review.pros == ["Easy to use", "Fast"]
|
||||
assert review.cons == ["Learning curve"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_marketplace_review_defaults():
|
||||
"""Test MarketplaceReview with default values"""
|
||||
review = MarketplaceReview(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
rating=4,
|
||||
title="Good plugin",
|
||||
content="Nice functionality"
|
||||
)
|
||||
assert review.pros == []
|
||||
assert review.cons == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_purchase_model():
|
||||
"""Test PluginPurchase model"""
|
||||
purchase = PluginPurchase(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
price=99.99,
|
||||
payment_method="credit_card"
|
||||
)
|
||||
assert purchase.plugin_id == "plugin_123"
|
||||
assert purchase.price == 99.99
|
||||
assert purchase.payment_method == "credit_card"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_purchase_negative_price():
|
||||
"""Test PluginPurchase with negative price"""
|
||||
purchase = PluginPurchase(
|
||||
plugin_id="plugin_123",
|
||||
user_id="user_123",
|
||||
price=-99.99,
|
||||
payment_method="credit_card"
|
||||
)
|
||||
assert purchase.price == -99.99
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_developer_application_model():
|
||||
"""Test DeveloperApplication model"""
|
||||
application = DeveloperApplication(
|
||||
developer_name="Dev Name",
|
||||
email="dev@example.com",
|
||||
company="Dev Corp",
|
||||
experience="5 years",
|
||||
portfolio_url="https://portfolio.com",
|
||||
github_username="devuser",
|
||||
description="Experienced developer"
|
||||
)
|
||||
assert application.developer_name == "Dev Name"
|
||||
assert application.email == "dev@example.com"
|
||||
assert application.company == "Dev Corp"
|
||||
assert application.github_username == "devuser"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_developer_application_defaults():
|
||||
"""Test DeveloperApplication with optional fields"""
|
||||
application = DeveloperApplication(
|
||||
developer_name="Dev Name",
|
||||
email="dev@example.com",
|
||||
experience="3 years",
|
||||
description="New developer"
|
||||
)
|
||||
assert application.company is None
|
||||
assert application.portfolio_url is None
|
||||
assert application.github_username is None
|
||||
485
examples/stubs/plugin-registry/main.py
Executable file
485
examples/stubs/plugin-registry/main.py
Executable file
@@ -0,0 +1,485 @@
|
||||
"""
|
||||
Production Plugin Registry Service for AITBC
|
||||
Handles plugin registration, discovery, versioning, and security validation
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import hashlib
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="AITBC Plugin Registry",
|
||||
description="Production plugin registry for AITBC ecosystem",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Data models
|
||||
class PluginRegistration(BaseModel):
|
||||
name: str
|
||||
version: str
|
||||
description: str
|
||||
author: str
|
||||
category: str
|
||||
tags: List[str]
|
||||
repository_url: str
|
||||
homepage_url: Optional[str] = None
|
||||
license: str
|
||||
dependencies: List[str] = []
|
||||
aitbc_version: str
|
||||
plugin_type: str # cli, blockchain, ai, web, etc.
|
||||
|
||||
class PluginVersion(BaseModel):
|
||||
version: str
|
||||
changelog: str
|
||||
download_url: str
|
||||
checksum: str
|
||||
aitbc_compatibility: List[str]
|
||||
release_date: datetime
|
||||
|
||||
class SecurityScan(BaseModel):
|
||||
scan_id: str
|
||||
plugin_id: str
|
||||
version: str
|
||||
scan_date: datetime
|
||||
vulnerabilities: List[Dict[str, Any]]
|
||||
risk_score: str # low, medium, high, critical
|
||||
passed: bool
|
||||
|
||||
# In-memory storage (in production, use database)
|
||||
plugins: Dict[str, Dict] = {}
|
||||
plugin_versions: Dict[str, List[Dict]] = {}
|
||||
security_scans: Dict[str, Dict] = {}
|
||||
analytics: Dict[str, Dict] = {}
|
||||
downloads: Dict[str, List[Dict]] = {}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "AITBC Plugin Registry",
|
||||
"status": "running",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"total_plugins": len(plugins),
|
||||
"total_versions": sum(len(versions) for versions in plugin_versions.values()),
|
||||
"security_scans": len(security_scans),
|
||||
"downloads_today": len([d for downloads_list in downloads.values()
|
||||
for d in downloads_list
|
||||
if datetime.fromisoformat(d["timestamp"]).date() == datetime.now(timezone.utc).date()])
|
||||
}
|
||||
|
||||
@app.post("/api/v1/plugins/register")
|
||||
async def register_plugin(plugin: PluginRegistration):
|
||||
"""Register a new plugin"""
|
||||
plugin_id = f"{plugin.name.lower().replace(' ', '_')}"
|
||||
|
||||
if plugin_id in plugins:
|
||||
raise HTTPException(status_code=400, detail="Plugin already registered")
|
||||
|
||||
# Create plugin record
|
||||
plugin_record = {
|
||||
"plugin_id": plugin_id,
|
||||
"name": plugin.name,
|
||||
"description": plugin.description,
|
||||
"author": plugin.author,
|
||||
"category": plugin.category,
|
||||
"tags": plugin.tags,
|
||||
"repository_url": plugin.repository_url,
|
||||
"homepage_url": plugin.homepage_url,
|
||||
"license": plugin.license,
|
||||
"dependencies": plugin.dependencies,
|
||||
"aitbc_version": plugin.aitbc_version,
|
||||
"plugin_type": plugin.plugin_type,
|
||||
"status": "active",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"verified": False,
|
||||
"featured": False,
|
||||
"download_count": 0,
|
||||
"rating": 0.0,
|
||||
"rating_count": 0,
|
||||
"latest_version": plugin.version
|
||||
}
|
||||
|
||||
plugins[plugin_id] = plugin_record
|
||||
plugin_versions[plugin_id] = []
|
||||
|
||||
# Initialize analytics
|
||||
analytics[plugin_id] = {
|
||||
"downloads": [],
|
||||
"views": [],
|
||||
"ratings": [],
|
||||
"daily_stats": {}
|
||||
}
|
||||
|
||||
logger.info(f"Plugin registered: {plugin.name}")
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"status": "registered",
|
||||
"name": plugin.name,
|
||||
"created_at": plugin_record["created_at"]
|
||||
}
|
||||
|
||||
@app.post("/api/v1/plugins/{plugin_id}/versions")
|
||||
async def add_plugin_version(plugin_id: str, version: PluginVersion):
|
||||
"""Add a new version to an existing plugin"""
|
||||
if plugin_id not in plugins:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
|
||||
# Check if version already exists
|
||||
for existing_version in plugin_versions[plugin_id]:
|
||||
if existing_version["version"] == version.version:
|
||||
raise HTTPException(status_code=400, detail="Version already exists")
|
||||
|
||||
# Create version record
|
||||
version_record = {
|
||||
"version_id": f"{plugin_id}_v_{version.version}",
|
||||
"plugin_id": plugin_id,
|
||||
"version": version.version,
|
||||
"changelog": version.changelog,
|
||||
"download_url": version.download_url,
|
||||
"checksum": version.checksum,
|
||||
"aitbc_compatibility": version.aitbc_compatibility,
|
||||
"release_date": version.release_date.isoformat(),
|
||||
"downloads": 0,
|
||||
"security_scan_passed": False,
|
||||
"created_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
plugin_versions[plugin_id].append(version_record)
|
||||
|
||||
# Update plugin's latest version
|
||||
plugins[plugin_id]["latest_version"] = version.version
|
||||
plugins[plugin_id]["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# Sort versions by version number (semantic versioning)
|
||||
plugin_versions[plugin_id].sort(key=lambda x: x["version"], reverse=True)
|
||||
|
||||
logger.info(f"Version added to plugin {plugin_id}: {version.version}")
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"version": version.version,
|
||||
"status": "added",
|
||||
"created_at": version_record["created_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/plugins")
|
||||
async def list_plugins(category: Optional[str] = None, tag: Optional[str] = None,
|
||||
search: Optional[str] = None, sort_by: str = "created_at"):
|
||||
"""List all plugins with filtering and sorting"""
|
||||
filtered_plugins = []
|
||||
|
||||
for plugin in plugins.values():
|
||||
# Apply filters
|
||||
if category and plugin["category"] != category:
|
||||
continue
|
||||
if tag and tag not in plugin["tags"]:
|
||||
continue
|
||||
if search and search.lower() not in plugin["name"].lower() and search.lower() not in plugin["description"].lower():
|
||||
continue
|
||||
|
||||
filtered_plugins.append(plugin.copy())
|
||||
|
||||
# Sort plugins
|
||||
if sort_by == "created_at":
|
||||
filtered_plugins.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
elif sort_by == "updated_at":
|
||||
filtered_plugins.sort(key=lambda x: x["updated_at"], reverse=True)
|
||||
elif sort_by == "name":
|
||||
filtered_plugins.sort(key=lambda x: x["name"])
|
||||
elif sort_by == "downloads":
|
||||
filtered_plugins.sort(key=lambda x: x["download_count"], reverse=True)
|
||||
elif sort_by == "rating":
|
||||
filtered_plugins.sort(key=lambda x: x["rating"], reverse=True)
|
||||
|
||||
return {
|
||||
"plugins": filtered_plugins,
|
||||
"total_plugins": len(filtered_plugins),
|
||||
"filters": {
|
||||
"category": category,
|
||||
"tag": tag,
|
||||
"search": search,
|
||||
"sort_by": sort_by
|
||||
}
|
||||
}
|
||||
|
||||
@app.get("/api/v1/plugins/{plugin_id}")
|
||||
async def get_plugin(plugin_id: str):
|
||||
"""Get detailed plugin information"""
|
||||
if plugin_id not in plugins:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
|
||||
plugin = plugins[plugin_id].copy()
|
||||
|
||||
# Add version information
|
||||
plugin["versions"] = plugin_versions.get(plugin_id, [])
|
||||
|
||||
# Add analytics
|
||||
plugin_analytics = analytics.get(plugin_id, {})
|
||||
plugin["analytics"] = {
|
||||
"total_downloads": len(plugin_analytics.get("downloads", [])),
|
||||
"total_views": len(plugin_analytics.get("views", [])),
|
||||
"average_rating": sum(plugin_analytics.get("ratings", [])) / len(plugin_analytics.get("ratings", [])) if plugin_analytics.get("ratings") else 0.0,
|
||||
"rating_count": len(plugin_analytics.get("ratings", []))
|
||||
}
|
||||
|
||||
return plugin
|
||||
|
||||
@app.get("/api/v1/plugins/{plugin_id}/versions")
|
||||
async def get_plugin_versions(plugin_id: str):
|
||||
"""Get all versions of a plugin"""
|
||||
if plugin_id not in plugins:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"versions": plugin_versions.get(plugin_id, []),
|
||||
"total_versions": len(plugin_versions.get(plugin_id, []))
|
||||
}
|
||||
|
||||
@app.get("/api/v1/plugins/{plugin_id}/download/{version}")
|
||||
async def download_plugin(plugin_id: str, version: str):
|
||||
"""Download a specific plugin version"""
|
||||
if plugin_id not in plugins:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
|
||||
# Find the version
|
||||
version_record = None
|
||||
for v in plugin_versions.get(plugin_id, []):
|
||||
if v["version"] == version:
|
||||
version_record = v
|
||||
break
|
||||
|
||||
if not version_record:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
# Record download
|
||||
download_record = {
|
||||
"version": version,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"ip_address": "client_ip", # In production, get actual IP
|
||||
"user_agent": "user_agent" # In production, get actual user agent
|
||||
}
|
||||
|
||||
if plugin_id not in downloads:
|
||||
downloads[plugin_id] = []
|
||||
downloads[plugin_id].append(download_record)
|
||||
|
||||
# Update analytics
|
||||
if plugin_id not in analytics:
|
||||
analytics[plugin_id] = {"downloads": [], "views": [], "ratings": []}
|
||||
analytics[plugin_id]["downloads"].append(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
# Update plugin download count
|
||||
plugins[plugin_id]["download_count"] += 1
|
||||
version_record["downloads"] += 1
|
||||
|
||||
# In production, this would return the actual file
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"version": version,
|
||||
"download_url": version_record["download_url"],
|
||||
"checksum": version_record["checksum"],
|
||||
"download_count": version_record["downloads"]
|
||||
}
|
||||
|
||||
@app.post("/api/v1/plugins/{plugin_id}/security-scan")
|
||||
async def create_security_scan(plugin_id: str, scan: SecurityScan):
|
||||
"""Create a security scan record for a plugin version"""
|
||||
if plugin_id not in plugins:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
|
||||
# Verify version exists
|
||||
version_exists = any(v["version"] == scan.version for v in plugin_versions.get(plugin_id, []))
|
||||
if not version_exists:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
# Create security scan record
|
||||
security_scans[scan.scan_id] = {
|
||||
"scan_id": scan.scan_id,
|
||||
"plugin_id": plugin_id,
|
||||
"version": scan.version,
|
||||
"scan_date": scan.scan_date.isoformat(),
|
||||
"vulnerabilities": scan.vulnerabilities,
|
||||
"risk_score": scan.risk_score,
|
||||
"passed": scan.passed,
|
||||
"created_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Update version security status
|
||||
for version_record in plugin_versions.get(plugin_id, []):
|
||||
if version_record["version"] == scan.version:
|
||||
version_record["security_scan_passed"] = scan.passed
|
||||
break
|
||||
|
||||
logger.info(f"Security scan created for {plugin_id} v{scan.version}: {scan.risk_score}")
|
||||
|
||||
return {
|
||||
"scan_id": scan.scan_id,
|
||||
"plugin_id": plugin_id,
|
||||
"version": scan.version,
|
||||
"risk_score": scan.risk_score,
|
||||
"passed": scan.passed,
|
||||
"scan_date": scan.scan_date.isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/plugins/{plugin_id}/security")
|
||||
async def get_plugin_security(plugin_id: str):
|
||||
"""Get security information for a plugin"""
|
||||
if plugin_id not in plugins:
|
||||
raise HTTPException(status_code=404, detail="Plugin not found")
|
||||
|
||||
plugin_scans = []
|
||||
for scan_id, scan in security_scans.items():
|
||||
if scan["plugin_id"] == plugin_id:
|
||||
plugin_scans.append(scan)
|
||||
|
||||
# Sort by scan date
|
||||
plugin_scans.sort(key=lambda x: x["scan_date"], reverse=True)
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"security_scans": plugin_scans,
|
||||
"total_scans": len(plugin_scans),
|
||||
"latest_scan": plugin_scans[0] if plugin_scans else None
|
||||
}
|
||||
|
||||
@app.get("/api/v1/categories")
|
||||
async def get_categories():
|
||||
"""Get all plugin categories"""
|
||||
categories = {}
|
||||
for plugin in plugins.values():
|
||||
category = plugin["category"]
|
||||
if category not in categories:
|
||||
categories[category] = {
|
||||
"name": category,
|
||||
"plugin_count": 0,
|
||||
"description": f"Plugins in {category} category"
|
||||
}
|
||||
categories[category]["plugin_count"] += 1
|
||||
|
||||
return {
|
||||
"categories": list(categories.values()),
|
||||
"total_categories": len(categories)
|
||||
}
|
||||
|
||||
@app.get("/api/v1/tags")
|
||||
async def get_tags():
|
||||
"""Get all plugin tags"""
|
||||
tag_counts = {}
|
||||
for plugin in plugins.values():
|
||||
for tag in plugin["tags"]:
|
||||
tag_counts[tag] = tag_counts.get(tag, 0) + 1
|
||||
|
||||
return {
|
||||
"tags": [{"tag": tag, "count": count} for tag, count in sorted(tag_counts.items(), key=lambda x: x[1], reverse=True)],
|
||||
"total_tags": len(tag_counts)
|
||||
}
|
||||
|
||||
@app.get("/api/v1/analytics/popular")
|
||||
async def get_popular_plugins(limit: int = 10):
|
||||
"""Get most popular plugins by downloads"""
|
||||
popular_plugins = sorted(plugins.values(), key=lambda x: x["download_count"], reverse=True)[:limit]
|
||||
|
||||
return {
|
||||
"popular_plugins": popular_plugins,
|
||||
"limit": limit,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/analytics/recent")
|
||||
async def get_recent_plugins(limit: int = 10):
|
||||
"""Get recently updated plugins"""
|
||||
recent_plugins = sorted(plugins.values(), key=lambda x: x["updated_at"], reverse=True)[:limit]
|
||||
|
||||
return {
|
||||
"recent_plugins": recent_plugins,
|
||||
"limit": limit,
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/analytics/dashboard")
|
||||
async def get_analytics_dashboard():
|
||||
"""Get registry analytics dashboard"""
|
||||
total_plugins = len(plugins)
|
||||
total_versions = sum(len(versions) for versions in plugin_versions.values())
|
||||
total_downloads = sum(plugin["download_count"] for plugin in plugins.values())
|
||||
|
||||
# Category distribution
|
||||
category_stats = {}
|
||||
for plugin in plugins.values():
|
||||
category = plugin["category"]
|
||||
category_stats[category] = category_stats.get(category, 0) + 1
|
||||
|
||||
# Recent activity
|
||||
recent_downloads = 0
|
||||
today = datetime.now(timezone.utc).date()
|
||||
for download_list in downloads.values():
|
||||
recent_downloads += len([d for d in download_list
|
||||
if datetime.fromisoformat(d["timestamp"]).date() == today])
|
||||
|
||||
return {
|
||||
"dashboard": {
|
||||
"total_plugins": total_plugins,
|
||||
"total_versions": total_versions,
|
||||
"total_downloads": total_downloads,
|
||||
"recent_downloads_today": recent_downloads,
|
||||
"categories": category_stats,
|
||||
"security_scans": len(security_scans),
|
||||
"passed_scans": len([s for s in security_scans.values() if s["passed"]])
|
||||
},
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Background task for analytics processing
|
||||
async def process_analytics():
|
||||
"""Background task to process analytics data"""
|
||||
while True:
|
||||
await asyncio.sleep(3600) # Process every hour
|
||||
|
||||
# Update daily statistics
|
||||
current_date = datetime.now(timezone.utc).date()
|
||||
|
||||
for plugin_id, plugin_analytics in analytics.items():
|
||||
daily_key = current_date.isoformat()
|
||||
|
||||
if daily_key not in plugin_analytics["daily_stats"]:
|
||||
plugin_analytics["daily_stats"][daily_key] = {
|
||||
"downloads": len([d for d in plugin_analytics.get("downloads", [])
|
||||
if datetime.fromtimestamp(d).date() == current_date]),
|
||||
"views": len([v for v in plugin_analytics.get("views", [])
|
||||
if datetime.fromtimestamp(v).date() == current_date]),
|
||||
"ratings": len([r for r in plugin_analytics.get("ratings", [])
|
||||
if datetime.fromtimestamp(r).date() == current_date])
|
||||
}
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("Starting AITBC Plugin Registry")
|
||||
# Start analytics processing
|
||||
asyncio.create_task(process_analytics())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logger.info("Shutting down AITBC Plugin Registry")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8013, log_level="info")
|
||||
1
examples/stubs/plugin-registry/tests/__init__.py
Normal file
1
examples/stubs/plugin-registry/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Plugin registry service tests"""
|
||||
@@ -0,0 +1,317 @@
|
||||
"""Edge case and error handling tests for plugin registry service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, PluginRegistration, PluginVersion, SecurityScan, plugins, plugin_versions, security_scans, analytics, downloads
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
plugins.clear()
|
||||
plugin_versions.clear()
|
||||
security_scans.clear()
|
||||
analytics.clear()
|
||||
downloads.clear()
|
||||
yield
|
||||
plugins.clear()
|
||||
plugin_versions.clear()
|
||||
security_scans.clear()
|
||||
analytics.clear()
|
||||
downloads.clear()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_registration_empty_name():
|
||||
"""Test PluginRegistration with empty name"""
|
||||
plugin = PluginRegistration(
|
||||
name="",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
assert plugin.name == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_registration_empty_tags():
|
||||
"""Test PluginRegistration with empty tags"""
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
assert plugin.tags == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_version_empty_changelog():
|
||||
"""Test PluginVersion with empty changelog"""
|
||||
version = PluginVersion(
|
||||
version="1.0.0",
|
||||
changelog="",
|
||||
download_url="https://github.com/test/plugin/archive/v1.0.0.tar.gz",
|
||||
checksum="abc123",
|
||||
aitbc_compatibility=["1.0.0"],
|
||||
release_date=datetime.now(timezone.utc)
|
||||
)
|
||||
assert version.changelog == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_security_scan_empty_vulnerabilities():
|
||||
"""Test SecurityScan with empty vulnerabilities"""
|
||||
scan = SecurityScan(
|
||||
scan_id="scan_123",
|
||||
plugin_id="test_plugin",
|
||||
version="1.0.0",
|
||||
scan_date=datetime.now(timezone.utc),
|
||||
vulnerabilities=[],
|
||||
risk_score="low",
|
||||
passed=True
|
||||
)
|
||||
assert scan.vulnerabilities == []
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_add_version_nonexistent_plugin():
|
||||
"""Test adding version to nonexistent plugin"""
|
||||
client = TestClient(app)
|
||||
version = PluginVersion(
|
||||
version="1.0.0",
|
||||
changelog="Initial release",
|
||||
download_url="https://github.com/test/plugin/archive/v1.0.0.tar.gz",
|
||||
checksum="abc123",
|
||||
aitbc_compatibility=["1.0.0"],
|
||||
release_date=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/plugins/nonexistent/versions", json=version.model_dump(mode='json'))
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_download_nonexistent_plugin():
|
||||
"""Test downloading nonexistent plugin"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/plugins/nonexistent/download/1.0.0")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_download_nonexistent_version():
|
||||
"""Test downloading nonexistent version"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Register plugin first
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
|
||||
# Try to download nonexistent version
|
||||
response = client.get("/api/v1/plugins/test_plugin/download/2.0.0")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_security_scan_nonexistent_plugin():
|
||||
"""Test creating security scan for nonexistent plugin"""
|
||||
client = TestClient(app)
|
||||
scan = SecurityScan(
|
||||
scan_id="scan_123",
|
||||
plugin_id="nonexistent",
|
||||
version="1.0.0",
|
||||
scan_date=datetime.now(timezone.utc),
|
||||
vulnerabilities=[],
|
||||
risk_score="low",
|
||||
passed=True
|
||||
)
|
||||
response = client.post("/api/v1/plugins/nonexistent/security-scan", json=scan.model_dump(mode='json'))
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_security_scan_nonexistent_version():
|
||||
"""Test creating security scan for nonexistent version"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Register plugin first
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
|
||||
# Create security scan for nonexistent version
|
||||
scan = SecurityScan(
|
||||
scan_id="scan_123",
|
||||
plugin_id="test_plugin",
|
||||
version="2.0.0",
|
||||
scan_date=datetime.now(timezone.utc),
|
||||
vulnerabilities=[],
|
||||
risk_score="low",
|
||||
passed=True
|
||||
)
|
||||
response = client.post("/api/v1/plugins/test_plugin/security-scan", json=scan.model_dump(mode='json'))
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_plugins_with_filters():
|
||||
"""Test listing plugins with filters"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Register multiple plugins
|
||||
plugin1 = PluginRegistration(
|
||||
name="Test Plugin 1",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=["test"],
|
||||
repository_url="https://github.com/test/plugin1",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
client.post("/api/v1/plugins/register", json=plugin1.model_dump())
|
||||
|
||||
plugin2 = PluginRegistration(
|
||||
name="Production Plugin",
|
||||
version="1.0.0",
|
||||
description="A production plugin",
|
||||
author="Test Author",
|
||||
category="production",
|
||||
tags=["prod"],
|
||||
repository_url="https://github.com/test/plugin2",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="web"
|
||||
)
|
||||
client.post("/api/v1/plugins/register", json=plugin2.model_dump())
|
||||
|
||||
# Filter by category
|
||||
response = client.get("/api/v1/plugins?category=testing")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_plugins"] == 1
|
||||
assert data["plugins"][0]["category"] == "testing"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_plugins_with_search():
|
||||
"""Test listing plugins with search"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Register plugin
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin for testing",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=["test"],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
|
||||
# Search for plugin
|
||||
response = client.get("/api/v1/plugins?search=test")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_plugins"] == 1
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_security_scan_failed():
|
||||
"""Test security scan that failed"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Register plugin first
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
|
||||
# Add version first
|
||||
version = PluginVersion(
|
||||
version="1.0.0",
|
||||
changelog="Initial release",
|
||||
download_url="https://github.com/test/plugin/archive/v1.0.0.tar.gz",
|
||||
checksum="abc123",
|
||||
aitbc_compatibility=["1.0.0"],
|
||||
release_date=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/plugins/test_plugin/versions", json=version.model_dump(mode='json'))
|
||||
|
||||
# Create failed security scan
|
||||
scan = SecurityScan(
|
||||
scan_id="scan_123",
|
||||
plugin_id="test_plugin",
|
||||
version="1.0.0",
|
||||
scan_date=datetime.now(timezone.utc),
|
||||
vulnerabilities=[{"severity": "high", "description": "Critical issue"}],
|
||||
risk_score="high",
|
||||
passed=False
|
||||
)
|
||||
response = client.post("/api/v1/plugins/test_plugin/security-scan", json=scan.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["passed"] is False
|
||||
assert data["risk_score"] == "high"
|
||||
@@ -0,0 +1,422 @@
|
||||
"""Integration tests for plugin registry service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, PluginRegistration, PluginVersion, SecurityScan, plugins, plugin_versions, security_scans, analytics, downloads
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
plugins.clear()
|
||||
plugin_versions.clear()
|
||||
security_scans.clear()
|
||||
analytics.clear()
|
||||
downloads.clear()
|
||||
yield
|
||||
plugins.clear()
|
||||
plugin_versions.clear()
|
||||
security_scans.clear()
|
||||
analytics.clear()
|
||||
downloads.clear()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_root_endpoint():
|
||||
"""Test root endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["service"] == "AITBC Plugin Registry"
|
||||
assert data["status"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_health_check_endpoint():
|
||||
"""Test health check endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "total_plugins" in data
|
||||
assert "total_versions" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_register_plugin():
|
||||
"""Test plugin registration"""
|
||||
client = TestClient(app)
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=["test", "demo"],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
response = client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["plugin_id"] == "test_plugin"
|
||||
assert data["status"] == "registered"
|
||||
assert data["name"] == "Test Plugin"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_register_duplicate_plugin():
|
||||
"""Test registering duplicate plugin"""
|
||||
client = TestClient(app)
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
|
||||
# First registration
|
||||
client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
|
||||
# Second registration should fail
|
||||
response = client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_add_plugin_version():
|
||||
"""Test adding plugin version"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Register plugin first
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
|
||||
# Add version
|
||||
version = PluginVersion(
|
||||
version="1.1.0",
|
||||
changelog="Bug fixes",
|
||||
download_url="https://github.com/test/plugin/archive/v1.1.0.tar.gz",
|
||||
checksum="def456",
|
||||
aitbc_compatibility=["1.0.0"],
|
||||
release_date=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/plugins/test_plugin/versions", json=version.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["version"] == "1.1.0"
|
||||
assert data["status"] == "added"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_add_duplicate_version():
|
||||
"""Test adding duplicate version"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Register plugin first
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
|
||||
# Add version
|
||||
version = PluginVersion(
|
||||
version="1.1.0",
|
||||
changelog="Bug fixes",
|
||||
download_url="https://github.com/test/plugin/archive/v1.1.0.tar.gz",
|
||||
checksum="def456",
|
||||
aitbc_compatibility=["1.0.0"],
|
||||
release_date=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/plugins/test_plugin/versions", json=version.model_dump(mode='json'))
|
||||
|
||||
# Add same version again should fail
|
||||
response = client.post("/api/v1/plugins/test_plugin/versions", json=version.model_dump(mode='json'))
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_plugins():
|
||||
"""Test listing plugins"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/plugins")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "plugins" in data
|
||||
assert "total_plugins" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin():
|
||||
"""Test getting specific plugin"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Register plugin first
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
|
||||
# Get plugin
|
||||
response = client.get("/api/v1/plugins/test_plugin")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["plugin_id"] == "test_plugin"
|
||||
assert data["name"] == "Test Plugin"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin_not_found():
|
||||
"""Test getting nonexistent plugin"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/plugins/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin_versions():
|
||||
"""Test getting plugin versions"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Register plugin first
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
|
||||
# Get versions
|
||||
response = client.get("/api/v1/plugins/test_plugin/versions")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["plugin_id"] == "test_plugin"
|
||||
assert "versions" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_download_plugin():
|
||||
"""Test downloading plugin"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Register plugin first
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
|
||||
# Add version first
|
||||
version = PluginVersion(
|
||||
version="1.0.0",
|
||||
changelog="Initial release",
|
||||
download_url="https://github.com/test/plugin/archive/v1.0.0.tar.gz",
|
||||
checksum="abc123",
|
||||
aitbc_compatibility=["1.0.0"],
|
||||
release_date=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/plugins/test_plugin/versions", json=version.model_dump(mode='json'))
|
||||
|
||||
# Download plugin
|
||||
response = client.get("/api/v1/plugins/test_plugin/download/1.0.0")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["plugin_id"] == "test_plugin"
|
||||
assert data["version"] == "1.0.0"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_security_scan():
|
||||
"""Test creating security scan"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Register plugin first
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
|
||||
# Add version first
|
||||
version = PluginVersion(
|
||||
version="1.0.0",
|
||||
changelog="Initial release",
|
||||
download_url="https://github.com/test/plugin/archive/v1.0.0.tar.gz",
|
||||
checksum="abc123",
|
||||
aitbc_compatibility=["1.0.0"],
|
||||
release_date=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/plugins/test_plugin/versions", json=version.model_dump(mode='json'))
|
||||
|
||||
# Create security scan
|
||||
scan = SecurityScan(
|
||||
scan_id="scan_123",
|
||||
plugin_id="test_plugin",
|
||||
version="1.0.0",
|
||||
scan_date=datetime.now(timezone.utc),
|
||||
vulnerabilities=[],
|
||||
risk_score="low",
|
||||
passed=True
|
||||
)
|
||||
response = client.post("/api/v1/plugins/test_plugin/security-scan", json=scan.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["scan_id"] == "scan_123"
|
||||
assert data["passed"] is True
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_plugin_security():
|
||||
"""Test getting plugin security info"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Register plugin first
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
dependencies=[],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
client.post("/api/v1/plugins/register", json=plugin.model_dump())
|
||||
|
||||
# Get security info
|
||||
response = client.get("/api/v1/plugins/test_plugin/security")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["plugin_id"] == "test_plugin"
|
||||
assert "security_scans" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_categories():
|
||||
"""Test getting categories"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/categories")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "categories" in data
|
||||
assert "total_categories" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_tags():
|
||||
"""Test getting tags"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/tags")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tags" in data
|
||||
assert "total_tags" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_popular_plugins():
|
||||
"""Test getting popular plugins"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/popular")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "popular_plugins" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_recent_plugins():
|
||||
"""Test getting recent plugins"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/recent")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "recent_plugins" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_analytics_dashboard():
|
||||
"""Test getting analytics dashboard"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/analytics/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "dashboard" in data
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Unit tests for plugin registry service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, PluginRegistration, PluginVersion, SecurityScan
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_app_initialization():
|
||||
"""Test that the FastAPI app initializes correctly"""
|
||||
assert app is not None
|
||||
assert app.title == "AITBC Plugin Registry"
|
||||
assert app.version == "1.0.0"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_registration_model():
|
||||
"""Test PluginRegistration model"""
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=["test", "demo"],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
homepage_url="https://test.com",
|
||||
license="MIT",
|
||||
dependencies=["dependency1"],
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
assert plugin.name == "Test Plugin"
|
||||
assert plugin.version == "1.0.0"
|
||||
assert plugin.author == "Test Author"
|
||||
assert plugin.category == "testing"
|
||||
assert plugin.tags == ["test", "demo"]
|
||||
assert plugin.license == "MIT"
|
||||
assert plugin.plugin_type == "cli"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_registration_defaults():
|
||||
"""Test PluginRegistration default values"""
|
||||
plugin = PluginRegistration(
|
||||
name="Test Plugin",
|
||||
version="1.0.0",
|
||||
description="A test plugin",
|
||||
author="Test Author",
|
||||
category="testing",
|
||||
tags=[],
|
||||
repository_url="https://github.com/test/plugin",
|
||||
license="MIT",
|
||||
aitbc_version="1.0.0",
|
||||
plugin_type="cli"
|
||||
)
|
||||
assert plugin.homepage_url is None
|
||||
assert plugin.dependencies == []
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_plugin_version_model():
|
||||
"""Test PluginVersion model"""
|
||||
version = PluginVersion(
|
||||
version="1.0.0",
|
||||
changelog="Initial release",
|
||||
download_url="https://github.com/test/plugin/archive/v1.0.0.tar.gz",
|
||||
checksum="abc123",
|
||||
aitbc_compatibility=["1.0.0", "1.1.0"],
|
||||
release_date=datetime.now(timezone.utc)
|
||||
)
|
||||
assert version.version == "1.0.0"
|
||||
assert version.changelog == "Initial release"
|
||||
assert version.download_url == "https://github.com/test/plugin/archive/v1.0.0.tar.gz"
|
||||
assert version.checksum == "abc123"
|
||||
assert version.aitbc_compatibility == ["1.0.0", "1.1.0"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_security_scan_model():
|
||||
"""Test SecurityScan model"""
|
||||
scan = SecurityScan(
|
||||
scan_id="scan_123",
|
||||
plugin_id="test_plugin",
|
||||
version="1.0.0",
|
||||
scan_date=datetime.now(timezone.utc),
|
||||
vulnerabilities=[{"severity": "low", "description": "Test"}],
|
||||
risk_score="low",
|
||||
passed=True
|
||||
)
|
||||
assert scan.scan_id == "scan_123"
|
||||
assert scan.plugin_id == "test_plugin"
|
||||
assert scan.version == "1.0.0"
|
||||
assert scan.risk_score == "low"
|
||||
assert scan.passed is True
|
||||
assert len(scan.vulnerabilities) == 1
|
||||
659
examples/stubs/plugin-security/main.py
Executable file
659
examples/stubs/plugin-security/main.py
Executable file
@@ -0,0 +1,659 @@
|
||||
"""
|
||||
Plugin Security Validation Service for AITBC
|
||||
Handles plugin security scanning, vulnerability detection, and validation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File
|
||||
from pydantic import BaseModel
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="AITBC Plugin Security Service",
|
||||
description="Security validation and vulnerability scanning for AITBC plugins",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Data models
|
||||
class SecurityScan(BaseModel):
|
||||
plugin_id: str
|
||||
version: str
|
||||
plugin_type: str
|
||||
scan_type: str # basic, comprehensive, deep
|
||||
priority: str # low, medium, high, critical
|
||||
|
||||
class Vulnerability(BaseModel):
|
||||
cve_id: Optional[str]
|
||||
severity: str # low, medium, high, critical
|
||||
title: str
|
||||
description: str
|
||||
affected_file: str
|
||||
line_number: Optional[int]
|
||||
recommendation: str
|
||||
|
||||
class SecurityReport(BaseModel):
|
||||
scan_id: str
|
||||
plugin_id: str
|
||||
version: str
|
||||
scan_date: datetime
|
||||
scan_duration: float
|
||||
overall_score: str # passed, warning, failed, critical
|
||||
vulnerabilities: List[Vulnerability]
|
||||
security_metrics: Dict[str, Any]
|
||||
recommendations: List[str]
|
||||
|
||||
# In-memory storage (in production, use database)
|
||||
scan_reports: Dict[str, Dict] = {}
|
||||
security_policies: Dict[str, Dict] = {}
|
||||
scan_queue: List[Dict] = []
|
||||
vulnerability_database: Dict[str, Dict] = {}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "AITBC Plugin Security Service",
|
||||
"status": "running",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"total_scans": len(scan_reports),
|
||||
"queue_size": len(scan_queue),
|
||||
"vulnerabilities_db": len(vulnerability_database),
|
||||
"active_policies": len(security_policies)
|
||||
}
|
||||
|
||||
@app.post("/api/v1/security/scan")
|
||||
async def initiate_security_scan(scan: SecurityScan):
|
||||
"""Initiate a security scan for a plugin"""
|
||||
scan_id = f"scan_{int(datetime.now(timezone.utc).timestamp())}"
|
||||
|
||||
# Create scan record
|
||||
scan_record = {
|
||||
"scan_id": scan_id,
|
||||
"plugin_id": scan.plugin_id,
|
||||
"version": scan.version,
|
||||
"plugin_type": scan.plugin_type,
|
||||
"scan_type": scan.scan_type,
|
||||
"priority": scan.priority,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"started_at": None,
|
||||
"completed_at": None,
|
||||
"duration": None,
|
||||
"result": None
|
||||
}
|
||||
|
||||
scan_queue.append(scan_record)
|
||||
|
||||
# Sort queue by priority
|
||||
priority_order = {"critical": 0, "high": 1, "medium": 2, "low": 3}
|
||||
scan_queue.sort(key=lambda x: priority_order.get(x["priority"], 4))
|
||||
|
||||
logger.info(f"Security scan queued: {scan_id} for {scan.plugin_id} v{scan.version}")
|
||||
|
||||
return {
|
||||
"scan_id": scan_id,
|
||||
"status": "queued",
|
||||
"queue_position": scan_queue.index(scan_record) + 1,
|
||||
"estimated_time": estimate_scan_time(scan.scan_type)
|
||||
}
|
||||
|
||||
@app.get("/api/v1/security/scan/{scan_id}")
|
||||
async def get_scan_status(scan_id: str):
|
||||
"""Get scan status and results"""
|
||||
if scan_id not in scan_reports and not any(s["scan_id"] == scan_id for s in scan_queue):
|
||||
raise HTTPException(status_code=404, detail="Scan not found")
|
||||
|
||||
# Check if scan is in queue
|
||||
for scan_record in scan_queue:
|
||||
if scan_record["scan_id"] == scan_id:
|
||||
return {
|
||||
"scan_id": scan_id,
|
||||
"status": scan_record["status"],
|
||||
"queue_position": scan_queue.index(scan_record) + 1,
|
||||
"created_at": scan_record["created_at"]
|
||||
}
|
||||
|
||||
# Return completed scan results
|
||||
return scan_reports.get(scan_id, {"status": "not_found"})
|
||||
|
||||
@app.get("/api/v1/security/reports")
|
||||
async def list_security_reports(plugin_id: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
limit: int = 50):
|
||||
"""List security scan reports"""
|
||||
reports = list(scan_reports.values())
|
||||
|
||||
# Apply filters
|
||||
if plugin_id:
|
||||
reports = [r for r in reports if r.get("plugin_id") == plugin_id]
|
||||
if status:
|
||||
reports = [r for r in reports if r.get("status") == status]
|
||||
|
||||
# Sort by scan date (most recent first)
|
||||
reports.sort(key=lambda x: x.get("scan_date", ""), reverse=True)
|
||||
|
||||
return {
|
||||
"reports": reports[:limit],
|
||||
"total_reports": len(reports),
|
||||
"filters": {
|
||||
"plugin_id": plugin_id,
|
||||
"status": status,
|
||||
"limit": limit
|
||||
}
|
||||
}
|
||||
|
||||
@app.get("/api/v1/security/vulnerabilities")
|
||||
async def list_vulnerabilities(severity: Optional[str] = None,
|
||||
plugin_id: Optional[str] = None):
|
||||
"""List known vulnerabilities"""
|
||||
vulnerabilities = list(vulnerability_database.values())
|
||||
|
||||
# Apply filters
|
||||
if severity:
|
||||
vulnerabilities = [v for v in vulnerabilities if v["severity"] == severity]
|
||||
if plugin_id:
|
||||
vulnerabilities = [v for v in vulnerabilities if v.get("plugin_id") == plugin_id]
|
||||
|
||||
return {
|
||||
"vulnerabilities": vulnerabilities,
|
||||
"total_vulnerabilities": len(vulnerabilities),
|
||||
"filters": {
|
||||
"severity": severity,
|
||||
"plugin_id": plugin_id
|
||||
}
|
||||
}
|
||||
|
||||
@app.post("/api/v1/security/policies")
|
||||
async def create_security_policy(policy: Dict[str, Any]):
|
||||
"""Create a new security policy"""
|
||||
policy_id = f"policy_{int(datetime.now(timezone.utc).timestamp())}"
|
||||
|
||||
policy_record = {
|
||||
"policy_id": policy_id,
|
||||
"name": policy.get("name"),
|
||||
"description": policy.get("description"),
|
||||
"rules": policy.get("rules", []),
|
||||
"severity_thresholds": policy.get("severity_thresholds", {
|
||||
"critical": 0,
|
||||
"high": 0,
|
||||
"medium": 5,
|
||||
"low": 10
|
||||
}),
|
||||
"plugin_types": policy.get("plugin_types", []),
|
||||
"active": True,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"updated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
security_policies[policy_id] = policy_record
|
||||
|
||||
logger.info(f"Security policy created: {policy_id} - {policy.get('name')}")
|
||||
|
||||
return {
|
||||
"policy_id": policy_id,
|
||||
"name": policy.get("name"),
|
||||
"status": "created",
|
||||
"active": True
|
||||
}
|
||||
|
||||
@app.get("/api/v1/security/policies")
|
||||
async def list_security_policies():
|
||||
"""List all security policies"""
|
||||
return {
|
||||
"policies": list(security_policies.values()),
|
||||
"total_policies": len(security_policies),
|
||||
"active_policies": len([p for p in security_policies.values() if p["active"]])
|
||||
}
|
||||
|
||||
@app.post("/api/v1/security/upload")
|
||||
async def upload_plugin_for_scan(plugin_id: str, version: str,
|
||||
file: UploadFile = File(...)):
|
||||
"""Upload plugin file for security scanning"""
|
||||
# Validate file
|
||||
if not file.filename.endswith(('.py', '.zip', '.tar.gz')):
|
||||
raise HTTPException(status_code=400, detail="Invalid file type")
|
||||
|
||||
# Save uploaded file temporarily
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as tmp_file:
|
||||
content = await file.read()
|
||||
tmp_file.write(content)
|
||||
tmp_file_path = tmp_file.name
|
||||
|
||||
# Initiate scan
|
||||
scan = SecurityScan(
|
||||
plugin_id=plugin_id,
|
||||
version=version,
|
||||
plugin_type="uploaded",
|
||||
scan_type="comprehensive",
|
||||
priority="medium"
|
||||
)
|
||||
|
||||
scan_result = await initiate_security_scan(scan)
|
||||
|
||||
# Start async scan process
|
||||
asyncio.create_task(process_scan_file(scan_result["scan_id"], tmp_file_path, file.filename))
|
||||
|
||||
return {
|
||||
"scan_id": scan_result["scan_id"],
|
||||
"filename": file.filename,
|
||||
"file_size": len(content),
|
||||
"status": "uploaded_and_queued"
|
||||
}
|
||||
|
||||
@app.get("/api/v1/security/dashboard")
|
||||
async def get_security_dashboard():
|
||||
"""Get security dashboard data"""
|
||||
total_scans = len(scan_reports)
|
||||
recent_scans = [r for r in scan_reports.values()
|
||||
if datetime.fromisoformat(r["scan_date"]) > datetime.now(timezone.utc) - timedelta(days=7)]
|
||||
|
||||
# Calculate statistics
|
||||
scan_results = list(scan_reports.values())
|
||||
passed_scans = len([r for r in scan_results if r.get("overall_score") == "passed"])
|
||||
warning_scans = len([r for r in scan_results if r.get("overall_score") == "warning"])
|
||||
failed_scans = len([r for r in scan_results if r.get("overall_score") in ["failed", "critical"]])
|
||||
|
||||
# Vulnerability statistics
|
||||
all_vulnerabilities = []
|
||||
for report in scan_results:
|
||||
all_vulnerabilities.extend(report.get("vulnerabilities", []))
|
||||
|
||||
vuln_by_severity = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||
for vuln in all_vulnerabilities:
|
||||
vuln_by_severity[vuln["severity"]] = vuln_by_severity.get(vuln["severity"], 0) + 1
|
||||
|
||||
return {
|
||||
"dashboard": {
|
||||
"total_scans": total_scans,
|
||||
"recent_scans": len(recent_scans),
|
||||
"scan_results": {
|
||||
"passed": passed_scans,
|
||||
"warning": warning_scans,
|
||||
"failed": failed_scans
|
||||
},
|
||||
"vulnerabilities": {
|
||||
"total": len(all_vulnerabilities),
|
||||
"by_severity": vuln_by_severity
|
||||
},
|
||||
"queue_size": len(scan_queue),
|
||||
"active_policies": len([p for p in security_policies.values() if p["active"]])
|
||||
},
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Core security scanning functions
|
||||
async def process_scan_file(scan_id: str, file_path: str, filename: str):
|
||||
"""Process uploaded file for security scanning"""
|
||||
try:
|
||||
# Update scan status
|
||||
for scan_record in scan_queue:
|
||||
if scan_record["scan_id"] == scan_id:
|
||||
scan_record["status"] = "running"
|
||||
scan_record["started_at"] = datetime.now(timezone.utc).isoformat()
|
||||
break
|
||||
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
# Perform security scan
|
||||
scan_result = await perform_security_scan(file_path, filename)
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
# Create security report
|
||||
security_report = SecurityReport(
|
||||
scan_id=scan_id,
|
||||
plugin_id=scan_record["plugin_id"],
|
||||
version=scan_record["version"],
|
||||
scan_date=end_time,
|
||||
scan_duration=duration,
|
||||
overall_score=calculate_overall_score(scan_result),
|
||||
vulnerabilities=scan_result["vulnerabilities"],
|
||||
security_metrics=scan_result["metrics"],
|
||||
recommendations=scan_result["recommendations"]
|
||||
)
|
||||
|
||||
# Save report
|
||||
report_data = {
|
||||
"scan_id": scan_id,
|
||||
"plugin_id": scan_record["plugin_id"],
|
||||
"version": scan_record["version"],
|
||||
"scan_date": security_report.scan_date.isoformat(),
|
||||
"scan_duration": security_report.scan_duration,
|
||||
"overall_score": security_report.overall_score,
|
||||
"vulnerabilities": [v.dict() for v in security_report.vulnerabilities],
|
||||
"security_metrics": security_report.security_metrics,
|
||||
"recommendations": security_report.recommendations,
|
||||
"status": "completed",
|
||||
"completed_at": security_report.scan_date.isoformat()
|
||||
}
|
||||
|
||||
scan_reports[scan_id] = report_data
|
||||
|
||||
# Remove from queue
|
||||
scan_queue[:] = [s for s in scan_queue if s["scan_id"] != scan_id]
|
||||
|
||||
# Clean up temporary file
|
||||
os.unlink(file_path)
|
||||
|
||||
logger.info(f"Security scan completed: {scan_id} - {security_report.overall_score}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing scan {scan_id}: {str(e)}")
|
||||
# Update scan status to failed
|
||||
for scan_record in scan_queue:
|
||||
if scan_record["scan_id"] == scan_id:
|
||||
scan_record["status"] = "failed"
|
||||
scan_record["completed_at"] = datetime.now(timezone.utc).isoformat()
|
||||
break
|
||||
|
||||
async def perform_security_scan(file_path: str, filename: str) -> Dict[str, Any]:
|
||||
"""Perform actual security scanning"""
|
||||
vulnerabilities = []
|
||||
metrics = {}
|
||||
recommendations = []
|
||||
|
||||
# File analysis
|
||||
try:
|
||||
# Basic file checks
|
||||
file_size = os.path.getsize(file_path)
|
||||
metrics["file_size"] = file_size
|
||||
|
||||
# Check for suspicious patterns (simplified for demo)
|
||||
if filename.endswith('.py'):
|
||||
vulnerabilities.extend(scan_python_file(file_path))
|
||||
elif filename.endswith('.zip'):
|
||||
vulnerabilities.extend(scan_zip_file(file_path))
|
||||
|
||||
# Check common vulnerabilities
|
||||
vulnerabilities.extend(check_common_vulnerabilities(file_path))
|
||||
|
||||
# Generate recommendations
|
||||
recommendations = generate_recommendations(vulnerabilities)
|
||||
|
||||
# Calculate metrics
|
||||
metrics.update({
|
||||
"vulnerability_count": len(vulnerabilities),
|
||||
"severity_distribution": get_severity_distribution(vulnerabilities),
|
||||
"file_type": filename.split('.')[-1],
|
||||
"scan_timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during security scan: {str(e)}")
|
||||
vulnerabilities.append({
|
||||
"severity": "medium",
|
||||
"title": "Scan Error",
|
||||
"description": f"Error during scanning: {str(e)}",
|
||||
"affected_file": filename,
|
||||
"recommendation": "Review file and rescan"
|
||||
})
|
||||
|
||||
return {
|
||||
"vulnerabilities": vulnerabilities,
|
||||
"metrics": metrics,
|
||||
"recommendations": recommendations
|
||||
}
|
||||
|
||||
async def scan_python_file(file_path: str) -> List[Dict]:
|
||||
"""Scan Python file for security issues"""
|
||||
vulnerabilities = []
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
lines = content.split('\n')
|
||||
|
||||
# Check for suspicious patterns
|
||||
suspicious_patterns = {
|
||||
"eval": "Use of eval() function",
|
||||
"exec": "Use of exec() function",
|
||||
"subprocess.call": "Unsafe subprocess usage",
|
||||
"os.system": "Use of os.system() function",
|
||||
"pickle.loads": "Unsafe pickle deserialization",
|
||||
"input(": "Use of input() function"
|
||||
}
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
for pattern, description in suspicious_patterns.items():
|
||||
if pattern in line:
|
||||
vulnerabilities.append({
|
||||
"severity": "medium",
|
||||
"title": "Suspicious Code Pattern",
|
||||
"description": description,
|
||||
"affected_file": file_path,
|
||||
"line_number": i,
|
||||
"recommendation": f"Review usage of {pattern} and consider safer alternatives"
|
||||
})
|
||||
|
||||
# Check for hardcoded credentials
|
||||
if any('password' in line.lower() or 'secret' in line.lower() or 'key' in line.lower()
|
||||
for line in lines):
|
||||
vulnerabilities.append({
|
||||
"severity": "high",
|
||||
"title": "Potential Hardcoded Credentials",
|
||||
"description": "Possible hardcoded sensitive information detected",
|
||||
"affected_file": file_path,
|
||||
"recommendation": "Use environment variables or secure configuration management"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning Python file: {str(e)}")
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
async def scan_zip_file(file_path: str) -> List[Dict]:
|
||||
"""Scan ZIP file for security issues"""
|
||||
vulnerabilities = []
|
||||
|
||||
try:
|
||||
import zipfile
|
||||
|
||||
with zipfile.ZipFile(file_path, 'r') as zip_file:
|
||||
# Check for suspicious files
|
||||
for file_info in zip_file.filelist:
|
||||
filename = file_info.filename.lower()
|
||||
|
||||
# Check for suspicious file types
|
||||
suspicious_extensions = ['.exe', '.bat', '.cmd', '.scr', '.dll', '.so']
|
||||
if any(filename.endswith(ext) for ext in suspicious_extensions):
|
||||
vulnerabilities.append({
|
||||
"severity": "high",
|
||||
"title": "Suspicious File Type",
|
||||
"description": f"Suspicious file found in archive: {filename}",
|
||||
"affected_file": file_path,
|
||||
"recommendation": "Review file contents and ensure they are safe"
|
||||
})
|
||||
|
||||
# Check for large files (potential data exfiltration)
|
||||
if file_info.file_size > 100 * 1024 * 1024: # 100MB
|
||||
vulnerabilities.append({
|
||||
"severity": "medium",
|
||||
"title": "Large File Detected",
|
||||
"description": f"Large file detected: {filename} ({file_info.file_size} bytes)",
|
||||
"affected_file": file_path,
|
||||
"recommendation": "Verify file contents and necessity"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning ZIP file: {str(e)}")
|
||||
vulnerabilities.append({
|
||||
"severity": "medium",
|
||||
"title": "ZIP Scan Error",
|
||||
"description": f"Error scanning ZIP file: {str(e)}",
|
||||
"affected_file": file_path,
|
||||
"recommendation": "Verify ZIP file integrity"
|
||||
})
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
async def check_common_vulnerabilities(file_path: str) -> List[Dict]:
|
||||
"""Check for common security vulnerabilities"""
|
||||
vulnerabilities = []
|
||||
|
||||
# Mock vulnerability database check
|
||||
known_vulnerabilities = {
|
||||
"requests": "Check for outdated requests library",
|
||||
"urllib": "Check for urllib security issues",
|
||||
"socket": "Check for unsafe socket usage"
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
for lib, issue in known_vulnerabilities.items():
|
||||
if lib in content:
|
||||
vulnerabilities.append({
|
||||
"severity": "low",
|
||||
"title": f"Library Security Check",
|
||||
"description": issue,
|
||||
"affected_file": file_path,
|
||||
"recommendation": f"Update {lib} to latest secure version"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking common vulnerabilities: {str(e)}")
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
def calculate_overall_score(scan_result: Dict[str, Any]) -> str:
|
||||
"""Calculate overall security score"""
|
||||
vulnerabilities = scan_result["vulnerabilities"]
|
||||
|
||||
if not vulnerabilities:
|
||||
return "passed"
|
||||
|
||||
# Count by severity
|
||||
critical_count = len([v for v in vulnerabilities if v["severity"] == "critical"])
|
||||
high_count = len([v for v in vulnerabilities if v["severity"] == "high"])
|
||||
medium_count = len([v for v in vulnerabilities if v["severity"] == "medium"])
|
||||
low_count = len([v for v in vulnerabilities if v["severity"] == "low"])
|
||||
|
||||
# Determine overall score
|
||||
if critical_count > 0:
|
||||
return "critical"
|
||||
elif high_count > 2:
|
||||
return "failed"
|
||||
elif high_count > 0 or medium_count > 5:
|
||||
return "warning"
|
||||
else:
|
||||
return "passed"
|
||||
|
||||
def generate_recommendations(vulnerabilities: List[Dict]) -> List[str]:
|
||||
"""Generate security recommendations"""
|
||||
recommendations = []
|
||||
|
||||
if not vulnerabilities:
|
||||
recommendations.append("No security issues detected. Plugin appears secure.")
|
||||
return recommendations
|
||||
|
||||
# Generate recommendations based on vulnerabilities
|
||||
severity_counts = {}
|
||||
for vuln in vulnerabilities:
|
||||
severity = vuln["severity"]
|
||||
severity_counts[severity] = severity_counts.get(severity, 0) + 1
|
||||
|
||||
if severity_counts.get("critical", 0) > 0:
|
||||
recommendations.append("CRITICAL: Address critical security vulnerabilities immediately.")
|
||||
|
||||
if severity_counts.get("high", 0) > 0:
|
||||
recommendations.append("HIGH: Review and fix high-severity security issues.")
|
||||
|
||||
if severity_counts.get("medium", 0) > 3:
|
||||
recommendations.append("MEDIUM: Consider addressing medium-severity issues.")
|
||||
|
||||
recommendations.append("Regular security scans recommended for ongoing protection.")
|
||||
recommendations.append("Keep all dependencies updated to latest secure versions.")
|
||||
|
||||
return recommendations
|
||||
|
||||
def get_severity_distribution(vulnerabilities: List[Dict]) -> Dict[str, int]:
|
||||
"""Get vulnerability severity distribution"""
|
||||
distribution = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||
for vuln in vulnerabilities:
|
||||
severity = vuln["severity"]
|
||||
distribution[severity] = distribution.get(severity, 0) + 1
|
||||
return distribution
|
||||
|
||||
def estimate_scan_time(scan_type: str) -> str:
|
||||
"""Estimate scan time based on scan type"""
|
||||
estimates = {
|
||||
"basic": "1-2 minutes",
|
||||
"comprehensive": "5-10 minutes",
|
||||
"deep": "15-30 minutes"
|
||||
}
|
||||
return estimates.get(scan_type, "5-10 minutes")
|
||||
|
||||
# Background task for processing scan queue
|
||||
async def process_scan_queue():
|
||||
"""Background task to process security scan queue"""
|
||||
while True:
|
||||
await asyncio.sleep(10) # Check queue every 10 seconds
|
||||
|
||||
if scan_queue:
|
||||
# Get next scan from queue
|
||||
scan_record = scan_queue[0]
|
||||
|
||||
# Process scan (in production, this would be more sophisticated)
|
||||
logger.info(f"Processing scan from queue: {scan_record['scan_id']}")
|
||||
|
||||
# Simulate processing time
|
||||
await asyncio.sleep(2)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
logger.info("Starting AITBC Plugin Security Service")
|
||||
# Initialize vulnerability database
|
||||
initialize_vulnerability_database()
|
||||
# Start queue processing
|
||||
asyncio.create_task(process_scan_queue())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logger.info("Shutting down AITBC Plugin Security Service")
|
||||
|
||||
def initialize_vulnerability_database():
|
||||
"""Initialize vulnerability database with known issues"""
|
||||
# Mock data for demo
|
||||
vulnerabilities = [
|
||||
{
|
||||
"vuln_id": "CVE-2023-1234",
|
||||
"severity": "high",
|
||||
"title": "Buffer Overflow in Library X",
|
||||
"description": "Buffer overflow vulnerability in commonly used library",
|
||||
"affected_plugins": ["plugin1", "plugin2"],
|
||||
"recommendation": "Update to latest version"
|
||||
},
|
||||
{
|
||||
"vuln_id": "CVE-2023-5678",
|
||||
"severity": "medium",
|
||||
"title": "Information Disclosure",
|
||||
"description": "Potential information disclosure in logging",
|
||||
"affected_plugins": ["plugin3"],
|
||||
"recommendation": "Review logging implementation"
|
||||
}
|
||||
]
|
||||
|
||||
for vuln in vulnerabilities:
|
||||
vulnerability_database[vuln["vuln_id"]] = vuln
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8015, log_level="info")
|
||||
1
examples/stubs/plugin-security/tests/__init__.py
Normal file
1
examples/stubs/plugin-security/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Plugin security service tests"""
|
||||
@@ -0,0 +1,159 @@
|
||||
"""Edge case and error handling tests for plugin security service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
from main import app, SecurityScan, scan_reports, security_policies, scan_queue, vulnerability_database
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
scan_reports.clear()
|
||||
security_policies.clear()
|
||||
scan_queue.clear()
|
||||
vulnerability_database.clear()
|
||||
yield
|
||||
scan_reports.clear()
|
||||
security_policies.clear()
|
||||
scan_queue.clear()
|
||||
vulnerability_database.clear()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_security_scan_empty_fields():
|
||||
"""Test SecurityScan with empty fields"""
|
||||
scan = SecurityScan(
|
||||
plugin_id="",
|
||||
version="",
|
||||
plugin_type="",
|
||||
scan_type="",
|
||||
priority=""
|
||||
)
|
||||
assert scan.plugin_id == ""
|
||||
assert scan.version == ""
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_vulnerability_empty_description():
|
||||
"""Test Vulnerability with empty description"""
|
||||
vuln = {
|
||||
"severity": "low",
|
||||
"title": "Test",
|
||||
"description": "",
|
||||
"affected_file": "file.py",
|
||||
"recommendation": "Fix"
|
||||
}
|
||||
assert vuln["description"] == ""
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_security_policy_minimal():
|
||||
"""Test creating security policy with minimal fields"""
|
||||
client = TestClient(app)
|
||||
policy = {
|
||||
"name": "Minimal Policy"
|
||||
}
|
||||
response = client.post("/api/v1/security/policies", json=policy)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["policy_id"]
|
||||
assert data["name"] == "Minimal Policy"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_security_policy_empty_name():
|
||||
"""Test creating security policy with empty name"""
|
||||
client = TestClient(app)
|
||||
policy = {}
|
||||
response = client.post("/api/v1/security/policies", json=policy)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_security_reports_with_no_reports():
|
||||
"""Test listing security reports when no reports exist"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/security/reports")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_reports"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_vulnerabilities_with_no_vulnerabilities():
|
||||
"""Test listing vulnerabilities when no vulnerabilities exist"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/security/vulnerabilities")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_vulnerabilities"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_security_policies_with_no_policies():
|
||||
"""Test listing security policies when no policies exist"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/security/policies")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_policies"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_scan_priority_ordering():
|
||||
"""Test that scan queue respects priority ordering"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Add scans in random priority order
|
||||
priorities = ["low", "critical", "medium", "high"]
|
||||
for priority in priorities:
|
||||
scan = SecurityScan(
|
||||
plugin_id=f"plugin_{priority}",
|
||||
version="1.0.0",
|
||||
plugin_type="cli",
|
||||
scan_type="basic",
|
||||
priority=priority
|
||||
)
|
||||
client.post("/api/v1/security/scan", json=scan.model_dump())
|
||||
|
||||
# Critical should be first, low should be last
|
||||
response = client.get("/api/v1/security/scan/nonexistent")
|
||||
# This will fail, but we can check queue size
|
||||
assert len(scan_queue) == 4
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_security_dashboard_with_no_data():
|
||||
"""Test security dashboard with no data"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/security/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["dashboard"]["total_scans"] == 0
|
||||
assert data["dashboard"]["queue_size"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_reports_limit_parameter():
|
||||
"""Test listing reports with limit parameter"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/security/reports?limit=5")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "reports" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_vulnerabilities_invalid_filter():
|
||||
"""Test listing vulnerabilities with invalid filter"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/security/vulnerabilities?severity=invalid")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_vulnerabilities"] == 0
|
||||
@@ -0,0 +1,217 @@
|
||||
"""Integration tests for plugin security service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
from main import app, SecurityScan, scan_reports, security_policies, scan_queue, vulnerability_database
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
scan_reports.clear()
|
||||
security_policies.clear()
|
||||
scan_queue.clear()
|
||||
vulnerability_database.clear()
|
||||
yield
|
||||
scan_reports.clear()
|
||||
security_policies.clear()
|
||||
scan_queue.clear()
|
||||
vulnerability_database.clear()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_root_endpoint():
|
||||
"""Test root endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["service"] == "AITBC Plugin Security Service"
|
||||
assert data["status"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_health_check_endpoint():
|
||||
"""Test health check endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "total_scans" in data
|
||||
assert "queue_size" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_initiate_security_scan():
|
||||
"""Test initiating a security scan"""
|
||||
client = TestClient(app)
|
||||
scan = SecurityScan(
|
||||
plugin_id="plugin_123",
|
||||
version="1.0.0",
|
||||
plugin_type="cli",
|
||||
scan_type="comprehensive",
|
||||
priority="high"
|
||||
)
|
||||
response = client.post("/api/v1/security/scan", json=scan.model_dump())
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["scan_id"]
|
||||
assert data["status"] == "queued"
|
||||
assert "queue_position" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_scan_status_queued():
|
||||
"""Test getting scan status for queued scan"""
|
||||
client = TestClient(app)
|
||||
scan = SecurityScan(
|
||||
plugin_id="plugin_123",
|
||||
version="1.0.0",
|
||||
plugin_type="cli",
|
||||
scan_type="basic",
|
||||
priority="medium"
|
||||
)
|
||||
scan_response = client.post("/api/v1/security/scan", json=scan.model_dump())
|
||||
scan_id = scan_response.json()["scan_id"]
|
||||
|
||||
response = client.get(f"/api/v1/security/scan/{scan_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["scan_id"] == scan_id
|
||||
assert data["status"] == "queued"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_scan_status_not_found():
|
||||
"""Test getting scan status for nonexistent scan"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/security/scan/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_security_reports():
|
||||
"""Test listing security reports"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/security/reports")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "reports" in data
|
||||
assert "total_reports" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_security_reports_with_filters():
|
||||
"""Test listing security reports with filters"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/security/reports?plugin_id=plugin_123&status=completed")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "reports" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_vulnerabilities():
|
||||
"""Test listing vulnerabilities"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/security/vulnerabilities")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "vulnerabilities" in data
|
||||
assert "total_vulnerabilities" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_vulnerabilities_with_filters():
|
||||
"""Test listing vulnerabilities with filters"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/security/vulnerabilities?severity=high&plugin_id=plugin_123")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "vulnerabilities" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_create_security_policy():
|
||||
"""Test creating a security policy"""
|
||||
client = TestClient(app)
|
||||
policy = {
|
||||
"name": "Test Policy",
|
||||
"description": "A test security policy",
|
||||
"rules": ["rule1", "rule2"],
|
||||
"severity_thresholds": {
|
||||
"critical": 0,
|
||||
"high": 0,
|
||||
"medium": 5,
|
||||
"low": 10
|
||||
},
|
||||
"plugin_types": ["cli", "web"]
|
||||
}
|
||||
response = client.post("/api/v1/security/policies", json=policy)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["policy_id"]
|
||||
assert data["name"] == "Test Policy"
|
||||
assert data["active"] is True
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_security_policies():
|
||||
"""Test listing security policies"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/security/policies")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "policies" in data
|
||||
assert "total_policies" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_security_dashboard():
|
||||
"""Test getting security dashboard"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/security/dashboard")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "dashboard" in data
|
||||
assert "total_scans" in data["dashboard"]
|
||||
assert "vulnerabilities" in data["dashboard"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_scan_priority_queueing():
|
||||
"""Test that scans are queued by priority"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Add low priority scan
|
||||
scan_low = SecurityScan(
|
||||
plugin_id="plugin_low",
|
||||
version="1.0.0",
|
||||
plugin_type="cli",
|
||||
scan_type="basic",
|
||||
priority="low"
|
||||
)
|
||||
client.post("/api/v1/security/scan", json=scan_low.model_dump())
|
||||
|
||||
# Add critical priority scan
|
||||
scan_critical = SecurityScan(
|
||||
plugin_id="plugin_critical",
|
||||
version="1.0.0",
|
||||
plugin_type="cli",
|
||||
scan_type="basic",
|
||||
priority="critical"
|
||||
)
|
||||
response = client.post("/api/v1/security/scan", json=scan_critical.model_dump())
|
||||
scan_id = response.json()["scan_id"]
|
||||
|
||||
# Critical scan should be at position 1
|
||||
response = client.get(f"/api/v1/security/scan/{scan_id}")
|
||||
data = response.json()
|
||||
assert data["queue_position"] == 1
|
||||
@@ -0,0 +1,205 @@
|
||||
"""Unit tests for plugin security service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, SecurityScan, Vulnerability, SecurityReport, calculate_overall_score, generate_recommendations, get_severity_distribution, estimate_scan_time
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_app_initialization():
|
||||
"""Test that the FastAPI app initializes correctly"""
|
||||
assert app is not None
|
||||
assert app.title == "AITBC Plugin Security Service"
|
||||
assert app.version == "1.0.0"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_security_scan_model():
|
||||
"""Test SecurityScan model"""
|
||||
scan = SecurityScan(
|
||||
plugin_id="plugin_123",
|
||||
version="1.0.0",
|
||||
plugin_type="cli",
|
||||
scan_type="comprehensive",
|
||||
priority="high"
|
||||
)
|
||||
assert scan.plugin_id == "plugin_123"
|
||||
assert scan.version == "1.0.0"
|
||||
assert scan.plugin_type == "cli"
|
||||
assert scan.scan_type == "comprehensive"
|
||||
assert scan.priority == "high"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_vulnerability_model():
|
||||
"""Test Vulnerability model"""
|
||||
vuln = Vulnerability(
|
||||
cve_id="CVE-2023-1234",
|
||||
severity="high",
|
||||
title="Buffer Overflow",
|
||||
description="Buffer overflow vulnerability",
|
||||
affected_file="file.py",
|
||||
line_number=42,
|
||||
recommendation="Update to latest version"
|
||||
)
|
||||
assert vuln.cve_id == "CVE-2023-1234"
|
||||
assert vuln.severity == "high"
|
||||
assert vuln.title == "Buffer Overflow"
|
||||
assert vuln.line_number == 42
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_vulnerability_model_optional_fields():
|
||||
"""Test Vulnerability model with optional fields"""
|
||||
vuln = Vulnerability(
|
||||
cve_id=None,
|
||||
severity="low",
|
||||
title="Minor issue",
|
||||
description="Description",
|
||||
affected_file="file.py",
|
||||
line_number=None,
|
||||
recommendation="Fix it"
|
||||
)
|
||||
assert vuln.cve_id is None
|
||||
assert vuln.line_number is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_security_report_model():
|
||||
"""Test SecurityReport model"""
|
||||
report = SecurityReport(
|
||||
scan_id="scan_123",
|
||||
plugin_id="plugin_123",
|
||||
version="1.0.0",
|
||||
scan_date=datetime.now(timezone.utc),
|
||||
scan_duration=120.5,
|
||||
overall_score="passed",
|
||||
vulnerabilities=[],
|
||||
security_metrics={},
|
||||
recommendations=[]
|
||||
)
|
||||
assert report.scan_id == "scan_123"
|
||||
assert report.overall_score == "passed"
|
||||
assert report.scan_duration == 120.5
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_overall_score_passed():
|
||||
"""Test calculate overall score with no vulnerabilities"""
|
||||
scan_result = {"vulnerabilities": []}
|
||||
score = calculate_overall_score(scan_result)
|
||||
assert score == "passed"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_overall_score_critical():
|
||||
"""Test calculate overall score with critical vulnerability"""
|
||||
scan_result = {
|
||||
"vulnerabilities": [
|
||||
{"severity": "critical"},
|
||||
{"severity": "low"}
|
||||
]
|
||||
}
|
||||
score = calculate_overall_score(scan_result)
|
||||
assert score == "critical"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_overall_score_failed():
|
||||
"""Test calculate overall score with multiple high vulnerabilities"""
|
||||
scan_result = {
|
||||
"vulnerabilities": [
|
||||
{"severity": "high"},
|
||||
{"severity": "high"},
|
||||
{"severity": "high"}
|
||||
]
|
||||
}
|
||||
score = calculate_overall_score(scan_result)
|
||||
assert score == "failed"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_overall_score_warning():
|
||||
"""Test calculate overall score with high and medium vulnerabilities"""
|
||||
scan_result = {
|
||||
"vulnerabilities": [
|
||||
{"severity": "high"},
|
||||
{"severity": "medium"},
|
||||
{"severity": "medium"},
|
||||
{"severity": "medium"},
|
||||
{"severity": "medium"},
|
||||
{"severity": "medium"}
|
||||
]
|
||||
}
|
||||
score = calculate_overall_score(scan_result)
|
||||
assert score == "warning"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generate_recommendations_no_vulnerabilities():
|
||||
"""Test generate recommendations with no vulnerabilities"""
|
||||
recommendations = generate_recommendations([])
|
||||
assert len(recommendations) == 1
|
||||
assert "No security issues detected" in recommendations[0]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generate_recommendations_critical():
|
||||
"""Test generate recommendations with critical vulnerabilities"""
|
||||
vulnerabilities = [
|
||||
{"severity": "critical"},
|
||||
{"severity": "high"}
|
||||
]
|
||||
recommendations = generate_recommendations(vulnerabilities)
|
||||
assert any("CRITICAL" in r for r in recommendations)
|
||||
assert any("HIGH" in r for r in recommendations)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_severity_distribution():
|
||||
"""Test get severity distribution"""
|
||||
vulnerabilities = [
|
||||
{"severity": "critical"},
|
||||
{"severity": "high"},
|
||||
{"severity": "high"},
|
||||
{"severity": "medium"},
|
||||
{"severity": "low"}
|
||||
]
|
||||
distribution = get_severity_distribution(vulnerabilities)
|
||||
assert distribution["critical"] == 1
|
||||
assert distribution["high"] == 2
|
||||
assert distribution["medium"] == 1
|
||||
assert distribution["low"] == 1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_estimate_scan_time_basic():
|
||||
"""Test estimate scan time for basic scan"""
|
||||
time = estimate_scan_time("basic")
|
||||
assert time == "1-2 minutes"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_estimate_scan_time_comprehensive():
|
||||
"""Test estimate scan time for comprehensive scan"""
|
||||
time = estimate_scan_time("comprehensive")
|
||||
assert time == "5-10 minutes"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_estimate_scan_time_deep():
|
||||
"""Test estimate scan time for deep scan"""
|
||||
time = estimate_scan_time("deep")
|
||||
assert time == "15-30 minutes"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_estimate_scan_time_unknown():
|
||||
"""Test estimate scan time for unknown scan type"""
|
||||
time = estimate_scan_time("unknown")
|
||||
assert time == "5-10 minutes"
|
||||
1292
examples/stubs/plugin-service/poetry.lock
generated
Normal file
1292
examples/stubs/plugin-service/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
25
examples/stubs/plugin-service/pyproject.toml
Normal file
25
examples/stubs/plugin-service/pyproject.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[tool.poetry]
|
||||
name = "plugin-service"
|
||||
version = "0.1.0"
|
||||
description = "AITBC Plugin Service for plugin registration, marketplace, and analytics"
|
||||
authors = ["AITBC Team"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.13"
|
||||
fastapi = ">=0.115.6"
|
||||
uvicorn = {extras = ["standard"], version = "^0.32.0"}
|
||||
sqlmodel = "^0.0.37"
|
||||
sqlalchemy = "^2.0.25"
|
||||
pydantic = "^2.6.0"
|
||||
pydantic-settings = "^2.1.0"
|
||||
httpx = ">=0.28.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = ">=9.0.3"
|
||||
pytest-asyncio = ">=1.3.0"
|
||||
black = ">=26.3.1"
|
||||
ruff = "^0.1.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
71
examples/stubs/plugin-service/src/plugin_service/main.py
Normal file
71
examples/stubs/plugin-service/src/plugin_service/main.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Plugin Service for plugin registration, marketplace, and analytics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
app = FastAPI(
|
||||
title="AITBC Plugin Service",
|
||||
description="Plugin registration, marketplace, and analytics service",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "service": "plugin-service"}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
return {
|
||||
"service": "AITBC Plugin Service",
|
||||
"version": "1.0.0",
|
||||
"status": "operational"
|
||||
}
|
||||
|
||||
|
||||
@app.post("/register")
|
||||
async def register_plugin(request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Register a new plugin"""
|
||||
return {
|
||||
"plugin_id": "plugin_123",
|
||||
"name": request.get("name", "unknown"),
|
||||
"version": request.get("version", "1.0.0"),
|
||||
"status": "registered"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/marketplace/plugins")
|
||||
async def list_marketplace_plugins() -> dict[str, Any]:
|
||||
"""List all plugins in marketplace"""
|
||||
return {
|
||||
"plugins": [
|
||||
{"id": "plugin_1", "name": "GPU Optimizer", "version": "1.0.0", "category": "performance"},
|
||||
{"id": "plugin_2", "name": "Analytics Dashboard", "version": "2.0.0", "category": "analytics"},
|
||||
],
|
||||
"total": 2
|
||||
}
|
||||
|
||||
|
||||
@app.get("/analytics/plugins")
|
||||
async def get_plugin_analytics() -> dict[str, Any]:
|
||||
"""Get plugin analytics data"""
|
||||
return {
|
||||
"total_plugins": 2,
|
||||
"active_installs": 150,
|
||||
"downloads": 500,
|
||||
"popular_categories": ["performance", "analytics"]
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8109)
|
||||
238
examples/stubs/simple-explorer/main.py
Normal file
238
examples/stubs/simple-explorer/main.py
Normal file
@@ -0,0 +1,238 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple AITBC Blockchain Explorer - Demonstrating the issues described in the analysis
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import HTMLResponse
|
||||
import uvicorn
|
||||
|
||||
from aitbc.network.http_client import AsyncAITBCHTTPClient
|
||||
from aitbc.aitbc_logging import get_logger
|
||||
from aitbc.exceptions import NetworkError
|
||||
|
||||
app = FastAPI(title="Simple AITBC Explorer", version="0.1.0")
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Configuration
|
||||
BLOCKCHAIN_RPC_URL = "http://localhost:8025"
|
||||
|
||||
# Validation patterns for user inputs to prevent SSRF
|
||||
TX_HASH_PATTERN = re.compile(r'^[a-fA-F0-9]{64}$') # 64-character hex string for transaction hash
|
||||
|
||||
|
||||
def validate_tx_hash(tx_hash: str) -> bool:
|
||||
"""Validate transaction hash to prevent SSRF"""
|
||||
if not tx_hash:
|
||||
return False
|
||||
# Check for path traversal or URL manipulation
|
||||
if any(char in tx_hash for char in ['/', '\\', '..', '\n', '\r', '\t', '?', '&']):
|
||||
return False
|
||||
# Validate against hash pattern
|
||||
return bool(TX_HASH_PATTERN.match(tx_hash))
|
||||
|
||||
# HTML Template with the problematic frontend
|
||||
HTML_TEMPLATE = """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Simple AITBC Explorer</title>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
</head>
|
||||
<body class="bg-gray-50">
|
||||
<div class="container mx-auto px-4 py-8">
|
||||
<h1 class="text-3xl font-bold mb-8">AITBC Blockchain Explorer</h1>
|
||||
|
||||
<!-- Search -->
|
||||
<div class="bg-white rounded-lg shadow p-6 mb-8">
|
||||
<h2 class="text-xl font-semibold mb-4">Search</h2>
|
||||
<div class="flex space-x-4">
|
||||
<input type="text" id="search-input" placeholder="Search by transaction hash (64 chars)"
|
||||
class="flex-1 px-4 py-2 border rounded-lg">
|
||||
<button onclick="performSearch()" class="bg-blue-600 text-white px-6 py-2 rounded-lg">
|
||||
Search
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Results -->
|
||||
<div id="results" class="hidden bg-white rounded-lg shadow p-6">
|
||||
<h2 class="text-xl font-semibold mb-4">Transaction Details</h2>
|
||||
<div id="tx-details"></div>
|
||||
</div>
|
||||
|
||||
<!-- Latest Blocks -->
|
||||
<div class="bg-white rounded-lg shadow p-6">
|
||||
<h2 class="text-xl font-semibold mb-4">Latest Blocks</h2>
|
||||
<div id="blocks-list"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Problem 1: Frontend calls /api/transactions/{hash} but backend doesn't have it
|
||||
async function performSearch() {
|
||||
const query = document.getElementById('search-input').value.trim();
|
||||
if (!query) return;
|
||||
|
||||
if (/^[a-fA-F0-9]{64}$/.test(query)) {
|
||||
try {
|
||||
const tx = await fetch(`/api/transactions/${query}`).then(r => {
|
||||
if (!r.ok) throw new Error('Transaction not found');
|
||||
return r.json();
|
||||
});
|
||||
showTransactionDetails(tx);
|
||||
} catch (error) {
|
||||
alert('Transaction not found');
|
||||
}
|
||||
} else {
|
||||
alert('Please enter a valid 64-character hex transaction hash');
|
||||
}
|
||||
}
|
||||
|
||||
// Problem 2: UI expects tx.hash, tx.from, tx.to, tx.amount, tx.fee
|
||||
// But RPC returns tx_hash, sender, recipient, payload, created_at
|
||||
function showTransactionDetails(tx) {
|
||||
const resultsDiv = document.getElementById('results');
|
||||
const detailsDiv = document.getElementById('tx-details');
|
||||
|
||||
detailsDiv.innerHTML = `
|
||||
<div class="space-y-4">
|
||||
<div><strong>Hash:</strong> ${tx.hash || 'N/A'}</div>
|
||||
<div><strong>From:</strong> ${tx.from || 'N/A'}</div>
|
||||
<div><strong>To:</strong> ${tx.to || 'N/A'}</div>
|
||||
<div><strong>Amount:</strong> ${tx.amount || 'N/A'}</div>
|
||||
<div><strong>Fee:</strong> ${tx.fee || 'N/A'}</div>
|
||||
<div><strong>Timestamp:</strong> ${formatTimestamp(tx.timestamp)}</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
resultsDiv.classList.remove('hidden');
|
||||
}
|
||||
|
||||
// Problem 3: formatTimestamp now handles both numeric and ISO string timestamps
|
||||
function formatTimestamp(timestamp) {
|
||||
if (!timestamp) return 'N/A';
|
||||
|
||||
// Handle ISO string timestamps
|
||||
if (typeof timestamp === 'string') {
|
||||
try {
|
||||
return new Date(timestamp).toLocaleString();
|
||||
} catch (e) {
|
||||
return 'Invalid timestamp';
|
||||
}
|
||||
}
|
||||
|
||||
// Handle numeric timestamps (Unix seconds)
|
||||
if (typeof timestamp === 'number') {
|
||||
try {
|
||||
return new Date(timestamp * 1000).toLocaleString();
|
||||
} catch (e) {
|
||||
return 'Invalid timestamp';
|
||||
}
|
||||
}
|
||||
|
||||
return 'Invalid timestamp format';
|
||||
}
|
||||
|
||||
// Load latest blocks
|
||||
async function loadBlocks() {
|
||||
try {
|
||||
const head = await fetch('/api/chain/head').then(r => r.json());
|
||||
const blocksList = document.getElementById('blocks-list');
|
||||
|
||||
let html = '<div class="space-y-4">';
|
||||
for (let i = 0; i < 5 && head.height - i >= 0; i++) {
|
||||
const block = await fetch(`/api/blocks/${head.height - i}`).then(r => r.json());
|
||||
html += `
|
||||
<div class="border rounded p-4">
|
||||
<div><strong>Height:</strong> ${block.height}</div>
|
||||
<div><strong>Hash:</strong> ${block.hash ? block.hash.substring(0, 16) + '...' : 'N/A'}</div>
|
||||
<div><strong>Time:</strong> ${formatTimestamp(block.timestamp)}</div>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
html += '</div>';
|
||||
blocksList.innerHTML = html;
|
||||
} catch (error) {
|
||||
console.error('Failed to load blocks:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
loadBlocks();
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
# Problem 1: Only /api/chain/head and /api/blocks/{height} defined, missing /api/transactions/{hash}
|
||||
@app.get("/api/chain/head")
|
||||
async def get_chain_head():
|
||||
"""Get current chain head"""
|
||||
try:
|
||||
client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10)
|
||||
response = await client.async_get("/rpc/head")
|
||||
if response:
|
||||
return response
|
||||
except NetworkError as e:
|
||||
logger.error(f"Error getting chain head: {e}")
|
||||
return {"height": 0, "hash": "", "timestamp": None}
|
||||
|
||||
@app.get("/api/blocks/{height}")
|
||||
async def get_block(height: int):
|
||||
"""Get block by height"""
|
||||
# Validate height is non-negative and reasonable
|
||||
if height < 0 or height > 10000000:
|
||||
return {"height": height, "hash": "", "timestamp": None, "transactions": []}
|
||||
try:
|
||||
client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10)
|
||||
response = await client.async_get(f"/rpc/blocks/{height}")
|
||||
if response:
|
||||
return response
|
||||
except NetworkError as e:
|
||||
logger.error(f"Error getting block: {e}")
|
||||
return {"height": height, "hash": "", "timestamp": None, "transactions": []}
|
||||
|
||||
@app.get("/api/transactions/{tx_hash}")
|
||||
async def get_transaction(tx_hash: str):
|
||||
"""Get transaction by hash - Problem 1: This endpoint was missing"""
|
||||
if not validate_tx_hash(tx_hash):
|
||||
return {"hash": tx_hash, "from": "unknown", "to": "unknown", "amount": 0, "timestamp": None}
|
||||
try:
|
||||
client = AsyncAITBCHTTPClient(base_url=BLOCKCHAIN_RPC_URL, timeout=10)
|
||||
response = await client.async_get(f"/rpc/tx/{tx_hash}")
|
||||
if response:
|
||||
# Problem 2: Map RPC schema to UI schema
|
||||
return {
|
||||
"hash": response.get("tx_hash", tx_hash), # tx_hash -> hash
|
||||
"from": response.get("sender", "unknown"), # sender -> from
|
||||
"to": response.get("recipient", "unknown"), # recipient -> to
|
||||
"amount": response.get("payload", {}).get("value", "0"), # payload.value -> amount
|
||||
"fee": response.get("payload", {}).get("fee", "0"), # payload.fee -> fee
|
||||
"timestamp": response.get("created_at"), # created_at -> timestamp
|
||||
"block_height": response.get("block_height", "pending")
|
||||
}
|
||||
except NetworkError as e:
|
||||
logger.error(f"Error getting transaction {tx_hash}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch transaction: {str(e)}")
|
||||
|
||||
# Missing: @app.get("/api/transactions/{tx_hash}") - THIS IS THE PROBLEM
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def root():
|
||||
"""Serve the explorer UI"""
|
||||
return HTML_TEMPLATE
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="127.0.0.1", port=8017)
|
||||
1
examples/stubs/simple-explorer/tests/__init__.py
Normal file
1
examples/stubs/simple-explorer/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Simple explorer service tests"""
|
||||
@@ -0,0 +1,221 @@
|
||||
"""Edge case and error handling tests for simple explorer service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# Mock httpx before importing
|
||||
sys.modules['httpx'] = Mock()
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_transaction_missing_fields():
|
||||
"""Test transaction mapping with missing fields"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"tx_hash": "0x" + "a" * 64,
|
||||
# Missing sender, recipient, payload
|
||||
"created_at": "2026-01-01T00:00:00"
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/transactions/" + "a" * 64)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["from"] == "unknown"
|
||||
assert data["to"] == "unknown"
|
||||
assert data["amount"] == "0"
|
||||
assert data["fee"] == "0"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_transaction_empty_payload():
|
||||
"""Test transaction mapping with empty payload"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"tx_hash": "0x" + "a" * 64,
|
||||
"sender": "0xsender",
|
||||
"recipient": "0xrecipient",
|
||||
"payload": {},
|
||||
"created_at": "2026-01-01T00:00:00"
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/transactions/" + "a" * 64)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["amount"] == "0"
|
||||
assert data["fee"] == "0"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_transaction_missing_created_at():
|
||||
"""Test transaction mapping with missing created_at"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"tx_hash": "0x" + "a" * 64,
|
||||
"sender": "0xsender",
|
||||
"recipient": "0xrecipient",
|
||||
"payload": {"value": "1000", "fee": "10"}
|
||||
# Missing created_at
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/transactions/" + "a" * 64)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["timestamp"] is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_transaction_missing_block_height():
|
||||
"""Test transaction mapping with missing block_height"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"tx_hash": "0x" + "a" * 64,
|
||||
"sender": "0xsender",
|
||||
"recipient": "0xrecipient",
|
||||
"payload": {"value": "1000", "fee": "10"},
|
||||
"created_at": "2026-01-01T00:00:00"
|
||||
# Missing block_height
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/transactions/" + "a" * 64)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["block_height"] == "pending"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_block_negative_height():
|
||||
"""Test /api/blocks/{height} with negative height"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"height": -1,
|
||||
"hash": "0xblock",
|
||||
"timestamp": 1234567890,
|
||||
"transactions": []
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/blocks/-1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["height"] == -1
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_block_zero_height():
|
||||
"""Test /api/blocks/{height} with zero height"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"height": 0,
|
||||
"hash": "0xgenesis",
|
||||
"timestamp": 1234567890,
|
||||
"transactions": []
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/blocks/0")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["height"] == 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_transaction_short_hash():
|
||||
"""Test /api/transactions/{tx_hash} with short hash"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"tx_hash": "0x" + "a" * 64,
|
||||
"sender": "0xsender",
|
||||
"recipient": "0xrecipient",
|
||||
"payload": {"value": "1000", "fee": "10"},
|
||||
"created_at": "2026-01-01T00:00:00",
|
||||
"block_height": 100
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/transactions/abc")
|
||||
assert response.status_code in [200, 404, 500] # Any valid response
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_transaction_invalid_hex_hash():
|
||||
"""Test /api/transactions/{tx_hash} with invalid hex characters"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"tx_hash": "0x" + "a" * 64,
|
||||
"sender": "0xsender",
|
||||
"recipient": "0xrecipient",
|
||||
"payload": {"value": "1000", "fee": "10"},
|
||||
"created_at": "2026-01-01T00:00:00",
|
||||
"block_height": 100
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/transactions/" + "z" * 64)
|
||||
assert response.status_code in [200, 404, 500]
|
||||
@@ -0,0 +1,170 @@
|
||||
"""Integration tests for simple explorer service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# Mock httpx before importing
|
||||
sys.modules['httpx'] = Mock()
|
||||
|
||||
from main import app
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_root_endpoint():
|
||||
"""Test root endpoint serves HTML"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
assert "AITBC Blockchain Explorer" in response.text
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_chain_head_success():
|
||||
"""Test /api/chain/head endpoint with successful response"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"height": 100, "hash": "0xabc123", "timestamp": 1234567890}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/chain/head")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["height"] == 100
|
||||
assert data["hash"] == "0xabc123"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_chain_head_error():
|
||||
"""Test /api/chain/head endpoint with error"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.side_effect = Exception("RPC error")
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/chain/head")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["height"] == 0
|
||||
assert data["hash"] == ""
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_block_success():
|
||||
"""Test /api/blocks/{height} endpoint with successful response"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"height": 50,
|
||||
"hash": "0xblock50",
|
||||
"timestamp": 1234567890,
|
||||
"transactions": []
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/blocks/50")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["height"] == 50
|
||||
assert data["hash"] == "0xblock50"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_block_error():
|
||||
"""Test /api/blocks/{height} endpoint with error"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.side_effect = Exception("RPC error")
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/blocks/50")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["height"] == 50
|
||||
assert data["hash"] == ""
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_transaction_success():
|
||||
"""Test /api/transactions/{tx_hash} endpoint with successful response"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"tx_hash": "0x" + "a" * 64,
|
||||
"sender": "0xsender",
|
||||
"recipient": "0xrecipient",
|
||||
"payload": {
|
||||
"value": "1000",
|
||||
"fee": "10"
|
||||
},
|
||||
"created_at": "2026-01-01T00:00:00",
|
||||
"block_height": 100
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/transactions/" + "a" * 64)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["hash"] == "0x" + "a" * 64
|
||||
assert data["from"] == "0xsender"
|
||||
assert data["to"] == "0xrecipient"
|
||||
assert data["amount"] == "1000"
|
||||
assert data["fee"] == "10"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_transaction_not_found():
|
||||
"""Test /api/transactions/{tx_hash} endpoint with 404 response"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/transactions/" + "a" * 64)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_transaction_error():
|
||||
"""Test /api/transactions/{tx_hash} endpoint with error"""
|
||||
client = TestClient(app)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.get.side_effect = Exception("RPC error")
|
||||
|
||||
with patch('main.httpx.AsyncClient', return_value=mock_client):
|
||||
response = client.get("/api/transactions/" + "a" * 64)
|
||||
assert response.status_code == 500
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Unit tests for simple explorer service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# Mock httpx before importing
|
||||
sys.modules['httpx'] = Mock()
|
||||
|
||||
from main import app, BLOCKCHAIN_RPC_URL, HTML_TEMPLATE
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_app_initialization():
|
||||
"""Test that the FastAPI app initializes correctly"""
|
||||
assert app is not None
|
||||
assert app.title == "Simple AITBC Explorer"
|
||||
assert app.version == "0.1.0"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_blockchain_rpc_url():
|
||||
"""Test that the blockchain RPC URL is configured"""
|
||||
assert BLOCKCHAIN_RPC_URL == "http://localhost:8025"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_html_template_exists():
|
||||
"""Test that the HTML template is defined"""
|
||||
assert HTML_TEMPLATE is not None
|
||||
assert "<!DOCTYPE html>" in HTML_TEMPLATE
|
||||
assert "AITBC Blockchain Explorer" in HTML_TEMPLATE
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_html_template_has_search():
|
||||
"""Test that the HTML template has search functionality"""
|
||||
assert "search-input" in HTML_TEMPLATE
|
||||
assert "performSearch()" in HTML_TEMPLATE
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_html_template_has_blocks_section():
|
||||
"""Test that the HTML template has blocks section"""
|
||||
assert "Latest Blocks" in HTML_TEMPLATE
|
||||
assert "blocks-list" in HTML_TEMPLATE
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_html_template_has_results_section():
|
||||
"""Test that the HTML template has results section"""
|
||||
assert "Transaction Details" in HTML_TEMPLATE
|
||||
assert "tx-details" in HTML_TEMPLATE
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_html_template_has_tailwind():
|
||||
"""Test that the HTML template includes Tailwind CSS"""
|
||||
assert "tailwindcss" in HTML_TEMPLATE
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_html_template_format_timestamp_function():
|
||||
"""Test that the HTML template has formatTimestamp function"""
|
||||
assert "formatTimestamp" in HTML_TEMPLATE
|
||||
assert "toLocaleString" in HTML_TEMPLATE
|
||||
584
examples/stubs/trading-engine/main.py
Executable file
584
examples/stubs/trading-engine/main.py
Executable file
@@ -0,0 +1,584 @@
|
||||
"""
|
||||
Production Trading Engine for AITBC
|
||||
Handles order matching, trade execution, and settlement
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict, deque
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
logger.info("Starting AITBC Trading Engine")
|
||||
# Start background market simulation
|
||||
asyncio.create_task(simulate_market_activity())
|
||||
yield
|
||||
# Shutdown
|
||||
logger.info("Shutting down AITBC Trading Engine")
|
||||
|
||||
app = FastAPI(
|
||||
title="AITBC Trading Engine",
|
||||
description="High-performance order matching and trade execution",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Data models
|
||||
class Order(BaseModel):
|
||||
order_id: str
|
||||
symbol: str
|
||||
side: str # buy/sell
|
||||
type: str # market/limit
|
||||
quantity: float
|
||||
price: Optional[float] = None
|
||||
user_id: str
|
||||
timestamp: datetime
|
||||
|
||||
class Trade(BaseModel):
|
||||
trade_id: str
|
||||
symbol: str
|
||||
buy_order_id: str
|
||||
sell_order_id: str
|
||||
quantity: float
|
||||
price: float
|
||||
timestamp: datetime
|
||||
|
||||
class OrderBookEntry(BaseModel):
|
||||
price: float
|
||||
quantity: float
|
||||
orders_count: int
|
||||
|
||||
# In-memory order books (in production, use more sophisticated data structures)
|
||||
order_books: Dict[str, Dict] = {}
|
||||
orders: Dict[str, Dict] = {}
|
||||
trades: Dict[str, Dict] = {}
|
||||
market_data: Dict[str, Dict] = {}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"service": "AITBC Trading Engine",
|
||||
"status": "running",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"active_order_books": len(order_books),
|
||||
"total_orders": len(orders),
|
||||
"total_trades": len(trades),
|
||||
"uptime": "running"
|
||||
}
|
||||
|
||||
@app.post("/api/v1/orders/submit")
|
||||
async def submit_order(order: Order):
|
||||
"""Submit a new order to the trading engine"""
|
||||
symbol = order.symbol
|
||||
|
||||
# Initialize order book if not exists
|
||||
if symbol not in order_books:
|
||||
order_books[symbol] = {
|
||||
"bids": defaultdict(list), # buy orders
|
||||
"asks": defaultdict(list), # sell orders
|
||||
"last_price": None,
|
||||
"volume_24h": 0.0,
|
||||
"high_24h": None,
|
||||
"low_24h": None,
|
||||
"created_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Store order
|
||||
order_data = {
|
||||
"order_id": order.order_id,
|
||||
"symbol": order.symbol,
|
||||
"side": order.side,
|
||||
"type": order.type,
|
||||
"quantity": order.quantity,
|
||||
"remaining_quantity": order.quantity,
|
||||
"price": order.price,
|
||||
"user_id": order.user_id,
|
||||
"timestamp": order.timestamp.isoformat(),
|
||||
"status": "open",
|
||||
"filled_quantity": 0.0,
|
||||
"average_price": None
|
||||
}
|
||||
|
||||
orders[order.order_id] = order_data
|
||||
|
||||
# Process order
|
||||
trades_executed = await process_order(order_data)
|
||||
|
||||
logger.info(f"Order submitted: {order.order_id} - {order.side} {order.quantity} {order.symbol}")
|
||||
|
||||
return {
|
||||
"order_id": order.order_id,
|
||||
"status": order_data["status"],
|
||||
"filled_quantity": order_data["filled_quantity"],
|
||||
"remaining_quantity": order_data["remaining_quantity"],
|
||||
"trades_executed": len(trades_executed),
|
||||
"average_price": order_data["average_price"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/orders/{order_id}")
|
||||
async def get_order(order_id: str):
|
||||
"""Get order details"""
|
||||
if order_id not in orders:
|
||||
raise HTTPException(status_code=404, detail="Order not found")
|
||||
|
||||
return orders[order_id]
|
||||
|
||||
@app.get("/api/v1/orders")
|
||||
async def list_orders():
|
||||
"""List all orders"""
|
||||
return {
|
||||
"orders": list(orders.values()),
|
||||
"total_orders": len(orders),
|
||||
"open_orders": len([o for o in orders.values() if o["status"] == "open"]),
|
||||
"filled_orders": len([o for o in orders.values() if o["status"] == "filled"])
|
||||
}
|
||||
|
||||
@app.get("/api/v1/orderbook/{symbol}")
|
||||
async def get_order_book(symbol: str, depth: int = 10):
|
||||
"""Get order book for a trading pair"""
|
||||
if symbol not in order_books:
|
||||
raise HTTPException(status_code=404, detail="Order book not found")
|
||||
|
||||
book = order_books[symbol]
|
||||
|
||||
# Get best bids and asks
|
||||
bids = sorted(book["bids"].items(), reverse=True)[:depth]
|
||||
asks = sorted(book["asks"].items())[:depth]
|
||||
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"bids": [
|
||||
{
|
||||
"price": price,
|
||||
"quantity": sum(order["remaining_quantity"] for order in orders_list),
|
||||
"orders_count": len(orders_list)
|
||||
}
|
||||
for price, orders_list in bids
|
||||
],
|
||||
"asks": [
|
||||
{
|
||||
"price": price,
|
||||
"quantity": sum(order["remaining_quantity"] for order in orders_list),
|
||||
"orders_count": len(orders_list)
|
||||
}
|
||||
for price, orders_list in asks
|
||||
],
|
||||
"last_price": book["last_price"],
|
||||
"volume_24h": book["volume_24h"],
|
||||
"high_24h": book["high_24h"],
|
||||
"low_24h": book["low_24h"],
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/trades")
|
||||
async def list_trades(symbol: Optional[str] = None, limit: int = 100):
|
||||
"""List recent trades"""
|
||||
all_trades = list(trades.values())
|
||||
|
||||
if symbol:
|
||||
all_trades = [t for t in all_trades if t["symbol"] == symbol]
|
||||
|
||||
# Sort by timestamp (most recent first)
|
||||
all_trades.sort(key=lambda x: x["timestamp"], reverse=True)
|
||||
|
||||
return {
|
||||
"trades": all_trades[:limit],
|
||||
"total_trades": len(all_trades)
|
||||
}
|
||||
|
||||
@app.get("/api/v1/ticker/{symbol}")
|
||||
async def get_ticker(symbol: str):
|
||||
"""Get ticker information for a trading pair"""
|
||||
if symbol not in order_books:
|
||||
raise HTTPException(status_code=404, detail="Trading pair not found")
|
||||
|
||||
book = order_books[symbol]
|
||||
|
||||
# Calculate 24h statistics
|
||||
trades_24h = [t for t in trades.values()
|
||||
if t["symbol"] == symbol and
|
||||
datetime.fromisoformat(t["timestamp"]) >
|
||||
datetime.now(timezone.utc) - timedelta(hours=24)]
|
||||
|
||||
if trades_24h:
|
||||
prices = [t["price"] for t in trades_24h]
|
||||
volume = sum(t["quantity"] for t in trades_24h)
|
||||
|
||||
ticker = {
|
||||
"symbol": symbol,
|
||||
"last_price": book["last_price"],
|
||||
"bid_price": max(book["bids"].keys()) if book["bids"] else None,
|
||||
"ask_price": min(book["asks"].keys()) if book["asks"] else None,
|
||||
"high_24h": max(prices),
|
||||
"low_24h": min(prices),
|
||||
"volume_24h": volume,
|
||||
"change_24h": prices[-1] - prices[0] if len(prices) > 1 else 0,
|
||||
"change_percent_24h": ((prices[-1] - prices[0]) / prices[0] * 100) if len(prices) > 1 else 0
|
||||
}
|
||||
else:
|
||||
ticker = {
|
||||
"symbol": symbol,
|
||||
"last_price": book["last_price"],
|
||||
"bid_price": None,
|
||||
"ask_price": None,
|
||||
"high_24h": None,
|
||||
"low_24h": None,
|
||||
"volume_24h": 0.0,
|
||||
"change_24h": 0.0,
|
||||
"change_percent_24h": 0.0
|
||||
}
|
||||
|
||||
return ticker
|
||||
|
||||
@app.delete("/api/v1/orders/{order_id}")
|
||||
async def cancel_order(order_id: str):
|
||||
"""Cancel an order"""
|
||||
if order_id not in orders:
|
||||
raise HTTPException(status_code=404, detail="Order not found")
|
||||
|
||||
order = orders[order_id]
|
||||
|
||||
if order["status"] != "open":
|
||||
raise HTTPException(status_code=400, detail="Order cannot be cancelled")
|
||||
|
||||
# Remove from order book
|
||||
symbol = order["symbol"]
|
||||
if symbol in order_books:
|
||||
book = order_books[symbol]
|
||||
price_key = str(order["price"])
|
||||
|
||||
if order["side"] == "buy" and price_key in book["bids"]:
|
||||
book["bids"][price_key] = [o for o in book["bids"][price_key] if o["order_id"] != order_id]
|
||||
if not book["bids"][price_key]:
|
||||
del book["bids"][price_key]
|
||||
elif order["side"] == "sell" and price_key in book["asks"]:
|
||||
book["asks"][price_key] = [o for o in book["asks"][price_key] if o["order_id"] != order_id]
|
||||
if not book["asks"][price_key]:
|
||||
del book["asks"][price_key]
|
||||
|
||||
# Update order status
|
||||
order["status"] = "cancelled"
|
||||
order["cancelled_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
logger.info(f"Order cancelled: {order_id}")
|
||||
|
||||
return {
|
||||
"order_id": order_id,
|
||||
"status": "cancelled",
|
||||
"cancelled_at": order["cancelled_at"]
|
||||
}
|
||||
|
||||
@app.get("/api/v1/market-data")
|
||||
async def get_market_data():
|
||||
"""Get market data for all symbols"""
|
||||
market_summary = {}
|
||||
|
||||
for symbol, book in order_books.items():
|
||||
trades_24h = [t for t in trades.values()
|
||||
if t["symbol"] == symbol and
|
||||
datetime.fromisoformat(t["timestamp"]) >
|
||||
datetime.now(timezone.utc) - timedelta(hours=24)]
|
||||
|
||||
market_summary[symbol] = {
|
||||
"last_price": book["last_price"],
|
||||
"volume_24h": book["volume_24h"],
|
||||
"high_24h": book["high_24h"],
|
||||
"low_24h": book["low_24h"],
|
||||
"trades_count_24h": len(trades_24h),
|
||||
"bid_price": max(book["bids"].keys()) if book["bids"] else None,
|
||||
"ask_price": min(book["asks"].keys()) if book["asks"] else None
|
||||
}
|
||||
|
||||
return {
|
||||
"market_data": market_summary,
|
||||
"total_symbols": len(market_summary),
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
@app.get("/api/v1/engine/stats")
|
||||
async def get_engine_stats():
|
||||
"""Get trading engine statistics"""
|
||||
total_orders = len(orders)
|
||||
total_trades = len(trades)
|
||||
total_volume = sum(t["quantity"] * t["price"] for t in trades.values())
|
||||
|
||||
orders_by_status = defaultdict(int)
|
||||
for order in orders.values():
|
||||
orders_by_status[order["status"]] += 1
|
||||
|
||||
trades_by_symbol = defaultdict(int)
|
||||
for trade in trades.values():
|
||||
trades_by_symbol[trade["symbol"]] += 1
|
||||
|
||||
return {
|
||||
"engine_stats": {
|
||||
"total_orders": total_orders,
|
||||
"total_trades": total_trades,
|
||||
"total_volume": total_volume,
|
||||
"orders_by_status": dict(orders_by_status),
|
||||
"trades_by_symbol": dict(trades_by_symbol),
|
||||
"active_order_books": len(order_books),
|
||||
"uptime": "running"
|
||||
},
|
||||
"generated_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
# Core trading engine logic
|
||||
async def process_order(order: Dict) -> List[Dict]:
|
||||
"""Process an order and execute trades"""
|
||||
symbol = order["symbol"]
|
||||
book = order_books[symbol]
|
||||
trades_executed = []
|
||||
|
||||
if order["type"] == "market":
|
||||
trades_executed = await process_market_order(order, book)
|
||||
else:
|
||||
trades_executed = await process_limit_order(order, book)
|
||||
|
||||
# Update market data
|
||||
update_market_data(symbol, trades_executed)
|
||||
|
||||
return trades_executed
|
||||
|
||||
async def process_market_order(order: Dict, book: Dict) -> List[Dict]:
|
||||
"""Process a market order"""
|
||||
trades_executed = []
|
||||
|
||||
if order["side"] == "buy":
|
||||
# Match against asks (sell orders)
|
||||
ask_prices = sorted(book["asks"].keys())
|
||||
|
||||
for price in ask_prices:
|
||||
if order["remaining_quantity"] <= 0:
|
||||
break
|
||||
|
||||
orders_at_price = book["asks"][price][:]
|
||||
for matching_order in orders_at_price:
|
||||
if order["remaining_quantity"] <= 0:
|
||||
break
|
||||
|
||||
trade = await execute_trade(order, matching_order, price)
|
||||
if trade:
|
||||
trades_executed.append(trade)
|
||||
|
||||
else: # sell order
|
||||
# Match against bids (buy orders)
|
||||
bid_prices = sorted(book["bids"].keys(), reverse=True)
|
||||
|
||||
for price in bid_prices:
|
||||
if order["remaining_quantity"] <= 0:
|
||||
break
|
||||
|
||||
orders_at_price = book["bids"][price][:]
|
||||
for matching_order in orders_at_price:
|
||||
if order["remaining_quantity"] <= 0:
|
||||
break
|
||||
|
||||
trade = await execute_trade(order, matching_order, price)
|
||||
if trade:
|
||||
trades_executed.append(trade)
|
||||
|
||||
return trades_executed
|
||||
|
||||
async def process_limit_order(order: Dict, book: Dict) -> List[Dict]:
|
||||
"""Process a limit order"""
|
||||
trades_executed = []
|
||||
|
||||
if order["side"] == "buy":
|
||||
# Match against asks at or below the limit price
|
||||
ask_prices = sorted([p for p in book["asks"].keys() if float(p) <= order["price"]])
|
||||
|
||||
for price in ask_prices:
|
||||
if order["remaining_quantity"] <= 0:
|
||||
break
|
||||
|
||||
orders_at_price = book["asks"][price][:]
|
||||
for matching_order in orders_at_price:
|
||||
if order["remaining_quantity"] <= 0:
|
||||
break
|
||||
|
||||
trade = await execute_trade(order, matching_order, price)
|
||||
if trade:
|
||||
trades_executed.append(trade)
|
||||
|
||||
# Add remaining quantity to order book
|
||||
if order["remaining_quantity"] > 0:
|
||||
price_key = str(order["price"])
|
||||
book["bids"][price_key].append(order)
|
||||
|
||||
else: # sell order
|
||||
# Match against bids at or above the limit price
|
||||
bid_prices = sorted([p for p in book["bids"].keys() if float(p) >= order["price"]], reverse=True)
|
||||
|
||||
for price in bid_prices:
|
||||
if order["remaining_quantity"] <= 0:
|
||||
break
|
||||
|
||||
orders_at_price = book["bids"][price][:]
|
||||
for matching_order in orders_at_price:
|
||||
if order["remaining_quantity"] <= 0:
|
||||
break
|
||||
|
||||
trade = await execute_trade(order, matching_order, price)
|
||||
if trade:
|
||||
trades_executed.append(trade)
|
||||
|
||||
# Add remaining quantity to order book
|
||||
if order["remaining_quantity"] > 0:
|
||||
price_key = str(order["price"])
|
||||
book["asks"][price_key].append(order)
|
||||
|
||||
return trades_executed
|
||||
|
||||
async def execute_trade(order1: Dict, order2: Dict, price: float) -> Optional[Dict]:
|
||||
"""Execute a trade between two orders"""
|
||||
# Determine trade quantity
|
||||
trade_quantity = min(order1["remaining_quantity"], order2["remaining_quantity"])
|
||||
|
||||
if trade_quantity <= 0:
|
||||
return None
|
||||
|
||||
# Create trade record
|
||||
trade_id = f"trade_{int(datetime.now(timezone.utc).timestamp())}_{len(trades)}"
|
||||
|
||||
trade = {
|
||||
"trade_id": trade_id,
|
||||
"symbol": order1["symbol"],
|
||||
"buy_order_id": order1["order_id"] if order1["side"] == "buy" else order2["order_id"],
|
||||
"sell_order_id": order2["order_id"] if order2["side"] == "sell" else order1["order_id"],
|
||||
"quantity": trade_quantity,
|
||||
"price": price,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
trades[trade_id] = trade
|
||||
|
||||
# Update orders
|
||||
for order in [order1, order2]:
|
||||
order["filled_quantity"] += trade_quantity
|
||||
order["remaining_quantity"] -= trade_quantity
|
||||
|
||||
if order["remaining_quantity"] <= 0:
|
||||
order["status"] = "filled"
|
||||
order["filled_at"] = trade["timestamp"]
|
||||
else:
|
||||
order["status"] = "partially_filled"
|
||||
|
||||
# Update average price
|
||||
if order["average_price"] is None:
|
||||
order["average_price"] = price
|
||||
else:
|
||||
total_value = order["average_price"] * (order["filled_quantity"] - trade_quantity) + price * trade_quantity
|
||||
order["average_price"] = total_value / order["filled_quantity"]
|
||||
|
||||
# Remove filled orders from order book
|
||||
symbol = order1["symbol"]
|
||||
book = order_books[symbol]
|
||||
price_key = str(price)
|
||||
|
||||
for order in [order1, order2]:
|
||||
if order["remaining_quantity"] <= 0:
|
||||
if order["side"] == "buy" and price_key in book["bids"]:
|
||||
book["bids"][price_key] = [o for o in book["bids"][price_key] if o["order_id"] != order["order_id"]]
|
||||
if not book["bids"][price_key]:
|
||||
del book["bids"][price_key]
|
||||
elif order["side"] == "sell" and price_key in book["asks"]:
|
||||
book["asks"][price_key] = [o for o in book["asks"][price_key] if o["order_id"] != order["order_id"]]
|
||||
if not book["asks"][price_key]:
|
||||
del book["asks"][price_key]
|
||||
|
||||
logger.info(f"Trade executed: {trade_id} - {trade_quantity} @ {price}")
|
||||
|
||||
return trade
|
||||
|
||||
def update_market_data(symbol: str, trades_executed: List[Dict]):
|
||||
"""Update market data after trades"""
|
||||
if not trades_executed:
|
||||
return
|
||||
|
||||
book = order_books[symbol]
|
||||
|
||||
# Update last price
|
||||
last_trade = trades_executed[-1]
|
||||
book["last_price"] = last_trade["price"]
|
||||
|
||||
# Update 24h high/low
|
||||
trades_24h = [t for t in trades.values()
|
||||
if t["symbol"] == symbol and
|
||||
datetime.fromisoformat(t["timestamp"]) >
|
||||
datetime.now(timezone.utc) - timedelta(hours=24)]
|
||||
|
||||
if trades_24h:
|
||||
prices = [t["price"] for t in trades_24h]
|
||||
book["high_24h"] = max(prices)
|
||||
book["low_24h"] = min(prices)
|
||||
book["volume_24h"] = sum(t["quantity"] for t in trades_24h)
|
||||
|
||||
# Background task for market data simulation
|
||||
async def simulate_market_activity():
|
||||
"""Background task to simulate market activity"""
|
||||
while True:
|
||||
await asyncio.sleep(60) # Simulate activity every minute
|
||||
|
||||
# Create some random market orders for demo
|
||||
if len(order_books) > 0:
|
||||
import random
|
||||
|
||||
for symbol in list(order_books.keys())[:3]: # Limit to 3 symbols
|
||||
if random.random() < 0.3: # 30% chance of market activity
|
||||
# Create random market order
|
||||
side = random.choice(["buy", "sell"])
|
||||
quantity = random.uniform(10, 1000)
|
||||
|
||||
order_id = f"sim_order_{int(datetime.now(timezone.utc).timestamp())}"
|
||||
order = Order(
|
||||
order_id=order_id,
|
||||
symbol=symbol,
|
||||
side=side,
|
||||
type="market",
|
||||
quantity=quantity,
|
||||
user_id="sim_user",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
order_data = {
|
||||
"order_id": order.order_id,
|
||||
"symbol": order.symbol,
|
||||
"side": order.side,
|
||||
"type": order.type,
|
||||
"quantity": order.quantity,
|
||||
"remaining_quantity": order.quantity,
|
||||
"price": order.price,
|
||||
"user_id": order.user_id,
|
||||
"timestamp": order.timestamp.isoformat(),
|
||||
"status": "open",
|
||||
"filled_quantity": 0.0,
|
||||
"average_price": None
|
||||
}
|
||||
|
||||
orders[order_id] = order_data
|
||||
await process_order(order_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=os.getenv("BIND_HOST", "127.0.0.1"), port=8012, log_level="info")
|
||||
1
examples/stubs/trading-engine/tests/__init__.py
Normal file
1
examples/stubs/trading-engine/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Trading engine service tests"""
|
||||
@@ -0,0 +1,208 @@
|
||||
"""Edge case and error handling tests for trading engine service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, Order, order_books, orders, trades
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
order_books.clear()
|
||||
orders.clear()
|
||||
trades.clear()
|
||||
yield
|
||||
order_books.clear()
|
||||
orders.clear()
|
||||
trades.clear()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_order_zero_quantity():
|
||||
"""Test Order with zero quantity"""
|
||||
order = Order(
|
||||
order_id="order_123",
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=0.0,
|
||||
price=0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert order.quantity == 0.0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_order_negative_quantity():
|
||||
"""Test Order with negative quantity"""
|
||||
order = Order(
|
||||
order_id="order_123",
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=-100.0,
|
||||
price=0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert order.quantity == -100.0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_order_negative_price():
|
||||
"""Test Order with negative price"""
|
||||
order = Order(
|
||||
order_id="order_123",
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=-0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert order.price == -0.00001
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_order_empty_symbol():
|
||||
"""Test Order with empty symbol"""
|
||||
order = Order(
|
||||
order_id="order_123",
|
||||
symbol="",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert order.symbol == ""
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_cancel_filled_order():
|
||||
"""Test cancelling a filled order"""
|
||||
client = TestClient(app)
|
||||
order = Order(
|
||||
order_id="order_129",
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
|
||||
|
||||
# Manually mark as filled
|
||||
orders["order_129"]["status"] = "filled"
|
||||
|
||||
response = client.delete("/api/v1/orders/order_129")
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_submit_order_with_slash_in_symbol():
|
||||
"""Test submitting order with slash in symbol"""
|
||||
client = TestClient(app)
|
||||
order = Order(
|
||||
order_id="order_130",
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_submit_order_with_hyphen_in_symbol():
|
||||
"""Test submitting order with hyphen in symbol"""
|
||||
client = TestClient(app)
|
||||
order = Order(
|
||||
order_id="order_131",
|
||||
symbol="AITBC-BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_orders_with_no_orders():
|
||||
"""Test listing orders when no orders exist"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/orders")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_orders"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_trades_with_no_trades():
|
||||
"""Test listing trades when no trades exist"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/trades")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_trades"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_market_data_with_no_symbols():
|
||||
"""Test getting market data when no symbols exist"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/market-data")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_symbols"] == 0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_order_book_depth_parameter():
|
||||
"""Test order book with depth parameter"""
|
||||
client = TestClient(app)
|
||||
order = Order(
|
||||
order_id="order_132",
|
||||
symbol="AITBC-BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
|
||||
|
||||
response = client.get("/api/v1/orderbook/AITBC-BTC?depth=5")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["symbol"] == "AITBC-BTC"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_trades_limit_parameter():
|
||||
"""Test listing trades with limit parameter"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/trades?limit=10")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "trades" in data
|
||||
@@ -0,0 +1,264 @@
|
||||
"""Integration tests for trading engine service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, Order, order_books, orders, trades
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""Reset global state before each test"""
|
||||
order_books.clear()
|
||||
orders.clear()
|
||||
trades.clear()
|
||||
yield
|
||||
order_books.clear()
|
||||
orders.clear()
|
||||
trades.clear()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_root_endpoint():
|
||||
"""Test root endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["service"] == "AITBC Trading Engine"
|
||||
assert data["status"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_health_check_endpoint():
|
||||
"""Test health check endpoint"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "active_order_books" in data
|
||||
assert "total_orders" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_submit_market_order():
|
||||
"""Test submitting a market order"""
|
||||
client = TestClient(app)
|
||||
order = Order(
|
||||
order_id="order_123",
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="market",
|
||||
quantity=100.0,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["order_id"] == "order_123"
|
||||
assert "status" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_submit_limit_order():
|
||||
"""Test submitting a limit order"""
|
||||
client = TestClient(app)
|
||||
order = Order(
|
||||
order_id="order_124",
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
response = client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["order_id"] == "order_124"
|
||||
assert "status" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_order():
|
||||
"""Test getting order details"""
|
||||
client = TestClient(app)
|
||||
order = Order(
|
||||
order_id="order_125",
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
|
||||
|
||||
response = client.get("/api/v1/orders/order_125")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["order_id"] == "order_125"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_order_not_found():
|
||||
"""Test getting nonexistent order"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/orders/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_orders():
|
||||
"""Test listing all orders"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/orders")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "orders" in data
|
||||
assert "total_orders" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_order_book():
|
||||
"""Test getting order book"""
|
||||
client = TestClient(app)
|
||||
# Create some orders first
|
||||
order1 = Order(
|
||||
order_id="order_126",
|
||||
symbol="AITBC-BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/orders/submit", json=order1.model_dump(mode='json'))
|
||||
|
||||
response = client.get("/api/v1/orderbook/AITBC-BTC")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["symbol"] == "AITBC-BTC"
|
||||
assert "bids" in data
|
||||
assert "asks" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_order_book_not_found():
|
||||
"""Test getting order book for nonexistent symbol"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/orderbook/NONEXISTENT")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_trades():
|
||||
"""Test listing trades"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/trades")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "trades" in data
|
||||
assert "total_trades" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_list_trades_by_symbol():
|
||||
"""Test listing trades by symbol"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/trades?symbol=AITBC-BTC")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "trades" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_ticker():
|
||||
"""Test getting ticker information"""
|
||||
client = TestClient(app)
|
||||
# Create order book first
|
||||
order = Order(
|
||||
order_id="order_127",
|
||||
symbol="AITBC-BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
|
||||
|
||||
response = client.get("/api/v1/ticker/AITBC-BTC")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["symbol"] == "AITBC-BTC"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_ticker_not_found():
|
||||
"""Test getting ticker for nonexistent symbol"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/ticker/NONEXISTENT")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_cancel_order():
|
||||
"""Test cancelling an order"""
|
||||
client = TestClient(app)
|
||||
order = Order(
|
||||
order_id="order_128",
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
client.post("/api/v1/orders/submit", json=order.model_dump(mode='json'))
|
||||
|
||||
response = client.delete("/api/v1/orders/order_128")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "cancelled"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_cancel_order_not_found():
|
||||
"""Test cancelling nonexistent order"""
|
||||
client = TestClient(app)
|
||||
response = client.delete("/api/v1/orders/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_market_data():
|
||||
"""Test getting market data"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/market-data")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "market_data" in data
|
||||
assert "total_symbols" in data
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_engine_stats():
|
||||
"""Test getting engine statistics"""
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/engine/stats")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "engine_stats" in data
|
||||
@@ -0,0 +1,89 @@
|
||||
"""Unit tests for trading engine service"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
from main import app, Order, Trade, OrderBookEntry
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_app_initialization():
|
||||
"""Test that the FastAPI app initializes correctly"""
|
||||
assert app is not None
|
||||
assert app.title == "AITBC Trading Engine"
|
||||
assert app.version == "1.0.0"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_order_model():
|
||||
"""Test Order model"""
|
||||
order = Order(
|
||||
order_id="order_123",
|
||||
symbol="AITBC/BTC",
|
||||
side="buy",
|
||||
type="limit",
|
||||
quantity=100.0,
|
||||
price=0.00001,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert order.order_id == "order_123"
|
||||
assert order.symbol == "AITBC/BTC"
|
||||
assert order.side == "buy"
|
||||
assert order.type == "limit"
|
||||
assert order.quantity == 100.0
|
||||
assert order.price == 0.00001
|
||||
assert order.user_id == "user_123"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_order_model_market_order():
|
||||
"""Test Order model for market order"""
|
||||
order = Order(
|
||||
order_id="order_123",
|
||||
symbol="AITBC/BTC",
|
||||
side="sell",
|
||||
type="market",
|
||||
quantity=50.0,
|
||||
user_id="user_123",
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert order.type == "market"
|
||||
assert order.price is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_trade_model():
|
||||
"""Test Trade model"""
|
||||
trade = Trade(
|
||||
trade_id="trade_123",
|
||||
symbol="AITBC/BTC",
|
||||
buy_order_id="buy_order_123",
|
||||
sell_order_id="sell_order_123",
|
||||
quantity=100.0,
|
||||
price=0.00001,
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
assert trade.trade_id == "trade_123"
|
||||
assert trade.symbol == "AITBC/BTC"
|
||||
assert trade.buy_order_id == "buy_order_123"
|
||||
assert trade.sell_order_id == "sell_order_123"
|
||||
assert trade.quantity == 100.0
|
||||
assert trade.price == 0.00001
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_order_book_entry_model():
|
||||
"""Test OrderBookEntry model"""
|
||||
entry = OrderBookEntry(
|
||||
price=0.00001,
|
||||
quantity=1000.0,
|
||||
orders_count=5
|
||||
)
|
||||
assert entry.price == 0.00001
|
||||
assert entry.quantity == 1000.0
|
||||
assert entry.orders_count == 5
|
||||
Reference in New Issue
Block a user