feat: add foreign key constraints and metrics for blockchain node

This commit is contained in:
oib
2025-09-28 06:04:30 +02:00
parent c1926136fb
commit fb60505cdf
189 changed files with 15678 additions and 158 deletions

View File

@ -0,0 +1,47 @@
from __future__ import annotations
import asyncio
from logging.config import fileConfig
from alembic import context
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from poolhub.models import Base
from poolhub.settings import settings
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = Base.metadata
def _configure_context(connection=None, *, url: str | None = None) -> None:
context.configure(
connection=connection,
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
def run_migrations_offline() -> None:
_configure_context(url=settings.postgres_dsn)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
connectable = create_async_engine(settings.postgres_dsn, pool_pre_ping=True)
async with connectable.connect() as connection:
await connection.run_sync(_configure_context)
await connection.run_sync(lambda conn: context.run_migrations())
await connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())

View File

@ -0,0 +1,104 @@
"""initial schema
Revision ID: a58c1f3b3e87
Revises:
Create Date: 2025-09-27 12:07:40.000000
"""
from __future__ import annotations
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a58c1f3b3e87"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"miners",
sa.Column("miner_id", sa.String(length=64), primary_key=True),
sa.Column("api_key_hash", sa.String(length=128), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
sa.Column("last_seen_at", sa.DateTime(timezone=True)),
sa.Column("addr", sa.String(length=256)),
sa.Column("proto", sa.String(length=32)),
sa.Column("gpu_vram_gb", sa.Float()),
sa.Column("gpu_name", sa.String(length=128)),
sa.Column("cpu_cores", sa.Integer()),
sa.Column("ram_gb", sa.Float()),
sa.Column("max_parallel", sa.Integer()),
sa.Column("base_price", sa.Float()),
sa.Column("tags", postgresql.JSONB(astext_type=sa.Text())),
sa.Column("capabilities", postgresql.JSONB(astext_type=sa.Text())),
sa.Column("trust_score", sa.Float(), server_default="0.5"),
sa.Column("region", sa.String(length=64)),
)
op.create_table(
"miner_status",
sa.Column("miner_id", sa.String(length=64), sa.ForeignKey("miners.miner_id", ondelete="CASCADE"), primary_key=True),
sa.Column("queue_len", sa.Integer(), server_default="0"),
sa.Column("busy", sa.Boolean(), server_default=sa.text("false")),
sa.Column("avg_latency_ms", sa.Integer()),
sa.Column("temp_c", sa.Integer()),
sa.Column("mem_free_gb", sa.Float()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
)
op.create_table(
"match_requests",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("job_id", sa.String(length=64), nullable=False),
sa.Column("requirements", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column("hints", postgresql.JSONB(astext_type=sa.Text()), server_default=sa.text("'{}'::jsonb")),
sa.Column("top_k", sa.Integer(), server_default="1"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
)
op.create_table(
"match_results",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("request_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("match_requests.id", ondelete="CASCADE"), nullable=False),
sa.Column("miner_id", sa.String(length=64), nullable=False),
sa.Column("score", sa.Float(), nullable=False),
sa.Column("explain", sa.Text()),
sa.Column("eta_ms", sa.Integer()),
sa.Column("price", sa.Float()),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
)
op.create_index("ix_match_results_request_id", "match_results", ["request_id"])
op.create_table(
"feedback",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("job_id", sa.String(length=64), nullable=False),
sa.Column("miner_id", sa.String(length=64), sa.ForeignKey("miners.miner_id", ondelete="CASCADE"), nullable=False),
sa.Column("outcome", sa.String(length=32), nullable=False),
sa.Column("latency_ms", sa.Integer()),
sa.Column("fail_code", sa.String(length=64)),
sa.Column("tokens_spent", sa.Float()),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
)
op.create_index("ix_feedback_miner_id", "feedback", ["miner_id"])
op.create_index("ix_feedback_job_id", "feedback", ["job_id"])
def downgrade() -> None:
op.drop_index("ix_feedback_job_id", table_name="feedback")
op.drop_index("ix_feedback_miner_id", table_name="feedback")
op.drop_table("feedback")
op.drop_index("ix_match_results_request_id", table_name="match_results")
op.drop_table("match_results")
op.drop_table("match_requests")
op.drop_table("miner_status")
op.drop_table("miners")

View File

@ -0,0 +1,13 @@
"""AITBC Pool Hub service package."""
from .settings import Settings, settings
from .database import create_engine, get_session
from .redis_cache import get_redis
__all__ = [
"Settings",
"settings",
"create_engine",
"get_session",
"get_redis",
]

View File

@ -0,0 +1,5 @@
"""FastAPI application wiring for the AITBC Pool Hub."""
from .main import create_app, app
__all__ = ["create_app", "app"]

View File

@ -0,0 +1,27 @@
from __future__ import annotations
from typing import AsyncGenerator
from fastapi import Depends
from ..database import get_session
from ..redis_cache import get_redis
def get_db_session() -> AsyncGenerator:
return get_session()
def get_redis_client() -> AsyncGenerator:
return get_redis()
# FastAPI dependency wrappers
async def db_session_dep(session=Depends(get_session)):
async for s in session:
yield s
async def redis_dep(client=Depends(get_redis)):
async for c in client:
yield c

View File

@ -0,0 +1,31 @@
from __future__ import annotations
from contextlib import asynccontextmanager
from fastapi import FastAPI
from ..database import close_engine, create_engine
from ..redis_cache import close_redis, create_redis
from ..settings import settings
from .routers import health_router, match_router, metrics_router
@asynccontextmanager
async def lifespan(_: FastAPI):
create_engine()
create_redis()
try:
yield
finally:
await close_engine()
await close_redis()
app = FastAPI(**settings.asgi_kwargs(), lifespan=lifespan)
app.include_router(match_router, prefix="/v1")
app.include_router(health_router)
app.include_router(metrics_router)
def create_app() -> FastAPI:
return app

View File

@ -0,0 +1,39 @@
from __future__ import annotations
from prometheus_client import CONTENT_TYPE_LATEST, Counter, Gauge, Histogram, generate_latest
match_requests_total = Counter(
"poolhub_match_requests_total",
"Total number of match requests received",
)
match_candidates_returned = Counter(
"poolhub_match_candidates_total",
"Total number of candidates returned",
)
match_failures_total = Counter(
"poolhub_match_failures_total",
"Total number of match request failures",
)
match_latency_seconds = Histogram(
"poolhub_match_latency_seconds",
"Latency of match processing",
buckets=(0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0),
)
miners_online_gauge = Gauge(
"poolhub_miners_online",
"Number of miners considered online",
)
def render_metrics() -> tuple[str, str]:
return generate_latest(), CONTENT_TYPE_LATEST
def reset_metrics() -> None:
match_requests_total._value.set(0) # type: ignore[attr-defined]
match_candidates_returned._value.set(0) # type: ignore[attr-defined]
match_failures_total._value.set(0) # type: ignore[attr-defined]
match_latency_seconds._sum.set(0) # type: ignore[attr-defined]
match_latency_seconds._count.set(0) # type: ignore[attr-defined]
match_latency_seconds._samples = [] # type: ignore[attr-defined]
miners_online_gauge._value.set(0) # type: ignore[attr-defined]

View File

@ -0,0 +1,7 @@
"""FastAPI routers for Pool Hub."""
from .match import router as match_router
from .health import router as health_router
from .metrics import router as metrics_router
__all__ = ["match_router", "health_router", "metrics_router"]

View File

@ -0,0 +1,50 @@
from __future__ import annotations
from fastapi import APIRouter, Depends
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from ..deps import db_session_dep, redis_dep
from ..prometheus import miners_online_gauge
from poolhub.repositories.miner_repository import MinerRepository
from ..schemas import HealthResponse
router = APIRouter(tags=["health"], prefix="/v1")
@router.get("/health", response_model=HealthResponse, summary="Pool Hub health status")
async def health_endpoint(
session: AsyncSession = Depends(db_session_dep),
redis: Redis = Depends(redis_dep),
) -> HealthResponse:
db_ok = True
redis_ok = True
db_error: str | None = None
redis_error: str | None = None
try:
await session.execute("SELECT 1")
except Exception as exc: # pragma: no cover
db_ok = False
db_error = str(exc)
try:
await redis.ping()
except Exception as exc: # pragma: no cover
redis_ok = False
redis_error = str(exc)
miner_repo = MinerRepository(session, redis)
active_miners = await miner_repo.list_active_miners()
miners_online = len(active_miners)
miners_online_gauge.set(miners_online)
status = "ok" if db_ok and redis_ok else "degraded"
return HealthResponse(
status=status,
db=db_ok,
redis=redis_ok,
miners_online=miners_online,
db_error=db_error,
redis_error=redis_error,
)

View File

@ -0,0 +1,116 @@
from __future__ import annotations
import time
from typing import Any, Dict, List
from fastapi import APIRouter, Depends, HTTPException, status
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from ..deps import db_session_dep, redis_dep
from ..prometheus import (
match_candidates_returned,
match_failures_total,
match_latency_seconds,
match_requests_total,
)
from poolhub.repositories.match_repository import MatchRepository
from poolhub.repositories.miner_repository import MinerRepository
from ..schemas import MatchCandidate, MatchRequestPayload, MatchResponse
router = APIRouter(tags=["match"])
def _normalize_requirements(requirements: Dict[str, Any]) -> Dict[str, Any]:
return requirements or {}
def _candidate_from_payload(payload: Dict[str, Any]) -> MatchCandidate:
return MatchCandidate(**payload)
@router.post("/match", response_model=MatchResponse, summary="Find top miners for a job")
async def match_endpoint(
payload: MatchRequestPayload,
session: AsyncSession = Depends(db_session_dep),
redis: Redis = Depends(redis_dep),
) -> MatchResponse:
start = time.perf_counter()
match_requests_total.inc()
miner_repo = MinerRepository(session, redis)
match_repo = MatchRepository(session, redis)
requirements = _normalize_requirements(payload.requirements)
top_k = payload.top_k
try:
request = await match_repo.create_request(
job_id=payload.job_id,
requirements=requirements,
hints=payload.hints,
top_k=top_k,
)
active_miners = await miner_repo.list_active_miners()
candidates = _select_candidates(requirements, payload.hints, active_miners, top_k)
await match_repo.add_results(
request_id=request.id,
candidates=candidates,
)
match_candidates_returned.inc(len(candidates))
duration = time.perf_counter() - start
match_latency_seconds.observe(duration)
return MatchResponse(
job_id=payload.job_id,
candidates=[_candidate_from_payload(candidate) for candidate in candidates],
)
except Exception as exc: # pragma: no cover - safeguards unexpected failures
match_failures_total.inc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="match_failed") from exc
def _select_candidates(
requirements: Dict[str, Any],
hints: Dict[str, Any],
active_miners: List[tuple],
top_k: int,
) -> List[Dict[str, Any]]:
min_vram = float(requirements.get("min_vram_gb", 0))
min_ram = float(requirements.get("min_ram_gb", 0))
capabilities_required = set(requirements.get("capabilities_any", []))
region_hint = hints.get("region")
ranked: List[Dict[str, Any]] = []
for miner, status, score in active_miners:
if miner.gpu_vram_gb and miner.gpu_vram_gb < min_vram:
continue
if miner.ram_gb and miner.ram_gb < min_ram:
continue
if capabilities_required and not capabilities_required.issubset(set(miner.capabilities or [])):
continue
if region_hint and miner.region and miner.region != region_hint:
continue
candidate = {
"miner_id": miner.miner_id,
"addr": miner.addr,
"proto": miner.proto,
"score": float(score),
"explain": _compose_explain(score, miner, status),
"eta_ms": status.avg_latency_ms if status else None,
"price": miner.base_price,
}
ranked.append(candidate)
ranked.sort(key=lambda item: item["score"], reverse=True)
return ranked[:top_k]
def _compose_explain(score: float, miner, status) -> str:
load = status.queue_len if status else 0
latency = status.avg_latency_ms if status else "n/a"
return f"score={score:.3f} load={load} latency={latency}"

View File

@ -0,0 +1,13 @@
from __future__ import annotations
from fastapi import APIRouter, Response
from ..prometheus import render_metrics
router = APIRouter(tags=["metrics"])
@router.get("/metrics", summary="Prometheus metrics")
async def metrics_endpoint() -> Response:
payload, content_type = render_metrics()
return Response(content=payload, media_type=content_type)

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class MatchRequestPayload(BaseModel):
job_id: str
requirements: Dict[str, Any] = Field(default_factory=dict)
hints: Dict[str, Any] = Field(default_factory=dict)
top_k: int = Field(default=1, ge=1, le=50)
class MatchCandidate(BaseModel):
miner_id: str
addr: str
proto: str
score: float
explain: Optional[str] = None
eta_ms: Optional[int] = None
price: Optional[float] = None
class MatchResponse(BaseModel):
job_id: str
candidates: List[MatchCandidate]
class HealthResponse(BaseModel):
status: str
db: bool
redis: bool
miners_online: int
db_error: Optional[str] = None
redis_error: Optional[str] = None
class MetricsResponse(BaseModel):
detail: str = "Prometheus metrics output"

View File

@ -0,0 +1,54 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from .settings import settings
_engine: AsyncEngine | None = None
_session_factory: async_sessionmaker[AsyncSession] | None = None
def create_engine() -> AsyncEngine:
global _engine, _session_factory
if _engine is None:
_engine = create_async_engine(
settings.postgres_dsn,
pool_size=settings.postgres_pool_max,
max_overflow=0,
pool_pre_ping=True,
)
_session_factory = async_sessionmaker(
bind=_engine,
expire_on_commit=False,
autoflush=False,
)
return _engine
def get_engine() -> AsyncEngine:
if _engine is None:
return create_engine()
return _engine
def get_session_factory() -> async_sessionmaker[AsyncSession]:
if _session_factory is None:
create_engine()
assert _session_factory is not None
return _session_factory
async def get_session() -> AsyncGenerator[AsyncSession, None]:
session_factory = get_session_factory()
async with session_factory() as session:
yield session
async def close_engine() -> None:
global _engine
if _engine is not None:
await _engine.dispose()
_engine = None

View File

@ -0,0 +1,95 @@
from __future__ import annotations
import datetime as dt
from typing import Dict, List, Optional
from sqlalchemy import Boolean, Column, DateTime, Float, ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from uuid import uuid4
class Base(DeclarativeBase):
pass
class Miner(Base):
__tablename__ = "miners"
miner_id: Mapped[str] = mapped_column(String(64), primary_key=True)
api_key_hash: Mapped[str] = mapped_column(String(128), nullable=False)
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow)
last_seen_at: Mapped[Optional[dt.datetime]] = mapped_column(DateTime(timezone=True))
addr: Mapped[str] = mapped_column(String(256))
proto: Mapped[str] = mapped_column(String(32))
gpu_vram_gb: Mapped[float] = mapped_column(Float)
gpu_name: Mapped[Optional[str]] = mapped_column(String(128))
cpu_cores: Mapped[int] = mapped_column(Integer)
ram_gb: Mapped[float] = mapped_column(Float)
max_parallel: Mapped[int] = mapped_column(Integer)
base_price: Mapped[float] = mapped_column(Float)
tags: Mapped[Dict[str, str]] = mapped_column(JSONB, default=dict)
capabilities: Mapped[List[str]] = mapped_column(JSONB, default=list)
trust_score: Mapped[float] = mapped_column(Float, default=0.5)
region: Mapped[Optional[str]] = mapped_column(String(64))
status: Mapped["MinerStatus"] = relationship(back_populates="miner", cascade="all, delete-orphan", uselist=False)
feedback: Mapped[List["Feedback"]] = relationship(back_populates="miner", cascade="all, delete-orphan")
class MinerStatus(Base):
__tablename__ = "miner_status"
miner_id: Mapped[str] = mapped_column(ForeignKey("miners.miner_id", ondelete="CASCADE"), primary_key=True)
queue_len: Mapped[int] = mapped_column(Integer, default=0)
busy: Mapped[bool] = mapped_column(Boolean, default=False)
avg_latency_ms: Mapped[Optional[int]] = mapped_column(Integer)
temp_c: Mapped[Optional[int]] = mapped_column(Integer)
mem_free_gb: Mapped[Optional[float]] = mapped_column(Float)
updated_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow, onupdate=dt.datetime.utcnow)
miner: Mapped[Miner] = relationship(back_populates="status")
class MatchRequest(Base):
__tablename__ = "match_requests"
id: Mapped[PGUUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4)
job_id: Mapped[str] = mapped_column(String(64), nullable=False)
requirements: Mapped[Dict[str, object]] = mapped_column(JSONB, nullable=False)
hints: Mapped[Dict[str, object]] = mapped_column(JSONB, default=dict)
top_k: Mapped[int] = mapped_column(Integer, default=1)
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow)
results: Mapped[List["MatchResult"]] = relationship(back_populates="request", cascade="all, delete-orphan")
class MatchResult(Base):
__tablename__ = "match_results"
id: Mapped[PGUUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4)
request_id: Mapped[PGUUID] = mapped_column(ForeignKey("match_requests.id", ondelete="CASCADE"), index=True)
miner_id: Mapped[str] = mapped_column(String(64))
score: Mapped[float] = mapped_column(Float)
explain: Mapped[Optional[str]] = mapped_column(Text)
eta_ms: Mapped[Optional[int]] = mapped_column(Integer)
price: Mapped[Optional[float]] = mapped_column(Float)
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow)
request: Mapped[MatchRequest] = relationship(back_populates="results")
class Feedback(Base):
__tablename__ = "feedback"
id: Mapped[PGUUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4)
job_id: Mapped[str] = mapped_column(String(64), nullable=False)
miner_id: Mapped[str] = mapped_column(ForeignKey("miners.miner_id", ondelete="CASCADE"), nullable=False)
outcome: Mapped[str] = mapped_column(String(32), nullable=False)
latency_ms: Mapped[Optional[int]] = mapped_column(Integer)
fail_code: Mapped[Optional[str]] = mapped_column(String(64))
tokens_spent: Mapped[Optional[float]] = mapped_column(Float)
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow)
miner: Mapped[Miner] = relationship(back_populates="feedback")

