feat: add foreign key constraints and metrics for blockchain node

This commit is contained in:
oib
2025-09-28 06:04:30 +02:00
parent c1926136fb
commit fb60505cdf
189 changed files with 15678 additions and 158 deletions

Binary file not shown.

View File

@ -12,23 +12,85 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '80bc0020bde2'
down_revision: Union[str, Sequence[str], None] = 'e31f486f1484'
revision: str = "80bc0020bde2"
down_revision: Union[str, Sequence[str], None] = "e31f486f1484"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_foreign_key(None, 'receipt', 'block', ['block_height'], ['height'])
op.create_foreign_key(None, 'transaction', 'block', ['block_height'], ['height'])
# ### end Alembic commands ###
# Recreate transaction table with foreign key to block.height
op.drop_table("transaction")
op.create_table(
"transaction",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False),
sa.Column("tx_hash", sa.String(), nullable=False),
sa.Column("block_height", sa.Integer(), sa.ForeignKey("block.height"), nullable=True),
sa.Column("sender", sa.String(), nullable=False),
sa.Column("recipient", sa.String(), nullable=False),
sa.Column("payload", sa.JSON(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
)
op.create_index("ix_transaction_tx_hash", "transaction", ["tx_hash"], unique=True)
op.create_index("ix_transaction_block_height", "transaction", ["block_height"], unique=False)
op.create_index("ix_transaction_created_at", "transaction", ["created_at"], unique=False)
# Recreate receipt table with foreign key to block.height
op.drop_table("receipt")
op.create_table(
"receipt",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False),
sa.Column("job_id", sa.String(), nullable=False),
sa.Column("receipt_id", sa.String(), nullable=False),
sa.Column("block_height", sa.Integer(), sa.ForeignKey("block.height"), nullable=True),
sa.Column("payload", sa.JSON(), nullable=False),
sa.Column("miner_signature", sa.JSON(), nullable=False),
sa.Column("coordinator_attestations", sa.JSON(), nullable=False),
sa.Column("minted_amount", sa.Integer(), nullable=True),
sa.Column("recorded_at", sa.DateTime(), nullable=False),
)
op.create_index("ix_receipt_job_id", "receipt", ["job_id"], unique=False)
op.create_index("ix_receipt_receipt_id", "receipt", ["receipt_id"], unique=True)
op.create_index("ix_receipt_block_height", "receipt", ["block_height"], unique=False)
op.create_index("ix_receipt_recorded_at", "receipt", ["recorded_at"], unique=False)
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'transaction', type_='foreignkey')
op.drop_constraint(None, 'receipt', type_='foreignkey')
# ### end Alembic commands ###
# Revert receipt table without foreign key
op.drop_table("receipt")
op.create_table(
"receipt",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False),
sa.Column("job_id", sa.String(), nullable=False),
sa.Column("receipt_id", sa.String(), nullable=False),
sa.Column("block_height", sa.Integer(), nullable=True),
sa.Column("payload", sa.JSON(), nullable=False),
sa.Column("miner_signature", sa.JSON(), nullable=False),
sa.Column("coordinator_attestations", sa.JSON(), nullable=False),
sa.Column("minted_amount", sa.Integer(), nullable=True),
sa.Column("recorded_at", sa.DateTime(), nullable=False),
)
op.create_index("ix_receipt_job_id", "receipt", ["job_id"], unique=False)
op.create_index("ix_receipt_receipt_id", "receipt", ["receipt_id"], unique=True)
op.create_index("ix_receipt_block_height", "receipt", ["block_height"], unique=False)
op.create_index("ix_receipt_recorded_at", "receipt", ["recorded_at"], unique=False)
# Revert transaction table without foreign key
op.drop_table("transaction")
op.create_table(
"transaction",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False),
sa.Column("tx_hash", sa.String(), nullable=False),
sa.Column("block_height", sa.Integer(), nullable=True),
sa.Column("sender", sa.String(), nullable=False),
sa.Column("recipient", sa.String(), nullable=False),
sa.Column("payload", sa.JSON(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
)
op.create_index("ix_transaction_tx_hash", "transaction", ["tx_hash"], unique=True)
op.create_index("ix_transaction_block_height", "transaction", ["block_height"], unique=False)
op.create_index("ix_transaction_created_at", "transaction", ["created_at"], unique=False)

View File

@ -0,0 +1,43 @@
# Blockchain Node Observability
This directory contains Prometheus and Grafana assets for the devnet environment. The stack relies on the HTTP `/metrics` endpoint exposed by:
1. The blockchain node API (`http://127.0.0.1:8080/metrics`).
2. The mock coordinator/miner exporter (`http://127.0.0.1:8090/metrics`).
## Files
- `prometheus.yml` Scrapes both blockchain node and mock coordinator/miner metrics.
- `grafana-dashboard.json` Panels for block interval, RPC throughput, miner activity, coordinator receipt flow, **plus new gossip queue, subscriber, and publication rate panels**.
- `alerts.yml` Alertmanager rules highlighting proposer stalls, miner errors, and coordinator receipt drop-offs.
- `gossip-recording-rules.yml` Prometheus recording rules that derive queue/subscriber gauges and publication rates from gossip metrics.
## Usage
```bash
# Launch Prometheus using the sample config
prometheus --config.file=apps/blockchain-node/observability/prometheus.yml
# Import the dashboard JSON into Grafana
grafana-cli dashboards import apps/blockchain-node/observability/grafana-dashboard.json
# Run Alertmanager with the example rules
alertmanager --config.file=apps/blockchain-node/observability/alerts.yml
# Reload Prometheus and Alertmanager after tuning thresholds
kill -HUP $(pgrep prometheus)
kill -HUP $(pgrep alertmanager)
```
> **Tip:** The devnet helper `scripts/devnet_up.sh` seeds the metrics endpoints. After running it, both scrape targets will begin emitting data in under a minute.
## Gossip Observability
Recent updates instrumented the gossip broker with Prometheus counters and gauges. Key metrics surfaced via the recording rules and dashboard include:
- `gossip_publications_rate_per_sec` and `gossip_broadcast_publications_rate_per_sec` per-second publication throughput for in-memory and broadcast backends.
- `gossip_publications_topic_rate_per_sec` topic-level publication rate time series (Grafana panel “Gossip Publication Rate by Topic”).
- `gossip_queue_size_by_topic` instantaneous queue depth per topic (“Gossip Queue Depth by Topic”).
- `gossip_subscribers_by_topic`, `gossip_subscribers_total`, `gossip_broadcast_subscribers_total` subscriber counts (“Gossip Subscriber Counts”).
Use these panels to monitor convergence/back-pressure during load tests (for example with `scripts/ws_load_test.py`) when running against a Redis-backed broadcast backend.

View File

@ -0,0 +1,43 @@
groups:
- name: blockchain-node
rules:
- alert: BlockProposalStalled
expr: (block_interval_seconds_sum / block_interval_seconds_count) > 5
for: 1m
labels:
severity: warning
annotations:
summary: "Block production interval exceeded 5s"
description: |
Average block interval is {{ $value }} seconds, exceeding the expected cadence.
- alert: BlockProposalDown
expr: (block_interval_seconds_sum / block_interval_seconds_count) > 10
for: 2m
labels:
severity: critical
annotations:
summary: "Block production halted"
description: |
Block intervals have spiked above 10 seconds for more than two minutes.
Check proposer loop and database state.
- alert: MinerErrorsDetected
expr: miner_error_rate > 0
for: 1m
labels:
severity: critical
annotations:
summary: "Miner mock reporting errors"
description: |
The miner mock error gauge is {{ $value }}. Investigate miner telemetry.
- alert: CoordinatorReceiptDrop
expr: rate(miner_receipts_attested_total[5m]) == 0
for: 5m
labels:
severity: warning
annotations:
summary: "No receipts attested in 5 minutes"
description: |
Receipt attestations ceased during the last five minutes. Inspect coordinator connectivity.

View File

@ -0,0 +1,36 @@
groups:
- name: gossip_metrics
interval: 15s
rules:
- record: gossip_publications_rate_per_sec
expr: rate(gossip_publications_total[1m])
- record: gossip_broadcast_publications_rate_per_sec
expr: rate(gossip_broadcast_publications_total[1m])
- record: gossip_publications_topic_rate_per_sec
expr: label_replace(
rate({__name__=~"gossip_publications_topic_.*"}[1m]),
"topic",
"$1",
"__name__",
"gossip_publications_topic_(.*)"
)
- record: gossip_queue_size_by_topic
expr: label_replace(
{__name__=~"gossip_queue_size_.*"},
"topic",
"$1",
"__name__",
"gossip_queue_size_(.*)"
)
- record: gossip_subscribers_by_topic
expr: label_replace(
{__name__=~"gossip_subscribers_topic_.*"},
"topic",
"$1",
"__name__",
"gossip_subscribers_topic_(.*)"
)

View File

@ -0,0 +1,377 @@
{
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": {
"type": "grafana",
"uid": "grafana"
},
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"description": "AITBC devnet observability for blockchain node, coordinator, and miner mock.",
"editable": true,
"fiscalYearStartMonth": 0,
"gnetId": null,
"graphTooltip": 0,
"id": null,
"iteration": 1727420700000,
"links": [],
"liveNow": false,
"panels": [
{
"datasource": {
"type": "prometheus",
"uid": "PROMETHEUS_DS"
},
"fieldConfig": {
"defaults": {
"custom": {},
"unit": "s"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 0
},
"id": 1,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
}
},
"targets": [
{
"expr": "block_interval_seconds_sum / block_interval_seconds_count",
"legendFormat": "avg block interval",
"refId": "A"
}
],
"title": "Block Interval (seconds)",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "PROMETHEUS_DS"
},
"fieldConfig": {
"defaults": {
"custom": {},
"unit": "ops"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 0
},
"id": 2,
"options": {
"legend": {
"calcs": ["lastNotNull"],
"displayMode": "table",
"placement": "bottom",
"showLegend": true
}
},
"targets": [
{
"expr": "rate(rpc_send_tx_total[5m])",
"legendFormat": "sendTx",
"refId": "A"
},
{
"expr": "rate(rpc_submit_receipt_total[5m])",
"legendFormat": "submitReceipt",
"refId": "B"
},
{
"expr": "rate(rpc_get_head_total[5m])",
"legendFormat": "getHead",
"refId": "C"
}
],
"title": "RPC Throughput",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "PROMETHEUS_DS"
},
"fieldConfig": {
"defaults": {
"custom": {}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 8
},
"id": 3,
"options": {
"legend": {
"calcs": ["lastNotNull"],
"displayMode": "table",
"placement": "bottom",
"showLegend": true
}
},
"targets": [
{
"expr": "miner_active_jobs",
"legendFormat": "active jobs",
"refId": "A"
},
{
"expr": "miner_error_rate",
"legendFormat": "error gauge",
"refId": "B"
}
],
"title": "Miner Activity",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "PROMETHEUS_DS"
},
"fieldConfig": {
"defaults": {
"custom": {},
"unit": "short"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 8
},
"id": 4,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
}
},
"targets": [
{
"expr": "rate(miner_receipts_attested_total[5m])",
"legendFormat": "receipts attested",
"refId": "A"
},
{
"expr": "rate(miner_receipts_unknown_total[5m])",
"legendFormat": "unknown receipts",
"refId": "B"
}
],
"title": "Coordinator Receipt Flow",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "PROMETHEUS_DS"
},
"fieldConfig": {
"defaults": {
"custom": {}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 16
},
"id": 5,
"options": {
"legend": {
"calcs": ["lastNotNull"],
"displayMode": "table",
"placement": "bottom",
"showLegend": true
}
},
"targets": [
{
"expr": "gossip_queue_size_by_topic",
"legendFormat": "{{topic}}",
"refId": "A"
}
],
"title": "Gossip Queue Depth by Topic",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "PROMETHEUS_DS"
},
"fieldConfig": {
"defaults": {
"custom": {}
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 16
},
"id": 6,
"options": {
"legend": {
"calcs": ["lastNotNull"],
"displayMode": "table",
"placement": "bottom",
"showLegend": true
}
},
"targets": [
{
"expr": "gossip_subscribers_by_topic",
"legendFormat": "{{topic}}",
"refId": "A"
},
{
"expr": "gossip_subscribers_total",
"legendFormat": "total subscribers",
"refId": "B"
},
{
"expr": "gossip_broadcast_subscribers_total",
"legendFormat": "broadcast subscribers",
"refId": "C"
}
],
"title": "Gossip Subscriber Counts",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "PROMETHEUS_DS"
},
"fieldConfig": {
"defaults": {
"custom": {},
"unit": "ops"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 24
},
"id": 7,
"options": {
"legend": {
"calcs": ["lastNotNull"],
"displayMode": "table",
"placement": "bottom",
"showLegend": true
}
},
"targets": [
{
"expr": "gossip_publications_rate_per_sec",
"legendFormat": "memory backend",
"refId": "A"
},
{
"expr": "gossip_broadcast_publications_rate_per_sec",
"legendFormat": "broadcast backend",
"refId": "B"
}
],
"title": "Gossip Publication Rate (total)",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "PROMETHEUS_DS"
},
"fieldConfig": {
"defaults": {
"custom": {},
"unit": "ops"
},
"overrides": []
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 24
},
"id": 8,
"options": {
"legend": {
"calcs": ["lastNotNull"],
"displayMode": "table",
"placement": "bottom",
"showLegend": true
}
},
"targets": [
{
"expr": "gossip_publications_topic_rate_per_sec",
"legendFormat": "{{topic}}",
"refId": "A"
}
],
"title": "Gossip Publication Rate by Topic",
"type": "timeseries"
}
],
"refresh": "10s",
"schemaVersion": 39,
"style": "dark",
"tags": [
"aitbc",
"blockchain-node"
],
"templating": {
"list": []
},
"time": {
"from": "now-30m",
"to": "now"
},
"timepicker": {},
"timezone": "",
"title": "AITBC Blockchain Node",
"uid": null,
"version": 1,
"weekStart": ""
}

View File

@ -0,0 +1,28 @@
global:
scrape_interval: 5s
evaluation_interval: 10s
alerting:
alertmanagers:
- static_configs:
- targets:
- "127.0.0.1:9093"
scrape_configs:
- job_name: "blockchain-node"
static_configs:
- targets:
- "127.0.0.1:8080"
labels:
service: "blockchain-node"
- job_name: "mock-coordinator"
static_configs:
- targets:
- "127.0.0.1:8090"
labels:
service: "mock-coordinator"
rule_files:
- alerts.yml
- gossip-recording-rules.yml

View File

