Extract trading services to trading-service

- Created TradingService with basic CRUD operations
- Created storage.py for database session management
- Updated main.py to include database initialization and trading endpoints:
  - GET /v1/trading/requests
  - GET /v1/trading/requests/{request_id}
  - POST /v1/trading/requests
  - GET /v1/trading/matches
  - POST /v1/trading/matches
  - GET /v1/trading/agreements
  - POST /v1/trading/agreements
  - GET /v1/trading/analytics
- Created database setup script for aitbc_trading database

This completes Phase 4.5c: Extract trading services and Phase 4.5d: Setup separate database for trading service
This commit is contained in:
aitbc
2026-04-30 11:35:45 +02:00
parent 3e494b8898
commit e8c10a5dc0
5 changed files with 263 additions and 1 deletions

View File

@@ -0,0 +1,19 @@
-- Setup database for Trading service
-- Create database
CREATE DATABASE aitbc_trading;
-- Create user
CREATE USER aitbc_trading WITH PASSWORD 'password';
-- Grant privileges
GRANT ALL PRIVILEGES ON DATABASE aitbc_trading TO aitbc_trading;
-- Connect to the database
\c aitbc_trading
-- Grant schema privileges
GRANT ALL ON SCHEMA public TO aitbc_trading;
-- Exit
\q

View File