View File

@ -0,0 +1,39 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
import redis.asyncio as redis
from .settings import settings
_redis_client: redis.Redis | None = None
def create_redis() -> redis.Redis:
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(
settings.redis_url,
max_connections=settings.redis_max_connections,
encoding="utf-8",
decode_responses=True,
)
return _redis_client
def get_redis_client() -> redis.Redis:
if _redis_client is None:
return create_redis()
return _redis_client
async def get_redis() -> AsyncGenerator[redis.Redis, None]:
client = get_redis_client()
yield client
async def close_redis() -> None:
global _redis_client
if _redis_client is not None:
await _redis_client.close()
_redis_client = None

View File

@ -0,0 +1,11 @@
"""Repository layer for Pool Hub."""
from .miner_repository import MinerRepository
from .match_repository import MatchRepository
from .feedback_repository import FeedbackRepository
__all__ = [
"MinerRepository",
"MatchRepository",
"FeedbackRepository",
]

View File

@ -0,0 +1,81 @@
from __future__ import annotations
import datetime as dt
import json
import logging
from typing import Iterable, List, Optional
from uuid import UUID
from redis.asyncio import Redis
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Feedback
from ..storage.redis_keys import RedisKeys
logger = logging.getLogger(__name__)
class FeedbackRepository:
"""Persists coordinator feedback and emits Redis notifications."""
def __init__(self, session: AsyncSession, redis: Redis) -> None:
self._session = session
self._redis = redis
async def add_feedback(
self,
*,
job_id: str,
miner_id: str,
outcome: str,
latency_ms: Optional[int] = None,
fail_code: Optional[str] = None,
tokens_spent: Optional[float] = None,
) -> Feedback:
feedback = Feedback(
job_id=job_id,
miner_id=miner_id,
outcome=outcome,
latency_ms=latency_ms,
fail_code=fail_code,
tokens_spent=tokens_spent,
created_at=dt.datetime.utcnow(),
)
self._session.add(feedback)
await self._session.flush()
payload = {
"job_id": job_id,
"miner_id": miner_id,
"outcome": outcome,
"latency_ms": latency_ms,
"fail_code": fail_code,
"tokens_spent": tokens_spent,
"created_at": feedback.created_at.isoformat() if feedback.created_at else None,
}
try:
await self._redis.publish(RedisKeys.feedback_channel(), json.dumps(payload))
except Exception as exc: # pragma: no cover - defensive
logger.warning("Failed to publish feedback event for job %s: %s", job_id, exc)
return feedback
async def list_feedback_for_miner(self, miner_id: str, limit: int = 50) -> List[Feedback]:
stmt = (
select(Feedback)
.where(Feedback.miner_id == miner_id)
.order_by(Feedback.created_at.desc())
.limit(limit)
)
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def list_feedback_for_job(self, job_id: str, limit: int = 50) -> List[Feedback]:
stmt = (
select(Feedback)
.where(Feedback.job_id == job_id)
.order_by(Feedback.created_at.desc())
.limit(limit)
)
result = await self._session.execute(stmt)
return list(result.scalars().all())