@ -3,9 +3,14 @@
from __future__ import annotations
import random
import time
from typing import Dict
from fastapi import FastAPI
from fastapi.responses import PlainTextResponse
from aitbc_chain.metrics import metrics_registry
app = FastAPI(title="Mock Coordinator API", version="0.1.0")
@ -15,6 +20,17 @@ MOCK_JOBS: Dict[str, Dict[str, str]] = {
}
def _simulate_miner_metrics() -> None:
metrics_registry.set_gauge("miner_active_jobs", float(random.randint(0, 5)))
metrics_registry.set_gauge("miner_error_rate", float(random.randint(0, 1)))
metrics_registry.observe("miner_job_duration_seconds", random.uniform(1.0, 5.0))
@app.on_event("startup")
async def _startup() -> None:
_simulate_miner_metrics()
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "ok"}
@ -24,15 +40,23 @@ def health() -> Dict[str, str]:
def attest_receipt(payload: Dict[str, str]) -> Dict[str, str | bool]:
job_id = payload.get("job_id")
if job_id in MOCK_JOBS:
metrics_registry.increment("miner_receipts_attested_total")
return {
"exists": True,
"paid": True,
"not_double_spent": True,
"quote": MOCK_JOBS[job_id],
}
metrics_registry.increment("miner_receipts_unknown_total")
return {
"exists": False,
"paid": False,
"not_double_spent": False,
"quote": {},
}
@app.get("/metrics", response_class=PlainTextResponse)
def metrics() -> str:
metrics_registry.observe("miner_metrics_scrape_duration_seconds", random.uniform(0.001, 0.01))
return metrics_registry.render_prometheus()

View File

@ -0,0 +1,224 @@
#!/usr/bin/env python3
"""Asynchronous load harness for blockchain-node WebSocket + gossip pipeline."""
from __future__ import annotations
import argparse
import asyncio
import json
import random
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import httpx
import websockets
@dataclass
class PublishStats:
sent: int = 0
failed: int = 0
latencies: List[float] = field(default_factory=list)
@property
def average_latency_ms(self) -> Optional[float]:
if not self.latencies:
return None
return (sum(self.latencies) / len(self.latencies)) * 1000.0
@property
def p95_latency_ms(self) -> Optional[float]:
if not self.latencies:
return None
sorted_latencies = sorted(self.latencies)
index = int(len(sorted_latencies) * 0.95)
index = min(index, len(sorted_latencies) - 1)
return sorted_latencies[index] * 1000.0
@dataclass
class SubscriptionStats:
messages: int = 0
disconnects: int = 0
async def _publish_transactions(
base_url: str,
stats: PublishStats,
stop_event: asyncio.Event,
rate_hz: float,
job_id: str,
client_id: str,
timeout: float,
) -> None:
interval = 1 / rate_hz if rate_hz > 0 else 0
async with httpx.AsyncClient(base_url=base_url, timeout=timeout) as client:
while not stop_event.is_set():
payload = {
"type": "TRANSFER",
"sender": f"miner-{client_id}",
"nonce": stats.sent,
"fee": 1,
"payload": {
"job_id": job_id,
"amount": random.randint(1, 10),
"timestamp": time.time_ns(),
},
}
started = time.perf_counter()
try:
response = await client.post("/rpc/sendTx", json=payload)
response.raise_for_status()
except httpx.HTTPError:
stats.failed += 1
else:
stats.sent += 1
stats.latencies.append(time.perf_counter() - started)
if interval:
try:
await asyncio.wait_for(stop_event.wait(), timeout=interval)
except asyncio.TimeoutError:
continue
else:
await asyncio.sleep(0)
async def _subscription_worker(
websocket_url: str,
stats: SubscriptionStats,
stop_event: asyncio.Event,
client_name: str,
) -> None:
while not stop_event.is_set():
try:
async with websockets.connect(websocket_url) as ws:
while not stop_event.is_set():
try:
message = await asyncio.wait_for(ws.recv(), timeout=1.0)
except asyncio.TimeoutError:
continue
except websockets.ConnectionClosed:
stats.disconnects += 1
break
else:
_ = message # lightweight backpressure test only
stats.messages += 1
except OSError:
stats.disconnects += 1
await asyncio.sleep(0.5)
async def run_load(args: argparse.Namespace) -> None:
stop_event = asyncio.Event()
publish_stats: List[PublishStats] = [PublishStats() for _ in range(args.publishers)]
subscription_stats: Dict[str, SubscriptionStats] = {
"blocks": SubscriptionStats(),
"transactions": SubscriptionStats(),
}
publisher_tasks = [
asyncio.create_task(
_publish_transactions(
base_url=args.http_base,
stats=publish_stats[i],
stop_event=stop_event,
rate_hz=args.publish_rate,
job_id=f"load-test-job-{i}",
client_id=f"{i}",
timeout=args.http_timeout,
),
name=f"publisher-{i}",
)
for i in range(args.publishers)
]
subscriber_tasks = [
asyncio.create_task(
_subscription_worker(
websocket_url=f"{args.ws_base}/blocks",
stats=subscription_stats["blocks"],
stop_event=stop_event,
client_name="blocks",
),
name="subscriber-blocks",
),
asyncio.create_task(
_subscription_worker(
websocket_url=f"{args.ws_base}/transactions",
stats=subscription_stats["transactions"],
stop_event=stop_event,
client_name="transactions",
),
name="subscriber-transactions",
),
]
all_tasks = publisher_tasks + subscriber_tasks
try:
await asyncio.wait_for(stop_event.wait(), timeout=args.duration)
except asyncio.TimeoutError:
pass
finally:
stop_event.set()
await asyncio.gather(*all_tasks, return_exceptions=True)
_print_summary(publish_stats, subscription_stats)
def _print_summary(publish_stats: List[PublishStats], subscription_stats: Dict[str, SubscriptionStats]) -> None:
total_sent = sum(s.sent for s in publish_stats)
total_failed = sum(s.failed for s in publish_stats)
all_latencies = [lat for s in publish_stats for lat in s.latencies]
summary = {
"publish": {
"total_sent": total_sent,
"total_failed": total_failed,
"average_latency_ms": (sum(all_latencies) / len(all_latencies) * 1000.0) if all_latencies else None,
"p95_latency_ms": _p95(all_latencies),
},
"subscriptions": {
name: {
"messages": stats.messages,
"disconnects": stats.disconnects,
}
for name, stats in subscription_stats.items()
},
}
print(json.dumps(summary, indent=2))
def _p95(latencies: List[float]) -> Optional[float]:
if not latencies:
return None
sorted_latencies = sorted(latencies)
index = int(len(sorted_latencies) * 0.95)
index = min(index, len(sorted_latencies) - 1)
return sorted_latencies[index] * 1000.0
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="AITBC blockchain-node WebSocket load harness")
parser.add_argument("--http-base", default="http://127.0.0.1:8080", help="Base URL for REST API")
parser.add_argument("--ws-base", default="ws://127.0.0.1:8080/rpc/ws", help="Base URL for WebSocket API")
parser.add_argument("--duration", type=float, default=30.0, help="Duration in seconds")
parser.add_argument("--publishers", type=int, default=4, help="Concurrent transaction publishers")
parser.add_argument("--publish-rate", type=float, default=5.0, help="Transactions per second per publisher")
parser.add_argument("--http-timeout", type=float, default=5.0, help="HTTP client timeout in seconds")
return parser.parse_args()
def main() -> None:
args = _parse_args()
try:
asyncio.run(run_load(args))
except KeyboardInterrupt:
pass
if __name__ == "__main__":
main()

View File

@ -1,24 +1,35 @@
from __future__ import annotations
from contextlib import asynccontextmanager
from fastapi import APIRouter, FastAPI
from fastapi.responses import PlainTextResponse
from .config import settings
from .database import init_db
from .gossip import create_backend, gossip_broker
from .metrics import metrics_registry
from .rpc.router import router as rpc_router
from .rpc.websocket import router as websocket_router
@asynccontextmanager
async def lifespan(app: FastAPI):
init_db()
backend = create_backend(
settings.gossip_backend,
broadcast_url=settings.gossip_broadcast_url,
)
await gossip_broker.set_backend(backend)
try:
yield
finally:
await gossip_broker.shutdown()
def create_app() -> FastAPI:
app = FastAPI(title="AITBC Blockchain Node", version="0.1.0", lifespan=lifespan)
app.include_router(rpc_router, prefix="/rpc", tags=["rpc"])
app.include_router(websocket_router, prefix="/rpc")
metrics_router = APIRouter()
@metrics_router.get("/metrics", response_class=PlainTextResponse, tags=["metrics"], summary="Prometheus metrics")

View File

@ -26,5 +26,8 @@ class ChainSettings(BaseSettings):
block_time_seconds: int = 2
gossip_backend: str = "memory"
gossip_broadcast_url: Optional[str] = None
settings = ChainSettings()

View File

@ -11,6 +11,7 @@ from sqlmodel import Session, select
from ..logging import get_logger
from ..metrics import metrics_registry
from ..models import Block
from ..gossip import gossip_broker
@dataclass
@ -78,9 +79,11 @@ class PoAProposer:
head = session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
next_height = 0
parent_hash = "0x00"
interval_seconds: Optional[float] = None
if head is not None:
next_height = head.height + 1
parent_hash = head.hash
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
timestamp = datetime.utcnow()
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp)
@ -99,6 +102,21 @@ class PoAProposer:
metrics_registry.increment("blocks_proposed_total")
metrics_registry.set_gauge("chain_head_height", float(next_height))
if interval_seconds is not None and interval_seconds >= 0:
metrics_registry.observe("block_interval_seconds", interval_seconds)
asyncio.create_task(
gossip_broker.publish(
"blocks",
{
"height": block.height,
"hash": block.hash,
"parent_hash": block.parent_hash,
"timestamp": block.timestamp.isoformat(),
"tx_count": block.tx_count,
},
)
)
self._logger.info(
"Proposed block",
@ -129,6 +147,19 @@ class PoAProposer:
)
session.add(genesis)
session.commit()
asyncio.create_task(
gossip_broker.publish(
"blocks",
{
"height": genesis.height,
"hash": genesis.hash,
"parent_hash": genesis.parent_hash,
"timestamp": genesis.timestamp.isoformat(),
"tx_count": genesis.tx_count,
},
)
)
self._logger.info("Created genesis block", extra={"hash": genesis_hash})
def _fetch_chain_head(self) -> Optional[Block]:

View File

@ -0,0 +1,17 @@
from .broker import (
BroadcastGossipBackend,
GossipBroker,
InMemoryGossipBackend,
TopicSubscription,
create_backend,
gossip_broker,
)
__all__ = [
"BroadcastGossipBackend",
"GossipBroker",
"InMemoryGossipBackend",
"TopicSubscription",
"create_backend",
"gossip_broker",
]

View File

@ -0,0 +1,254 @@
from __future__ import annotations
import asyncio
import json
from collections import defaultdict
from contextlib import suppress
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set
try:
from starlette.broadcast import Broadcast
except ImportError: # pragma: no cover - Starlette is an indirect dependency of FastAPI
Broadcast = None # type: ignore[assignment]
from ..metrics import metrics_registry
def _increment_publication(metric_prefix: str, topic: str) -> None:
metrics_registry.increment(f"{metric_prefix}_total")
metrics_registry.increment(f"{metric_prefix}_topic_{topic}")
def _set_queue_gauge(topic: str, size: int) -> None:
metrics_registry.set_gauge(f"gossip_queue_size_{topic}", float(size))
def _update_subscriber_metrics(topics: Dict[str, List["asyncio.Queue[Any]"]]) -> None:
for topic, queues in topics.items():
metrics_registry.set_gauge(f"gossip_subscribers_topic_{topic}", float(len(queues)))
total = sum(len(queues) for queues in topics.values())
metrics_registry.set_gauge("gossip_subscribers_total", float(total))
def _clear_topic_metrics(topic: str) -> None:
metrics_registry.set_gauge(f"gossip_subscribers_topic_{topic}", 0.0)
_set_queue_gauge(topic, 0)
@dataclass
class TopicSubscription:
topic: str
queue: "asyncio.Queue[Any]"
_unsubscribe: Callable[[], None]
def close(self) -> None:
self._unsubscribe()
async def get(self) -> Any:
return await self.queue.get()
async def __aiter__(self): # type: ignore[override]
try:
while True:
yield await self.queue.get()
finally:
self.close()
class GossipBackend:
async def start(self) -> None: # pragma: no cover - overridden as needed
return None
async def publish(self, topic: str, message: Any) -> None:
raise NotImplementedError
async def subscribe(self, topic: str, max_queue_size: int = 100) -> TopicSubscription:
raise NotImplementedError
async def shutdown(self) -> None:
return None
class InMemoryGossipBackend(GossipBackend):
def __init__(self) -> None:
self._topics: Dict[str, List["asyncio.Queue[Any]"]] = defaultdict(list)
self._lock = asyncio.Lock()
async def publish(self, topic: str, message: Any) -> None:
async with self._lock:
queues = list(self._topics.get(topic, []))
for queue in queues:
await queue.put(message)
_set_queue_gauge(topic, queue.qsize())
_increment_publication("gossip_publications", topic)
async def subscribe(self, topic: str, max_queue_size: int = 100) -> TopicSubscription:
queue: "asyncio.Queue[Any]" = asyncio.Queue(maxsize=max_queue_size)
async with self._lock:
self._topics[topic].append(queue)
_update_subscriber_metrics(self._topics)
_set_queue_gauge(topic, queue.qsize())
def _unsubscribe() -> None:
async def _remove() -> None:
async with self._lock:
queues = self._topics.get(topic)
if queues is None:
return
if queue in queues:
queues.remove(queue)
if not queues:
self._topics.pop(topic, None)
_clear_topic_metrics(topic)
_update_subscriber_metrics(self._topics)
asyncio.create_task(_remove())
return TopicSubscription(topic=topic, queue=queue, _unsubscribe=_unsubscribe)
async def shutdown(self) -> None:
async with self._lock:
topics = list(self._topics.keys())
self._topics.clear()
for topic in topics:
_clear_topic_metrics(topic)
_update_subscriber_metrics(self._topics)
class BroadcastGossipBackend(GossipBackend):
def __init__(self, url: str) -> None:
if Broadcast is None: # pragma: no cover - dependency is optional
raise RuntimeError("Starlette Broadcast backend requested but starlette is not available")
self._broadcast = Broadcast(url) # type: ignore[arg-type]
self._tasks: Set[asyncio.Task[None]] = set()
self._lock = asyncio.Lock()
self._running = False
async def start(self) -> None:
if not self._running:
await self._broadcast.connect() # type: ignore[union-attr]
self._running = True
async def publish(self, topic: str, message: Any) -> None:
if not self._running:
raise RuntimeError("Broadcast backend not started")
payload = _encode_message(message)
await self._broadcast.publish(topic, payload) # type: ignore[union-attr]
_increment_publication("gossip_broadcast_publications", topic)
async def subscribe(self, topic: str, max_queue_size: int = 100) -> TopicSubscription:
if not self._running:
raise RuntimeError("Broadcast backend not started")
queue: "asyncio.Queue[Any]" = asyncio.Queue(maxsize=max_queue_size)
stop_event = asyncio.Event()
async def _run_subscription() -> None:
async with self._broadcast.subscribe(topic) as subscriber: # type: ignore[attr-defined,union-attr]
async for event in subscriber: # type: ignore[union-attr]
if stop_event.is_set():
break
data = _decode_message(getattr(event, "message", event))
try:
await queue.put(data)
_set_queue_gauge(topic, queue.qsize())
except asyncio.CancelledError:
break
task = asyncio.create_task(_run_subscription(), name=f"broadcast-sub:{topic}")
async with self._lock:
self._tasks.add(task)
metrics_registry.set_gauge("gossip_broadcast_subscribers_total", float(len(self._tasks)))
def _unsubscribe() -> None:
async def _stop() -> None:
stop_event.set()
task.cancel()
with suppress(asyncio.CancelledError):
await task
async with self._lock:
self._tasks.discard(task)
metrics_registry.set_gauge("gossip_broadcast_subscribers_total", float(len(self._tasks)))
asyncio.create_task(_stop())
return TopicSubscription(topic=topic, queue=queue, _unsubscribe=_unsubscribe)
async def shutdown(self) -> None:
async with self._lock:
tasks = list(self._tasks)
self._tasks.clear()
metrics_registry.set_gauge("gossip_broadcast_subscribers_total", 0.0)
for task in tasks:
task.cancel()
with suppress(asyncio.CancelledError):
await task
if self._running:
await self._broadcast.disconnect() # type: ignore[union-attr]
self._running = False
class GossipBroker:
def __init__(self, backend: GossipBackend) -> None:
self._backend = backend
self._lock = asyncio.Lock()
self._started = False
async def publish(self, topic: str, message: Any) -> None:
if not self._started:
await self._backend.start()
self._started = True
await self._backend.publish(topic, message)
async def subscribe(self, topic: str, max_queue_size: int = 100) -> TopicSubscription:
if not self._started:
await self._backend.start()
self._started = True
return await self._backend.subscribe(topic, max_queue_size=max_queue_size)
async def set_backend(self, backend: GossipBackend) -> None:
await backend.start()
async with self._lock:
previous = self._backend
self._backend = backend
self._started = True
await previous.shutdown()
async def shutdown(self) -> None:
await self._backend.shutdown()
self._started = False
metrics_registry.set_gauge("gossip_subscribers_total", 0.0)
def create_backend(backend_type: str, *, broadcast_url: Optional[str] = None) -> GossipBackend:
backend = backend_type.lower()
if backend in {"memory", "inmemory", "local"}:
return InMemoryGossipBackend()
if backend in {"broadcast", "starlette", "redis"}:
if not broadcast_url:
raise ValueError("Broadcast backend requires a gossip_broadcast_url setting")
return BroadcastGossipBackend(broadcast_url)
raise ValueError(f"Unsupported gossip backend '{backend_type}'")
def _encode_message(message: Any) -> Any:
if isinstance(message, (str, bytes, bytearray)):
return message
return json.dumps(message, separators=(",", ":"))
def _decode_message(message: Any) -> Any:
if isinstance(message, (bytes, bytearray)):
message = message.decode("utf-8")
if isinstance(message, str):
try:
return json.loads(message)
except json.JSONDecodeError:
return message
return message
gossip_broker = GossipBroker(InMemoryGossipBackend())

