Files
aitbc/apps/coordinator-api/tests/test_billing.py
aitbc 5f03ded7ff fix: replace deprecated datetime.utcnow() with datetime.now(datetime.UTC)
- Replace all 2,087 uses of datetime.utcnow() across 294 files
- Add UTC import to datetime statements where needed
- Addresses Python 3.12+ deprecation warning (report item #3)
2026-04-30 08:36:55 +02:00

440 lines
16 KiB
Python
Executable File

"""
Tests for coordinator billing stubs: usage tracking, billing events, and tenant context.
Uses lightweight in-memory mocks to avoid PostgreSQL/UUID dependencies.
import sys
"""
import asyncio
import uuid
from datetime import datetime, UTC, timedelta
from decimal import Decimal
from unittest.mock import MagicMock, AsyncMock, patch
from dataclasses import dataclass
import pytest
# ---------------------------------------------------------------------------
# Lightweight stubs for the ORM models so we don't need a real DB
# ---------------------------------------------------------------------------
@dataclass
class FakeTenant:
id: str
slug: str
name: str
status: str = "active"
plan: str = "basic"
contact_email: str = "t@test.com"
billing_email: str = "b@test.com"
settings: dict = None
features: dict = None
balance: Decimal = Decimal("100.00")
def __post_init__(self):
self.settings = self.settings or {}
self.features = self.features or {}
@dataclass
class FakeQuota:
id: str
tenant_id: str
resource_type: str
limit_value: Decimal
used_value: Decimal = Decimal("0")
period_type: str = "daily"
period_start: datetime = None
period_end: datetime = None
is_active: bool = True
def __post_init__(self):
if self.period_start is None:
self.period_start = datetime.now(datetime.UTC) - timedelta(hours=1)
if self.period_end is None:
self.period_end = datetime.now(datetime.UTC) + timedelta(hours=23)
@dataclass
class FakeUsageRecord:
id: str
tenant_id: str
resource_type: str
quantity: Decimal
unit: str
unit_price: Decimal
total_cost: Decimal
currency: str = "USD"
usage_start: datetime = None
usage_end: datetime = None
job_id: str = None
metadata: dict = None
# ---------------------------------------------------------------------------
# In-memory billing store used by the implementations under test
# ---------------------------------------------------------------------------
class InMemoryBillingStore:
"""Replaces the DB session for testing."""
def __init__(self):
self.tenants: dict[str, FakeTenant] = {}
self.quotas: list[FakeQuota] = []
self.usage_records: list[FakeUsageRecord] = []
self.credits: list[dict] = []
self.charges: list[dict] = []
self.invoices_generated: list[str] = []
self.pending_events: list[dict] = []
# helpers
def get_tenant(self, tenant_id: str):
return self.tenants.get(tenant_id)
def get_active_quota(self, tenant_id: str, resource_type: str):
now = datetime.now(datetime.UTC)
for q in self.quotas:
if (q.tenant_id == tenant_id
and q.resource_type == resource_type
and q.is_active
and q.period_start <= now <= q.period_end):
return q
return None
# ---------------------------------------------------------------------------
# Implementations (the actual code we're testing / implementing)
# ---------------------------------------------------------------------------
async def apply_credit(store: InMemoryBillingStore, tenant_id: str, amount: Decimal, reason: str = "") -> bool:
"""Apply credit to tenant account."""
tenant = store.get_tenant(tenant_id)
if not tenant:
raise ValueError(f"Tenant not found: {tenant_id}")
if amount <= 0:
raise ValueError("Credit amount must be positive")
tenant.balance += amount
store.credits.append({
"tenant_id": tenant_id,
"amount": amount,
"reason": reason,
"timestamp": datetime.now(datetime.UTC),
})
return True
async def apply_charge(store: InMemoryBillingStore, tenant_id: str, amount: Decimal, reason: str = "") -> bool:
"""Apply charge to tenant account."""
tenant = store.get_tenant(tenant_id)
if not tenant:
raise ValueError(f"Tenant not found: {tenant_id}")
if amount <= 0:
raise ValueError("Charge amount must be positive")
if tenant.balance < amount:
raise ValueError(f"Insufficient balance: {tenant.balance} < {amount}")
tenant.balance -= amount
store.charges.append({
"tenant_id": tenant_id,
"amount": amount,
"reason": reason,
"timestamp": datetime.now(datetime.UTC),
})
return True
async def adjust_quota(
store: InMemoryBillingStore,
tenant_id: str,
resource_type: str,
new_limit: Decimal,
) -> bool:
"""Adjust quota limit for a tenant resource."""
quota = store.get_active_quota(tenant_id, resource_type)
if not quota:
raise ValueError(f"No active quota for {tenant_id}/{resource_type}")
if new_limit < 0:
raise ValueError("Quota limit must be non-negative")
quota.limit_value = new_limit
return True
async def reset_daily_quotas(store: InMemoryBillingStore) -> int:
"""Reset used_value to 0 for all daily quotas whose period has ended."""
now = datetime.now(datetime.UTC)
count = 0
for q in store.quotas:
if q.period_type == "daily" and q.is_active and q.period_end <= now:
q.used_value = Decimal("0")
q.period_start = now
q.period_end = now + timedelta(days=1)
count += 1
return count
async def process_pending_events(store: InMemoryBillingStore) -> int:
"""Process all pending billing events and clear the queue."""
processed = len(store.pending_events)
for event in store.pending_events:
etype = event.get("event_type")
tid = event.get("tenant_id")
amount = Decimal(str(event.get("amount", 0)))
if etype == "credit":
await apply_credit(store, tid, amount, reason="pending_event")
elif etype == "charge":
await apply_charge(store, tid, amount, reason="pending_event")
store.pending_events.clear()
return processed
async def generate_monthly_invoices(store: InMemoryBillingStore) -> list[str]:
"""Generate invoices for all active tenants with usage."""
generated = []
for tid, tenant in store.tenants.items():
if tenant.status != "active":
continue
tenant_usage = [r for r in store.usage_records if r.tenant_id == tid]
if not tenant_usage:
continue
total = sum(r.total_cost for r in tenant_usage)
inv_id = f"INV-{tenant.slug}-{datetime.now(datetime.UTC).strftime('%Y%m')}-{len(generated)+1:04d}"
store.invoices_generated.append(inv_id)
generated.append(inv_id)
return generated
async def extract_from_token(token: str, secret: str = "test-secret") -> dict | None:
"""Extract tenant_id from a JWT-like token. Returns claims dict or None."""
import json, hmac, hashlib, base64
parts = token.split(".")
if len(parts) != 3:
return None
try:
# Verify signature (HS256-like)
payload_b64 = parts[1]
sig = parts[2]
expected_sig = hmac.new(
secret.encode(), f"{parts[0]}.{payload_b64}".encode(), hashlib.sha256
).hexdigest()[:16]
if not hmac.compare_digest(sig, expected_sig):
return None
# Decode payload
padded = payload_b64 + "=" * (-len(payload_b64) % 4)
payload = json.loads(base64.urlsafe_b64decode(padded))
if "tenant_id" not in payload:
return None
return payload
except Exception:
return None
def _make_token(claims: dict, secret: str = "test-secret") -> str:
"""Helper to create a test token."""
import json, hmac, hashlib, base64
header = base64.urlsafe_b64encode(b'{"alg":"HS256"}').decode().rstrip("=")
payload = base64.urlsafe_b64encode(json.dumps(claims).encode()).decode().rstrip("=")
sig = hmac.new(secret.encode(), f"{header}.{payload}".encode(), hashlib.sha256).hexdigest()[:16]
return f"{header}.{payload}.{sig}"
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def store():
s = InMemoryBillingStore()
s.tenants["t1"] = FakeTenant(id="t1", slug="acme", name="Acme Corp", balance=Decimal("500.00"))
s.tenants["t2"] = FakeTenant(id="t2", slug="beta", name="Beta Inc", balance=Decimal("50.00"), status="inactive")
s.quotas.append(FakeQuota(
id="q1", tenant_id="t1", resource_type="gpu_hours",
limit_value=Decimal("100"), used_value=Decimal("40"),
))
s.quotas.append(FakeQuota(
id="q2", tenant_id="t1", resource_type="api_calls",
limit_value=Decimal("10000"), used_value=Decimal("5000"),
period_type="daily",
period_start=datetime.now(datetime.UTC) - timedelta(days=2),
period_end=datetime.now(datetime.UTC) - timedelta(hours=1), # expired
))
return s
# ---------------------------------------------------------------------------
# Tests: apply_credit
# ---------------------------------------------------------------------------
class TestApplyCredit:
@pytest.mark.asyncio
async def test_credit_increases_balance(self, store):
await apply_credit(store, "t1", Decimal("25.00"), reason="promo")
assert store.tenants["t1"].balance == Decimal("525.00")
assert len(store.credits) == 1
assert store.credits[0]["amount"] == Decimal("25.00")
@pytest.mark.asyncio
async def test_credit_unknown_tenant_raises(self, store):
with pytest.raises(ValueError, match="Tenant not found"):
await apply_credit(store, "unknown", Decimal("10"))
@pytest.mark.asyncio
async def test_credit_zero_or_negative_raises(self, store):
with pytest.raises(ValueError, match="positive"):
await apply_credit(store, "t1", Decimal("0"))
with pytest.raises(ValueError, match="positive"):
await apply_credit(store, "t1", Decimal("-5"))
# ---------------------------------------------------------------------------
# Tests: apply_charge
# ---------------------------------------------------------------------------
class TestApplyCharge:
@pytest.mark.asyncio
async def test_charge_decreases_balance(self, store):
await apply_charge(store, "t1", Decimal("100.00"), reason="usage")
assert store.tenants["t1"].balance == Decimal("400.00")
assert len(store.charges) == 1
@pytest.mark.asyncio
async def test_charge_insufficient_balance_raises(self, store):
with pytest.raises(ValueError, match="Insufficient balance"):
await apply_charge(store, "t1", Decimal("999.99"))
@pytest.mark.asyncio
async def test_charge_unknown_tenant_raises(self, store):
with pytest.raises(ValueError, match="Tenant not found"):
await apply_charge(store, "nope", Decimal("1"))
@pytest.mark.asyncio
async def test_charge_zero_raises(self, store):
with pytest.raises(ValueError, match="positive"):
await apply_charge(store, "t1", Decimal("0"))
# ---------------------------------------------------------------------------
# Tests: adjust_quota
# ---------------------------------------------------------------------------
class TestAdjustQuota:
@pytest.mark.asyncio
async def test_adjust_quota_updates_limit(self, store):
await adjust_quota(store, "t1", "gpu_hours", Decimal("200"))
q = store.get_active_quota("t1", "gpu_hours")
assert q.limit_value == Decimal("200")
@pytest.mark.asyncio
async def test_adjust_quota_no_active_raises(self, store):
with pytest.raises(ValueError, match="No active quota"):
await adjust_quota(store, "t1", "storage_gb", Decimal("50"))
@pytest.mark.asyncio
async def test_adjust_quota_negative_raises(self, store):
with pytest.raises(ValueError, match="non-negative"):
await adjust_quota(store, "t1", "gpu_hours", Decimal("-1"))
# ---------------------------------------------------------------------------
# Tests: reset_daily_quotas
# ---------------------------------------------------------------------------
class TestResetDailyQuotas:
@pytest.mark.asyncio
async def test_resets_expired_daily_quotas(self, store):
count = await reset_daily_quotas(store)
assert count == 1 # q2 is expired daily
q2 = store.quotas[1]
assert q2.used_value == Decimal("0")
assert q2.period_end > datetime.now(datetime.UTC)
@pytest.mark.asyncio
async def test_does_not_reset_active_quotas(self, store):
# q1 is still active (not expired)
count = await reset_daily_quotas(store)
q1 = store.quotas[0]
assert q1.used_value == Decimal("40") # unchanged
# ---------------------------------------------------------------------------
# Tests: process_pending_events
# ---------------------------------------------------------------------------
class TestProcessPendingEvents:
@pytest.mark.asyncio
async def test_processes_credit_and_charge_events(self, store):
store.pending_events = [
{"event_type": "credit", "tenant_id": "t1", "amount": 10},
{"event_type": "charge", "tenant_id": "t1", "amount": 5},
]
processed = await process_pending_events(store)
assert processed == 2
assert len(store.pending_events) == 0
assert store.tenants["t1"].balance == Decimal("505.00") # +10 -5
@pytest.mark.asyncio
async def test_empty_queue_returns_zero(self, store):
assert await process_pending_events(store) == 0
# ---------------------------------------------------------------------------
# Tests: generate_monthly_invoices
# ---------------------------------------------------------------------------
class TestGenerateMonthlyInvoices:
@pytest.mark.asyncio
async def test_generates_for_active_tenants_with_usage(self, store):
store.usage_records.append(FakeUsageRecord(
id="u1", tenant_id="t1", resource_type="gpu_hours",
quantity=Decimal("10"), unit="hours",
unit_price=Decimal("0.50"), total_cost=Decimal("5.00"),
))
invoices = await generate_monthly_invoices(store)
assert len(invoices) == 1
assert invoices[0].startswith("INV-acme-")
@pytest.mark.asyncio
async def test_skips_inactive_tenants(self, store):
store.usage_records.append(FakeUsageRecord(
id="u2", tenant_id="t2", resource_type="gpu_hours",
quantity=Decimal("5"), unit="hours",
unit_price=Decimal("0.50"), total_cost=Decimal("2.50"),
))
invoices = await generate_monthly_invoices(store)
assert len(invoices) == 0 # t2 is inactive
@pytest.mark.asyncio
async def test_skips_tenants_without_usage(self, store):
invoices = await generate_monthly_invoices(store)
assert len(invoices) == 0
# ---------------------------------------------------------------------------
# Tests: extract_from_token
# ---------------------------------------------------------------------------
class TestExtractFromToken:
@pytest.mark.asyncio
async def test_valid_token_returns_claims(self):
token = _make_token({"tenant_id": "t1", "role": "admin"})
claims = await extract_from_token(token)
assert claims is not None
assert claims["tenant_id"] == "t1"
@pytest.mark.asyncio
async def test_invalid_signature_returns_none(self):
token = _make_token({"tenant_id": "t1"}, secret="wrong-secret")
claims = await extract_from_token(token, secret="test-secret")
assert claims is None
@pytest.mark.asyncio
async def test_missing_tenant_id_returns_none(self):
token = _make_token({"role": "admin"})
claims = await extract_from_token(token)
assert claims is None
@pytest.mark.asyncio
async def test_malformed_token_returns_none(self):
assert await extract_from_token("not.a.valid.token.format") is None
assert await extract_from_token("garbage") is None
assert await extract_from_token("") is None