123 lines
4.2 KiB
Python
123 lines
4.2 KiB
Python
from __future__ import annotations
|
|
|
|
import datetime as dt
|
|
import json
|
|
from typing import Iterable, List, Optional, Sequence
|
|
from uuid import UUID
|
|
|
|
from redis.asyncio import Redis
|
|
from sqlalchemy import Select, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from ..models import MatchRequest, MatchResult
|
|
from ..storage.redis_keys import RedisKeys
|
|
|
|
|
|
class MatchRepository:
|
|
"""Handles match request logging, result persistence, and Redis fan-out."""
|
|
|
|
def __init__(self, session: AsyncSession, redis: Redis) -> None:
|
|
self._session = session
|
|
self._redis = redis
|
|
|
|
async def create_request(
|
|
self,
|
|
*,
|
|
job_id: str,
|
|
requirements: dict[str, object],
|
|
hints: Optional[dict[str, object]] = None,
|
|
top_k: int = 1,
|
|
enqueue: bool = True,
|
|
) -> MatchRequest:
|
|
request = MatchRequest(
|
|
job_id=job_id,
|
|
requirements=requirements,
|
|
hints=hints or {},
|
|
top_k=top_k,
|
|
created_at=dt.datetime.utcnow(),
|
|
)
|
|
self._session.add(request)
|
|
await self._session.flush()
|
|
|
|
if enqueue:
|
|
payload = {
|
|
"request_id": str(request.id),
|
|
"job_id": request.job_id,
|
|
"requirements": request.requirements,
|
|
"hints": request.hints,
|
|
"top_k": request.top_k,
|
|
}
|
|
await self._redis.rpush(RedisKeys.match_requests(), json.dumps(payload))
|
|
return request
|
|
|
|
async def add_results(
|
|
self,
|
|
*,
|
|
request_id: UUID,
|
|
candidates: Sequence[dict[str, object]],
|
|
publish: bool = True,
|
|
) -> List[MatchResult]:
|
|
results: List[MatchResult] = []
|
|
created_at = dt.datetime.utcnow()
|
|
for candidate in candidates:
|
|
result = MatchResult(
|
|
request_id=request_id,
|
|
miner_id=str(candidate.get("miner_id")),
|
|
score=float(candidate.get("score", 0.0)),
|
|
explain=candidate.get("explain"),
|
|
eta_ms=candidate.get("eta_ms"),
|
|
price=candidate.get("price"),
|
|
created_at=created_at,
|
|
)
|
|
self._session.add(result)
|
|
results.append(result)
|
|
await self._session.flush()
|
|
|
|
if publish:
|
|
request = await self._session.get(MatchRequest, request_id)
|
|
if request:
|
|
redis_key = RedisKeys.match_results(request.job_id)
|
|
await self._redis.delete(redis_key)
|
|
if results:
|
|
payloads = [json.dumps(self._result_payload(result)) for result in results]
|
|
await self._redis.rpush(redis_key, *payloads)
|
|
await self._redis.expire(redis_key, 300)
|
|
channel = RedisKeys.match_results_channel(request.job_id)
|
|
for payload in payloads:
|
|
await self._redis.publish(channel, payload)
|
|
return results
|
|
|
|
async def get_request(self, request_id: UUID) -> Optional[MatchRequest]:
|
|
return await self._session.get(MatchRequest, request_id)
|
|
|
|
async def list_recent_requests(self, limit: int = 20) -> List[MatchRequest]:
|
|
stmt: Select[MatchRequest] = (
|
|
select(MatchRequest)
|
|
.order_by(MatchRequest.created_at.desc())
|
|
.limit(limit)
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def list_results_for_job(self, job_id: str, limit: int = 10) -> List[MatchResult]:
|
|
stmt: Select[MatchResult] = (
|
|
select(MatchResult)
|
|
.join(MatchRequest)
|
|
.where(MatchRequest.job_id == job_id)
|
|
.order_by(MatchResult.created_at.desc())
|
|
.limit(limit)
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
def _result_payload(self, result: MatchResult) -> dict[str, object]:
|
|
return {
|
|
"request_id": str(result.request_id),
|
|
"miner_id": result.miner_id,
|
|
"score": result.score,
|
|
"explain": result.explain,
|
|
"eta_ms": result.eta_ms,
|
|
"price": result.price,
|
|
"created_at": result.created_at.isoformat() if result.created_at else None,
|
|
}
|