View File

@ -0,0 +1,122 @@
from __future__ import annotations
import datetime as dt
import json
from typing import Iterable, List, Optional, Sequence
from uuid import UUID
from redis.asyncio import Redis
from sqlalchemy import Select, select
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import MatchRequest, MatchResult
from ..storage.redis_keys import RedisKeys
class MatchRepository:
"""Handles match request logging, result persistence, and Redis fan-out."""
def __init__(self, session: AsyncSession, redis: Redis) -> None:
self._session = session
self._redis = redis
async def create_request(
self,
*,
job_id: str,
requirements: dict[str, object],
hints: Optional[dict[str, object]] = None,
top_k: int = 1,
enqueue: bool = True,
) -> MatchRequest:
request = MatchRequest(
job_id=job_id,
requirements=requirements,
hints=hints or {},
top_k=top_k,
created_at=dt.datetime.utcnow(),
)
self._session.add(request)
await self._session.flush()
if enqueue:
payload = {
"request_id": str(request.id),
"job_id": request.job_id,
"requirements": request.requirements,
"hints": request.hints,
"top_k": request.top_k,
}
await self._redis.rpush(RedisKeys.match_requests(), json.dumps(payload))
return request
async def add_results(
self,
*,
request_id: UUID,
candidates: Sequence[dict[str, object]],
publish: bool = True,
) -> List[MatchResult]:
results: List[MatchResult] = []
created_at = dt.datetime.utcnow()
for candidate in candidates:
result = MatchResult(
request_id=request_id,
miner_id=str(candidate.get("miner_id")),
score=float(candidate.get("score", 0.0)),
explain=candidate.get("explain"),
eta_ms=candidate.get("eta_ms"),
price=candidate.get("price"),
created_at=created_at,
)
self._session.add(result)
results.append(result)
await self._session.flush()
if publish:
request = await self._session.get(MatchRequest, request_id)
if request:
redis_key = RedisKeys.match_results(request.job_id)
await self._redis.delete(redis_key)
if results:
payloads = [json.dumps(self._result_payload(result)) for result in results]
await self._redis.rpush(redis_key, *payloads)
await self._redis.expire(redis_key, 300)
channel = RedisKeys.match_results_channel(request.job_id)
for payload in payloads:
await self._redis.publish(channel, payload)
return results
async def get_request(self, request_id: UUID) -> Optional[MatchRequest]:
return await self._session.get(MatchRequest, request_id)
async def list_recent_requests(self, limit: int = 20) -> List[MatchRequest]:
stmt: Select[MatchRequest] = (
select(MatchRequest)
.order_by(MatchRequest.created_at.desc())
.limit(limit)
)
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def list_results_for_job(self, job_id: str, limit: int = 10) -> List[MatchResult]:
stmt: Select[MatchResult] = (
select(MatchResult)
.join(MatchRequest)
.where(MatchRequest.job_id == job_id)
.order_by(MatchResult.created_at.desc())
.limit(limit)
)
result = await self._session.execute(stmt)
return list(result.scalars().all())
def _result_payload(self, result: MatchResult) -> dict[str, object]:
return {
"request_id": str(result.request_id),
"miner_id": result.miner_id,
"score": result.score,
"explain": result.explain,
"eta_ms": result.eta_ms,
"price": result.price,
"created_at": result.created_at.isoformat() if result.created_at else None,
}