View File

@ -15,6 +15,7 @@ class MetricsRegistry:
def __init__(self) -> None:
self._counters: Dict[str, float] = {}
self._gauges: Dict[str, float] = {}
self._summaries: Dict[str, tuple[float, float]] = {}
self._lock = Lock()
def increment(self, name: str, amount: float = 1.0) -> None:
@ -25,6 +26,17 @@ class MetricsRegistry:
with self._lock:
self._gauges[name] = value
def observe(self, name: str, value: float) -> None:
with self._lock:
count, total = self._summaries.get(name, (0.0, 0.0))
self._summaries[name] = (count + 1.0, total + value)
def reset(self) -> None:
with self._lock:
self._counters.clear()
self._gauges.clear()
self._summaries.clear()
def render_prometheus(self) -> str:
with self._lock:
lines: list[str] = []
@ -34,6 +46,10 @@ class MetricsRegistry:
for name, value in sorted(self._gauges.items()):
lines.append(f"# TYPE {name} gauge")
lines.append(f"{name} {value}")
for name, (count, total) in sorted(self._summaries.items()):
lines.append(f"# TYPE {name} summary")
lines.append(f"{name}_count {count}")
lines.append(f"{name}_sum {total}")
return "\n".join(lines) + "\n"

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import asyncio
import json
import time
from typing import Any, Dict, Optional
from fastapi import APIRouter, HTTPException, status
@ -8,6 +10,7 @@ from pydantic import BaseModel, Field, model_validator
from sqlmodel import select
from ..database import session_scope
from ..gossip import gossip_broker
from ..mempool import get_mempool
from ..metrics import metrics_registry
from ..models import Account, Block, Receipt, Transaction
@ -64,84 +67,134 @@ class MintFaucetRequest(BaseModel):
@router.get("/head", summary="Get current chain head")
async def get_head() -> Dict[str, Any]:
metrics_registry.increment("rpc_get_head_total")
start = time.perf_counter()
with session_scope() as session:
result = session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
if result is None:
metrics_registry.increment("rpc_get_head_not_found_total")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="no blocks yet")
return {
"height": result.height,
"hash": result.hash,
"timestamp": result.timestamp.isoformat(),
"tx_count": result.tx_count,
}
metrics_registry.increment("rpc_get_head_success_total")
metrics_registry.observe("rpc_get_head_duration_seconds", time.perf_counter() - start)
return {
"height": result.height,
"hash": result.hash,
"timestamp": result.timestamp.isoformat(),
"tx_count": result.tx_count,
}
@router.get("/blocks/{height}", summary="Get block by height")
async def get_block(height: int) -> Dict[str, Any]:
metrics_registry.increment("rpc_get_block_total")
start = time.perf_counter()
with session_scope() as session:
block = session.exec(select(Block).where(Block.height == height)).first()
if block is None:
metrics_registry.increment("rpc_get_block_not_found_total")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="block not found")
return {
"height": block.height,
"hash": block.hash,
"parent_hash": block.parent_hash,
"timestamp": block.timestamp.isoformat(),
"tx_count": block.tx_count,
"state_root": block.state_root,
}
metrics_registry.increment("rpc_get_block_success_total")
metrics_registry.observe("rpc_get_block_duration_seconds", time.perf_counter() - start)
return {
"height": block.height,
"hash": block.hash,
"parent_hash": block.parent_hash,
"timestamp": block.timestamp.isoformat(),
"tx_count": block.tx_count,
"state_root": block.state_root,
}
@router.get("/tx/{tx_hash}", summary="Get transaction by hash")
async def get_transaction(tx_hash: str) -> Dict[str, Any]:
metrics_registry.increment("rpc_get_transaction_total")
start = time.perf_counter()
with session_scope() as session:
tx = session.exec(select(Transaction).where(Transaction.tx_hash == tx_hash)).first()
if tx is None:
metrics_registry.increment("rpc_get_transaction_not_found_total")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="transaction not found")
return {
"tx_hash": tx.tx_hash,
"block_height": tx.block_height,
"sender": tx.sender,
"recipient": tx.recipient,
"payload": tx.payload,
"created_at": tx.created_at.isoformat(),
}
metrics_registry.increment("rpc_get_transaction_success_total")
metrics_registry.observe("rpc_get_transaction_duration_seconds", time.perf_counter() - start)
return {
"tx_hash": tx.tx_hash,
"block_height": tx.block_height,
"sender": tx.sender,
"recipient": tx.recipient,
"payload": tx.payload,
"created_at": tx.created_at.isoformat(),
}
@router.get("/receipts/{receipt_id}", summary="Get receipt by ID")
async def get_receipt(receipt_id: str) -> Dict[str, Any]:
metrics_registry.increment("rpc_get_receipt_total")
start = time.perf_counter()
with session_scope() as session:
receipt = session.exec(select(Receipt).where(Receipt.receipt_id == receipt_id)).first()
if receipt is None:
metrics_registry.increment("rpc_get_receipt_not_found_total")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="receipt not found")
return _serialize_receipt(receipt)
metrics_registry.increment("rpc_get_receipt_success_total")
metrics_registry.observe("rpc_get_receipt_duration_seconds", time.perf_counter() - start)
return _serialize_receipt(receipt)
@router.get("/getBalance/{address}", summary="Get account balance")
async def get_balance(address: str) -> Dict[str, Any]:
metrics_registry.increment("rpc_get_balance_total")
start = time.perf_counter()
with session_scope() as session:
account = session.get(Account, address)
if account is None:
metrics_registry.increment("rpc_get_balance_empty_total")
metrics_registry.observe("rpc_get_balance_duration_seconds", time.perf_counter() - start)
return {"address": address, "balance": 0, "nonce": 0}
return {
"address": account.address,
"balance": account.balance,
"nonce": account.nonce,
"updated_at": account.updated_at.isoformat(),
}
metrics_registry.increment("rpc_get_balance_success_total")
metrics_registry.observe("rpc_get_balance_duration_seconds", time.perf_counter() - start)
return {
"address": account.address,
"balance": account.balance,
"nonce": account.nonce,
"updated_at": account.updated_at.isoformat(),
}
@router.post("/sendTx", summary="Submit a new transaction")
async def send_transaction(request: TransactionRequest) -> Dict[str, Any]:
metrics_registry.increment("rpc_send_tx_total")
start = time.perf_counter()
mempool = get_mempool()
tx_dict = request.model_dump()
tx_hash = mempool.add(tx_dict)
metrics_registry.increment("rpc_send_tx_total")
return {"tx_hash": tx_hash}
try:
asyncio.create_task(
gossip_broker.publish(
"transactions",
{
"tx_hash": tx_hash,
"sender": request.sender,
"recipient": request.recipient,
"payload": request.payload,
"nonce": request.nonce,
"fee": request.fee,
"type": request.type,
},
)
)
metrics_registry.increment("rpc_send_tx_success_total")
return {"tx_hash": tx_hash}
except Exception:
metrics_registry.increment("rpc_send_tx_failed_total")
raise
finally:
metrics_registry.observe("rpc_send_tx_duration_seconds", time.perf_counter() - start)
@router.post("/submitReceipt", summary="Submit receipt claim transaction")
async def submit_receipt(request: ReceiptSubmissionRequest) -> Dict[str, Any]:
metrics_registry.increment("rpc_submit_receipt_total")
start = time.perf_counter()
tx_payload = {
"type": "RECEIPT_CLAIM",
"sender": request.sender,
@ -151,17 +204,31 @@ async def submit_receipt(request: ReceiptSubmissionRequest) -> Dict[str, Any]:
"sig": request.sig,
}
tx_request = TransactionRequest.model_validate(tx_payload)
metrics_registry.increment("rpc_submit_receipt_total")
return await send_transaction(tx_request)
try:
response = await send_transaction(tx_request)
metrics_registry.increment("rpc_submit_receipt_success_total")
return response
except HTTPException:
metrics_registry.increment("rpc_submit_receipt_failed_total")
raise
except Exception:
metrics_registry.increment("rpc_submit_receipt_failed_total")
raise
finally:
metrics_registry.observe("rpc_submit_receipt_duration_seconds", time.perf_counter() - start)
@router.post("/estimateFee", summary="Estimate transaction fee")
async def estimate_fee(request: EstimateFeeRequest) -> Dict[str, Any]:
metrics_registry.increment("rpc_estimate_fee_total")
start = time.perf_counter()
base_fee = 10
per_byte = 1
payload_bytes = len(json.dumps(request.payload, sort_keys=True, separators=(",", ":")).encode())
estimated_fee = base_fee + per_byte * payload_bytes
tx_type = (request.type or "TRANSFER").upper()
metrics_registry.increment("rpc_estimate_fee_success_total")
metrics_registry.observe("rpc_estimate_fee_duration_seconds", time.perf_counter() - start)
return {
"type": tx_type,
"base_fee": base_fee,
@ -172,6 +239,8 @@ async def estimate_fee(request: EstimateFeeRequest) -> Dict[str, Any]:
@router.post("/admin/mintFaucet", summary="Mint devnet funds to an address")
async def mint_faucet(request: MintFaucetRequest) -> Dict[str, Any]:
metrics_registry.increment("rpc_mint_faucet_total")
start = time.perf_counter()
with session_scope() as session:
account = session.get(Account, request.address)
if account is None:
@ -181,4 +250,6 @@ async def mint_faucet(request: MintFaucetRequest) -> Dict[str, Any]:
account.balance += request.amount
session.commit()
updated_balance = account.balance
metrics_registry.increment("rpc_mint_faucet_success_total")
metrics_registry.observe("rpc_mint_faucet_duration_seconds", time.perf_counter() - start)
return {"address": request.address, "balance": updated_balance}

View File

@ -0,0 +1,34 @@
from __future__ import annotations
import asyncio
from typing import AsyncIterator, Dict
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from ..gossip import gossip_broker
router = APIRouter(prefix="/ws", tags=["ws"])
async def _stream_topic(topic: str, websocket: WebSocket) -> None:
subscription = await gossip_broker.subscribe(topic)
try:
while True:
message = await subscription.get()
await websocket.send_json(message)
except WebSocketDisconnect:
pass
finally:
subscription.close()
@router.websocket("/blocks")
async def blocks_stream(websocket: WebSocket) -> None:
await websocket.accept()
await _stream_topic("blocks", websocket)
@router.websocket("/transactions")
async def transactions_stream(websocket: WebSocket) -> None:
await websocket.accept()
await _stream_topic("transactions", websocket)

View File

@ -0,0 +1,23 @@
from __future__ import annotations
import pytest
from sqlmodel import SQLModel, Session, create_engine
from aitbc_chain.models import Block, Transaction, Receipt # noqa: F401 - ensure models imported for metadata
@pytest.fixture(name="engine")
def engine_fixture():
engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False})
SQLModel.metadata.create_all(engine)
try:
yield engine
finally:
SQLModel.metadata.drop_all(engine)
@pytest.fixture(name="session")
def session_fixture(engine):
with Session(engine) as session:
yield session
session.rollback()

View File