@@ -6,8 +6,9 @@ Manages trading operations
from contextlib import asynccontextmanager
from typing import AsyncIterator
from fastapi import FastAPI
from fastapi import FastAPI, Depends
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from aitbc import (
configure_logging,
@@ -18,6 +19,9 @@ from aitbc import (
ErrorHandlerMiddleware,
)
from .storage import init_db, get_session
from .services.trading_service import TradingService
# Configure structured logging
configure_logging(level="INFO")
logger = get_logger(__name__)
@@ -27,6 +31,8 @@ logger = get_logger(__name__)
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Lifecycle events for the Trading Service."""
logger.info("Starting Trading Service")
# Initialize database
await init_db()
yield
logger.info("Shutting down Trading Service")
@@ -67,6 +73,89 @@ async def trading_status() -> dict[str, str]:
}
async def get_trading_service(session: AsyncSession = Depends(get_session)) -> TradingService:
"""Get trading service instance"""
return TradingService(session)
@app.get("/v1/trading/requests")
async def get_requests(
status: str | None = None,
buyer_agent_id: str | None = None,
trade_type: str | None = None,
svc: TradingService = Depends(get_trading_service),
):
"""Get trade requests"""
return svc.list_requests(status=status, buyer_agent_id=buyer_agent_id, trade_type=trade_type)
@app.get("/v1/trading/requests/{request_id}")
async def get_request(
request_id: str,
svc: TradingService = Depends(get_trading_service),
):
"""Get a specific trade request"""
return svc.get_request(request_id)
@app.post("/v1/trading/requests")
async def create_request(
request_data: dict,
svc: TradingService = Depends(get_trading_service),
):
"""Create a new trade request"""
return svc.create_request(request_data)
@app.get("/v1/trading/matches")
async def get_matches(
status: str | None = None,
buyer_agent_id: str | None = None,
seller_agent_id: str | None = None,
svc: TradingService = Depends(get_trading_service),
):
"""Get trade matches"""
return svc.list_matches(status=status, buyer_agent_id=buyer_agent_id, seller_agent_id=seller_agent_id)
@app.post("/v1/trading/matches")
async def create_match(
match_data: dict,
svc: TradingService = Depends(get_trading_service),
):
"""Create a new trade match"""
return svc.create_match(match_data)
@app.get("/v1/trading/agreements")
async def get_agreements(
status: str | None = None,
buyer_agent_id: str | None = None,
seller_agent_id: str | None = None,
svc: TradingService = Depends(get_trading_service),
):
"""Get trade agreements"""
return svc.list_agreements(status=status, buyer_agent_id=buyer_agent_id, seller_agent_id=seller_agent_id)
@app.post("/v1/trading/agreements")
async def create_agreement(
agreement_data: dict,
svc: TradingService = Depends(get_trading_service),
):
"""Create a new trade agreement"""
return svc.create_agreement(agreement_data)
@app.get("/v1/trading/analytics")
async def get_analytics(
period_type: str = "daily",
svc: TradingService = Depends(get_trading_service),
):
"""Get trading analytics"""
return await svc.get_analytics(period_type=period_type)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8104)

View File

@@ -0,0 +1,7 @@
"""
Trading Service services
"""
from .trading_service import TradingService
__all__ = ["TradingService"]

View File

@@ -0,0 +1,103 @@
"""
Trading service for managing trading operations
"""
from typing import Any
from sqlmodel import Session, select
from ..domain.trading import TradeRequest, TradeMatch, TradeAgreement
class TradingService:
def __init__(self, session: Session):
self.session = session
def list_requests(
self,
status: str | None = None,
buyer_agent_id: str | None = None,
trade_type: str | None = None,
) -> list[TradeRequest]:
"""List trade requests"""
stmt = select(TradeRequest)
if status:
stmt = stmt.where(TradeRequest.status == status)
if buyer_agent_id:
stmt = stmt.where(TradeRequest.buyer_agent_id == buyer_agent_id)
if trade_type:
stmt = stmt.where(TradeRequest.trade_type == trade_type)
return list(self.session.execute(stmt).all())
def get_request(self, request_id: str) -> TradeRequest | None:
"""Get a specific trade request"""
stmt = select(TradeRequest).where(TradeRequest.request_id == request_id)
result = self.session.execute(stmt).first()
return result[0] if result else None
def create_request(self, request_data: dict) -> TradeRequest:
"""Create a new trade request"""
request = TradeRequest(**request_data)
self.session.add(request)
self.session.commit()
self.session.refresh(request)
return request
def list_matches(
self,
status: str | None = None,
buyer_agent_id: str | None = None,
seller_agent_id: str | None = None,
) -> list[TradeMatch]:
"""List trade matches"""
stmt = select(TradeMatch)
if status:
stmt = stmt.where(TradeMatch.status == status)
if buyer_agent_id:
stmt = stmt.where(TradeMatch.buyer_agent_id == buyer_agent_id)
if seller_agent_id:
stmt = stmt.where(TradeMatch.seller_agent_id == seller_agent_id)
return list(self.session.execute(stmt).all())
def create_match(self, match_data: dict) -> TradeMatch:
"""Create a new trade match"""
match = TradeMatch(**match_data)
self.session.add(match)
self.session.commit()
self.session.refresh(match)
return match
def list_agreements(
self,
status: str | None = None,
buyer_agent_id: str | None = None,
seller_agent_id: str | None = None,
) -> list[TradeAgreement]:
"""List trade agreements"""
stmt = select(TradeAgreement)
if status:
stmt = stmt.where(TradeAgreement.status == status)
if buyer_agent_id:
stmt = stmt.where(TradeAgreement.buyer_agent_id == buyer_agent_id)
if seller_agent_id:
stmt = stmt.where(TradeAgreement.seller_agent_id == seller_agent_id)
return list(self.session.execute(stmt).all())
def create_agreement(self, agreement_data: dict) -> TradeAgreement:
"""Create a new trade agreement"""
agreement = TradeAgreement(**agreement_data)
self.session.add(agreement)
self.session.commit()
self.session.refresh(agreement)
return agreement
async def get_analytics(self, period_type: str = "daily") -> dict[str, Any]:
"""Get trading analytics"""
# Placeholder for analytics logic
return {
"period_type": period_type,
"total_trades": 0,
"completed_trades": 0,
"total_trade_volume": 0.0,
"average_trade_value": 0.0,
}

View File

@@ -0,0 +1,44 @@
"""
Database session management for Trading service
"""
from contextlib import asynccontextmanager
from typing import AsyncIterator
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlmodel import SQLModel
from aitbc import get_logger
logger = get_logger(__name__)
# Database URL from environment variable or default
DATABASE_URL = "postgresql+asyncpg://aitbc_trading:password@localhost:5432/aitbc_trading"
# Create async engine
engine = create_async_engine(DATABASE_URL, echo=False)
async def init_db() -> None:
"""Initialize database tables"""
from .domain.trading import (
TradeRequest,
TradeMatch,
TradeNegotiation,
TradeAgreement,
TradeSettlement,
TradeFeedback,
TradingAnalytics,
)
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
logger.info("Trading service database initialized")
@asynccontextmanager
async def get_session() -> AsyncIterator[AsyncSession]:
"""Get database session"""
async with AsyncSession(engine) as session:
yield session