diff --git a/apps/edge-api/src/edge_api/routers/serve.py b/apps/edge-api/src/edge_api/routers/serve.py index 214dc2b0..2511273f 100644 --- a/apps/edge-api/src/edge_api/routers/serve.py +++ b/apps/edge-api/src/edge_api/routers/serve.py @@ -1,35 +1,63 @@ """Edge serve operations router for Edge API Service""" -from fastapi import APIRouter +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, Field + +from ..services.serve_service import ServeService router = APIRouter() -@router.post("/start") -async def start_serve(): - """Start serving edge compute requests - TODO: Implement in Phase 5""" - return {"message": "Start serve endpoint - to be implemented in Phase 5"} +class SubmitComputeRequest(BaseModel): + """Request model for submitting compute request""" + gpu_id: str + model_name: str + input_data: dict + priority: str = Field(default="normal") -@router.post("/stop") -async def stop_serve(): - """Stop serving edge compute requests - TODO: Implement in Phase 5""" - return {"message": "Stop serve endpoint - to be implemented in Phase 5"} +def get_serve_service() -> ServeService: + """Dependency injection for serve service""" + return ServeService() -@router.get("/status") -async def get_serve_status(): - """Get serve status - TODO: Implement in Phase 5""" - return {"message": "Get serve status endpoint - to be implemented in Phase 5"} +@router.post("/requests") +async def submit_compute_request(request: SubmitComputeRequest, svc: ServeService = Depends(get_serve_service)): + """Submit compute request""" + result = await svc.submit_compute_request(request.gpu_id, request.model_name, request.input_data, request.priority) + return result @router.get("/requests") -async def get_pending_requests(): - """Get pending compute requests - TODO: Implement in Phase 5""" - return {"message": "Get pending requests endpoint - to be implemented in Phase 5"} +async def list_compute_requests(gpu_id: str = Query(None), status: str = Query(None), svc: ServeService = Depends(get_serve_service)): + """List compute requests, optionally filtered""" + requests = await svc.list_compute_requests(gpu_id, status) + return {"requests": requests, "total": len(requests)} -@router.post("/requests/{request_id}/complete") -async def complete_request(request_id: str): - """Complete a compute request - TODO: Implement in Phase 5""" - return {"message": f"Complete request {request_id} - to be implemented in Phase 5"} +@router.get("/requests/{request_id}") +async def get_compute_request(request_id: str, svc: ServeService = Depends(get_serve_service)): + """Get compute request details""" + req = await svc.get_compute_request(request_id) + if req is None: + raise HTTPException(status_code=404, detail=f"Request {request_id} not found") + return req + + +@router.post("/requests/{request_id}/cancel") +async def cancel_compute_request(request_id: str, svc: ServeService = Depends(get_serve_service)): + """Cancel compute request""" + success = await svc.cancel_compute_request(request_id) + if success: + return {"message": f"Request {request_id} cancelled"} + else: + raise HTTPException(status_code=400, detail=f"Request {request_id} cannot be cancelled") + + +@router.get("/requests/{request_id}/result") +async def get_compute_result(request_id: str, svc: ServeService = Depends(get_serve_service)): + """Get compute result""" + result = await svc.get_compute_result(request_id) + if result is None: + raise HTTPException(status_code=404, detail=f"Result for request {request_id} not found") + return result diff --git a/apps/edge-api/src/edge_api/schemas/serve.py b/apps/edge-api/src/edge_api/schemas/serve.py index 52e0a9a1..014bf2e3 100644 --- a/apps/edge-api/src/edge_api/schemas/serve.py +++ b/apps/edge-api/src/edge_api/schemas/serve.py @@ -1,6 +1,6 @@ """Edge serve-related schemas for Edge API Service""" -from datetime import datetime, timezone +from datetime import datetime from uuid import uuid4 from sqlalchemy import JSON, Column @@ -15,24 +15,17 @@ class ComputeRequest(SQLModel, table=True): id: str = Field(default_factory=lambda: f"compute_req_{uuid4().hex[:8]}", primary_key=True) request_id: str = Field(index=True) - island_id: str = Field(index=True) - gpu_type: str - status: str = Field(default="pending", index=True) # pending, processing, completed, failed - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) - updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) - - # Request parameters + gpu_id: str = Field(index=True) model_name: str input_data: dict = Field(default_factory=dict, sa_column=Column(JSON, nullable=True)) - - # Processing info - assigned_gpu_id: str | None = Field(default=None) + priority: str = Field(default="normal") + status: str = Field(default="queued", index=True) # queued, running, completed, failed, cancelled + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) started_at: datetime | None = Field(default=None) completed_at: datetime | None = Field(default=None) - - # Result - result: dict | None = Field(default=None, sa_column=Column(JSON, nullable=True)) error: str | None = Field(default=None) + extra_data: dict = Field(default_factory=dict, sa_column=Column(JSON, nullable=True)) class ComputeResult(SQLModel, table=True): @@ -48,5 +41,5 @@ class ComputeResult(SQLModel, table=True): gpu_id: str result: dict = Field(default_factory=dict, sa_column=Column(JSON, nullable=False)) cache_ttl: int = Field(default=3600) # 1 hour default - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) - expires_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + created_at: datetime = Field(default_factory=datetime.utcnow) + expires_at: datetime = Field(default_factory=datetime.utcnow) diff --git a/apps/edge-api/src/edge_api/services/serve_service.py b/apps/edge-api/src/edge_api/services/serve_service.py index c8e8ac04..2255a06e 100644 --- a/apps/edge-api/src/edge_api/services/serve_service.py +++ b/apps/edge-api/src/edge_api/services/serve_service.py @@ -1,33 +1,114 @@ """Edge serve service for Edge API Service""" -from typing import Dict, List +from typing import Dict, Optional, List +from datetime import datetime +from uuid import uuid4 +from ..storage import get_session from ..schemas.serve import ComputeRequest, ComputeResult +from sqlmodel import select, delete class ServeService: """Service for edge serve operations""" - def __init__(self): - # TODO: Initialize serve queue in Phase 5 - pass + async def submit_compute_request(self, gpu_id: str, model_name: str, input_data: dict, priority: str = "normal") -> Dict: + """Submit compute request""" + async with get_session() as session: + request_id = f"req_{uuid4().hex[:8]}" + + request = ComputeRequest( + request_id=request_id, + gpu_id=gpu_id, + model_name=model_name, + input_data=input_data, + priority=priority, + status="queued" + ) + session.add(request) + await session.commit() + + return { + "success": True, + "request_id": request_id, + "status": "queued", + "message": f"Compute request {request_id} submitted" + } - async def start_serve(self, island_id: str) -> Dict: - """Start serving edge compute requests - TODO: Implement in Phase 5""" - return {"message": "start_serve - to be implemented in Phase 5"} + async def get_compute_request(self, request_id: str) -> Optional[Dict]: + """Get compute request details""" + async with get_session() as session: + result = await session.execute(select(ComputeRequest).where(ComputeRequest.request_id == request_id)) + req = result.scalar_one_or_none() + + if req: + return { + "request_id": req.request_id, + "gpu_id": req.gpu_id, + "model_name": req.model_name, + "input_data": req.input_data, + "priority": req.priority, + "status": req.status, + "created_at": req.created_at.isoformat() if req.created_at else None, + "started_at": req.started_at.isoformat() if req.started_at else None, + "completed_at": req.completed_at.isoformat() if req.completed_at else None, + "error": req.error, + "extra_data": req.extra_data + } + return None - async def stop_serve(self, island_id: str) -> Dict: - """Stop serving edge compute requests - TODO: Implement in Phase 5""" - return {"message": "stop_serve - to be implemented in Phase 5"} + async def cancel_compute_request(self, request_id: str) -> bool: + """Cancel compute request""" + async with get_session() as session: + result = await session.execute(select(ComputeRequest).where(ComputeRequest.request_id == request_id)) + req = result.scalar_one_or_none() + + if req and req.status in ["queued", "running"]: + req.status = "cancelled" + req.completed_at = datetime.utcnow() + await session.commit() + return True + return False - async def get_serve_status(self, island_id: str) -> Dict: - """Get serve status - TODO: Implement in Phase 5""" - return {"message": "get_serve_status - to be implemented in Phase 5"} + async def list_compute_requests(self, gpu_id: str = None, status: str = None) -> List[Dict]: + """List compute requests, optionally filtered""" + async with get_session() as session: + query = select(ComputeRequest) + + if gpu_id: + query = query.where(ComputeRequest.gpu_id == gpu_id) + if status: + query = query.where(ComputeRequest.status == status) + + result = await session.execute(query) + requests = result.scalars().all() + + return [ + { + "request_id": req.request_id, + "gpu_id": req.gpu_id, + "model_name": req.model_name, + "priority": req.priority, + "status": req.status, + "created_at": req.created_at.isoformat() if req.created_at else None + } + for req in requests + ] - async def get_pending_requests(self, island_id: str) -> List[Dict]: - """Get pending compute requests - TODO: Implement in Phase 5""" - return [{"message": "get_pending_requests - to be implemented in Phase 5"}] - - async def complete_request(self, request_id: str, result: Dict) -> Dict: - """Complete a compute request - TODO: Implement in Phase 5""" - return {"message": f"complete_request {request_id} - to be implemented in Phase 5"} + async def get_compute_result(self, request_id: str) -> Optional[Dict]: + """Get compute result""" + async with get_session() as session: + result = await session.execute(select(ComputeResult).where(ComputeResult.request_id == request_id)) + res = result.scalar_one_or_none() + + if res: + return { + "result_id": res.result_id, + "request_id": res.request_id, + "output_data": res.output_data, + "metrics": res.metrics, + "status": res.status, + "created_at": res.created_at.isoformat() if res.created_at else None, + "extra_data": res.extra_data + } + return None