feat: add marketplace metrics, privacy features, and service registry endpoints
- Add Prometheus metrics for marketplace API throughput and error rates with new dashboard panels - Implement confidential transaction models with encryption support and access control - Add key management system with registration, rotation, and audit logging - Create services and registry routers for service discovery and management - Integrate ZK proof generation for privacy-preserving receipts - Add metrics instru
This commit is contained in:
@@ -5,5 +5,7 @@ from .miner import router as miner
|
||||
from .admin import router as admin
|
||||
from .marketplace import router as marketplace
|
||||
from .explorer import router as explorer
|
||||
from .services import router as services
|
||||
from .registry import router as registry
|
||||
|
||||
__all__ = ["client", "miner", "admin", "marketplace", "explorer"]
|
||||
__all__ = ["client", "miner", "admin", "marketplace", "explorer", "services", "registry"]
|
||||
|
||||
423
apps/coordinator-api/src/app/routers/confidential.py
Normal file
423
apps/coordinator-api/src/app/routers/confidential.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
API endpoints for confidential transactions
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
import json
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..models import (
|
||||
ConfidentialTransaction,
|
||||
ConfidentialTransactionCreate,
|
||||
ConfidentialTransactionView,
|
||||
ConfidentialAccessRequest,
|
||||
ConfidentialAccessResponse,
|
||||
KeyRegistrationRequest,
|
||||
KeyRegistrationResponse,
|
||||
AccessLogQuery,
|
||||
AccessLogResponse
|
||||
)
|
||||
from ..services.encryption import EncryptionService, EncryptedData
|
||||
from ..services.key_management import KeyManager, KeyManagementError
|
||||
from ..services.access_control import AccessController
|
||||
from ..auth import get_api_key
|
||||
from ..logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Initialize router and security
|
||||
router = APIRouter(prefix="/confidential", tags=["confidential"])
|
||||
security = HTTPBearer()
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Global instances (in production, inject via DI)
|
||||
encryption_service: Optional[EncryptionService] = None
|
||||
key_manager: Optional[KeyManager] = None
|
||||
access_controller: Optional[AccessController] = None
|
||||
|
||||
|
||||
def get_encryption_service() -> EncryptionService:
|
||||
"""Get encryption service instance"""
|
||||
global encryption_service
|
||||
if encryption_service is None:
|
||||
# Initialize with key manager
|
||||
from ..services.key_management import FileKeyStorage
|
||||
key_storage = FileKeyStorage("/tmp/aitbc_keys")
|
||||
key_manager = KeyManager(key_storage)
|
||||
encryption_service = EncryptionService(key_manager)
|
||||
return encryption_service
|
||||
|
||||
|
||||
def get_key_manager() -> KeyManager:
|
||||
"""Get key manager instance"""
|
||||
global key_manager
|
||||
if key_manager is None:
|
||||
from ..services.key_management import FileKeyStorage
|
||||
key_storage = FileKeyStorage("/tmp/aitbc_keys")
|
||||
key_manager = KeyManager(key_storage)
|
||||
return key_manager
|
||||
|
||||
|
||||
def get_access_controller() -> AccessController:
|
||||
"""Get access controller instance"""
|
||||
global access_controller
|
||||
if access_controller is None:
|
||||
from ..services.access_control import PolicyStore
|
||||
policy_store = PolicyStore()
|
||||
access_controller = AccessController(policy_store)
|
||||
return access_controller
|
||||
|
||||
|
||||
@router.post("/transactions", response_model=ConfidentialTransactionView)
|
||||
async def create_confidential_transaction(
|
||||
request: ConfidentialTransactionCreate,
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Create a new confidential transaction with optional encryption"""
|
||||
try:
|
||||
# Generate transaction ID
|
||||
transaction_id = f"ctx-{datetime.utcnow().timestamp()}"
|
||||
|
||||
# Create base transaction
|
||||
transaction = ConfidentialTransaction(
|
||||
transaction_id=transaction_id,
|
||||
job_id=request.job_id,
|
||||
timestamp=datetime.utcnow(),
|
||||
status="created",
|
||||
amount=request.amount,
|
||||
pricing=request.pricing,
|
||||
settlement_details=request.settlement_details,
|
||||
confidential=request.confidential,
|
||||
participants=request.participants,
|
||||
access_policies=request.access_policies
|
||||
)
|
||||
|
||||
# Encrypt sensitive data if requested
|
||||
if request.confidential and request.participants:
|
||||
# Prepare data for encryption
|
||||
sensitive_data = {
|
||||
"amount": request.amount,
|
||||
"pricing": request.pricing,
|
||||
"settlement_details": request.settlement_details
|
||||
}
|
||||
|
||||
# Remove None values
|
||||
sensitive_data = {k: v for k, v in sensitive_data.items() if v is not None}
|
||||
|
||||
if sensitive_data:
|
||||
# Encrypt data
|
||||
enc_service = get_encryption_service()
|
||||
encrypted = enc_service.encrypt(
|
||||
data=sensitive_data,
|
||||
participants=request.participants,
|
||||
include_audit=True
|
||||
)
|
||||
|
||||
# Update transaction with encrypted data
|
||||
transaction.encrypted_data = encrypted.to_dict()["ciphertext"]
|
||||
transaction.encrypted_keys = encrypted.to_dict()["encrypted_keys"]
|
||||
transaction.algorithm = encrypted.algorithm
|
||||
|
||||
# Clear plaintext fields
|
||||
transaction.amount = None
|
||||
transaction.pricing = None
|
||||
transaction.settlement_details = None
|
||||
|
||||
# Store transaction (in production, save to database)
|
||||
logger.info(f"Created confidential transaction: {transaction_id}")
|
||||
|
||||
# Return view
|
||||
return ConfidentialTransactionView(
|
||||
transaction_id=transaction.transaction_id,
|
||||
job_id=transaction.job_id,
|
||||
timestamp=transaction.timestamp,
|
||||
status=transaction.status,
|
||||
amount=transaction.amount, # Will be None if encrypted
|
||||
pricing=transaction.pricing,
|
||||
settlement_details=transaction.settlement_details,
|
||||
confidential=transaction.confidential,
|
||||
participants=transaction.participants,
|
||||
has_encrypted_data=transaction.encrypted_data is not None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create confidential transaction: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/transactions/{transaction_id}", response_model=ConfidentialTransactionView)
|
||||
async def get_confidential_transaction(
|
||||
transaction_id: str,
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Get confidential transaction metadata (without decrypting sensitive data)"""
|
||||
try:
|
||||
# Retrieve transaction (in production, query from database)
|
||||
# For now, return error as we don't have storage
|
||||
raise HTTPException(status_code=404, detail="Transaction not found")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get transaction {transaction_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/transactions/{transaction_id}/access", response_model=ConfidentialAccessResponse)
|
||||
@limiter.limit("10/minute") # Rate limit decryption requests
|
||||
async def access_confidential_data(
|
||||
request: ConfidentialAccessRequest,
|
||||
transaction_id: str,
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Request access to decrypt confidential transaction data"""
|
||||
try:
|
||||
# Validate request
|
||||
if request.transaction_id != transaction_id:
|
||||
raise HTTPException(status_code=400, detail="Transaction ID mismatch")
|
||||
|
||||
# Get transaction (in production, query from database)
|
||||
# For now, create mock transaction
|
||||
transaction = ConfidentialTransaction(
|
||||
transaction_id=transaction_id,
|
||||
job_id="test-job",
|
||||
timestamp=datetime.utcnow(),
|
||||
status="completed",
|
||||
confidential=True,
|
||||
participants=["client-456", "miner-789"]
|
||||
)
|
||||
|
||||
if not transaction.confidential:
|
||||
raise HTTPException(status_code=400, detail="Transaction is not confidential")
|
||||
|
||||
# Check access authorization
|
||||
acc_controller = get_access_controller()
|
||||
if not acc_controller.verify_access(request):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Decrypt data
|
||||
enc_service = get_encryption_service()
|
||||
|
||||
# Reconstruct encrypted data
|
||||
if not transaction.encrypted_data or not transaction.encrypted_keys:
|
||||
raise HTTPException(status_code=404, detail="No encrypted data found")
|
||||
|
||||
encrypted_data = EncryptedData.from_dict({
|
||||
"ciphertext": transaction.encrypted_data,
|
||||
"encrypted_keys": transaction.encrypted_keys,
|
||||
"algorithm": transaction.algorithm or "AES-256-GCM+X25519"
|
||||
})
|
||||
|
||||
# Decrypt for requester
|
||||
try:
|
||||
decrypted_data = enc_service.decrypt(
|
||||
encrypted_data=encrypted_data,
|
||||
participant_id=request.requester,
|
||||
purpose=request.purpose
|
||||
)
|
||||
|
||||
return ConfidentialAccessResponse(
|
||||
success=True,
|
||||
data=decrypted_data,
|
||||
access_id=f"access-{datetime.utcnow().timestamp()}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Decryption failed: {e}")
|
||||
return ConfidentialAccessResponse(
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to access confidential data: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/transactions/{transaction_id}/audit", response_model=ConfidentialAccessResponse)
|
||||
async def audit_access_confidential_data(
|
||||
transaction_id: str,
|
||||
authorization: str,
|
||||
purpose: str = "compliance",
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Audit access to confidential transaction data"""
|
||||
try:
|
||||
# Get transaction
|
||||
transaction = ConfidentialTransaction(
|
||||
transaction_id=transaction_id,
|
||||
job_id="test-job",
|
||||
timestamp=datetime.utcnow(),
|
||||
status="completed",
|
||||
confidential=True
|
||||
)
|
||||
|
||||
if not transaction.confidential:
|
||||
raise HTTPException(status_code=400, detail="Transaction is not confidential")
|
||||
|
||||
# Decrypt with audit key
|
||||
enc_service = get_encryption_service()
|
||||
|
||||
if not transaction.encrypted_data or not transaction.encrypted_keys:
|
||||
raise HTTPException(status_code=404, detail="No encrypted data found")
|
||||
|
||||
encrypted_data = EncryptedData.from_dict({
|
||||
"ciphertext": transaction.encrypted_data,
|
||||
"encrypted_keys": transaction.encrypted_keys,
|
||||
"algorithm": transaction.algorithm or "AES-256-GCM+X25519"
|
||||
})
|
||||
|
||||
# Decrypt for audit
|
||||
try:
|
||||
decrypted_data = enc_service.audit_decrypt(
|
||||
encrypted_data=encrypted_data,
|
||||
audit_authorization=authorization,
|
||||
purpose=purpose
|
||||
)
|
||||
|
||||
return ConfidentialAccessResponse(
|
||||
success=True,
|
||||
data=decrypted_data,
|
||||
access_id=f"audit-{datetime.utcnow().timestamp()}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Audit decryption failed: {e}")
|
||||
return ConfidentialAccessResponse(
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed audit access: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/keys/register", response_model=KeyRegistrationResponse)
|
||||
async def register_encryption_key(
|
||||
request: KeyRegistrationRequest,
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Register public key for confidential transactions"""
|
||||
try:
|
||||
# Get key manager
|
||||
km = get_key_manager()
|
||||
|
||||
# Check if participant already has keys
|
||||
try:
|
||||
existing_key = km.get_public_key(request.participant_id)
|
||||
if existing_key:
|
||||
# Key exists, return version
|
||||
return KeyRegistrationResponse(
|
||||
success=True,
|
||||
participant_id=request.participant_id,
|
||||
key_version=1, # Would get from storage
|
||||
registered_at=datetime.utcnow(),
|
||||
error=None
|
||||
)
|
||||
except:
|
||||
pass # Key doesn't exist, continue
|
||||
|
||||
# Generate new key pair
|
||||
key_pair = await km.generate_key_pair(request.participant_id)
|
||||
|
||||
return KeyRegistrationResponse(
|
||||
success=True,
|
||||
participant_id=request.participant_id,
|
||||
key_version=key_pair.version,
|
||||
registered_at=key_pair.created_at,
|
||||
error=None
|
||||
)
|
||||
|
||||
except KeyManagementError as e:
|
||||
logger.error(f"Key registration failed: {e}")
|
||||
return KeyRegistrationResponse(
|
||||
success=False,
|
||||
participant_id=request.participant_id,
|
||||
key_version=0,
|
||||
registered_at=datetime.utcnow(),
|
||||
error=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register key: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/keys/rotate")
|
||||
async def rotate_encryption_key(
|
||||
participant_id: str,
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Rotate encryption keys for participant"""
|
||||
try:
|
||||
km = get_key_manager()
|
||||
|
||||
# Rotate keys
|
||||
new_key_pair = await km.rotate_keys(participant_id)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"participant_id": participant_id,
|
||||
"new_version": new_key_pair.version,
|
||||
"rotated_at": new_key_pair.created_at
|
||||
}
|
||||
|
||||
except KeyManagementError as e:
|
||||
logger.error(f"Key rotation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rotate keys: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/access/logs", response_model=AccessLogResponse)
|
||||
async def get_access_logs(
|
||||
query: AccessLogQuery = Depends(),
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Get access logs for confidential transactions"""
|
||||
try:
|
||||
# Query logs (in production, query from database)
|
||||
# For now, return empty response
|
||||
return AccessLogResponse(
|
||||
logs=[],
|
||||
total_count=0,
|
||||
has_more=False
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get access logs: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_confidential_status(
|
||||
api_key: str = Depends(get_api_key)
|
||||
):
|
||||
"""Get status of confidential transaction system"""
|
||||
try:
|
||||
km = get_key_manager()
|
||||
enc_service = get_encryption_service()
|
||||
|
||||
# Get system status
|
||||
participants = await km.list_participants()
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"algorithm": "AES-256-GCM+X25519",
|
||||
"participants_count": len(participants),
|
||||
"transactions_count": 0, # Would query from database
|
||||
"audit_enabled": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get status: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -6,6 +6,7 @@ from fastapi import status as http_status
|
||||
from ..models import MarketplaceBidRequest, MarketplaceOfferView, MarketplaceStatsView
|
||||
from ..services import MarketplaceService
|
||||
from ..storage import SessionDep
|
||||
from ..metrics import marketplace_requests_total, marketplace_errors_total
|
||||
|
||||
router = APIRouter(tags=["marketplace"])
|
||||
|
||||
@@ -26,11 +27,16 @@ async def list_marketplace_offers(
|
||||
limit: int = Query(default=100, ge=1, le=500),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> list[MarketplaceOfferView]:
|
||||
marketplace_requests_total.labels(endpoint="/marketplace/offers", method="GET").inc()
|
||||
service = _get_service(session)
|
||||
try:
|
||||
return service.list_offers(status=status_filter, limit=limit, offset=offset)
|
||||
except ValueError:
|
||||
marketplace_errors_total.labels(endpoint="/marketplace/offers", method="GET", error_type="invalid_request").inc()
|
||||
raise HTTPException(status_code=http_status.HTTP_400_BAD_REQUEST, detail="invalid status filter") from None
|
||||
except Exception:
|
||||
marketplace_errors_total.labels(endpoint="/marketplace/offers", method="GET", error_type="internal").inc()
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -39,8 +45,13 @@ async def list_marketplace_offers(
|
||||
summary="Get marketplace summary statistics",
|
||||
)
|
||||
async def get_marketplace_stats(*, session: SessionDep) -> MarketplaceStatsView:
|
||||
marketplace_requests_total.labels(endpoint="/marketplace/stats", method="GET").inc()
|
||||
service = _get_service(session)
|
||||
return service.get_stats()
|
||||
try:
|
||||
return service.get_stats()
|
||||
except Exception:
|
||||
marketplace_errors_total.labels(endpoint="/marketplace/stats", method="GET", error_type="internal").inc()
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -52,6 +63,14 @@ async def submit_marketplace_bid(
|
||||
payload: MarketplaceBidRequest,
|
||||
session: SessionDep,
|
||||
) -> dict[str, str]:
|
||||
marketplace_requests_total.labels(endpoint="/marketplace/bids", method="POST").inc()
|
||||
service = _get_service(session)
|
||||
bid = service.create_bid(payload)
|
||||
return {"id": bid.id}
|
||||
try:
|
||||
bid = service.create_bid(payload)
|
||||
return {"id": bid.id}
|
||||
except ValueError:
|
||||
marketplace_errors_total.labels(endpoint="/marketplace/bids", method="POST", error_type="invalid_request").inc()
|
||||
raise HTTPException(status_code=http_status.HTTP_400_BAD_REQUEST, detail="invalid bid data") from None
|
||||
except Exception:
|
||||
marketplace_errors_total.labels(endpoint="/marketplace/bids", method="POST", error_type="internal").inc()
|
||||
raise
|
||||
|
||||
303
apps/coordinator-api/src/app/routers/registry.py
Normal file
303
apps/coordinator-api/src/app/routers/registry.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Service registry router for dynamic service management
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from ..models.registry import (
|
||||
ServiceRegistry,
|
||||
ServiceDefinition,
|
||||
ServiceCategory
|
||||
)
|
||||
from ..models.registry_media import MEDIA_PROCESSING_SERVICES
|
||||
from ..models.registry_scientific import SCIENTIFIC_COMPUTING_SERVICES
|
||||
from ..models.registry_data import DATA_ANALYTICS_SERVICES
|
||||
from ..models.registry_gaming import GAMING_SERVICES
|
||||
from ..models.registry_devtools import DEVTOOLS_SERVICES
|
||||
from ..models.registry import AI_ML_SERVICES
|
||||
|
||||
router = APIRouter(prefix="/registry", tags=["service-registry"])
|
||||
|
||||
# Initialize service registry with all services
|
||||
def create_service_registry() -> ServiceRegistry:
|
||||
"""Create and populate the service registry"""
|
||||
all_services = {}
|
||||
|
||||
# Add all service categories
|
||||
all_services.update(AI_ML_SERVICES)
|
||||
all_services.update(MEDIA_PROCESSING_SERVICES)
|
||||
all_services.update(SCIENTIFIC_COMPUTING_SERVICES)
|
||||
all_services.update(DATA_ANALYTICS_SERVICES)
|
||||
all_services.update(GAMING_SERVICES)
|
||||
all_services.update(DEVTOOLS_SERVICES)
|
||||
|
||||
return ServiceRegistry(
|
||||
version="1.0.0",
|
||||
services=all_services
|
||||
)
|
||||
|
||||
# Global registry instance
|
||||
service_registry = create_service_registry()
|
||||
|
||||
|
||||
@router.get("/", response_model=ServiceRegistry)
|
||||
async def get_registry() -> ServiceRegistry:
|
||||
"""Get the complete service registry"""
|
||||
return service_registry
|
||||
|
||||
|
||||
@router.get("/services", response_model=List[ServiceDefinition])
|
||||
async def list_services(
|
||||
category: Optional[ServiceCategory] = None,
|
||||
search: Optional[str] = None
|
||||
) -> List[ServiceDefinition]:
|
||||
"""List all available services with optional filtering"""
|
||||
services = list(service_registry.services.values())
|
||||
|
||||
# Filter by category
|
||||
if category:
|
||||
services = [s for s in services if s.category == category]
|
||||
|
||||
# Search by name, description, or tags
|
||||
if search:
|
||||
search = search.lower()
|
||||
services = [
|
||||
s for s in services
|
||||
if (search in s.name.lower() or
|
||||
search in s.description.lower() or
|
||||
any(search in tag.lower() for tag in s.tags))
|
||||
]
|
||||
|
||||
return services
|
||||
|
||||
|
||||
@router.get("/services/{service_id}", response_model=ServiceDefinition)
|
||||
async def get_service(service_id: str) -> ServiceDefinition:
|
||||
"""Get a specific service definition"""
|
||||
service = service_registry.get_service(service_id)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_id} not found"
|
||||
)
|
||||
return service
|
||||
|
||||
|
||||
@router.get("/categories", response_model=List[Dict[str, Any]])
|
||||
async def list_categories() -> List[Dict[str, Any]]:
|
||||
"""List all service categories with counts"""
|
||||
category_counts = {}
|
||||
for service in service_registry.services.values():
|
||||
category = service.category.value
|
||||
if category not in category_counts:
|
||||
category_counts[category] = 0
|
||||
category_counts[category] += 1
|
||||
|
||||
return [
|
||||
{"category": cat, "count": count}
|
||||
for cat, count in category_counts.items()
|
||||
]
|
||||
|
||||
|
||||
@router.get("/categories/{category}", response_model=List[ServiceDefinition])
|
||||
async def get_services_by_category(category: ServiceCategory) -> List[ServiceDefinition]:
|
||||
"""Get all services in a specific category"""
|
||||
return service_registry.get_services_by_category(category)
|
||||
|
||||
|
||||
@router.get("/services/{service_id}/schema")
|
||||
async def get_service_schema(service_id: str) -> Dict[str, Any]:
|
||||
"""Get JSON schema for a service's input parameters"""
|
||||
service = service_registry.get_service(service_id)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_id} not found"
|
||||
)
|
||||
|
||||
# Convert input parameters to JSON schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param in service.input_parameters:
|
||||
prop = {
|
||||
"type": param.type.value,
|
||||
"description": param.description
|
||||
}
|
||||
|
||||
if param.default is not None:
|
||||
prop["default"] = param.default
|
||||
if param.min_value is not None:
|
||||
prop["minimum"] = param.min_value
|
||||
if param.max_value is not None:
|
||||
prop["maximum"] = param.max_value
|
||||
if param.options:
|
||||
prop["enum"] = param.options
|
||||
if param.validation:
|
||||
prop.update(param.validation)
|
||||
|
||||
properties[param.name] = prop
|
||||
if param.required:
|
||||
required.append(param.name)
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
|
||||
|
||||
@router.get("/services/{service_id}/requirements")
|
||||
async def get_service_requirements(service_id: str) -> Dict[str, Any]:
|
||||
"""Get hardware requirements for a service"""
|
||||
service = service_registry.get_service(service_id)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_id} not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"requirements": [
|
||||
{
|
||||
"component": req.component,
|
||||
"minimum": req.min_value,
|
||||
"recommended": req.recommended,
|
||||
"unit": req.unit
|
||||
}
|
||||
for req in service.requirements
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/services/{service_id}/pricing")
|
||||
async def get_service_pricing(service_id: str) -> Dict[str, Any]:
|
||||
"""Get pricing information for a service"""
|
||||
service = service_registry.get_service(service_id)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_id} not found"
|
||||
)
|
||||
|
||||
return {
|
||||
"pricing": [
|
||||
{
|
||||
"tier": tier.name,
|
||||
"model": tier.model.value,
|
||||
"unit_price": tier.unit_price,
|
||||
"min_charge": tier.min_charge,
|
||||
"currency": tier.currency,
|
||||
"description": tier.description
|
||||
}
|
||||
for tier in service.pricing
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.post("/services/validate")
|
||||
async def validate_service_request(
|
||||
service_id: str,
|
||||
request_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Validate a service request against the service schema"""
|
||||
service = service_registry.get_service(service_id)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_id} not found"
|
||||
)
|
||||
|
||||
# Validate request data
|
||||
validation_result = {
|
||||
"valid": True,
|
||||
"errors": [],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
# Check required parameters
|
||||
provided_params = set(request_data.keys())
|
||||
required_params = {p.name for p in service.input_parameters if p.required}
|
||||
missing_params = required_params - provided_params
|
||||
|
||||
if missing_params:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].extend([
|
||||
f"Missing required parameter: {param}"
|
||||
for param in missing_params
|
||||
])
|
||||
|
||||
# Validate parameter types and constraints
|
||||
for param in service.input_parameters:
|
||||
if param.name in request_data:
|
||||
value = request_data[param.name]
|
||||
|
||||
# Type validation (simplified)
|
||||
if param.type == "integer" and not isinstance(value, int):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be an integer"
|
||||
)
|
||||
elif param.type == "float" and not isinstance(value, (int, float)):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be a number"
|
||||
)
|
||||
elif param.type == "boolean" and not isinstance(value, bool):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be a boolean"
|
||||
)
|
||||
elif param.type == "array" and not isinstance(value, list):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be an array"
|
||||
)
|
||||
|
||||
# Value constraints
|
||||
if param.min_value is not None and value < param.min_value:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be >= {param.min_value}"
|
||||
)
|
||||
|
||||
if param.max_value is not None and value > param.max_value:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be <= {param.max_value}"
|
||||
)
|
||||
|
||||
# Enum options
|
||||
if param.options and value not in param.options:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be one of: {', '.join(param.options)}"
|
||||
)
|
||||
|
||||
return validation_result
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_registry_stats() -> Dict[str, Any]:
|
||||
"""Get registry statistics"""
|
||||
total_services = len(service_registry.services)
|
||||
category_counts = {}
|
||||
|
||||
for service in service_registry.services.values():
|
||||
category = service.category.value
|
||||
if category not in category_counts:
|
||||
category_counts[category] = 0
|
||||
category_counts[category] += 1
|
||||
|
||||
# Count unique pricing models
|
||||
pricing_models = set()
|
||||
for service in service_registry.services.values():
|
||||
for tier in service.pricing:
|
||||
pricing_models.add(tier.model.value)
|
||||
|
||||
return {
|
||||
"total_services": total_services,
|
||||
"categories": category_counts,
|
||||
"pricing_models": list(pricing_models),
|
||||
"last_updated": service_registry.last_updated.isoformat()
|
||||
}
|
||||
612
apps/coordinator-api/src/app/routers/services.py
Normal file
612
apps/coordinator-api/src/app/routers/services.py
Normal file
@@ -0,0 +1,612 @@
|
||||
"""
|
||||
Services router for specific GPU workloads
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from ..deps import require_client_key
|
||||
from ..models import JobCreate, JobView, JobResult
|
||||
from ..models.services import (
|
||||
ServiceType,
|
||||
ServiceRequest,
|
||||
ServiceResponse,
|
||||
WhisperRequest,
|
||||
StableDiffusionRequest,
|
||||
LLMRequest,
|
||||
FFmpegRequest,
|
||||
BlenderRequest,
|
||||
)
|
||||
from ..models.registry import ServiceRegistry, service_registry
|
||||
from ..services import JobService
|
||||
from ..storage import SessionDep
|
||||
|
||||
router = APIRouter(tags=["services"])
|
||||
|
||||
|
||||
@router.post(
|
||||
"/services/{service_type}",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Submit a service-specific job",
|
||||
deprecated=True
|
||||
)
|
||||
async def submit_service_job(
|
||||
service_type: ServiceType,
|
||||
request_data: Dict[str, Any],
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
user_agent: str = Header(None),
|
||||
) -> ServiceResponse:
|
||||
"""Submit a job for a specific service type
|
||||
|
||||
DEPRECATED: Use /v1/registry/services/{service_id} endpoint instead.
|
||||
This endpoint will be removed in version 2.0.
|
||||
"""
|
||||
|
||||
# Add deprecation warning header
|
||||
from fastapi import Response
|
||||
response = Response()
|
||||
response.headers["X-Deprecated"] = "true"
|
||||
response.headers["X-Deprecation-Message"] = "Use /v1/registry/services/{service_id} instead"
|
||||
|
||||
# Check if service exists in registry
|
||||
service = service_registry.get_service(service_type.value)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_type} not found"
|
||||
)
|
||||
|
||||
# Validate request against service schema
|
||||
validation_result = await validate_service_request(service_type.value, request_data)
|
||||
if not validation_result["valid"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid request: {', '.join(validation_result['errors'])}"
|
||||
)
|
||||
|
||||
# Create service request wrapper
|
||||
service_request = ServiceRequest(
|
||||
service_type=service_type,
|
||||
request_data=request_data
|
||||
)
|
||||
|
||||
# Validate and parse service-specific request
|
||||
try:
|
||||
typed_request = service_request.get_service_request()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid request for {service_type}: {str(e)}"
|
||||
)
|
||||
|
||||
# Get constraints from service request
|
||||
constraints = typed_request.get_constraints()
|
||||
|
||||
# Create job with service-specific payload
|
||||
job_payload = {
|
||||
"service_type": service_type.value,
|
||||
"service_request": request_data,
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=constraints,
|
||||
ttl_seconds=900 # Default 15 minutes
|
||||
)
|
||||
|
||||
# Submit job
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=service_type,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
# Whisper endpoints
|
||||
@router.post(
|
||||
"/services/whisper/transcribe",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Transcribe audio using Whisper"
|
||||
)
|
||||
async def whisper_transcribe(
|
||||
request: WhisperRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Transcribe audio file using Whisper"""
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.WHISPER.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=900
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.WHISPER,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/services/whisper/translate",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Translate audio using Whisper"
|
||||
)
|
||||
async def whisper_translate(
|
||||
request: WhisperRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Translate audio file using Whisper"""
|
||||
# Force task to be translate
|
||||
request.task = "translate"
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.WHISPER.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=900
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.WHISPER,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
# Stable Diffusion endpoints
|
||||
@router.post(
|
||||
"/services/stable-diffusion/generate",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Generate images using Stable Diffusion"
|
||||
)
|
||||
async def stable_diffusion_generate(
|
||||
request: StableDiffusionRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Generate images using Stable Diffusion"""
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.STABLE_DIFFUSION.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=600 # 10 minutes for image generation
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.STABLE_DIFFUSION,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/services/stable-diffusion/img2img",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Image-to-image generation"
|
||||
)
|
||||
async def stable_diffusion_img2img(
|
||||
request: StableDiffusionRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Image-to-image generation using Stable Diffusion"""
|
||||
# Add img2img specific parameters
|
||||
request_data = request.dict()
|
||||
request_data["mode"] = "img2img"
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.STABLE_DIFFUSION.value,
|
||||
"service_request": request_data,
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=600
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.STABLE_DIFFUSION,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
# LLM Inference endpoints
|
||||
@router.post(
|
||||
"/services/llm/inference",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Run LLM inference"
|
||||
)
|
||||
async def llm_inference(
|
||||
request: LLMRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Run inference on a language model"""
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.LLM_INFERENCE.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=300 # 5 minutes for text generation
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.LLM_INFERENCE,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/services/llm/stream",
|
||||
summary="Stream LLM inference"
|
||||
)
|
||||
async def llm_stream(
|
||||
request: LLMRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
):
|
||||
"""Stream LLM inference response"""
|
||||
# Force streaming mode
|
||||
request.stream = True
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.LLM_INFERENCE.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=300
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
# Return streaming response
|
||||
# This would implement WebSocket or Server-Sent Events
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.LLM_INFERENCE,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
# FFmpeg endpoints
|
||||
@router.post(
|
||||
"/services/ffmpeg/transcode",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Transcode video using FFmpeg"
|
||||
)
|
||||
async def ffmpeg_transcode(
|
||||
request: FFmpegRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Transcode video using FFmpeg"""
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.FFMPEG.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
# Adjust TTL based on video length (would need to probe video)
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=1800 # 30 minutes for video transcoding
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.FFMPEG,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
# Blender endpoints
|
||||
@router.post(
|
||||
"/services/blender/render",
|
||||
response_model=ServiceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Render using Blender"
|
||||
)
|
||||
async def blender_render(
|
||||
request: BlenderRequest,
|
||||
session: SessionDep,
|
||||
client_id: str = Depends(require_client_key()),
|
||||
) -> ServiceResponse:
|
||||
"""Render scene using Blender"""
|
||||
|
||||
job_payload = {
|
||||
"service_type": ServiceType.BLENDER.value,
|
||||
"service_request": request.dict(),
|
||||
}
|
||||
|
||||
# Adjust TTL based on frame count
|
||||
frame_count = request.frame_end - request.frame_start + 1
|
||||
estimated_time = frame_count * 30 # 30 seconds per frame estimate
|
||||
ttl_seconds = max(600, estimated_time) # Minimum 10 minutes
|
||||
|
||||
job_create = JobCreate(
|
||||
payload=job_payload,
|
||||
constraints=request.get_constraints(),
|
||||
ttl_seconds=ttl_seconds
|
||||
)
|
||||
|
||||
service = JobService(session)
|
||||
job = service.create_job(client_id, job_create)
|
||||
|
||||
return ServiceResponse(
|
||||
job_id=job.job_id,
|
||||
service_type=ServiceType.BLENDER,
|
||||
status=job.state.value,
|
||||
estimated_completion=job.expires_at.isoformat()
|
||||
)
|
||||
|
||||
|
||||
# Utility endpoints
|
||||
@router.get(
|
||||
"/services",
|
||||
summary="List available services"
|
||||
)
|
||||
async def list_services() -> Dict[str, Any]:
|
||||
"""List all available service types and their capabilities"""
|
||||
return {
|
||||
"services": [
|
||||
{
|
||||
"type": ServiceType.WHISPER.value,
|
||||
"name": "Whisper Speech Recognition",
|
||||
"description": "Transcribe and translate audio files",
|
||||
"models": [m.value for m in WhisperModel],
|
||||
"constraints": {
|
||||
"gpu": "nvidia",
|
||||
"min_vram_gb": 1,
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": ServiceType.STABLE_DIFFUSION.value,
|
||||
"name": "Stable Diffusion",
|
||||
"description": "Generate images from text prompts",
|
||||
"models": [m.value for m in SDModel],
|
||||
"constraints": {
|
||||
"gpu": "nvidia",
|
||||
"min_vram_gb": 4,
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": ServiceType.LLM_INFERENCE.value,
|
||||
"name": "LLM Inference",
|
||||
"description": "Run inference on large language models",
|
||||
"models": [m.value for m in LLMModel],
|
||||
"constraints": {
|
||||
"gpu": "nvidia",
|
||||
"min_vram_gb": 8,
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": ServiceType.FFMPEG.value,
|
||||
"name": "FFmpeg Video Processing",
|
||||
"description": "Transcode and process video files",
|
||||
"codecs": [c.value for c in FFmpegCodec],
|
||||
"constraints": {
|
||||
"gpu": "any",
|
||||
"min_vram_gb": 0,
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": ServiceType.BLENDER.value,
|
||||
"name": "Blender Rendering",
|
||||
"description": "Render 3D scenes using Blender",
|
||||
"engines": [e.value for e in BlenderEngine],
|
||||
"constraints": {
|
||||
"gpu": "any",
|
||||
"min_vram_gb": 4,
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/services/{service_type}/schema",
|
||||
summary="Get service request schema",
|
||||
deprecated=True
|
||||
)
|
||||
async def get_service_schema(service_type: ServiceType) -> Dict[str, Any]:
|
||||
"""Get the JSON schema for a specific service type
|
||||
|
||||
DEPRECATED: Use /v1/registry/services/{service_id}/schema instead.
|
||||
This endpoint will be removed in version 2.0.
|
||||
"""
|
||||
# Get service from registry
|
||||
service = service_registry.get_service(service_type.value)
|
||||
if not service:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Service {service_type} not found"
|
||||
)
|
||||
|
||||
# Build schema from service definition
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param in service.input_parameters:
|
||||
prop = {
|
||||
"type": param.type.value,
|
||||
"description": param.description
|
||||
}
|
||||
|
||||
if param.default is not None:
|
||||
prop["default"] = param.default
|
||||
if param.min_value is not None:
|
||||
prop["minimum"] = param.min_value
|
||||
if param.max_value is not None:
|
||||
prop["maximum"] = param.max_value
|
||||
if param.options:
|
||||
prop["enum"] = param.options
|
||||
if param.validation:
|
||||
prop.update(param.validation)
|
||||
|
||||
properties[param.name] = prop
|
||||
if param.required:
|
||||
required.append(param.name)
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
|
||||
return {
|
||||
"service_type": service_type.value,
|
||||
"schema": schema
|
||||
}
|
||||
|
||||
|
||||
async def validate_service_request(service_id: str, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate a service request against the service schema"""
|
||||
service = service_registry.get_service(service_id)
|
||||
if not service:
|
||||
return {"valid": False, "errors": [f"Service {service_id} not found"]}
|
||||
|
||||
validation_result = {
|
||||
"valid": True,
|
||||
"errors": [],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
# Check required parameters
|
||||
provided_params = set(request_data.keys())
|
||||
required_params = {p.name for p in service.input_parameters if p.required}
|
||||
missing_params = required_params - provided_params
|
||||
|
||||
if missing_params:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].extend([
|
||||
f"Missing required parameter: {param}"
|
||||
for param in missing_params
|
||||
])
|
||||
|
||||
# Validate parameter types and constraints
|
||||
for param in service.input_parameters:
|
||||
if param.name in request_data:
|
||||
value = request_data[param.name]
|
||||
|
||||
# Type validation (simplified)
|
||||
if param.type == "integer" and not isinstance(value, int):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be an integer"
|
||||
)
|
||||
elif param.type == "float" and not isinstance(value, (int, float)):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be a number"
|
||||
)
|
||||
elif param.type == "boolean" and not isinstance(value, bool):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be a boolean"
|
||||
)
|
||||
elif param.type == "array" and not isinstance(value, list):
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be an array"
|
||||
)
|
||||
|
||||
# Value constraints
|
||||
if param.min_value is not None and value < param.min_value:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be >= {param.min_value}"
|
||||
)
|
||||
|
||||
if param.max_value is not None and value > param.max_value:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be <= {param.max_value}"
|
||||
)
|
||||
|
||||
# Enum options
|
||||
if param.options and value not in param.options:
|
||||
validation_result["valid"] = False
|
||||
validation_result["errors"].append(
|
||||
f"Parameter {param.name} must be one of: {', '.join(param.options)}"
|
||||
)
|
||||
|
||||
return validation_result
|
||||
|
||||
|
||||
# Import models for type hints
|
||||
from ..models.services import (
|
||||
WhisperModel,
|
||||
SDModel,
|
||||
LLMModel,
|
||||
FFmpegCodec,
|
||||
FFmpegPreset,
|
||||
BlenderEngine,
|
||||
BlenderFormat,
|
||||
)
|
||||
Reference in New Issue
Block a user