View File

@ -0,0 +1,181 @@
from __future__ import annotations
import datetime as dt
from typing import List, Optional, Tuple
from redis.asyncio import Redis
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Miner, MinerStatus
from ..settings import settings
from ..storage.redis_keys import RedisKeys
class MinerRepository:
"""Coordinates miner registry persistence across PostgreSQL and Redis."""
def __init__(self, session: AsyncSession, redis: Redis) -> None:
self._session = session
self._redis = redis
async def register_miner(
self,
miner_id: str,
api_key_hash: str,
*,
addr: str,
proto: str,
gpu_vram_gb: float,
gpu_name: Optional[str],
cpu_cores: int,
ram_gb: float,
max_parallel: int,
base_price: float,
tags: dict[str, str],
capabilities: list[str],
region: Optional[str],
) -> Miner:
miner = await self._session.get(Miner, miner_id)
if miner is None:
miner = Miner(
miner_id=miner_id,
api_key_hash=api_key_hash,
addr=addr,
proto=proto,
gpu_vram_gb=gpu_vram_gb,
gpu_name=gpu_name,
cpu_cores=cpu_cores,
ram_gb=ram_gb,
max_parallel=max_parallel,
base_price=base_price,
tags=tags,
capabilities=capabilities,
region=region,
)
self._session.add(miner)
status = MinerStatus(miner_id=miner_id)
self._session.add(status)
else:
miner.addr = addr
miner.proto = proto
miner.gpu_vram_gb = gpu_vram_gb
miner.gpu_name = gpu_name
miner.cpu_cores = cpu_cores
miner.ram_gb = ram_gb
miner.max_parallel = max_parallel
miner.base_price = base_price
miner.tags = tags
miner.capabilities = capabilities
miner.region = region
miner.last_seen_at = dt.datetime.utcnow()
await self._session.flush()
await self._sync_miner_to_redis(miner_id)
return miner
async def update_status(
self,
miner_id: str,
*,
queue_len: Optional[int] = None,
busy: Optional[bool] = None,
avg_latency_ms: Optional[int] = None,
temp_c: Optional[int] = None,
mem_free_gb: Optional[float] = None,
) -> None:
stmt = (
update(MinerStatus)
.where(MinerStatus.miner_id == miner_id)
.values(
{
k: v
for k, v in {
"queue_len": queue_len,
"busy": busy,
"avg_latency_ms": avg_latency_ms,
"temp_c": temp_c,
"mem_free_gb": mem_free_gb,
"updated_at": dt.datetime.utcnow(),
}.items()
if v is not None
}
)
)
await self._session.execute(stmt)
miner = await self._session.get(Miner, miner_id)
if miner:
miner.last_seen_at = dt.datetime.utcnow()
await self._session.flush()
await self._sync_miner_to_redis(miner_id)
async def touch_heartbeat(self, miner_id: str) -> None:
miner = await self._session.get(Miner, miner_id)
if miner is None:
return
miner.last_seen_at = dt.datetime.utcnow()
await self._session.flush()
await self._sync_miner_to_redis(miner_id)
async def get_miner(self, miner_id: str) -> Optional[Miner]:
return await self._session.get(Miner, miner_id)
async def iter_miners(self) -> List[Miner]:
result = await self._session.execute(select(Miner))
return list(result.scalars().all())
async def get_status(self, miner_id: str) -> Optional[MinerStatus]:
return await self._session.get(MinerStatus, miner_id)
async def list_active_miners(self) -> List[Tuple[Miner, Optional[MinerStatus], float]]:
stmt = select(Miner, MinerStatus).join(MinerStatus, MinerStatus.miner_id == Miner.miner_id, isouter=True)
result = await self._session.execute(stmt)
records: List[Tuple[Miner, Optional[MinerStatus], float]] = []
for miner, status in result.all():
score = self._compute_score(miner, status)
records.append((miner, status, score))
return records
async def _sync_miner_to_redis(self, miner_id: str) -> None:
miner = await self._session.get(Miner, miner_id)
if miner is None:
return
status = await self._session.get(MinerStatus, miner_id)
payload = {
"miner_id": miner.miner_id,
"addr": miner.addr,
"proto": miner.proto,
"region": miner.region or "",
"gpu_vram_gb": str(miner.gpu_vram_gb),
"ram_gb": str(miner.ram_gb),
"max_parallel": str(miner.max_parallel),
"base_price": str(miner.base_price),
"trust_score": str(miner.trust_score),
"queue_len": str(status.queue_len if status else 0),
"busy": str(status.busy if status else False),
}
redis_key = RedisKeys.miner_hash(miner_id)
await self._redis.hset(redis_key, mapping=payload)
await self._redis.expire(redis_key, settings.session_ttl_seconds + settings.heartbeat_grace_seconds)
score = self._compute_score(miner, status)
ranking_key = RedisKeys.miner_rankings(miner.region)
await self._redis.zadd(ranking_key, {miner_id: score})
await self._redis.expire(ranking_key, settings.session_ttl_seconds + settings.heartbeat_grace_seconds)
def _compute_score(self, miner: Miner, status: Optional[MinerStatus]) -> float:
load_factor = 1.0
if status and miner.max_parallel:
utilization = min(status.queue_len / max(miner.max_parallel, 1), 1.0)
load_factor = 1.0 - utilization
price_factor = 1.0 if miner.base_price <= 0 else min(1.0, 1.0 / miner.base_price)
trust_factor = max(miner.trust_score, 0.0)
return (settings.default_score_weights.capability * 1.0) + (
settings.default_score_weights.price * price_factor
) + (settings.default_score_weights.load * load_factor) + (
settings.default_score_weights.trust * trust_factor
)

