Files
aitbc/apps/pool-hub/src/poolhub/repositories/match_repository.py

123 lines
4.2 KiB
Python

from __future__ import annotations
import datetime as dt
import json
from typing import Iterable, List, Optional, Sequence
from uuid import UUID
from redis.asyncio import Redis
from sqlalchemy import Select, select
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import MatchRequest, MatchResult
from ..storage.redis_keys import RedisKeys
class MatchRepository:
"""Handles match request logging, result persistence, and Redis fan-out."""
def __init__(self, session: AsyncSession, redis: Redis) -> None:
self._session = session
self._redis = redis
async def create_request(
self,
*,
job_id: str,
requirements: dict[str, object],
hints: Optional[dict[str, object]] = None,
top_k: int = 1,
enqueue: bool = True,
) -> MatchRequest:
request = MatchRequest(
job_id=job_id,
requirements=requirements,
hints=hints or {},
top_k=top_k,
created_at=dt.datetime.utcnow(),
)
self._session.add(request)
await self._session.flush()
if enqueue:
payload = {
"request_id": str(request.id),
"job_id": request.job_id,
"requirements": request.requirements,
"hints": request.hints,
"top_k": request.top_k,
}
await self._redis.rpush(RedisKeys.match_requests(), json.dumps(payload))
return request
async def add_results(
self,
*,
request_id: UUID,
candidates: Sequence[dict[str, object]],
publish: bool = True,
) -> List[MatchResult]:
results: List[MatchResult] = []
created_at = dt.datetime.utcnow()
for candidate in candidates:
result = MatchResult(
request_id=request_id,
miner_id=str(candidate.get("miner_id")),
score=float(candidate.get("score", 0.0)),
explain=candidate.get("explain"),
eta_ms=candidate.get("eta_ms"),
price=candidate.get("price"),
created_at=created_at,
)
self._session.add(result)
results.append(result)
await self._session.flush()
if publish:
request = await self._session.get(MatchRequest, request_id)
if request:
redis_key = RedisKeys.match_results(request.job_id)
await self._redis.delete(redis_key)
if results:
payloads = [json.dumps(self._result_payload(result)) for result in results]
await self._redis.rpush(redis_key, *payloads)
await self._redis.expire(redis_key, 300)
channel = RedisKeys.match_results_channel(request.job_id)
for payload in payloads:
await self._redis.publish(channel, payload)
return results
async def get_request(self, request_id: UUID) -> Optional[MatchRequest]:
return await self._session.get(MatchRequest, request_id)
async def list_recent_requests(self, limit: int = 20) -> List[MatchRequest]:
stmt: Select[MatchRequest] = (
select(MatchRequest)
.order_by(MatchRequest.created_at.desc())
.limit(limit)
)
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def list_results_for_job(self, job_id: str, limit: int = 10) -> List[MatchResult]:
stmt: Select[MatchResult] = (
select(MatchResult)
.join(MatchRequest)
.where(MatchRequest.job_id == job_id)
.order_by(MatchResult.created_at.desc())
.limit(limit)
)
result = await self._session.execute(stmt)
return list(result.scalars().all())
def _result_payload(self, result: MatchResult) -> dict[str, object]:
return {
"request_id": str(result.request_id),
"miner_id": result.miner_id,
"score": result.score,
"explain": result.explain,
"eta_ms": result.eta_ms,
"price": result.price,
"created_at": result.created_at.isoformat() if result.created_at else None,
}