- Change file mode from 644 to 755 for all project files - Add chain_id parameter to get_balance RPC endpoint with default "ait-devnet" - Rename Miner.extra_meta_data to extra_metadata for consistency
439 lines
16 KiB
Python
Executable File
439 lines
16 KiB
Python
Executable File
"""
|
|
Tests for coordinator billing stubs: usage tracking, billing events, and tenant context.
|
|
|
|
Uses lightweight in-memory mocks to avoid PostgreSQL/UUID dependencies.
|
|
"""
|
|
|
|
import asyncio
|
|
import uuid
|
|
from datetime import datetime, timedelta
|
|
from decimal import Decimal
|
|
from unittest.mock import MagicMock, AsyncMock, patch
|
|
from dataclasses import dataclass
|
|
|
|
import pytest
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Lightweight stubs for the ORM models so we don't need a real DB
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class FakeTenant:
|
|
id: str
|
|
slug: str
|
|
name: str
|
|
status: str = "active"
|
|
plan: str = "basic"
|
|
contact_email: str = "t@test.com"
|
|
billing_email: str = "b@test.com"
|
|
settings: dict = None
|
|
features: dict = None
|
|
balance: Decimal = Decimal("100.00")
|
|
|
|
def __post_init__(self):
|
|
self.settings = self.settings or {}
|
|
self.features = self.features or {}
|
|
|
|
|
|
@dataclass
|
|
class FakeQuota:
|
|
id: str
|
|
tenant_id: str
|
|
resource_type: str
|
|
limit_value: Decimal
|
|
used_value: Decimal = Decimal("0")
|
|
period_type: str = "daily"
|
|
period_start: datetime = None
|
|
period_end: datetime = None
|
|
is_active: bool = True
|
|
|
|
def __post_init__(self):
|
|
if self.period_start is None:
|
|
self.period_start = datetime.utcnow() - timedelta(hours=1)
|
|
if self.period_end is None:
|
|
self.period_end = datetime.utcnow() + timedelta(hours=23)
|
|
|
|
|
|
@dataclass
|
|
class FakeUsageRecord:
|
|
id: str
|
|
tenant_id: str
|
|
resource_type: str
|
|
quantity: Decimal
|
|
unit: str
|
|
unit_price: Decimal
|
|
total_cost: Decimal
|
|
currency: str = "USD"
|
|
usage_start: datetime = None
|
|
usage_end: datetime = None
|
|
job_id: str = None
|
|
metadata: dict = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# In-memory billing store used by the implementations under test
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class InMemoryBillingStore:
|
|
"""Replaces the DB session for testing."""
|
|
|
|
def __init__(self):
|
|
self.tenants: dict[str, FakeTenant] = {}
|
|
self.quotas: list[FakeQuota] = []
|
|
self.usage_records: list[FakeUsageRecord] = []
|
|
self.credits: list[dict] = []
|
|
self.charges: list[dict] = []
|
|
self.invoices_generated: list[str] = []
|
|
self.pending_events: list[dict] = []
|
|
|
|
# helpers
|
|
def get_tenant(self, tenant_id: str):
|
|
return self.tenants.get(tenant_id)
|
|
|
|
def get_active_quota(self, tenant_id: str, resource_type: str):
|
|
now = datetime.utcnow()
|
|
for q in self.quotas:
|
|
if (q.tenant_id == tenant_id
|
|
and q.resource_type == resource_type
|
|
and q.is_active
|
|
and q.period_start <= now <= q.period_end):
|
|
return q
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Implementations (the actual code we're testing / implementing)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def apply_credit(store: InMemoryBillingStore, tenant_id: str, amount: Decimal, reason: str = "") -> bool:
|
|
"""Apply credit to tenant account."""
|
|
tenant = store.get_tenant(tenant_id)
|
|
if not tenant:
|
|
raise ValueError(f"Tenant not found: {tenant_id}")
|
|
if amount <= 0:
|
|
raise ValueError("Credit amount must be positive")
|
|
tenant.balance += amount
|
|
store.credits.append({
|
|
"tenant_id": tenant_id,
|
|
"amount": amount,
|
|
"reason": reason,
|
|
"timestamp": datetime.utcnow(),
|
|
})
|
|
return True
|
|
|
|
|
|
async def apply_charge(store: InMemoryBillingStore, tenant_id: str, amount: Decimal, reason: str = "") -> bool:
|
|
"""Apply charge to tenant account."""
|
|
tenant = store.get_tenant(tenant_id)
|
|
if not tenant:
|
|
raise ValueError(f"Tenant not found: {tenant_id}")
|
|
if amount <= 0:
|
|
raise ValueError("Charge amount must be positive")
|
|
if tenant.balance < amount:
|
|
raise ValueError(f"Insufficient balance: {tenant.balance} < {amount}")
|
|
tenant.balance -= amount
|
|
store.charges.append({
|
|
"tenant_id": tenant_id,
|
|
"amount": amount,
|
|
"reason": reason,
|
|
"timestamp": datetime.utcnow(),
|
|
})
|
|
return True
|
|
|
|
|
|
async def adjust_quota(
|
|
store: InMemoryBillingStore,
|
|
tenant_id: str,
|
|
resource_type: str,
|
|
new_limit: Decimal,
|
|
) -> bool:
|
|
"""Adjust quota limit for a tenant resource."""
|
|
quota = store.get_active_quota(tenant_id, resource_type)
|
|
if not quota:
|
|
raise ValueError(f"No active quota for {tenant_id}/{resource_type}")
|
|
if new_limit < 0:
|
|
raise ValueError("Quota limit must be non-negative")
|
|
quota.limit_value = new_limit
|
|
return True
|
|
|
|
|
|
async def reset_daily_quotas(store: InMemoryBillingStore) -> int:
|
|
"""Reset used_value to 0 for all daily quotas whose period has ended."""
|
|
now = datetime.utcnow()
|
|
count = 0
|
|
for q in store.quotas:
|
|
if q.period_type == "daily" and q.is_active and q.period_end <= now:
|
|
q.used_value = Decimal("0")
|
|
q.period_start = now
|
|
q.period_end = now + timedelta(days=1)
|
|
count += 1
|
|
return count
|
|
|
|
|
|
async def process_pending_events(store: InMemoryBillingStore) -> int:
|
|
"""Process all pending billing events and clear the queue."""
|
|
processed = len(store.pending_events)
|
|
for event in store.pending_events:
|
|
etype = event.get("event_type")
|
|
tid = event.get("tenant_id")
|
|
amount = Decimal(str(event.get("amount", 0)))
|
|
if etype == "credit":
|
|
await apply_credit(store, tid, amount, reason="pending_event")
|
|
elif etype == "charge":
|
|
await apply_charge(store, tid, amount, reason="pending_event")
|
|
store.pending_events.clear()
|
|
return processed
|
|
|
|
|
|
async def generate_monthly_invoices(store: InMemoryBillingStore) -> list[str]:
|
|
"""Generate invoices for all active tenants with usage."""
|
|
generated = []
|
|
for tid, tenant in store.tenants.items():
|
|
if tenant.status != "active":
|
|
continue
|
|
tenant_usage = [r for r in store.usage_records if r.tenant_id == tid]
|
|
if not tenant_usage:
|
|
continue
|
|
total = sum(r.total_cost for r in tenant_usage)
|
|
inv_id = f"INV-{tenant.slug}-{datetime.utcnow().strftime('%Y%m')}-{len(generated)+1:04d}"
|
|
store.invoices_generated.append(inv_id)
|
|
generated.append(inv_id)
|
|
return generated
|
|
|
|
|
|
async def extract_from_token(token: str, secret: str = "test-secret") -> dict | None:
|
|
"""Extract tenant_id from a JWT-like token. Returns claims dict or None."""
|
|
import json, hmac, hashlib, base64
|
|
parts = token.split(".")
|
|
if len(parts) != 3:
|
|
return None
|
|
try:
|
|
# Verify signature (HS256-like)
|
|
payload_b64 = parts[1]
|
|
sig = parts[2]
|
|
expected_sig = hmac.new(
|
|
secret.encode(), f"{parts[0]}.{payload_b64}".encode(), hashlib.sha256
|
|
).hexdigest()[:16]
|
|
if not hmac.compare_digest(sig, expected_sig):
|
|
return None
|
|
# Decode payload
|
|
padded = payload_b64 + "=" * (-len(payload_b64) % 4)
|
|
payload = json.loads(base64.urlsafe_b64decode(padded))
|
|
if "tenant_id" not in payload:
|
|
return None
|
|
return payload
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _make_token(claims: dict, secret: str = "test-secret") -> str:
|
|
"""Helper to create a test token."""
|
|
import json, hmac, hashlib, base64
|
|
header = base64.urlsafe_b64encode(b'{"alg":"HS256"}').decode().rstrip("=")
|
|
payload = base64.urlsafe_b64encode(json.dumps(claims).encode()).decode().rstrip("=")
|
|
sig = hmac.new(secret.encode(), f"{header}.{payload}".encode(), hashlib.sha256).hexdigest()[:16]
|
|
return f"{header}.{payload}.{sig}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.fixture
|
|
def store():
|
|
s = InMemoryBillingStore()
|
|
s.tenants["t1"] = FakeTenant(id="t1", slug="acme", name="Acme Corp", balance=Decimal("500.00"))
|
|
s.tenants["t2"] = FakeTenant(id="t2", slug="beta", name="Beta Inc", balance=Decimal("50.00"), status="inactive")
|
|
s.quotas.append(FakeQuota(
|
|
id="q1", tenant_id="t1", resource_type="gpu_hours",
|
|
limit_value=Decimal("100"), used_value=Decimal("40"),
|
|
))
|
|
s.quotas.append(FakeQuota(
|
|
id="q2", tenant_id="t1", resource_type="api_calls",
|
|
limit_value=Decimal("10000"), used_value=Decimal("5000"),
|
|
period_type="daily",
|
|
period_start=datetime.utcnow() - timedelta(days=2),
|
|
period_end=datetime.utcnow() - timedelta(hours=1), # expired
|
|
))
|
|
return s
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: apply_credit
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestApplyCredit:
|
|
@pytest.mark.asyncio
|
|
async def test_credit_increases_balance(self, store):
|
|
await apply_credit(store, "t1", Decimal("25.00"), reason="promo")
|
|
assert store.tenants["t1"].balance == Decimal("525.00")
|
|
assert len(store.credits) == 1
|
|
assert store.credits[0]["amount"] == Decimal("25.00")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_credit_unknown_tenant_raises(self, store):
|
|
with pytest.raises(ValueError, match="Tenant not found"):
|
|
await apply_credit(store, "unknown", Decimal("10"))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_credit_zero_or_negative_raises(self, store):
|
|
with pytest.raises(ValueError, match="positive"):
|
|
await apply_credit(store, "t1", Decimal("0"))
|
|
with pytest.raises(ValueError, match="positive"):
|
|
await apply_credit(store, "t1", Decimal("-5"))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: apply_charge
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestApplyCharge:
|
|
@pytest.mark.asyncio
|
|
async def test_charge_decreases_balance(self, store):
|
|
await apply_charge(store, "t1", Decimal("100.00"), reason="usage")
|
|
assert store.tenants["t1"].balance == Decimal("400.00")
|
|
assert len(store.charges) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_charge_insufficient_balance_raises(self, store):
|
|
with pytest.raises(ValueError, match="Insufficient balance"):
|
|
await apply_charge(store, "t1", Decimal("999.99"))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_charge_unknown_tenant_raises(self, store):
|
|
with pytest.raises(ValueError, match="Tenant not found"):
|
|
await apply_charge(store, "nope", Decimal("1"))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_charge_zero_raises(self, store):
|
|
with pytest.raises(ValueError, match="positive"):
|
|
await apply_charge(store, "t1", Decimal("0"))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: adjust_quota
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestAdjustQuota:
|
|
@pytest.mark.asyncio
|
|
async def test_adjust_quota_updates_limit(self, store):
|
|
await adjust_quota(store, "t1", "gpu_hours", Decimal("200"))
|
|
q = store.get_active_quota("t1", "gpu_hours")
|
|
assert q.limit_value == Decimal("200")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_adjust_quota_no_active_raises(self, store):
|
|
with pytest.raises(ValueError, match="No active quota"):
|
|
await adjust_quota(store, "t1", "storage_gb", Decimal("50"))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_adjust_quota_negative_raises(self, store):
|
|
with pytest.raises(ValueError, match="non-negative"):
|
|
await adjust_quota(store, "t1", "gpu_hours", Decimal("-1"))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: reset_daily_quotas
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestResetDailyQuotas:
|
|
@pytest.mark.asyncio
|
|
async def test_resets_expired_daily_quotas(self, store):
|
|
count = await reset_daily_quotas(store)
|
|
assert count == 1 # q2 is expired daily
|
|
q2 = store.quotas[1]
|
|
assert q2.used_value == Decimal("0")
|
|
assert q2.period_end > datetime.utcnow()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_does_not_reset_active_quotas(self, store):
|
|
# q1 is still active (not expired)
|
|
count = await reset_daily_quotas(store)
|
|
q1 = store.quotas[0]
|
|
assert q1.used_value == Decimal("40") # unchanged
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: process_pending_events
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestProcessPendingEvents:
|
|
@pytest.mark.asyncio
|
|
async def test_processes_credit_and_charge_events(self, store):
|
|
store.pending_events = [
|
|
{"event_type": "credit", "tenant_id": "t1", "amount": 10},
|
|
{"event_type": "charge", "tenant_id": "t1", "amount": 5},
|
|
]
|
|
processed = await process_pending_events(store)
|
|
assert processed == 2
|
|
assert len(store.pending_events) == 0
|
|
assert store.tenants["t1"].balance == Decimal("505.00") # +10 -5
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_queue_returns_zero(self, store):
|
|
assert await process_pending_events(store) == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: generate_monthly_invoices
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestGenerateMonthlyInvoices:
|
|
@pytest.mark.asyncio
|
|
async def test_generates_for_active_tenants_with_usage(self, store):
|
|
store.usage_records.append(FakeUsageRecord(
|
|
id="u1", tenant_id="t1", resource_type="gpu_hours",
|
|
quantity=Decimal("10"), unit="hours",
|
|
unit_price=Decimal("0.50"), total_cost=Decimal("5.00"),
|
|
))
|
|
invoices = await generate_monthly_invoices(store)
|
|
assert len(invoices) == 1
|
|
assert invoices[0].startswith("INV-acme-")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_skips_inactive_tenants(self, store):
|
|
store.usage_records.append(FakeUsageRecord(
|
|
id="u2", tenant_id="t2", resource_type="gpu_hours",
|
|
quantity=Decimal("5"), unit="hours",
|
|
unit_price=Decimal("0.50"), total_cost=Decimal("2.50"),
|
|
))
|
|
invoices = await generate_monthly_invoices(store)
|
|
assert len(invoices) == 0 # t2 is inactive
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_skips_tenants_without_usage(self, store):
|
|
invoices = await generate_monthly_invoices(store)
|
|
assert len(invoices) == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: extract_from_token
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestExtractFromToken:
|
|
@pytest.mark.asyncio
|
|
async def test_valid_token_returns_claims(self):
|
|
token = _make_token({"tenant_id": "t1", "role": "admin"})
|
|
claims = await extract_from_token(token)
|
|
assert claims is not None
|
|
assert claims["tenant_id"] == "t1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_signature_returns_none(self):
|
|
token = _make_token({"tenant_id": "t1"}, secret="wrong-secret")
|
|
claims = await extract_from_token(token, secret="test-secret")
|
|
assert claims is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_tenant_id_returns_none(self):
|
|
token = _make_token({"role": "admin"})
|
|
claims = await extract_from_token(token)
|
|
assert claims is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_malformed_token_returns_none(self):
|
|
assert await extract_from_token("not.a.valid.token.format") is None
|
|
assert await extract_from_token("garbage") is None
|
|
assert await extract_from_token("") is None
|