View File

@ -0,0 +1,59 @@
from __future__ import annotations
from functools import lru_cache
from typing import Any, Dict, List
from pydantic import AnyHttpUrl, BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class ScoreWeights(BaseModel):
capability: float = Field(default=0.40, alias="cap")
price: float = Field(default=0.20)
latency: float = Field(default=0.20)
trust: float = Field(default=0.15)
load: float = Field(default=0.05)
model_config = SettingsConfigDict(populate_by_name=True)
def as_vector(self) -> List[float]:
return [self.capability, self.price, self.latency, self.trust, self.load]
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="poolhub_", env_file=".env", case_sensitive=False)
app_name: str = "AITBC Pool Hub"
bind_host: str = Field(default="127.0.0.1")
bind_port: int = Field(default=8203)
coordinator_shared_secret: str = Field(default="changeme")
postgres_dsn: str = Field(default="postgresql+asyncpg://poolhub:poolhub@127.0.0.1:5432/aitbc")
postgres_pool_min: int = Field(default=1)
postgres_pool_max: int = Field(default=10)
redis_url: str = Field(default="redis://127.0.0.1:6379/4")
redis_max_connections: int = Field(default=32)
session_ttl_seconds: int = Field(default=60)
heartbeat_grace_seconds: int = Field(default=120)
default_score_weights: ScoreWeights = Field(default_factory=ScoreWeights)
allowed_origins: List[AnyHttpUrl] = Field(default_factory=list)
prometheus_namespace: str = Field(default="poolhub")
def asgi_kwargs(self) -> Dict[str, Any]:
return {
"title": self.app_name,
}
@lru_cache(maxsize=1)
def get_settings() -> Settings:
return Settings()
settings = get_settings()

