chore: initialize monorepo with project scaffolding, configs, and CI setup

This commit is contained in:
oib
2025-09-27 06:05:25 +02:00
commit fe29631a86
170 changed files with 13708 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""AITBC Coordinator API package."""

View File

@@ -0,0 +1,32 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import List, Optional
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", case_sensitive=False)
app_env: str = "dev"
app_host: str = "127.0.0.1"
app_port: int = 8011
database_url: str = "sqlite:///./coordinator.db"
client_api_keys: List[str] = ["REDACTED_CLIENT_KEY"]
miner_api_keys: List[str] = ["REDACTED_MINER_KEY"]
admin_api_keys: List[str] = ["REDACTED_ADMIN_KEY"]
hmac_secret: Optional[str] = None
allow_origins: List[str] = ["*"]
job_ttl_seconds: int = 900
heartbeat_interval_seconds: int = 10
heartbeat_timeout_seconds: int = 30
rate_limit_requests: int = 60
rate_limit_window_seconds: int = 60
receipt_signing_key_hex: Optional[str] = None
receipt_attestation_key_hex: Optional[str] = None
settings = Settings()

View File

@@ -0,0 +1,26 @@
from typing import Callable
from fastapi import Depends, Header, HTTPException
from .config import settings
class APIKeyValidator:
def __init__(self, allowed_keys: list[str]):
self.allowed_keys = {key.strip() for key in allowed_keys if key}
def __call__(self, api_key: str | None = Header(default=None, alias="X-Api-Key")) -> str:
if not api_key or api_key not in self.allowed_keys:
raise HTTPException(status_code=401, detail="invalid api key")
return api_key
def require_client_key() -> Callable[[str | None], str]:
return APIKeyValidator(settings.client_api_keys)
def require_miner_key() -> Callable[[str | None], str]:
return APIKeyValidator(settings.miner_api_keys)
def require_admin_key() -> Callable[[str | None], str]:
return APIKeyValidator(settings.admin_api_keys)

View File

@@ -0,0 +1,7 @@
"""Domain models for the coordinator API."""
from .job import Job
from .miner import Miner
from .job_receipt import JobReceipt
__all__ = ["Job", "Miner", "JobReceipt"]

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from datetime import datetime
from typing import Optional
from uuid import uuid4
from sqlalchemy import Column, JSON
from sqlmodel import Field, SQLModel
from ..models import JobState
class Job(SQLModel, table=True):
id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True, index=True)
client_id: str = Field(index=True)
state: JobState = Field(default=JobState.queued, sa_column_kwargs={"nullable": False})
payload: dict = Field(sa_column=Column(JSON, nullable=False))
constraints: dict = Field(default_factory=dict, sa_column=Column(JSON, nullable=False))
ttl_seconds: int = Field(default=900)
requested_at: datetime = Field(default_factory=datetime.utcnow)
expires_at: datetime = Field(default_factory=datetime.utcnow)
assigned_miner_id: Optional[str] = Field(default=None, index=True)
result: Optional[dict] = Field(default=None, sa_column=Column(JSON, nullable=True))
receipt: Optional[dict] = Field(default=None, sa_column=Column(JSON, nullable=True))
receipt_id: Optional[str] = Field(default=None, index=True)
error: Optional[str] = None

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from datetime import datetime
from uuid import uuid4
from sqlalchemy import Column, JSON
from sqlmodel import Field, SQLModel
class JobReceipt(SQLModel, table=True):
id: str = Field(default_factory=lambda: uuid4().hex, primary_key=True, index=True)
job_id: str = Field(index=True, foreign_key="job.id")
receipt_id: str = Field(index=True)
payload: dict = Field(sa_column=Column(JSON, nullable=False))
created_at: datetime = Field(default_factory=datetime.utcnow, index=True)

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, JSON
from sqlmodel import Field, SQLModel
class Miner(SQLModel, table=True):
id: str = Field(primary_key=True, index=True)
region: Optional[str] = Field(default=None, index=True)
capabilities: dict = Field(default_factory=dict, sa_column=Column(JSON, nullable=False))
concurrency: int = Field(default=1)
status: str = Field(default="ONLINE", index=True)
inflight: int = Field(default=0)
extra_metadata: dict = Field(default_factory=dict, sa_column=Column(JSON, nullable=False))
last_heartbeat: datetime = Field(default_factory=datetime.utcnow, index=True)
session_token: Optional[str] = None
last_job_at: Optional[datetime] = Field(default=None, index=True)
jobs_completed: int = Field(default=0)
jobs_failed: int = Field(default=0)
total_job_duration_ms: int = Field(default=0)
average_job_duration_ms: float = Field(default=0.0)
last_receipt_id: Optional[str] = Field(default=None, index=True)

View File

@@ -0,0 +1,34 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .config import settings
from .routers import client, miner, admin
def create_app() -> FastAPI:
app = FastAPI(
title="AITBC Coordinator API",
version="0.1.0",
description="Stage 1 coordinator service handling job orchestration between clients and miners.",
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
app.include_router(client.router, prefix="/v1")
app.include_router(miner.router, prefix="/v1")
app.include_router(admin.router, prefix="/v1")
@app.get("/v1/health", tags=["health"], summary="Service healthcheck")
async def health() -> dict[str, str]:
return {"status": "ok", "env": settings.app_env}
return app
app = create_app()

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
class JobState(str, Enum):
queued = "QUEUED"
running = "RUNNING"
completed = "COMPLETED"
failed = "FAILED"
canceled = "CANCELED"
expired = "EXPIRED"
class Constraints(BaseModel):
gpu: Optional[str] = None
cuda: Optional[str] = None
min_vram_gb: Optional[int] = None
models: Optional[list[str]] = None
region: Optional[str] = None
max_price: Optional[float] = None
class JobCreate(BaseModel):
payload: Dict[str, Any]
constraints: Constraints = Field(default_factory=Constraints)
ttl_seconds: int = 900
class JobView(BaseModel):
job_id: str
state: JobState
assigned_miner_id: Optional[str] = None
requested_at: datetime
expires_at: datetime
error: Optional[str] = None
class JobResult(BaseModel):
result: Optional[Dict[str, Any]] = None
receipt: Optional[Dict[str, Any]] = None
class MinerRegister(BaseModel):
capabilities: Dict[str, Any]
concurrency: int = 1
region: Optional[str] = None
class MinerHeartbeat(BaseModel):
inflight: int = 0
status: str = "ONLINE"
metadata: Dict[str, Any] = Field(default_factory=dict)
class PollRequest(BaseModel):
max_wait_seconds: int = 15
class AssignedJob(BaseModel):
job_id: str
payload: Dict[str, Any]
constraints: Constraints
class JobResultSubmit(BaseModel):
result: Dict[str, Any]
metrics: Dict[str, Any] = Field(default_factory=dict)
class JobFailSubmit(BaseModel):
error_code: str
error_message: str
metrics: Dict[str, Any] = Field(default_factory=dict)

View File

@@ -0,0 +1 @@
"""Router modules for the coordinator API."""

View File

@@ -0,0 +1,69 @@
from fastapi import APIRouter, Depends, HTTPException, status
from ..deps import require_admin_key
from ..services import JobService, MinerService
from ..storage import SessionDep
router = APIRouter(prefix="/admin", tags=["admin"])
@router.get("/stats", summary="Get coordinator stats")
async def get_stats(session: SessionDep, admin_key: str = Depends(require_admin_key())) -> dict[str, int]: # type: ignore[arg-type]
service = JobService(session)
from sqlmodel import func, select
from ..domain import Job
total_jobs = session.exec(select(func.count()).select_from(Job)).one()
active_jobs = session.exec(select(func.count()).select_from(Job).where(Job.state.in_(["QUEUED", "RUNNING"]))).one()
miner_service = MinerService(session)
miners = miner_service.list_records()
avg_job_duration = (
sum(miner.average_job_duration_ms for miner in miners if miner.average_job_duration_ms) / max(len(miners), 1)
)
return {
"total_jobs": int(total_jobs or 0),
"active_jobs": int(active_jobs or 0),
"online_miners": miner_service.online_count(),
"avg_miner_job_duration_ms": avg_job_duration,
}
@router.get("/jobs", summary="List jobs")
async def list_jobs(session: SessionDep, admin_key: str = Depends(require_admin_key())) -> dict[str, list[dict]]: # type: ignore[arg-type]
from ..domain import Job
jobs = session.exec(select(Job).order_by(Job.requested_at.desc()).limit(100)).all()
return {
"items": [
{
"job_id": job.id,
"state": job.state,
"client_id": job.client_id,
"assigned_miner_id": job.assigned_miner_id,
"requested_at": job.requested_at.isoformat(),
}
for job in jobs
]
}
@router.get("/miners", summary="List miners")
async def list_miners(session: SessionDep, admin_key: str = Depends(require_admin_key())) -> dict[str, list[dict]]: # type: ignore[arg-type]
miner_service = MinerService(session)
miners = [
{
"miner_id": record.miner_id,
"status": record.status,
"inflight": record.inflight,
"concurrency": record.concurrency,
"region": record.region,
"last_heartbeat": record.last_heartbeat.isoformat(),
"average_job_duration_ms": record.average_job_duration_ms,
"jobs_completed": record.jobs_completed,
"jobs_failed": record.jobs_failed,
"last_receipt_id": record.last_receipt_id,
}
for record in miner_service.list_records()
]
return {"items": miners}

View File

@@ -0,0 +1,97 @@
from fastapi import APIRouter, Depends, HTTPException, status
from ..deps import require_client_key
from ..models import JobCreate, JobView, JobResult
from ..services import JobService
from ..storage import SessionDep
router = APIRouter(tags=["client"])
@router.post("/jobs", response_model=JobView, status_code=status.HTTP_201_CREATED, summary="Submit a job")
async def submit_job(
req: JobCreate,
session: SessionDep,
client_id: str = Depends(require_client_key()),
) -> JobView: # type: ignore[arg-type]
service = JobService(session)
job = service.create_job(client_id, req)
return service.to_view(job)
@router.get("/jobs/{job_id}", response_model=JobView, summary="Get job status")
async def get_job(
job_id: str,
session: SessionDep,
client_id: str = Depends(require_client_key()),
) -> JobView: # type: ignore[arg-type]
service = JobService(session)
try:
job = service.get_job(job_id, client_id=client_id)
except KeyError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="job not found")
return service.to_view(job)
@router.get("/jobs/{job_id}/result", response_model=JobResult, summary="Get job result")
async def get_job_result(
job_id: str,
session: SessionDep,
client_id: str = Depends(require_client_key()),
) -> JobResult: # type: ignore[arg-type]
service = JobService(session)
try:
job = service.get_job(job_id, client_id=client_id)
except KeyError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="job not found")
if job.state not in {JobState.completed, JobState.failed, JobState.canceled, JobState.expired}:
raise HTTPException(status_code=status.HTTP_425_TOO_EARLY, detail="job not ready")
if job.result is None and job.receipt is None:
raise HTTPException(status_code=status.HTTP_425_TOO_EARLY, detail="job not ready")
return service.to_result(job)
@router.post("/jobs/{job_id}/cancel", response_model=JobView, summary="Cancel job")
async def cancel_job(
job_id: str,
session: SessionDep,
client_id: str = Depends(require_client_key()),
) -> JobView: # type: ignore[arg-type]
service = JobService(session)
try:
job = service.get_job(job_id, client_id=client_id)
except KeyError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="job not found")
if job.state not in {JobState.queued, JobState.running}:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="job not cancelable")
job = service.cancel_job(job)
return service.to_view(job)
@router.get("/jobs/{job_id}/receipt", summary="Get latest signed receipt")
async def get_job_receipt(
job_id: str,
session: SessionDep,
client_id: str = Depends(require_client_key()),
) -> dict: # type: ignore[arg-type]
service = JobService(session)
try:
job = service.get_job(job_id, client_id=client_id)
except KeyError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="job not found")
if not job.receipt:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="receipt not available")
return job.receipt
@router.get("/jobs/{job_id}/receipts", summary="List signed receipts")
async def list_job_receipts(
job_id: str,
session: SessionDep,
client_id: str = Depends(require_client_key()),
) -> dict: # type: ignore[arg-type]
service = JobService(session)
receipts = service.list_receipts(job_id, client_id=client_id)
return {"items": [row.payload for row in receipts]}

View File

@@ -0,0 +1,110 @@
from datetime import datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Response, status
from ..deps import require_miner_key
from ..models import AssignedJob, JobFailSubmit, JobResultSubmit, JobState, MinerHeartbeat, MinerRegister, PollRequest
from ..services import JobService, MinerService
from ..services.receipts import ReceiptService
from ..storage import SessionDep
router = APIRouter(tags=["miner"])
@router.post("/miners/register", summary="Register or update miner")
async def register(
req: MinerRegister,
session: SessionDep,
miner_id: str = Depends(require_miner_key()),
) -> dict[str, Any]: # type: ignore[arg-type]
service = MinerService(session)
record = service.register(miner_id, req)
return {"status": "ok", "session_token": record.session_token}
@router.post("/miners/heartbeat", summary="Send miner heartbeat")
async def heartbeat(
req: MinerHeartbeat,
session: SessionDep,
miner_id: str = Depends(require_miner_key()),
) -> dict[str, str]: # type: ignore[arg-type]
try:
MinerService(session).heartbeat(miner_id, req)
except KeyError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="miner not registered")
return {"status": "ok"}
# NOTE: until scheduling is fully implemented the poll endpoint performs a simple FIFO assignment.
@router.post("/miners/poll", response_model=AssignedJob, summary="Poll for next job")
async def poll(
req: PollRequest,
session: SessionDep,
miner_id: str = Depends(require_miner_key()),
) -> AssignedJob | Response: # type: ignore[arg-type]
job = MinerService(session).poll(miner_id, req.max_wait_seconds)
if job is None:
return Response(status_code=status.HTTP_204_NO_CONTENT)
return job
@router.post("/miners/{job_id}/result", summary="Submit job result")
async def submit_result(
job_id: str,
req: JobResultSubmit,
session: SessionDep,
miner_id: str = Depends(require_miner_key()),
) -> dict[str, Any]: # type: ignore[arg-type]
job_service = JobService(session)
miner_service = MinerService(session)
receipt_service = ReceiptService(session)
try:
job = job_service.get_job(job_id)
except KeyError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="job not found")
job.result = req.result
job.state = JobState.completed
job.error = None
metrics = dict(req.metrics or {})
duration_ms = metrics.get("duration_ms")
if duration_ms is None and job.requested_at:
duration_ms = int((datetime.utcnow() - job.requested_at).total_seconds() * 1000)
metrics["duration_ms"] = duration_ms
receipt = receipt_service.create_receipt(job, miner_id, req.result, metrics)
job.receipt = receipt
job.receipt_id = receipt["receipt_id"] if receipt else None
session.add(job)
session.commit()
miner_service.release(
miner_id,
success=True,
duration_ms=duration_ms,
receipt_id=receipt["receipt_id"] if receipt else None,
)
return {"status": "ok", "receipt": receipt}
@router.post("/miners/{job_id}/fail", summary="Submit job failure")
async def submit_failure(
job_id: str,
req: JobFailSubmit,
session: SessionDep,
miner_id: str = Depends(require_miner_key()),
) -> dict[str, str]: # type: ignore[arg-type]
job_service = JobService(session)
miner_service = MinerService(session)
try:
job = job_service.get_job(job_id)
except KeyError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="job not found")
job.state = JobState.failed
job.error = f"{req.error_code}: {req.error_message}"
job.assigned_miner_id = miner_id
session.add(job)
session.commit()
miner_service.release(miner_id, success=False)
return {"status": "ok"}

View File

@@ -0,0 +1,6 @@
"""Service layer for coordinator business logic."""
from .jobs import JobService
from .miners import MinerService
__all__ = ["JobService", "MinerService"]

View File

@@ -0,0 +1,156 @@
from __future__ import annotations
from datetime import datetime, timedelta
from typing import Optional
from sqlmodel import Session, select
from ..domain import Job, Miner, JobReceipt
from ..models import AssignedJob, Constraints, JobCreate, JobResult, JobState, JobView
class JobService:
def __init__(self, session: Session):
self.session = session
def create_job(self, client_id: str, req: JobCreate) -> Job:
ttl = max(req.ttl_seconds, 1)
now = datetime.utcnow()
job = Job(
client_id=client_id,
payload=req.payload,
constraints=req.constraints.model_dump(exclude_none=True),
ttl_seconds=ttl,
requested_at=now,
expires_at=now + timedelta(seconds=ttl),
)
self.session.add(job)
self.session.commit()
self.session.refresh(job)
return job
def get_job(self, job_id: str, client_id: Optional[str] = None) -> Job:
query = select(Job).where(Job.id == job_id)
if client_id:
query = query.where(Job.client_id == client_id)
job = self.session.exec(query).one_or_none()
if not job:
raise KeyError("job not found")
return self._ensure_not_expired(job)
def list_receipts(self, job_id: str, client_id: Optional[str] = None) -> list[JobReceipt]:
job = self.get_job(job_id, client_id=client_id)
receipts = self.session.exec(
select(JobReceipt)
.where(JobReceipt.job_id == job.id)
.order_by(JobReceipt.created_at.asc())
).all()
return receipts
def cancel_job(self, job: Job) -> Job:
if job.state not in {JobState.queued, JobState.running}:
return job
job.state = JobState.canceled
job.error = "canceled by client"
job.assigned_miner_id = None
self.session.add(job)
self.session.commit()
self.session.refresh(job)
return job
def to_view(self, job: Job) -> JobView:
return JobView(
job_id=job.id,
state=job.state,
assigned_miner_id=job.assigned_miner_id,
requested_at=job.requested_at,
expires_at=job.expires_at,
error=job.error,
)
def to_result(self, job: Job) -> JobResult:
return JobResult(result=job.result, receipt=job.receipt)
def to_assigned(self, job: Job) -> AssignedJob:
constraints = Constraints(**job.constraints) if isinstance(job.constraints, dict) else Constraints()
return AssignedJob(job_id=job.id, payload=job.payload, constraints=constraints)
def acquire_next_job(self, miner: Miner) -> Optional[Job]:
now = datetime.utcnow()
statement = (
select(Job)
.where(Job.state == JobState.queued)
.order_by(Job.requested_at.asc())
)
jobs = self.session.exec(statement).all()
for job in jobs:
job = self._ensure_not_expired(job)
if job.state != JobState.queued:
continue
if job.expires_at <= now:
continue
if not self._satisfies_constraints(job, miner):
continue
job.state = JobState.running
job.assigned_miner_id = miner.id
self.session.add(job)
self.session.commit()
self.session.refresh(job)
return job
return None
def _ensure_not_expired(self, job: Job) -> Job:
if job.state == JobState.queued and job.expires_at <= datetime.utcnow():
job.state = JobState.expired
job.error = "job expired"
self.session.add(job)
self.session.commit()
self.session.refresh(job)
return job
def _satisfies_constraints(self, job: Job, miner: Miner) -> bool:
if not job.constraints:
return True
constraints = Constraints(**job.constraints)
capabilities = miner.capabilities or {}
# Region matching
if constraints.region and constraints.region != miner.region:
return False
gpu_specs = capabilities.get("gpus", []) or []
has_gpu = bool(gpu_specs)
if constraints.gpu:
if not has_gpu:
return False
names = [gpu.get("name") for gpu in gpu_specs]
if constraints.gpu not in names:
return False
if constraints.min_vram_gb:
required_mb = constraints.min_vram_gb * 1024
if not any((gpu.get("memory_mb") or 0) >= required_mb for gpu in gpu_specs):
return False
if constraints.cuda:
cuda_info = capabilities.get("cuda")
if not cuda_info or constraints.cuda not in str(cuda_info):
return False
if constraints.models:
available_models = capabilities.get("models", [])
if not set(constraints.models).issubset(set(available_models)):
return False
if constraints.max_price is not None:
price = capabilities.get("price")
try:
price_value = float(price)
except (TypeError, ValueError):
return False
if price_value > constraints.max_price:
return False
return True

View File

@@ -0,0 +1,110 @@
from __future__ import annotations
from datetime import datetime
from typing import Optional
from uuid import uuid4
from sqlmodel import Session, select
from ..domain import Miner
from ..models import AssignedJob, MinerHeartbeat, MinerRegister
from .jobs import JobService
class MinerService:
def __init__(self, session: Session):
self.session = session
def register(self, miner_id: str, payload: MinerRegister) -> Miner:
miner = self.session.get(Miner, miner_id)
session_token = uuid4().hex
if miner is None:
miner = Miner(
id=miner_id,
capabilities=payload.capabilities,
concurrency=payload.concurrency,
region=payload.region,
session_token=session_token,
)
self.session.add(miner)
else:
miner.capabilities = payload.capabilities
miner.concurrency = payload.concurrency
miner.region = payload.region
miner.session_token = session_token
miner.last_heartbeat = datetime.utcnow()
miner.status = "ONLINE"
self.session.commit()
self.session.refresh(miner)
return miner
def heartbeat(self, miner_id: str, payload: MinerHeartbeat | dict) -> Miner:
if not isinstance(payload, MinerHeartbeat):
payload = MinerHeartbeat.model_validate(payload)
miner = self.session.get(Miner, miner_id)
if miner is None:
raise KeyError("miner not registered")
miner.inflight = payload.inflight
miner.status = payload.status
miner.extra_metadata = payload.metadata
miner.last_heartbeat = datetime.utcnow()
self.session.add(miner)
self.session.commit()
self.session.refresh(miner)
return miner
def poll(self, miner_id: str, max_wait_seconds: int) -> Optional[AssignedJob]:
miner = self.session.get(Miner, miner_id)
if miner is None:
raise KeyError("miner not registered")
if miner.concurrency and miner.inflight >= miner.concurrency:
return None
job_service = JobService(self.session)
job = job_service.acquire_next_job(miner)
if not job:
return None
miner.inflight += 1
miner.last_heartbeat = datetime.utcnow()
miner.last_job_at = datetime.utcnow()
self.session.add(miner)
self.session.commit()
return job_service.to_assigned(job)
def release(
self,
miner_id: str,
success: bool | None = None,
duration_ms: int | None = None,
receipt_id: str | None = None,
) -> None:
miner = self.session.get(Miner, miner_id)
if miner:
miner.inflight = max(0, miner.inflight - 1)
if success is True:
miner.jobs_completed += 1
if duration_ms is not None:
miner.total_job_duration_ms += duration_ms
miner.average_job_duration_ms = (
miner.total_job_duration_ms / max(miner.jobs_completed, 1)
)
elif success is False:
miner.jobs_failed += 1
if receipt_id:
miner.last_receipt_id = receipt_id
self.session.add(miner)
self.session.commit()
def get(self, miner_id: str) -> Miner:
miner = self.session.get(Miner, miner_id)
if miner is None:
raise KeyError("miner not registered")
return miner
def list_records(self) -> list[Miner]:
return list(self.session.exec(select(Miner)).all())
def online_count(self) -> int:
result = self.session.exec(select(Miner).where(Miner.status == "ONLINE"))
return len(result.all())

View File

@@ -0,0 +1,79 @@
from __future__ import annotations
from typing import Any, Dict, Optional
from secrets import token_hex
from datetime import datetime
from aitbc_crypto.signing import ReceiptSigner
from sqlmodel import Session
from ..config import settings
from ..domain import Job, JobReceipt
class ReceiptService:
def __init__(self, session: Session) -> None:
self.session = session
self._signer: Optional[ReceiptSigner] = None
self._attestation_signer: Optional[ReceiptSigner] = None
if settings.receipt_signing_key_hex:
key_bytes = bytes.fromhex(settings.receipt_signing_key_hex)
self._signer = ReceiptSigner(key_bytes)
if settings.receipt_attestation_key_hex:
attest_bytes = bytes.fromhex(settings.receipt_attestation_key_hex)
self._attestation_signer = ReceiptSigner(attest_bytes)
def create_receipt(
self,
job: Job,
miner_id: str,
job_result: Dict[str, Any] | None,
result_metrics: Dict[str, Any] | None,
) -> Dict[str, Any] | None:
if self._signer is None:
return None
payload = {
"version": "1.0",
"receipt_id": token_hex(16),
"job_id": job.id,
"provider": miner_id,
"client": job.client_id,
"units": _first_present([
(result_metrics or {}).get("units"),
(job_result or {}).get("units"),
], default=0.0),
"unit_type": _first_present([
(result_metrics or {}).get("unit_type"),
(job_result or {}).get("unit_type"),
], default="gpu_seconds"),
"price": _first_present([
(result_metrics or {}).get("price"),
(job_result or {}).get("price"),
]),
"started_at": int(job.requested_at.timestamp()) if job.requested_at else int(datetime.utcnow().timestamp()),
"completed_at": int(datetime.utcnow().timestamp()),
"metadata": {
"job_payload": job.payload,
"job_constraints": job.constraints,
"result": job_result,
"metrics": result_metrics,
},
}
payload["signature"] = self._signer.sign(payload)
if self._attestation_signer:
payload.setdefault("attestations", [])
attestation_payload = dict(payload)
attestation_payload.pop("attestations", None)
attestation_payload.pop("signature", None)
payload["attestations"].append(self._attestation_signer.sign(attestation_payload))
receipt_row = JobReceipt(job_id=job.id, receipt_id=payload["receipt_id"], payload=payload)
self.session.add(receipt_row)
return payload
def _first_present(values: list[Optional[Any]], default: Optional[Any] = None) -> Optional[Any]:
for value in values:
if value is not None:
return value
return default

View File

@@ -0,0 +1,5 @@
"""Persistence helpers for the coordinator API."""
from .db import SessionDep, get_session, init_db
__all__ = ["SessionDep", "get_session", "init_db"]

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import Annotated, Generator
from fastapi import Depends
from sqlalchemy.engine import Engine
from sqlmodel import Session, SQLModel, create_engine
from ..config import settings
from ..domain import Job, Miner
_engine: Engine | None = None
def get_engine() -> Engine:
global _engine
if _engine is None:
connect_args = {"check_same_thread": False} if settings.database_url.startswith("sqlite") else {}
_engine = create_engine(settings.database_url, echo=False, connect_args=connect_args)
return _engine
def init_db() -> None:
engine = get_engine()
SQLModel.metadata.create_all(engine)
@contextmanager
def session_scope() -> Generator[Session, None, None]:
engine = get_engine()
with Session(engine) as session:
yield session
def get_session() -> Generator[Session, None, None]:
with session_scope() as session:
yield session
SessionDep = Annotated[Session, Depends(get_session)]