@ -0,0 +1,76 @@
from __future__ import annotations
import asyncio
import pytest
from fastapi.testclient import TestClient
from aitbc_chain.app import create_app
from aitbc_chain.gossip import BroadcastGossipBackend, InMemoryGossipBackend, gossip_broker
@pytest.fixture(autouse=True)
async def reset_broker_backend():
previous_backend = InMemoryGossipBackend()
await gossip_broker.set_backend(previous_backend)
yield
await gossip_broker.set_backend(InMemoryGossipBackend())
def _run_in_thread(fn):
loop = asyncio.get_event_loop()
return loop.run_in_executor(None, fn)
@pytest.mark.asyncio
async def test_websocket_fanout_with_broadcast_backend():
backend = BroadcastGossipBackend("memory://")
await gossip_broker.set_backend(backend)
app = create_app()
loop = asyncio.get_running_loop()
def _sync_test() -> None:
with TestClient(app) as client:
with client.websocket_connect("/rpc/ws/transactions") as ws_a, client.websocket_connect(
"/rpc/ws/transactions"
) as ws_b:
payload = {
"tx_hash": "0x01",
"sender": "alice",
"recipient": "bob",
"payload": {"amount": 1},
"nonce": 0,
"fee": 0,
"type": "TRANSFER",
}
fut = asyncio.run_coroutine_threadsafe(gossip_broker.publish("transactions", payload), loop)
fut.result(timeout=5.0)
assert ws_a.receive_json() == payload
assert ws_b.receive_json() == payload
await _run_in_thread(_sync_test)
@pytest.mark.asyncio
async def test_broadcast_backend_decodes_cursorless_payload():
backend = BroadcastGossipBackend("memory://")
await gossip_broker.set_backend(backend)
app = create_app()
loop = asyncio.get_running_loop()
def _sync_test() -> None:
with TestClient(app) as client:
with client.websocket_connect("/rpc/ws/blocks") as ws:
payload = [
{"height": 1, "hash": "0xabc"},
{"height": 2, "hash": "0xdef"},
]
fut = asyncio.run_coroutine_threadsafe(gossip_broker.publish("blocks", payload), loop)
fut.result(timeout=5.0)
assert ws.receive_json() == payload
await _run_in_thread(_sync_test)

View File

@ -0,0 +1,92 @@
from __future__ import annotations
import pytest
from sqlmodel import Session
from aitbc_chain.models import Block, Transaction, Receipt
def _insert_block(session: Session, height: int = 0) -> Block:
block = Block(
height=height,
hash=f"0x{'0'*63}{height}",
parent_hash="0x" + "0" * 64,
proposer="validator",
tx_count=0,
)
session.add(block)
session.commit()
session.refresh(block)
return block
def test_relationships(session: Session) -> None:
block = _insert_block(session, height=1)
tx = Transaction(
tx_hash="0x" + "1" * 64,
block_height=block.height,
sender="alice",
recipient="bob",
payload={"foo": "bar"},
)
receipt = Receipt(
job_id="job-1",
receipt_id="0x" + "2" * 64,
block_height=block.height,
payload={},
miner_signature={},
coordinator_attestations=[],
)
session.add(tx)
session.add(receipt)
session.commit()
session.refresh(tx)
session.refresh(receipt)
assert tx.block is not None
assert tx.block.hash == block.hash
assert receipt.block is not None
assert receipt.block.hash == block.hash
def test_hash_validation_accepts_hex(session: Session) -> None:
block = Block(
height=10,
hash="0x" + "a" * 64,
parent_hash="0x" + "b" * 64,
proposer="validator",
)
session.add(block)
session.commit()
session.refresh(block)
assert block.hash.startswith("0x")
assert block.parent_hash.startswith("0x")
def test_hash_validation_rejects_non_hex(session: Session) -> None:
with pytest.raises(ValueError):
Block(
height=20,
hash="not-hex",
parent_hash="0x" + "c" * 64,
proposer="validator",
)
with pytest.raises(ValueError):
Transaction(
tx_hash="bad",
sender="alice",
recipient="bob",
payload={},
)
with pytest.raises(ValueError):
Receipt(
job_id="job",
receipt_id="oops",
payload={},
miner_signature={},
coordinator_attestations=[],
)

View File

@ -0,0 +1,46 @@
from __future__ import annotations
import asyncio
from fastapi.testclient import TestClient
from aitbc_chain.app import create_app
from aitbc_chain.gossip import gossip_broker
def _publish(topic: str, message: dict) -> None:
asyncio.run(gossip_broker.publish(topic, message))
def test_blocks_websocket_stream() -> None:
client = TestClient(create_app())
with client.websocket_connect("/rpc/ws/blocks") as websocket:
payload = {
"height": 1,
"hash": "0x" + "1" * 64,
"parent_hash": "0x" + "0" * 64,
"timestamp": "2025-01-01T00:00:00Z",
"tx_count": 2,
}
_publish("blocks", payload)
message = websocket.receive_json()
assert message == payload
def test_transactions_websocket_stream() -> None:
client = TestClient(create_app())
with client.websocket_connect("/rpc/ws/transactions") as websocket:
payload = {
"tx_hash": "0x" + "a" * 64,
"sender": "alice",
"recipient": "bob",
"payload": {"amount": 1},
"nonce": 1,
"fee": 0,
"type": "TRANSFER",
}
_publish("transactions", payload)
message = websocket.receive_json()
assert message == payload

View File

@ -6,6 +6,45 @@
z-index: 1000;
}
@media (max-width: 600px) {
.page {
padding: 1.5rem 1rem 3rem;
}
.site-header__inner {
flex-direction: column;
align-items: flex-start;
}
.site-header__controls {
align-items: stretch;
gap: 0.5rem;
}
.site-header__nav {
gap: 0.5rem;
}
.site-header__nav a {
flex: 1 1 45%;
text-align: center;
}
.addresses__input-group,
.receipts__input-group {
flex-direction: column;
}
.toast-container {
left: 0;
right: 0;
top: auto;
bottom: 1rem;
width: min(90vw, 360px);
margin: 0 auto;
}
}
.site-header__inner {
margin: 0 auto;
max-width: 1200px;
@ -80,6 +119,37 @@
padding: 2rem 1.5rem 4rem;
}
.toast-container {
position: fixed;
top: 1.25rem;
right: 1.25rem;
display: grid;
gap: 0.75rem;
z-index: 1200;
}
.toast {
opacity: 0;
transform: translateY(-6px);
transition: opacity 150ms ease, transform 180ms ease;
border-radius: 0.75rem;
padding: 0.75rem 1rem;
font-size: 0.9rem;
min-width: 220px;
}
.toast--error {
background: rgba(255, 102, 102, 0.16);
border: 1px solid rgba(255, 102, 102, 0.35);
color: #ffd3d3;
box-shadow: 0 12px 30px rgba(0, 0, 0, 0.35);
}
.toast.is-visible {
opacity: 1;
transform: translateY(0px);
}
@media (max-width: 768px) {
.site-header__inner {
justify-content: space-between;

View File

@ -0,0 +1,34 @@
const TOAST_DURATION_MS = 4000;
let container: HTMLDivElement | null = null;
export function initNotifications(): void {
if (!container) {
container = document.createElement("div");
container.className = "toast-container";
document.body.appendChild(container);
}
}
export function notifyError(message: string): void {
if (!container) {
initNotifications();
}
if (!container) {
return;
}
const toast = document.createElement("div");
toast.className = "toast toast--error";
toast.textContent = message;
container.appendChild(toast);
requestAnimationFrame(() => {
toast.classList.add("is-visible");
});
setTimeout(() => {
toast.classList.remove("is-visible");
setTimeout(() => toast.remove(), 250);
}, TOAST_DURATION_MS);
}

View File

@ -1,4 +1,5 @@
import { CONFIG, type DataMode } from "../config";
import { notifyError } from "../components/notifications";
import type {
BlockListResponse,
TransactionListResponse,
@ -35,6 +36,7 @@ export async function fetchBlocks(): Promise<BlockSummary[]> {
return data.items;
} catch (error) {
console.warn("[Explorer] Failed to fetch live block data", error);
notifyError("Unable to load live block data. Displaying placeholders.");
return [];
}
}
@ -54,6 +56,7 @@ export async function fetchTransactions(): Promise<TransactionSummary[]> {
return data.items;
} catch (error) {
console.warn("[Explorer] Failed to fetch live transaction data", error);
notifyError("Unable to load live transaction data. Displaying placeholders.");
return [];
}
}
@ -73,6 +76,7 @@ export async function fetchAddresses(): Promise<AddressSummary[]> {
return Array.isArray(data) ? data : data.items;
} catch (error) {
console.warn("[Explorer] Failed to fetch live address data", error);
notifyError("Unable to load live address data. Displaying placeholders.");
return [];
}
}
@ -92,6 +96,7 @@ export async function fetchReceipts(): Promise<ReceiptSummary[]> {
return data.items;
} catch (error) {
console.warn("[Explorer] Failed to fetch live receipt data", error);
notifyError("Unable to load live receipt data. Displaying placeholders.");
return [];
}
}
@ -107,6 +112,7 @@ async function fetchMock<T>(resource: string): Promise<T> {
return (await response.json()) as T;
} catch (error) {
console.warn(`[Explorer] Failed to fetch mock data from ${url}`, error);
notifyError("Mock data is unavailable. Please verify development assets.");
return [] as unknown as T;
}
}

View File