View File

@ -0,0 +1,5 @@
"""Storage utilities for the Pool Hub service."""
from .redis_keys import RedisKeys
__all__ = ["RedisKeys"]

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from typing import Final
class RedisKeys:
namespace: Final[str] = "poolhub"
@classmethod
def miner_hash(cls, miner_id: str) -> str:
return f"{cls.namespace}:miner:{miner_id}"
@classmethod
def miner_rankings(cls, region: str | None = None) -> str:
suffix = region or "global"
return f"{cls.namespace}:rankings:{suffix}"
@classmethod
def miner_session(cls, session_token: str) -> str:
return f"{cls.namespace}:session:{session_token}"
@classmethod
def heartbeat_stream(cls) -> str:
return f"{cls.namespace}:heartbeat-stream"
@classmethod
def match_requests(cls) -> str:
return f"{cls.namespace}:match-requests"
@classmethod
def match_results(cls, job_id: str) -> str:
return f"{cls.namespace}:match-results:{job_id}"
@classmethod
def feedback_channel(cls) -> str:
return f"{cls.namespace}:events:feedback"
@classmethod
def match_results_channel(cls, job_id: str) -> str:
return f"{cls.namespace}:events:match-results:{job_id}"

View File

@ -0,0 +1,63 @@
from __future__ import annotations
import os
import sys
from pathlib import Path
import pytest
import pytest_asyncio
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
BASE_DIR = Path(__file__).resolve().parents[2]
POOLHUB_SRC = BASE_DIR / "pool-hub" / "src"
if str(POOLHUB_SRC) not in sys.path:
sys.path.insert(0, str(POOLHUB_SRC))
from poolhub.models import Base
def _get_required_env(name: str) -> str:
value = os.getenv(name)
if not value:
pytest.skip(f"Set {name} to run Pool Hub integration tests")
return value
@pytest_asyncio.fixture()
async def db_engine() -> AsyncEngine:
dsn = _get_required_env("POOLHUB_TEST_POSTGRES_DSN")
engine = create_async_engine(dsn, pool_pre_ping=True)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
yield engine
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest_asyncio.fixture
async def db_session(db_engine: AsyncEngine) -> AsyncSession:
session_factory = async_sessionmaker(db_engine, expire_on_commit=False, autoflush=False)
async with session_factory() as session:
yield session
await session.rollback()
@pytest_asyncio.fixture()
async def redis_client() -> Redis:
redis_url = _get_required_env("POOLHUB_TEST_REDIS_URL")
client = Redis.from_url(redis_url, encoding="utf-8", decode_responses=True)
await client.flushdb()
yield client
await client.flushdb()
await client.close()
@pytest_asyncio.fixture(autouse=True)
async def _clear_redis(redis_client: Redis) -> None:
await redis_client.flushdb()

View File

@ -0,0 +1,153 @@
from __future__ import annotations
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import async_sessionmaker
from poolhub.app import deps
from poolhub.app.main import create_app
from poolhub.app.prometheus import reset_metrics
from poolhub.repositories.miner_repository import MinerRepository
@pytest_asyncio.fixture()
async def async_client(db_engine, redis_client): # noqa: F811
async def _session_override():
factory = async_sessionmaker(db_engine, expire_on_commit=False, autoflush=False)
async with factory() as session:
yield session
async def _redis_override():
yield redis_client
app = create_app()
app.dependency_overrides.clear()
app.dependency_overrides[deps.db_session_dep] = _session_override
app.dependency_overrides[deps.redis_dep] = _redis_override
reset_metrics()
async with AsyncClient(app=app, base_url="http://testserver") as client:
yield client
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_match_endpoint(async_client, db_session, redis_client): # noqa: F811
repo = MinerRepository(db_session, redis_client)
await repo.register_miner(
miner_id="miner-1",
api_key_hash="hash",
addr="127.0.0.1",
proto="grpc",
gpu_vram_gb=16,
gpu_name="A100",
cpu_cores=32,
ram_gb=128,
max_parallel=4,
base_price=0.8,
tags={"tier": "gold"},
capabilities=["embedding"],
region="eu",
)
await db_session.commit()
response = await async_client.post(
"/v1/match",
json={
"job_id": "job-123",
"requirements": {"min_vram_gb": 8},
"hints": {"region": "eu"},
"top_k": 1,
},
)
assert response.status_code == 200
payload = response.json()
assert payload["job_id"] == "job-123"
assert len(payload["candidates"]) == 1
@pytest.mark.asyncio
async def test_match_endpoint_no_miners(async_client):
response = await async_client.post(
"/v1/match",
json={"job_id": "empty", "requirements": {}, "hints": {}, "top_k": 2},
)
assert response.status_code == 200
payload = response.json()
assert payload["candidates"] == []
@pytest.mark.asyncio
async def test_health_endpoint(async_client): # noqa: F811
response = await async_client.get("/v1/health")
assert response.status_code == 200
data = response.json()
assert data["status"] in {"ok", "degraded"}
assert "db_error" in data
assert "redis_error" in data
@pytest.mark.asyncio
async def test_health_endpoint_degraded(db_engine, redis_client): # noqa: F811
async def _session_override():
factory = async_sessionmaker(db_engine, expire_on_commit=False, autoflush=False)
async with factory() as session:
yield session
class FailingRedis:
async def ping(self) -> None:
raise RuntimeError("redis down")
def __getattr__(self, _: str) -> None: # pragma: no cover - minimal stub
raise RuntimeError("redis down")
async def _redis_override():
yield FailingRedis()
app = create_app()
app.dependency_overrides.clear()
app.dependency_overrides[deps.db_session_dep] = _session_override
app.dependency_overrides[deps.redis_dep] = _redis_override
reset_metrics()
async with AsyncClient(app=app, base_url="http://testserver") as client:
response = await client.get("/v1/health")
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "degraded"
assert payload["redis_error"]
assert payload["db_error"] is None
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_metrics_endpoint(async_client):
baseline = await async_client.get("/metrics")
before = _extract_counter(baseline.text, "poolhub_match_requests_total")
for _ in range(2):
await async_client.post(
"/v1/match",
json={"job_id": str(uuid.uuid4()), "requirements": {}, "hints": {}, "top_k": 1},
)
updated = await async_client.get("/metrics")
after = _extract_counter(updated.text, "poolhub_match_requests_total")
assert after >= before + 2
def _extract_counter(metrics_text: str, metric: str) -> float:
for line in metrics_text.splitlines():
if line.startswith(metric):
parts = line.split()
if len(parts) >= 2:
try:
return float(parts[1])
except ValueError: # pragma: no cover
return 0.0
return 0.0