@ -10,6 +10,7 @@ import { addressesTitle, renderAddressesPage, initAddressesPage } from "./pages/
import { receiptsTitle, renderReceiptsPage, initReceiptsPage } from "./pages/receipts";
import { initDataModeToggle } from "./components/dataModeToggle";
import { getDataMode } from "./lib/mockData";
import { initNotifications } from "./components/notifications";
type PageConfig = {
title: string;
@ -49,14 +50,13 @@ const routes: Record<string, PageConfig> = {
};
function render(): void {
initNotifications();
const root = document.querySelector<HTMLDivElement>("#app");
if (!root) {
console.warn("[Explorer] Missing #app root element");
return;
}
document.documentElement.dataset.mode = getDataMode();
const currentPath = window.location.pathname.replace(/\/$/, "");
const normalizedPath = currentPath === "" ? "/" : currentPath;
const page = routes[normalizedPath] ?? null;

View File

@ -40,7 +40,6 @@ export async function initOverviewPage(): Promise<void> {
fetchTransactions(),
fetchReceipts(),
]);
const blockStats = document.querySelector<HTMLUListElement>(
"#overview-block-stats",
);
@ -54,13 +53,12 @@ export async function initOverviewPage(): Promise<void> {
<li><strong>Time:</strong> ${new Date(latest.timestamp).toLocaleString()}</li>
`;
} else {
blockStats.innerHTML = `<li class="placeholder">No mock block data available.</li>`;
blockStats.innerHTML = `
<li class="placeholder">No blocks available. Try switching data mode.</li>
`;
}
}
const txStats = document.querySelector<HTMLUListElement>(
"#overview-transaction-stats",
);
const txStats = document.querySelector<HTMLUListElement>("#overview-transaction-stats");
if (txStats) {
if (transactions.length > 0) {
const succeeded = transactions.filter((tx) => tx.status === "Succeeded");
@ -70,7 +68,7 @@ export async function initOverviewPage(): Promise<void> {
<li><strong>Pending:</strong> ${transactions.length - succeeded.length}</li>
`;
} else {
txStats.innerHTML = `<li class="placeholder">No mock transaction data available.</li>`;
txStats.innerHTML = `<li class="placeholder">No transactions available. Try switching data mode.</li>`;
}
}
@ -86,7 +84,7 @@ export async function initOverviewPage(): Promise<void> {
<li><strong>Pending:</strong> ${receipts.length - attested.length}</li>
`;
} else {
receiptStats.innerHTML = `<li class="placeholder">No mock receipt data available.</li>`;
receiptStats.innerHTML = `<li class="placeholder">No receipts available. Try switching data mode.</li>`;
}
}
}

View File

@ -0,0 +1,47 @@
from __future__ import annotations
import asyncio
from logging.config import fileConfig
from alembic import context
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from poolhub.models import Base
from poolhub.settings import settings
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = Base.metadata
def _configure_context(connection=None, *, url: str | None = None) -> None:
context.configure(
connection=connection,
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
def run_migrations_offline() -> None:
_configure_context(url=settings.postgres_dsn)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
connectable = create_async_engine(settings.postgres_dsn, pool_pre_ping=True)
async with connectable.connect() as connection:
await connection.run_sync(_configure_context)
await connection.run_sync(lambda conn: context.run_migrations())
await connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())

View File

@ -0,0 +1,104 @@
"""initial schema
Revision ID: a58c1f3b3e87
Revises:
Create Date: 2025-09-27 12:07:40.000000
"""
from __future__ import annotations
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a58c1f3b3e87"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"miners",
sa.Column("miner_id", sa.String(length=64), primary_key=True),
sa.Column("api_key_hash", sa.String(length=128), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
sa.Column("last_seen_at", sa.DateTime(timezone=True)),
sa.Column("addr", sa.String(length=256)),
sa.Column("proto", sa.String(length=32)),
sa.Column("gpu_vram_gb", sa.Float()),
sa.Column("gpu_name", sa.String(length=128)),
sa.Column("cpu_cores", sa.Integer()),
sa.Column("ram_gb", sa.Float()),
sa.Column("max_parallel", sa.Integer()),
sa.Column("base_price", sa.Float()),
sa.Column("tags", postgresql.JSONB(astext_type=sa.Text())),
sa.Column("capabilities", postgresql.JSONB(astext_type=sa.Text())),
sa.Column("trust_score", sa.Float(), server_default="0.5"),
sa.Column("region", sa.String(length=64)),
)
op.create_table(
"miner_status",
sa.Column("miner_id", sa.String(length=64), sa.ForeignKey("miners.miner_id", ondelete="CASCADE"), primary_key=True),
sa.Column("queue_len", sa.Integer(), server_default="0"),
sa.Column("busy", sa.Boolean(), server_default=sa.text("false")),
sa.Column("avg_latency_ms", sa.Integer()),
sa.Column("temp_c", sa.Integer()),
sa.Column("mem_free_gb", sa.Float()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
)
op.create_table(
"match_requests",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("job_id", sa.String(length=64), nullable=False),
sa.Column("requirements", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column("hints", postgresql.JSONB(astext_type=sa.Text()), server_default=sa.text("'{}'::jsonb")),
sa.Column("top_k", sa.Integer(), server_default="1"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
)
op.create_table(
"match_results",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("request_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("match_requests.id", ondelete="CASCADE"), nullable=False),
sa.Column("miner_id", sa.String(length=64), nullable=False),
sa.Column("score", sa.Float(), nullable=False),
sa.Column("explain", sa.Text()),
sa.Column("eta_ms", sa.Integer()),
sa.Column("price", sa.Float()),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
)
op.create_index("ix_match_results_request_id", "match_results", ["request_id"])
op.create_table(
"feedback",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column("job_id", sa.String(length=64), nullable=False),
sa.Column("miner_id", sa.String(length=64), sa.ForeignKey("miners.miner_id", ondelete="CASCADE"), nullable=False),
sa.Column("outcome", sa.String(length=32), nullable=False),
sa.Column("latency_ms", sa.Integer()),
sa.Column("fail_code", sa.String(length=64)),
sa.Column("tokens_spent", sa.Float()),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("NOW()")),
)
op.create_index("ix_feedback_miner_id", "feedback", ["miner_id"])
op.create_index("ix_feedback_job_id", "feedback", ["job_id"])
def downgrade() -> None:
op.drop_index("ix_feedback_job_id", table_name="feedback")
op.drop_index("ix_feedback_miner_id", table_name="feedback")
op.drop_table("feedback")
op.drop_index("ix_match_results_request_id", table_name="match_results")
op.drop_table("match_results")
op.drop_table("match_requests")
op.drop_table("miner_status")
op.drop_table("miners")

View File

@ -0,0 +1,13 @@
"""AITBC Pool Hub service package."""
from .settings import Settings, settings
from .database import create_engine, get_session
from .redis_cache import get_redis
__all__ = [
"Settings",
"settings",
"create_engine",
"get_session",
"get_redis",
]

View File

@ -0,0 +1,5 @@
"""FastAPI application wiring for the AITBC Pool Hub."""
from .main import create_app, app
__all__ = ["create_app", "app"]

View File

@ -0,0 +1,27 @@
from __future__ import annotations
from typing import AsyncGenerator
from fastapi import Depends
from ..database import get_session
from ..redis_cache import get_redis
def get_db_session() -> AsyncGenerator:
return get_session()
def get_redis_client() -> AsyncGenerator:
return get_redis()
# FastAPI dependency wrappers
async def db_session_dep(session=Depends(get_session)):
async for s in session:
yield s
async def redis_dep(client=Depends(get_redis)):
async for c in client:
yield c

View File

@ -0,0 +1,31 @@
from __future__ import annotations
from contextlib import asynccontextmanager
from fastapi import FastAPI
from ..database import close_engine, create_engine
from ..redis_cache import close_redis, create_redis
from ..settings import settings
from .routers import health_router, match_router, metrics_router
@asynccontextmanager
async def lifespan(_: FastAPI):
create_engine()
create_redis()
try:
yield
finally:
await close_engine()
await close_redis()
app = FastAPI(**settings.asgi_kwargs(), lifespan=lifespan)
app.include_router(match_router, prefix="/v1")
app.include_router(health_router)
app.include_router(metrics_router)
def create_app() -> FastAPI:
return app

View File

@ -0,0 +1,39 @@
from __future__ import annotations
from prometheus_client import CONTENT_TYPE_LATEST, Counter, Gauge, Histogram, generate_latest
match_requests_total = Counter(
"poolhub_match_requests_total",
"Total number of match requests received",
)
match_candidates_returned = Counter(
"poolhub_match_candidates_total",
"Total number of candidates returned",
)
match_failures_total = Counter(
"poolhub_match_failures_total",
"Total number of match request failures",
)
match_latency_seconds = Histogram(
"poolhub_match_latency_seconds",
"Latency of match processing",
buckets=(0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0),
)
miners_online_gauge = Gauge(
"poolhub_miners_online",
"Number of miners considered online",
)
def render_metrics() -> tuple[str, str]:
return generate_latest(), CONTENT_TYPE_LATEST
def reset_metrics() -> None:
match_requests_total._value.set(0) # type: ignore[attr-defined]
match_candidates_returned._value.set(0) # type: ignore[attr-defined]
match_failures_total._value.set(0) # type: ignore[attr-defined]
match_latency_seconds._sum.set(0) # type: ignore[attr-defined]
match_latency_seconds._count.set(0) # type: ignore[attr-defined]
match_latency_seconds._samples = [] # type: ignore[attr-defined]
miners_online_gauge._value.set(0) # type: ignore[attr-defined]

View File

@ -0,0 +1,7 @@
"""FastAPI routers for Pool Hub."""
from .match import router as match_router
from .health import router as health_router
from .metrics import router as metrics_router
__all__ = ["match_router", "health_router", "metrics_router"]

View File

@ -0,0 +1,50 @@
from __future__ import annotations
from fastapi import APIRouter, Depends
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from ..deps import db_session_dep, redis_dep
from ..prometheus import miners_online_gauge
from poolhub.repositories.miner_repository import MinerRepository
from ..schemas import HealthResponse
router = APIRouter(tags=["health"], prefix="/v1")
@router.get("/health", response_model=HealthResponse, summary="Pool Hub health status")
async def health_endpoint(
session: AsyncSession = Depends(db_session_dep),
redis: Redis = Depends(redis_dep),
) -> HealthResponse:
db_ok = True
redis_ok = True
db_error: str | None = None
redis_error: str | None = None
try:
await session.execute("SELECT 1")
except Exception as exc: # pragma: no cover
db_ok = False
db_error = str(exc)
try:
await redis.ping()
except Exception as exc: # pragma: no cover
redis_ok = False
redis_error = str(exc)
miner_repo = MinerRepository(session, redis)
active_miners = await miner_repo.list_active_miners()
miners_online = len(active_miners)
miners_online_gauge.set(miners_online)
status = "ok" if db_ok and redis_ok else "degraded"
return HealthResponse(
status=status,
db=db_ok,
redis=redis_ok,
miners_online=miners_online,
db_error=db_error,
redis_error=redis_error,
)

View File

@ -0,0 +1,116 @@
from __future__ import annotations
import time
from typing import Any, Dict, List
from fastapi import APIRouter, Depends, HTTPException, status
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from ..deps import db_session_dep, redis_dep
from ..prometheus import (
match_candidates_returned,
match_failures_total,
match_latency_seconds,
match_requests_total,
)
from poolhub.repositories.match_repository import MatchRepository
from poolhub.repositories.miner_repository import MinerRepository
from ..schemas import MatchCandidate, MatchRequestPayload, MatchResponse
router = APIRouter(tags=["match"])
def _normalize_requirements(requirements: Dict[str, Any]) -> Dict[str, Any]:
return requirements or {}
def _candidate_from_payload(payload: Dict[str, Any]) -> MatchCandidate:
return MatchCandidate(**payload)
@router.post("/match", response_model=MatchResponse, summary="Find top miners for a job")
async def match_endpoint(
payload: MatchRequestPayload,
session: AsyncSession = Depends(db_session_dep),
redis: Redis = Depends(redis_dep),
) -> MatchResponse:
start = time.perf_counter()
match_requests_total.inc()
miner_repo = MinerRepository(session, redis)
match_repo = MatchRepository(session, redis)
requirements = _normalize_requirements(payload.requirements)
top_k = payload.top_k
try:
request = await match_repo.create_request(
job_id=payload.job_id,
requirements=requirements,
hints=payload.hints,
top_k=top_k,
)
active_miners = await miner_repo.list_active_miners()
candidates = _select_candidates(requirements, payload.hints, active_miners, top_k)
await match_repo.add_results(
request_id=request.id,
candidates=candidates,
)
match_candidates_returned.inc(len(candidates))
duration = time.perf_counter() - start
match_latency_seconds.observe(duration)
return MatchResponse(
job_id=payload.job_id,
candidates=[_candidate_from_payload(candidate) for candidate in candidates],
)
except Exception as exc: # pragma: no cover - safeguards unexpected failures
match_failures_total.inc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="match_failed") from exc
def _select_candidates(
requirements: Dict[str, Any],
hints: Dict[str, Any],
active_miners: List[tuple],
top_k: int,
) -> List[Dict[str, Any]]:
min_vram = float(requirements.get("min_vram_gb", 0))
min_ram = float(requirements.get("min_ram_gb", 0))
capabilities_required = set(requirements.get("capabilities_any", []))
region_hint = hints.get("region")
ranked: List[Dict[str, Any]] = []
for miner, status, score in active_miners:
if miner.gpu_vram_gb and miner.gpu_vram_gb < min_vram:
continue
if miner.ram_gb and miner.ram_gb < min_ram:
continue
if capabilities_required and not capabilities_required.issubset(set(miner.capabilities or [])):
continue
if region_hint and miner.region and miner.region != region_hint:
continue
candidate = {
"miner_id": miner.miner_id,
"addr": miner.addr,
"proto": miner.proto,
"score": float(score),
"explain": _compose_explain(score, miner, status),
"eta_ms": status.avg_latency_ms if status else None,
"price": miner.base_price,
}
ranked.append(candidate)
ranked.sort(key=lambda item: item["score"], reverse=True)
return ranked[:top_k]
def _compose_explain(score: float, miner, status) -> str:
load = status.queue_len if status else 0
latency = status.avg_latency_ms if status else "n/a"
return f"score={score:.3f} load={load} latency={latency}"

View File

@ -0,0 +1,13 @@
from __future__ import annotations
from fastapi import APIRouter, Response
from ..prometheus import render_metrics
router = APIRouter(tags=["metrics"])
@router.get("/metrics", summary="Prometheus metrics")
async def metrics_endpoint() -> Response:
payload, content_type = render_metrics()
return Response(content=payload, media_type=content_type)

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class MatchRequestPayload(BaseModel):
job_id: str
requirements: Dict[str, Any] = Field(default_factory=dict)
hints: Dict[str, Any] = Field(default_factory=dict)
top_k: int = Field(default=1, ge=1, le=50)
class MatchCandidate(BaseModel):
miner_id: str
addr: str
proto: str
score: float
explain: Optional[str] = None
eta_ms: Optional[int] = None
price: Optional[float] = None
class MatchResponse(BaseModel):
job_id: str
candidates: List[MatchCandidate]
class HealthResponse(BaseModel):
status: str
db: bool
redis: bool
miners_online: int
db_error: Optional[str] = None
redis_error: Optional[str] = None
class MetricsResponse(BaseModel):
detail: str = "Prometheus metrics output"

View File

@ -0,0 +1,54 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from .settings import settings
_engine: AsyncEngine | None = None
_session_factory: async_sessionmaker[AsyncSession] | None = None
def create_engine() -> AsyncEngine:
global _engine, _session_factory
if _engine is None:
_engine = create_async_engine(
settings.postgres_dsn,
pool_size=settings.postgres_pool_max,
max_overflow=0,
pool_pre_ping=True,
)
_session_factory = async_sessionmaker(
bind=_engine,
expire_on_commit=False,
autoflush=False,
)
return _engine
def get_engine() -> AsyncEngine:
if _engine is None:
return create_engine()
return _engine
def get_session_factory() -> async_sessionmaker[AsyncSession]:
if _session_factory is None:
create_engine()
assert _session_factory is not None
return _session_factory
async def get_session() -> AsyncGenerator[AsyncSession, None]:
session_factory = get_session_factory()
async with session_factory() as session:
yield session
async def close_engine() -> None:
global _engine
if _engine is not None:
await _engine.dispose()
_engine = None

View File

@ -0,0 +1,95 @@
from __future__ import annotations
import datetime as dt
from typing import Dict, List, Optional
from sqlalchemy import Boolean, Column, DateTime, Float, ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from uuid import uuid4
class Base(DeclarativeBase):
pass
class Miner(Base):
__tablename__ = "miners"
miner_id: Mapped[str] = mapped_column(String(64), primary_key=True)
api_key_hash: Mapped[str] = mapped_column(String(128), nullable=False)
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow)
last_seen_at: Mapped[Optional[dt.datetime]] = mapped_column(DateTime(timezone=True))
addr: Mapped[str] = mapped_column(String(256))
proto: Mapped[str] = mapped_column(String(32))
gpu_vram_gb: Mapped[float] = mapped_column(Float)
gpu_name: Mapped[Optional[str]] = mapped_column(String(128))
cpu_cores: Mapped[int] = mapped_column(Integer)
ram_gb: Mapped[float] = mapped_column(Float)
max_parallel: Mapped[int] = mapped_column(Integer)
base_price: Mapped[float] = mapped_column(Float)
tags: Mapped[Dict[str, str]] = mapped_column(JSONB, default=dict)
capabilities: Mapped[List[str]] = mapped_column(JSONB, default=list)
trust_score: Mapped[float] = mapped_column(Float, default=0.5)
region: Mapped[Optional[str]] = mapped_column(String(64))
status: Mapped["MinerStatus"] = relationship(back_populates="miner", cascade="all, delete-orphan", uselist=False)
feedback: Mapped[List["Feedback"]] = relationship(back_populates="miner", cascade="all, delete-orphan")
class MinerStatus(Base):
__tablename__ = "miner_status"
miner_id: Mapped[str] = mapped_column(ForeignKey("miners.miner_id", ondelete="CASCADE"), primary_key=True)
queue_len: Mapped[int] = mapped_column(Integer, default=0)
busy: Mapped[bool] = mapped_column(Boolean, default=False)
avg_latency_ms: Mapped[Optional[int]] = mapped_column(Integer)
temp_c: Mapped[Optional[int]] = mapped_column(Integer)
mem_free_gb: Mapped[Optional[float]] = mapped_column(Float)
updated_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow, onupdate=dt.datetime.utcnow)
miner: Mapped[Miner] = relationship(back_populates="status")
class MatchRequest(Base):
__tablename__ = "match_requests"
id: Mapped[PGUUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4)
job_id: Mapped[str] = mapped_column(String(64), nullable=False)
requirements: Mapped[Dict[str, object]] = mapped_column(JSONB, nullable=False)
hints: Mapped[Dict[str, object]] = mapped_column(JSONB, default=dict)
top_k: Mapped[int] = mapped_column(Integer, default=1)
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow)
results: Mapped[List["MatchResult"]] = relationship(back_populates="request", cascade="all, delete-orphan")
class MatchResult(Base):
__tablename__ = "match_results"
id: Mapped[PGUUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4)
request_id: Mapped[PGUUID] = mapped_column(ForeignKey("match_requests.id", ondelete="CASCADE"), index=True)
miner_id: Mapped[str] = mapped_column(String(64))
score: Mapped[float] = mapped_column(Float)
explain: Mapped[Optional[str]] = mapped_column(Text)
eta_ms: Mapped[Optional[int]] = mapped_column(Integer)
price: Mapped[Optional[float]] = mapped_column(Float)
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow)
request: Mapped[MatchRequest] = relationship(back_populates="results")
class Feedback(Base):
__tablename__ = "feedback"
id: Mapped[PGUUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4)
job_id: Mapped[str] = mapped_column(String(64), nullable=False)
miner_id: Mapped[str] = mapped_column(ForeignKey("miners.miner_id", ondelete="CASCADE"), nullable=False)
outcome: Mapped[str] = mapped_column(String(32), nullable=False)
latency_ms: Mapped[Optional[int]] = mapped_column(Integer)
fail_code: Mapped[Optional[str]] = mapped_column(String(64))
tokens_spent: Mapped[Optional[float]] = mapped_column(Float)
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=dt.datetime.utcnow)
miner: Mapped[Miner] = relationship(back_populates="feedback")

View File

@ -0,0 +1,39 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
import redis.asyncio as redis
from .settings import settings
_redis_client: redis.Redis | None = None
def create_redis() -> redis.Redis:
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(
settings.redis_url,
max_connections=settings.redis_max_connections,
encoding="utf-8",
decode_responses=True,
)
return _redis_client
def get_redis_client() -> redis.Redis:
if _redis_client is None:
return create_redis()
return _redis_client
async def get_redis() -> AsyncGenerator[redis.Redis, None]:
client = get_redis_client()
yield client
async def close_redis() -> None:
global _redis_client
if _redis_client is not None:
await _redis_client.close()
_redis_client = None

View File

@ -0,0 +1,11 @@
"""Repository layer for Pool Hub."""
from .miner_repository import MinerRepository
from .match_repository import MatchRepository
from .feedback_repository import FeedbackRepository
__all__ = [
"MinerRepository",
"MatchRepository",
"FeedbackRepository",
]

View File

@ -0,0 +1,81 @@
from __future__ import annotations
import datetime as dt
import json
import logging
from typing import Iterable, List, Optional
from uuid import UUID
from redis.asyncio import Redis
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Feedback
from ..storage.redis_keys import RedisKeys
logger = logging.getLogger(__name__)
class FeedbackRepository:
"""Persists coordinator feedback and emits Redis notifications."""
def __init__(self, session: AsyncSession, redis: Redis) -> None:
self._session = session
self._redis = redis
async def add_feedback(
self,
*,
job_id: str,
miner_id: str,
outcome: str,
latency_ms: Optional[int] = None,
fail_code: Optional[str] = None,
tokens_spent: Optional[float] = None,
) -> Feedback:
feedback = Feedback(
job_id=job_id,
miner_id=miner_id,
outcome=outcome,
latency_ms=latency_ms,
fail_code=fail_code,
tokens_spent=tokens_spent,
created_at=dt.datetime.utcnow(),
)
self._session.add(feedback)
await self._session.flush()
payload = {
"job_id": job_id,
"miner_id": miner_id,
"outcome": outcome,
"latency_ms": latency_ms,
"fail_code": fail_code,
"tokens_spent": tokens_spent,
"created_at": feedback.created_at.isoformat() if feedback.created_at else None,
}
try:
await self._redis.publish(RedisKeys.feedback_channel(), json.dumps(payload))
except Exception as exc: # pragma: no cover - defensive
logger.warning("Failed to publish feedback event for job %s: %s", job_id, exc)
return feedback
async def list_feedback_for_miner(self, miner_id: str, limit: int = 50) -> List[Feedback]:
stmt = (
select(Feedback)
.where(Feedback.miner_id == miner_id)
.order_by(Feedback.created_at.desc())
.limit(limit)
)
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def list_feedback_for_job(self, job_id: str, limit: int = 50) -> List[Feedback]:
stmt = (
select(Feedback)
.where(Feedback.job_id == job_id)
.order_by(Feedback.created_at.desc())
.limit(limit)
)
result = await self._session.execute(stmt)
return list(result.scalars().all())

View File

@ -0,0 +1,122 @@
from __future__ import annotations
import datetime as dt
import json
from typing import Iterable, List, Optional, Sequence
from uuid import UUID
from redis.asyncio import Redis
from sqlalchemy import Select, select
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import MatchRequest, MatchResult
from ..storage.redis_keys import RedisKeys
class MatchRepository:
"""Handles match request logging, result persistence, and Redis fan-out."""
def __init__(self, session: AsyncSession, redis: Redis) -> None:
self._session = session
self._redis = redis
async def create_request(
self,
*,
job_id: str,
requirements: dict[str, object],
hints: Optional[dict[str, object]] = None,
top_k: int = 1,
enqueue: bool = True,
) -> MatchRequest:
request = MatchRequest(
job_id=job_id,
requirements=requirements,
hints=hints or {},
top_k=top_k,
created_at=dt.datetime.utcnow(),
)
self._session.add(request)
await self._session.flush()
if enqueue:
payload = {
"request_id": str(request.id),
"job_id": request.job_id,
"requirements": request.requirements,
"hints": request.hints,
"top_k": request.top_k,
}
await self._redis.rpush(RedisKeys.match_requests(), json.dumps(payload))
return request
async def add_results(
self,
*,
request_id: UUID,
candidates: Sequence[dict[str, object]],
publish: bool = True,
) -> List[MatchResult]:
results: List[MatchResult] = []
created_at = dt.datetime.utcnow()
for candidate in candidates:
result = MatchResult(
request_id=request_id,
miner_id=str(candidate.get("miner_id")),
score=float(candidate.get("score", 0.0)),
explain=candidate.get("explain"),
eta_ms=candidate.get("eta_ms"),
price=candidate.get("price"),
created_at=created_at,
)
self._session.add(result)
results.append(result)
await self._session.flush()
if publish:
request = await self._session.get(MatchRequest, request_id)
if request:
redis_key = RedisKeys.match_results(request.job_id)
await self._redis.delete(redis_key)
if results:
payloads = [json.dumps(self._result_payload(result)) for result in results]
await self._redis.rpush(redis_key, *payloads)
await self._redis.expire(redis_key, 300)
channel = RedisKeys.match_results_channel(request.job_id)
for payload in payloads:
await self._redis.publish(channel, payload)
return results
async def get_request(self, request_id: UUID) -> Optional[MatchRequest]:
return await self._session.get(MatchRequest, request_id)
async def list_recent_requests(self, limit: int = 20) -> List[MatchRequest]:
stmt: Select[MatchRequest] = (
select(MatchRequest)
.order_by(MatchRequest.created_at.desc())
.limit(limit)
)
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def list_results_for_job(self, job_id: str, limit: int = 10) -> List[MatchResult]:
stmt: Select[MatchResult] = (
select(MatchResult)
.join(MatchRequest)
.where(MatchRequest.job_id == job_id)
.order_by(MatchResult.created_at.desc())
.limit(limit)
)
result = await self._session.execute(stmt)
return list(result.scalars().all())
def _result_payload(self, result: MatchResult) -> dict[str, object]:
return {
"request_id": str(result.request_id),
"miner_id": result.miner_id,
"score": result.score,
"explain": result.explain,
"eta_ms": result.eta_ms,
"price": result.price,
"created_at": result.created_at.isoformat() if result.created_at else None,
}

View File

@ -0,0 +1,181 @@
from __future__ import annotations
import datetime as dt
from typing import List, Optional, Tuple
from redis.asyncio import Redis
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Miner, MinerStatus
from ..settings import settings
from ..storage.redis_keys import RedisKeys
class MinerRepository:
"""Coordinates miner registry persistence across PostgreSQL and Redis."""
def __init__(self, session: AsyncSession, redis: Redis) -> None:
self._session = session
self._redis = redis
async def register_miner(
self,
miner_id: str,
api_key_hash: str,
*,
addr: str,
proto: str,
gpu_vram_gb: float,
gpu_name: Optional[str],
cpu_cores: int,
ram_gb: float,
max_parallel: int,
base_price: float,
tags: dict[str, str],
capabilities: list[str],
region: Optional[str],
) -> Miner:
miner = await self._session.get(Miner, miner_id)
if miner is None:
miner = Miner(
miner_id=miner_id,
api_key_hash=api_key_hash,
addr=addr,
proto=proto,
gpu_vram_gb=gpu_vram_gb,
gpu_name=gpu_name,
cpu_cores=cpu_cores,
ram_gb=ram_gb,
max_parallel=max_parallel,
base_price=base_price,
tags=tags,
capabilities=capabilities,
region=region,
)
self._session.add(miner)
status = MinerStatus(miner_id=miner_id)
self._session.add(status)
else:
miner.addr = addr
miner.proto = proto
miner.gpu_vram_gb = gpu_vram_gb
miner.gpu_name = gpu_name
miner.cpu_cores = cpu_cores
miner.ram_gb = ram_gb
miner.max_parallel = max_parallel
miner.base_price = base_price
miner.tags = tags
miner.capabilities = capabilities
miner.region = region
miner.last_seen_at = dt.datetime.utcnow()
await self._session.flush()
await self._sync_miner_to_redis(miner_id)
return miner
async def update_status(
self,
miner_id: str,
*,
queue_len: Optional[int] = None,
busy: Optional[bool] = None,
avg_latency_ms: Optional[int] = None,
temp_c: Optional[int] = None,
mem_free_gb: Optional[float] = None,
) -> None:
stmt = (
update(MinerStatus)
.where(MinerStatus.miner_id == miner_id)
.values(
{
k: v
for k, v in {
"queue_len": queue_len,
"busy": busy,
"avg_latency_ms": avg_latency_ms,
"temp_c": temp_c,
"mem_free_gb": mem_free_gb,
"updated_at": dt.datetime.utcnow(),
}.items()
if v is not None
}
)
)
await self._session.execute(stmt)
miner = await self._session.get(Miner, miner_id)
if miner:
miner.last_seen_at = dt.datetime.utcnow()
await self._session.flush()
await self._sync_miner_to_redis(miner_id)
async def touch_heartbeat(self, miner_id: str) -> None:
miner = await self._session.get(Miner, miner_id)
if miner is None:
return
miner.last_seen_at = dt.datetime.utcnow()
await self._session.flush()
await self._sync_miner_to_redis(miner_id)
async def get_miner(self, miner_id: str) -> Optional[Miner]:
return await self._session.get(Miner, miner_id)
async def iter_miners(self) -> List[Miner]:
result = await self._session.execute(select(Miner))
return list(result.scalars().all())
async def get_status(self, miner_id: str) -> Optional[MinerStatus]:
return await self._session.get(MinerStatus, miner_id)
async def list_active_miners(self) -> List[Tuple[Miner, Optional[MinerStatus], float]]:
stmt = select(Miner, MinerStatus).join(MinerStatus, MinerStatus.miner_id == Miner.miner_id, isouter=True)
result = await self._session.execute(stmt)
records: List[Tuple[Miner, Optional[MinerStatus], float]] = []
for miner, status in result.all():
score = self._compute_score(miner, status)
records.append((miner, status, score))
return records
async def _sync_miner_to_redis(self, miner_id: str) -> None:
miner = await self._session.get(Miner, miner_id)
if miner is None:
return
status = await self._session.get(MinerStatus, miner_id)
payload = {
"miner_id": miner.miner_id,
"addr": miner.addr,
"proto": miner.proto,
"region": miner.region or "",
"gpu_vram_gb": str(miner.gpu_vram_gb),
"ram_gb": str(miner.ram_gb),
"max_parallel": str(miner.max_parallel),
"base_price": str(miner.base_price),
"trust_score": str(miner.trust_score),
"queue_len": str(status.queue_len if status else 0),
"busy": str(status.busy if status else False),
}
redis_key = RedisKeys.miner_hash(miner_id)
await self._redis.hset(redis_key, mapping=payload)
await self._redis.expire(redis_key, settings.session_ttl_seconds + settings.heartbeat_grace_seconds)
score = self._compute_score(miner, status)
ranking_key = RedisKeys.miner_rankings(miner.region)
await self._redis.zadd(ranking_key, {miner_id: score})
await self._redis.expire(ranking_key, settings.session_ttl_seconds + settings.heartbeat_grace_seconds)
def _compute_score(self, miner: Miner, status: Optional[MinerStatus]) -> float:
load_factor = 1.0
if status and miner.max_parallel:
utilization = min(status.queue_len / max(miner.max_parallel, 1), 1.0)
load_factor = 1.0 - utilization
price_factor = 1.0 if miner.base_price <= 0 else min(1.0, 1.0 / miner.base_price)
trust_factor = max(miner.trust_score, 0.0)
return (settings.default_score_weights.capability * 1.0) + (
settings.default_score_weights.price * price_factor
) + (settings.default_score_weights.load * load_factor) + (
settings.default_score_weights.trust * trust_factor
)

View File

@ -0,0 +1,59 @@
from __future__ import annotations
from functools import lru_cache
from typing import Any, Dict, List
from pydantic import AnyHttpUrl, BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class ScoreWeights(BaseModel):
capability: float = Field(default=0.40, alias="cap")
price: float = Field(default=0.20)
latency: float = Field(default=0.20)
trust: float = Field(default=0.15)
load: float = Field(default=0.05)
model_config = SettingsConfigDict(populate_by_name=True)
def as_vector(self) -> List[float]:
return [self.capability, self.price, self.latency, self.trust, self.load]
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="poolhub_", env_file=".env", case_sensitive=False)
app_name: str = "AITBC Pool Hub"
bind_host: str = Field(default="127.0.0.1")
bind_port: int = Field(default=8203)
coordinator_shared_secret: str = Field(default="changeme")
postgres_dsn: str = Field(default="postgresql+asyncpg://poolhub:poolhub@127.0.0.1:5432/aitbc")
postgres_pool_min: int = Field(default=1)
postgres_pool_max: int = Field(default=10)
redis_url: str = Field(default="redis://127.0.0.1:6379/4")
redis_max_connections: int = Field(default=32)
session_ttl_seconds: int = Field(default=60)
heartbeat_grace_seconds: int = Field(default=120)
default_score_weights: ScoreWeights = Field(default_factory=ScoreWeights)
allowed_origins: List[AnyHttpUrl] = Field(default_factory=list)
prometheus_namespace: str = Field(default="poolhub")
def asgi_kwargs(self) -> Dict[str, Any]:
return {
"title": self.app_name,
}
@lru_cache(maxsize=1)
def get_settings() -> Settings:
return Settings()
settings = get_settings()

View File

@ -0,0 +1,5 @@
"""Storage utilities for the Pool Hub service."""
from .redis_keys import RedisKeys
__all__ = ["RedisKeys"]

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from typing import Final
class RedisKeys:
namespace: Final[str] = "poolhub"
@classmethod
def miner_hash(cls, miner_id: str) -> str:
return f"{cls.namespace}:miner:{miner_id}"
@classmethod
def miner_rankings(cls, region: str | None = None) -> str:
suffix = region or "global"
return f"{cls.namespace}:rankings:{suffix}"
@classmethod
def miner_session(cls, session_token: str) -> str:
return f"{cls.namespace}:session:{session_token}"
@classmethod
def heartbeat_stream(cls) -> str:
return f"{cls.namespace}:heartbeat-stream"
@classmethod
def match_requests(cls) -> str:
return f"{cls.namespace}:match-requests"
@classmethod
def match_results(cls, job_id: str) -> str:
return f"{cls.namespace}:match-results:{job_id}"
@classmethod
def feedback_channel(cls) -> str:
return f"{cls.namespace}:events:feedback"
@classmethod
def match_results_channel(cls, job_id: str) -> str:
return f"{cls.namespace}:events:match-results:{job_id}"

View File

@ -0,0 +1,63 @@
from __future__ import annotations
import os
import sys
from pathlib import Path
import pytest
import pytest_asyncio
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
BASE_DIR = Path(__file__).resolve().parents[2]
POOLHUB_SRC = BASE_DIR / "pool-hub" / "src"
if str(POOLHUB_SRC) not in sys.path:
sys.path.insert(0, str(POOLHUB_SRC))
from poolhub.models import Base
def _get_required_env(name: str) -> str:
value = os.getenv(name)
if not value:
pytest.skip(f"Set {name} to run Pool Hub integration tests")
return value
@pytest_asyncio.fixture()
async def db_engine() -> AsyncEngine:
dsn = _get_required_env("POOLHUB_TEST_POSTGRES_DSN")
engine = create_async_engine(dsn, pool_pre_ping=True)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
yield engine
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest_asyncio.fixture
async def db_session(db_engine: AsyncEngine) -> AsyncSession:
session_factory = async_sessionmaker(db_engine, expire_on_commit=False, autoflush=False)
async with session_factory() as session:
yield session
await session.rollback()
@pytest_asyncio.fixture()
async def redis_client() -> Redis:
redis_url = _get_required_env("POOLHUB_TEST_REDIS_URL")
client = Redis.from_url(redis_url, encoding="utf-8", decode_responses=True)
await client.flushdb()
yield client
await client.flushdb()
await client.close()
@pytest_asyncio.fixture(autouse=True)
async def _clear_redis(redis_client: Redis) -> None:
await redis_client.flushdb()

View File

@ -0,0 +1,153 @@
from __future__ import annotations
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import async_sessionmaker
from poolhub.app import deps
from poolhub.app.main import create_app
from poolhub.app.prometheus import reset_metrics
from poolhub.repositories.miner_repository import MinerRepository
@pytest_asyncio.fixture()
async def async_client(db_engine, redis_client): # noqa: F811
async def _session_override():
factory = async_sessionmaker(db_engine, expire_on_commit=False, autoflush=False)
async with factory() as session:
yield session
async def _redis_override():
yield redis_client
app = create_app()
app.dependency_overrides.clear()
app.dependency_overrides[deps.db_session_dep] = _session_override
app.dependency_overrides[deps.redis_dep] = _redis_override
reset_metrics()
async with AsyncClient(app=app, base_url="http://testserver") as client:
yield client
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_match_endpoint(async_client, db_session, redis_client): # noqa: F811
repo = MinerRepository(db_session, redis_client)
await repo.register_miner(
miner_id="miner-1",
api_key_hash="hash",
addr="127.0.0.1",
proto="grpc",
gpu_vram_gb=16,
gpu_name="A100",
cpu_cores=32,
ram_gb=128,
max_parallel=4,
base_price=0.8,
tags={"tier": "gold"},
capabilities=["embedding"],
region="eu",
)
await db_session.commit()
response = await async_client.post(
"/v1/match",
json={
"job_id": "job-123",
"requirements": {"min_vram_gb": 8},
"hints": {"region": "eu"},
"top_k": 1,
},
)
assert response.status_code == 200
payload = response.json()
assert payload["job_id"] == "job-123"
assert len(payload["candidates"]) == 1
@pytest.mark.asyncio
async def test_match_endpoint_no_miners(async_client):
response = await async_client.post(
"/v1/match",
json={"job_id": "empty", "requirements": {}, "hints": {}, "top_k": 2},
)
assert response.status_code == 200
payload = response.json()
assert payload["candidates"] == []
@pytest.mark.asyncio
async def test_health_endpoint(async_client): # noqa: F811
response = await async_client.get("/v1/health")
assert response.status_code == 200
data = response.json()
assert data["status"] in {"ok", "degraded"}
assert "db_error" in data
assert "redis_error" in data
@pytest.mark.asyncio
async def test_health_endpoint_degraded(db_engine, redis_client): # noqa: F811
async def _session_override():
factory = async_sessionmaker(db_engine, expire_on_commit=False, autoflush=False)
async with factory() as session:
yield session
class FailingRedis:
async def ping(self) -> None:
raise RuntimeError("redis down")
def __getattr__(self, _: str) -> None: # pragma: no cover - minimal stub
raise RuntimeError("redis down")
async def _redis_override():
yield FailingRedis()
app = create_app()
app.dependency_overrides.clear()
app.dependency_overrides[deps.db_session_dep] = _session_override
app.dependency_overrides[deps.redis_dep] = _redis_override
reset_metrics()
async with AsyncClient(app=app, base_url="http://testserver") as client:
response = await client.get("/v1/health")
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "degraded"
assert payload["redis_error"]
assert payload["db_error"] is None
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_metrics_endpoint(async_client):
baseline = await async_client.get("/metrics")
before = _extract_counter(baseline.text, "poolhub_match_requests_total")
for _ in range(2):
await async_client.post(
"/v1/match",
json={"job_id": str(uuid.uuid4()), "requirements": {}, "hints": {}, "top_k": 1},
)
updated = await async_client.get("/metrics")
after = _extract_counter(updated.text, "poolhub_match_requests_total")
assert after >= before + 2
def _extract_counter(metrics_text: str, metric: str) -> float:
for line in metrics_text.splitlines():
if line.startswith(metric):
parts = line.split()
if len(parts) >= 2:
try:
return float(parts[1])
except ValueError: # pragma: no cover
return 0.0
return 0.0

View File

@ -0,0 +1,96 @@
from __future__ import annotations
import json
import uuid
import pytest
from poolhub.repositories.feedback_repository import FeedbackRepository
from poolhub.repositories.match_repository import MatchRepository
from poolhub.repositories.miner_repository import MinerRepository
from poolhub.storage.redis_keys import RedisKeys
@pytest.mark.asyncio
async def test_register_miner_persists_and_syncs(db_session, redis_client):
repo = MinerRepository(db_session, redis_client)
await repo.register_miner(
miner_id="miner-1",
api_key_hash="hash",
addr="127.0.0.1",
proto="grpc",
gpu_vram_gb=16,
gpu_name="A100",
cpu_cores=32,
ram_gb=128,
max_parallel=4,
base_price=0.8,
tags={"tier": "gold"},
capabilities=["embedding"],
region="eu",
)
miner = await repo.get_miner("miner-1")
assert miner is not None
assert miner.addr == "127.0.0.1"
redis_hash = await redis_client.hgetall(RedisKeys.miner_hash("miner-1"))
assert redis_hash["miner_id"] == "miner-1"
ranking = await redis_client.zscore(RedisKeys.miner_rankings("eu"), "miner-1")
assert ranking is not None
@pytest.mark.asyncio
async def test_match_request_flow(db_session, redis_client):
match_repo = MatchRepository(db_session, redis_client)
req = await match_repo.create_request(
job_id="job-123",
requirements={"min_vram_gb": 8},
hints={"region": "eu"},
top_k=2,
)
await db_session.commit()
queue_entry = await redis_client.lpop(RedisKeys.match_requests())
assert queue_entry is not None
payload = json.loads(queue_entry)
assert payload["job_id"] == "job-123"
await match_repo.add_results(
request_id=req.id,
candidates=[
{"miner_id": "miner-1", "score": 0.9, "explain": "fit"},
{"miner_id": "miner-2", "score": 0.8, "explain": "backup"},
],
)
await db_session.commit()
results = await match_repo.list_results_for_job("job-123")
assert len(results) == 2
redis_results = await redis_client.lrange(RedisKeys.match_results("job-123"), 0, -1)
assert len(redis_results) == 2
@pytest.mark.asyncio
async def test_feedback_repository(db_session, redis_client):
feedback_repo = FeedbackRepository(db_session, redis_client)
feedback = await feedback_repo.add_feedback(
job_id="job-321",
miner_id="miner-1",
outcome="completed",
latency_ms=1200,
tokens_spent=1.5,
)
await db_session.commit()
rows = await feedback_repo.list_feedback_for_job("job-321")
assert len(rows) == 1
assert rows[0].outcome == "completed"
# Ensure Redis publish occurred by checking pubsub message count via monitor list (best effort)
# Redis doesn't buffer publishes for inspection, so this is a smoke check ensuring repository returns object
assert feedback.miner_id == "miner-1"

View File

@ -1,11 +1,14 @@
from __future__ import annotations
import base64
from typing import Any, Dict, Optional
from fastapi import APIRouter, Depends
from .deps import get_receipt_service
from .deps import get_receipt_service, get_keystore, get_ledger
from .models import ReceiptVerificationModel, from_validation_result
from .keystore.service import KeystoreService
from .ledger_mock import SQLiteLedgerAdapter
from .receipts.service import ReceiptVerifierService
router = APIRouter(tags=["jsonrpc"])
@ -24,6 +27,8 @@ def _response(result: Optional[Dict[str, Any]] = None, error: Optional[Dict[str,
def handle_jsonrpc(
request: Dict[str, Any],
service: ReceiptVerifierService = Depends(get_receipt_service),
keystore: KeystoreService = Depends(get_keystore),
ledger: SQLiteLedgerAdapter = Depends(get_ledger),
) -> Dict[str, Any]:
method = request.get("method")
params = request.get("params") or {}
@ -46,4 +51,68 @@ def handle_jsonrpc(
results = [from_validation_result(item).model_dump() for item in service.verify_history(str(job_id))]
return _response(result={"items": results}, request_id=request_id)
if method == "wallet.list":
items = []
for record in keystore.list_records():
ledger_record = ledger.get_wallet(record.wallet_id)
metadata = ledger_record.metadata if ledger_record else record.metadata
items.append({"wallet_id": record.wallet_id, "public_key": record.public_key, "metadata": metadata})
return _response(result={"items": items}, request_id=request_id)
if method == "wallet.create":
wallet_id = params.get("wallet_id")
password = params.get("password")
metadata = params.get("metadata") or {}
secret_b64 = params.get("secret_key")
if not wallet_id or not password:
return _response(error={"code": -32602, "message": "wallet_id and password required"}, request_id=request_id)
secret = base64.b64decode(secret_b64) if secret_b64 else None
record = keystore.create_wallet(wallet_id=wallet_id, password=password, secret=secret, metadata=metadata)
ledger.upsert_wallet(record.wallet_id, record.public_key, record.metadata)
ledger.record_event(record.wallet_id, "created", {"metadata": record.metadata})
return _response(
result={
"wallet": {
"wallet_id": record.wallet_id,
"public_key": record.public_key,
"metadata": record.metadata,
}
},
request_id=request_id,
)
if method == "wallet.unlock":
wallet_id = params.get("wallet_id")
password = params.get("password")
if not wallet_id or not password:
return _response(error={"code": -32602, "message": "wallet_id and password required"}, request_id=request_id)
try:
keystore.unlock_wallet(wallet_id, password)
ledger.record_event(wallet_id, "unlocked", {"success": True})
return _response(result={"wallet_id": wallet_id, "unlocked": True}, request_id=request_id)
except (KeyError, ValueError):
ledger.record_event(wallet_id, "unlocked", {"success": False})
return _response(error={"code": -32001, "message": "invalid credentials"}, request_id=request_id)
if method == "wallet.sign":
wallet_id = params.get("wallet_id")
password = params.get("password")
message_b64 = params.get("message")
if not wallet_id or not password or not message_b64:
return _response(error={"code": -32602, "message": "wallet_id, password, message required"}, request_id=request_id)
try:
message = base64.b64decode(message_b64)
except Exception:
return _response(error={"code": -32602, "message": "invalid base64 message"}, request_id=request_id)
try:
signature = keystore.sign_message(wallet_id, password, message)
ledger.record_event(wallet_id, "sign", {"success": True})
except (KeyError, ValueError):
ledger.record_event(wallet_id, "sign", {"success": False})
return _response(error={"code": -32001, "message": "invalid credentials"}, request_id=request_id)
signature_b64 = base64.b64encode(signature).decode()
return _response(result={"wallet_id": wallet_id, "signature": signature_b64}, request_id=request_id)
return _response(error={"code": -32601, "message": "Method not found"}, request_id=request_id)

View File

@ -1,18 +1,52 @@
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
import base64
from .deps import get_receipt_service
import logging
import base64
from fastapi import APIRouter, Depends, HTTPException, status, Request
from .deps import get_receipt_service, get_keystore, get_ledger
from .models import (
ReceiptVerificationListResponse,
ReceiptVerificationModel,
ReceiptVerifyResponse,
SignatureValidationModel,
WalletCreateRequest,
WalletCreateResponse,
WalletListResponse,
WalletUnlockRequest,
WalletUnlockResponse,
WalletSignRequest,
WalletSignResponse,
WalletDescriptor,
from_validation_result,
)
from .keystore.service import KeystoreService
from .ledger_mock import SQLiteLedgerAdapter
from .receipts.service import ReceiptValidationResult, ReceiptVerifierService
from .security import RateLimiter, wipe_buffer
router = APIRouter(prefix="/v1", tags=["receipts"])
logger = logging.getLogger(__name__)
_rate_limiter = RateLimiter(max_requests=30, window_seconds=60)
def _rate_key(action: str, request: Request, wallet_id: Optional[str] = None) -> str:
host = request.client.host if request.client else "unknown"
parts = [action, host]
if wallet_id:
parts.append(wallet_id)
return ":".join(parts)
def _enforce_limit(action: str, request: Request, wallet_id: Optional[str] = None) -> None:
key = _rate_key(action, request, wallet_id)
if not _rate_limiter.allow(key):
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="rate limit exceeded")
router = APIRouter(prefix="/v1", tags=["wallets", "receipts"])
def _result_to_response(result: ReceiptValidationResult) -> ReceiptVerifyResponse:
@ -47,3 +81,97 @@ def verify_receipt_history(
results = service.verify_history(job_id)
items = [from_validation_result(result) for result in results]
return ReceiptVerificationListResponse(items=items)
@router.get("/wallets", response_model=WalletListResponse, summary="List wallets")
def list_wallets(
keystore: KeystoreService = Depends(get_keystore),
ledger: SQLiteLedgerAdapter = Depends(get_ledger),
) -> WalletListResponse:
descriptors = []
for record in keystore.list_records():
ledger_record = ledger.get_wallet(record.wallet_id)
metadata = ledger_record.metadata if ledger_record else record.metadata
descriptors.append(
WalletDescriptor(wallet_id=record.wallet_id, public_key=record.public_key, metadata=metadata)
)
@router.post("/wallets", response_model=WalletCreateResponse, status_code=status.HTTP_201_CREATED, summary="Create wallet")
def create_wallet(
request: WalletCreateRequest,
http_request: Request,
keystore: KeystoreService = Depends(get_keystore),
ledger: SQLiteLedgerAdapter = Depends(get_ledger),
) -> WalletCreateResponse:
_enforce_limit("wallet-create", http_request)
try:
secret = base64.b64decode(request.secret_key) if request.secret_key else None
except Exception as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="invalid base64 secret") from exc
try:
record = keystore.create_wallet(
wallet_id=request.wallet_id,
password=request.password,
secret=secret,
metadata=request.metadata,
)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
ledger.upsert_wallet(record.wallet_id, record.public_key, record.metadata)
ledger.record_event(record.wallet_id, "created", {"metadata": record.metadata})
logger.info("Created wallet", extra={"wallet_id": record.wallet_id})
wallet = WalletDescriptor(wallet_id=record.wallet_id, public_key=record.public_key, metadata=record.metadata)
return WalletCreateResponse(wallet=wallet)
@router.post("/wallets/{wallet_id}/unlock", response_model=WalletUnlockResponse, summary="Unlock wallet")
def unlock_wallet(
wallet_id: str,
request: WalletUnlockRequest,
http_request: Request,
keystore: KeystoreService = Depends(get_keystore),
ledger: SQLiteLedgerAdapter = Depends(get_ledger),
) -> WalletUnlockResponse:
_enforce_limit("wallet-unlock", http_request, wallet_id)
try:
secret = bytearray(keystore.unlock_wallet(wallet_id, request.password))
ledger.record_event(wallet_id, "unlocked", {"success": True})
logger.info("Unlocked wallet", extra={"wallet_id": wallet_id})
except (KeyError, ValueError):
ledger.record_event(wallet_id, "unlocked", {"success": False})
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid credentials")
finally:
if "secret" in locals():
wipe_buffer(secret)
# We don't expose the secret in response
return WalletUnlockResponse(wallet_id=wallet_id, unlocked=True)
@router.post("/wallets/{wallet_id}/sign", response_model=WalletSignResponse, summary="Sign payload")
def sign_payload(
wallet_id: str,
request: WalletSignRequest,
http_request: Request,
keystore: KeystoreService = Depends(get_keystore),
ledger: SQLiteLedgerAdapter = Depends(get_ledger),
) -> WalletSignResponse:
_enforce_limit("wallet-sign", http_request, wallet_id)
try:
message = base64.b64decode(request.message_base64)
except Exception as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="invalid base64 message") from exc
try:
signature = keystore.sign_message(wallet_id, request.password, message)
ledger.record_event(wallet_id, "sign", {"success": True})
logger.debug("Signed payload", extra={"wallet_id": wallet_id})
except (KeyError, ValueError):
ledger.record_event(wallet_id, "sign", {"success": False})
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid credentials")
signature_b64 = base64.b64encode(signature).decode()
return WalletSignResponse(wallet_id=wallet_id, signature_base64=signature_b64)

View File

@ -9,6 +9,8 @@ from nacl.bindings import (
crypto_aead_xchacha20poly1305_ietf_encrypt,
)
from ..security import wipe_buffer
class EncryptionError(Exception):
"""Raised when encryption or decryption fails."""
@ -50,13 +52,15 @@ class EncryptionSuite:
raise EncryptionError("encryption failed") from exc
def decrypt(self, *, password: str, ciphertext: bytes, salt: bytes, nonce: bytes) -> bytes:
key = self._derive_key(password=password, salt=salt)
key_bytes = bytearray(self._derive_key(password=password, salt=salt))
try:
return crypto_aead_xchacha20poly1305_ietf_decrypt(
ciphertext=ciphertext,
aad=b"",
nonce=nonce,
key=key,
key=bytes(key_bytes),
)
except Exception as exc:
raise EncryptionError("decryption failed") from exc
finally:
wipe_buffer(key_bytes)

View File

@ -5,6 +5,7 @@ from functools import lru_cache
from fastapi import Depends
from .keystore.service import KeystoreService
from .ledger_mock import SQLiteLedgerAdapter
from .receipts.service import ReceiptVerifierService
from .settings import Settings, settings
@ -24,3 +25,8 @@ def get_receipt_service(config: Settings = Depends(get_settings)) -> ReceiptVeri
@lru_cache
def get_keystore() -> KeystoreService:
return KeystoreService()
@lru_cache
def get_ledger(config: Settings = Depends(get_settings)) -> SQLiteLedgerAdapter:
return SQLiteLedgerAdapter(config.ledger_db_path)

View File

@ -1,16 +1,20 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Dict, Iterable, List, Optional
from secrets import token_bytes
from nacl.signing import SigningKey
from ..crypto.encryption import EncryptionSuite, EncryptionError
from ..security import validate_password_rules, wipe_buffer
@dataclass
class WalletRecord:
wallet_id: str
public_key: str
salt: bytes
nonce: bytes
ciphertext: bytes
@ -27,14 +31,46 @@ class KeystoreService:
def list_wallets(self) -> List[str]:
return list(self._wallets.keys())
def list_records(self) -> Iterable[WalletRecord]:
return list(self._wallets.values())
def get_wallet(self, wallet_id: str) -> Optional[WalletRecord]:
return self._wallets.get(wallet_id)
def create_wallet(self, wallet_id: str, password: str, plaintext: bytes, metadata: Optional[Dict[str, str]] = None) -> WalletRecord:
def create_wallet(
self,
wallet_id: str,
password: str,
secret: Optional[bytes] = None,
metadata: Optional[Dict[str, str]] = None,
) -> WalletRecord:
if wallet_id in self._wallets:
raise ValueError("wallet already exists")
validate_password_rules(password)
metadata_map = {str(k): str(v) for k, v in (metadata or {}).items()}
if secret is None:
signing_key = SigningKey.generate()
secret_bytes = signing_key.encode()
else:
if len(secret) != SigningKey.seed_size:
raise ValueError("secret key must be 32 bytes")
secret_bytes = secret
signing_key = SigningKey(secret_bytes)
salt = token_bytes(self._encryption.salt_bytes)
nonce = token_bytes(self._encryption.nonce_bytes)
ciphertext = self._encryption.encrypt(password=password, plaintext=plaintext, salt=salt, nonce=nonce)
record = WalletRecord(wallet_id=wallet_id, salt=salt, nonce=nonce, ciphertext=ciphertext, metadata=metadata or {})
ciphertext = self._encryption.encrypt(password=password, plaintext=secret_bytes, salt=salt, nonce=nonce)
record = WalletRecord(
wallet_id=wallet_id,
public_key=signing_key.verify_key.encode().hex(),
salt=salt,
nonce=nonce,
ciphertext=ciphertext,
metadata=metadata_map,
)
self._wallets[wallet_id] = record
return record
@ -49,3 +85,12 @@ class KeystoreService:
def delete_wallet(self, wallet_id: str) -> bool:
return self._wallets.pop(wallet_id, None) is not None
def sign_message(self, wallet_id: str, password: str, message: bytes) -> bytes:
secret_bytes = bytearray(self.unlock_wallet(wallet_id, password))
try:
signing_key = SigningKey(bytes(secret_bytes))
signed = signing_key.sign(message)
return signed.signature
finally:
wipe_buffer(secret_bytes)

View File

@ -0,0 +1,3 @@
from .sqlite_adapter import SQLiteLedgerAdapter, WalletRecord, WalletEvent
__all__ = ["SQLiteLedgerAdapter", "WalletRecord", "WalletEvent"]

View File

@ -0,0 +1,106 @@
from __future__ import annotations
import json
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional
@dataclass
class WalletRecord:
wallet_id: str
public_key: str
metadata: dict
@dataclass
class WalletEvent:
wallet_id: str
event_type: str
payload: dict
class SQLiteLedgerAdapter:
def __init__(self, db_path: Path) -> None:
self._db_path = db_path
self._ensure_schema()
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self._db_path)
conn.row_factory = sqlite3.Row
return conn
def _ensure_schema(self) -> None:
self._db_path.parent.mkdir(parents=True, exist_ok=True)
with self._connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS wallets (
wallet_id TEXT PRIMARY KEY,
public_key TEXT NOT NULL,
metadata TEXT NOT NULL
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS wallet_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
wallet_id TEXT NOT NULL,
event_type TEXT NOT NULL,
payload TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(wallet_id) REFERENCES wallets(wallet_id)
)
"""
)
def upsert_wallet(self, wallet_id: str, public_key: str, metadata: dict) -> None:
payload = json.dumps(metadata)
with self._connect() as conn:
conn.execute(
"""
INSERT INTO wallets(wallet_id, public_key, metadata)
VALUES (?, ?, ?)
ON CONFLICT(wallet_id) DO UPDATE SET public_key=excluded.public_key, metadata=excluded.metadata
""",
(wallet_id, public_key, payload),
)
def get_wallet(self, wallet_id: str) -> Optional[WalletRecord]:
with self._connect() as conn:
row = conn.execute(
"SELECT wallet_id, public_key, metadata FROM wallets WHERE wallet_id = ?",
(wallet_id,),
).fetchone()
if row is None:
return None
return WalletRecord(wallet_id=row["wallet_id"], public_key=row["public_key"], metadata=json.loads(row["metadata"]))
def list_wallets(self) -> Iterable[WalletRecord]:
with self._connect() as conn:
rows = conn.execute("SELECT wallet_id, public_key, metadata FROM wallets").fetchall()
for row in rows:
yield WalletRecord(wallet_id=row["wallet_id"], public_key=row["public_key"], metadata=json.loads(row["metadata"]))
def record_event(self, wallet_id: str, event_type: str, payload: dict) -> None:
data = json.dumps(payload)
with self._connect() as conn:
conn.execute(
"INSERT INTO wallet_events(wallet_id, event_type, payload) VALUES (?, ?, ?)",
(wallet_id, event_type, data),
)
def list_events(self, wallet_id: str) -> Iterable[WalletEvent]:
with self._connect() as conn:
rows = conn.execute(
"SELECT wallet_id, event_type, payload FROM wallet_events WHERE wallet_id = ? ORDER BY id",
(wallet_id,),
).fetchall()
for row in rows:
yield WalletEvent(
wallet_id=row["wallet_id"],
event_type=row["event_type"],
payload=json.loads(row["payload"]),
)

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import List
from typing import Any, Dict, List, Optional
from aitbc_sdk import SignatureValidation
@ -43,3 +43,43 @@ def from_validation_result(result) -> ReceiptVerificationModel:
class ReceiptVerificationListResponse(BaseModel):
items: List[ReceiptVerificationModel]
class WalletDescriptor(BaseModel):
wallet_id: str
public_key: str
metadata: Dict[str, Any]
class WalletListResponse(BaseModel):
items: List[WalletDescriptor]
class WalletCreateRequest(BaseModel):
wallet_id: str
password: str
metadata: Dict[str, Any] = {}
secret_key: Optional[str] = None
class WalletCreateResponse(BaseModel):
wallet: WalletDescriptor
class WalletUnlockRequest(BaseModel):
password: str
class WalletUnlockResponse(BaseModel):
wallet_id: str
unlocked: bool
class WalletSignRequest(BaseModel):
password: str
message_base64: str
class WalletSignResponse(BaseModel):
wallet_id: str
signature_base64: str

View File

@ -0,0 +1,43 @@
from __future__ import annotations
import re
import threading
import time
from collections import defaultdict, deque
class RateLimiter:
def __init__(self, max_requests: int = 30, window_seconds: int = 60) -> None:
self._max_requests = max_requests
self._window_seconds = window_seconds
self._lock = threading.Lock()
self._records: dict[str, deque[float]] = defaultdict(deque)
def allow(self, key: str) -> bool:
now = time.monotonic()
with self._lock:
entries = self._records[key]
while entries and now - entries[0] > self._window_seconds:
entries.popleft()
if len(entries) >= self._max_requests:
return False
entries.append(now)
return True
def validate_password_rules(password: str) -> None:
if len(password) < 12:
raise ValueError("password must be at least 12 characters long")
if not re.search(r"[A-Z]", password):
raise ValueError("password must include at least one uppercase letter")
if not re.search(r"[a-z]", password):
raise ValueError("password must include at least one lowercase letter")
if not re.search(r"\d", password):
raise ValueError("password must include at least one digit")
if not re.search(r"[^A-Za-z0-9]", password):
raise ValueError("password must include at least one symbol")
def wipe_buffer(buffer: bytearray) -> None:
for index in range(len(buffer)):
buffer[index] = 0

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from pathlib import Path
from pydantic import Field
from pydantic_settings import BaseSettings
@ -14,6 +16,7 @@ class Settings(BaseSettings):
coordinator_api_key: str = Field(default="client_dev_key_1", alias="COORDINATOR_API_KEY")
rest_prefix: str = Field(default="/v1", alias="REST_PREFIX")
ledger_db_path: Path = Field(default=Path("./data/wallet_ledger.db"), alias="LEDGER_DB_PATH")
class Config:
env_file = ".env"

View File

@ -0,0 +1,38 @@
from __future__ import annotations
from pathlib import Path
from app.ledger_mock import SQLiteLedgerAdapter
def test_upsert_and_get_wallet(tmp_path: Path) -> None:
db_path = tmp_path / "ledger.db"
adapter = SQLiteLedgerAdapter(db_path)
adapter.upsert_wallet("wallet-1", "pubkey", {"label": "primary"})
record = adapter.get_wallet("wallet-1")
assert record is not None
assert record.wallet_id == "wallet-1"
assert record.public_key == "pubkey"
assert record.metadata["label"] == "primary"
# Update metadata and ensure persistence
adapter.upsert_wallet("wallet-1", "pubkey", {"label": "updated"})
updated = adapter.get_wallet("wallet-1")
assert updated is not None
assert updated.metadata["label"] == "updated"
def test_event_ordering(tmp_path: Path) -> None:
db_path = tmp_path / "ledger.db"
adapter = SQLiteLedgerAdapter(db_path)
adapter.upsert_wallet("wallet-1", "pubkey", {})
adapter.record_event("wallet-1", "created", {"step": 1})
adapter.record_event("wallet-1", "unlock", {"step": 2})
adapter.record_event("wallet-1", "sign", {"step": 3})
events = list(adapter.list_events("wallet-1"))
assert [event.event_type for event in events] == ["created", "unlock", "sign"]
assert [event.payload["step"] for event in events] == [1, 2, 3]

View File

@ -0,0 +1,82 @@
from __future__ import annotations
import base64
import pytest
from fastapi.testclient import TestClient
from aitbc_chain.app import create_app # noqa: I100
from app.deps import get_keystore, get_ledger
@pytest.fixture(name="client")
def client_fixture(tmp_path, monkeypatch):
# Override ledger path to temporary directory
from app.settings import Settings
class TestSettings(Settings):
ledger_db_path = tmp_path / "ledger.db"
monkeypatch.setattr("app.deps.get_settings", lambda: TestSettings())
app = create_app()
return TestClient(app)
def _create_wallet(client: TestClient, wallet_id: str, password: str = "Password!234") -> None:
payload = {
"wallet_id": wallet_id,
"password": password,
}
response = client.post("/v1/wallets", json=payload)
assert response.status_code == 201, response.text
def test_wallet_workflow(client: TestClient):
wallet_id = "wallet-1"
password = "StrongPass!234"
# Create wallet
response = client.post(
"/v1/wallets",
json={
"wallet_id": wallet_id,
"password": password,
"metadata": {"label": "test"},
},
)
assert response.status_code == 201, response.text
data = response.json()["wallet"]
assert data["wallet_id"] == wallet_id
assert "public_key" in data
# List wallets
response = client.get("/v1/wallets")
assert response.status_code == 200
items = response.json()["items"]
assert any(item["wallet_id"] == wallet_id for item in items)
# Unlock wallet
response = client.post(f"/v1/wallets/{wallet_id}/unlock", json={"password": password})
assert response.status_code == 200
assert response.json()["unlocked"] is True
# Sign payload
message = base64.b64encode(b"hello").decode()
response = client.post(
f"/v1/wallets/{wallet_id}/sign",
json={"password": password, "message_base64": message},
)
assert response.status_code == 200, response.text
signature = response.json()["signature_base64"]
assert isinstance(signature, str) and len(signature) > 0
def test_wallet_password_rules(client: TestClient):
response = client.post(
"/v1/wallets",
json={"wallet_id": "weak", "password": "short"},
)
assert response.status_code == 400
***