View File

@ -0,0 +1,96 @@
from __future__ import annotations
import json
import uuid
import pytest
from poolhub.repositories.feedback_repository import FeedbackRepository
from poolhub.repositories.match_repository import MatchRepository
from poolhub.repositories.miner_repository import MinerRepository
from poolhub.storage.redis_keys import RedisKeys
@pytest.mark.asyncio
async def test_register_miner_persists_and_syncs(db_session, redis_client):
repo = MinerRepository(db_session, redis_client)
await repo.register_miner(
miner_id="miner-1",
api_key_hash="hash",
addr="127.0.0.1",
proto="grpc",
gpu_vram_gb=16,
gpu_name="A100",
cpu_cores=32,
ram_gb=128,
max_parallel=4,
base_price=0.8,
tags={"tier": "gold"},
capabilities=["embedding"],
region="eu",
)
miner = await repo.get_miner("miner-1")
assert miner is not None
assert miner.addr == "127.0.0.1"
redis_hash = await redis_client.hgetall(RedisKeys.miner_hash("miner-1"))
assert redis_hash["miner_id"] == "miner-1"
ranking = await redis_client.zscore(RedisKeys.miner_rankings("eu"), "miner-1")
assert ranking is not None
@pytest.mark.asyncio
async def test_match_request_flow(db_session, redis_client):
match_repo = MatchRepository(db_session, redis_client)
req = await match_repo.create_request(
job_id="job-123",
requirements={"min_vram_gb": 8},
hints={"region": "eu"},
top_k=2,
)
await db_session.commit()
queue_entry = await redis_client.lpop(RedisKeys.match_requests())
assert queue_entry is not None
payload = json.loads(queue_entry)
assert payload["job_id"] == "job-123"
await match_repo.add_results(
request_id=req.id,
candidates=[
{"miner_id": "miner-1", "score": 0.9, "explain": "fit"},
{"miner_id": "miner-2", "score": 0.8, "explain": "backup"},
],
)
await db_session.commit()
results = await match_repo.list_results_for_job("job-123")
assert len(results) == 2
redis_results = await redis_client.lrange(RedisKeys.match_results("job-123"), 0, -1)
assert len(redis_results) == 2
@pytest.mark.asyncio
async def test_feedback_repository(db_session, redis_client):
feedback_repo = FeedbackRepository(db_session, redis_client)
feedback = await feedback_repo.add_feedback(
job_id="job-321",
miner_id="miner-1",
outcome="completed",
latency_ms=1200,
tokens_spent=1.5,
)
await db_session.commit()
rows = await feedback_repo.list_feedback_for_job("job-321")
assert len(rows) == 1
assert rows[0].outcome == "completed"
# Ensure Redis publish occurred by checking pubsub message count via monitor list (best effort)
# Redis doesn't buffer publishes for inspection, so this is a smoke check ensuring repository returns object
assert feedback.miner_id == "miner-1"