feat(coordinator-api): integrate dynamic pricing engine with GPU marketplace and add agent identity router
- Add DynamicPricingEngine and MarketDataCollector dependencies to GPU marketplace endpoints
- Implement dynamic pricing calculation for GPU registration with market_balance strategy
- Calculate real-time dynamic prices at booking time with confidence scores and pricing factors
- Enhance /marketplace/pricing/{model} endpoint with comprehensive dynamic pricing analysis
- Add static vs dynamic price
This commit is contained in:
687
apps/coordinator-api/src/app/services/agent_portfolio_manager.py
Normal file
687
apps/coordinator-api/src/app/services/agent_portfolio_manager.py
Normal file
@@ -0,0 +1,687 @@
|
||||
"""
|
||||
Agent Portfolio Manager Service
|
||||
|
||||
Advanced portfolio management for autonomous AI agents in the AITBC ecosystem.
|
||||
Provides portfolio creation, rebalancing, risk assessment, and trading strategy execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlmodel import Session
|
||||
|
||||
from ..domain.agent_portfolio import (
|
||||
AgentPortfolio,
|
||||
PortfolioStrategy,
|
||||
PortfolioAsset,
|
||||
PortfolioTrade,
|
||||
RiskMetrics,
|
||||
StrategyType,
|
||||
TradeStatus,
|
||||
RiskLevel
|
||||
)
|
||||
from ..schemas.portfolio import (
|
||||
PortfolioCreate,
|
||||
PortfolioResponse,
|
||||
PortfolioUpdate,
|
||||
TradeRequest,
|
||||
TradeResponse,
|
||||
RiskAssessmentResponse,
|
||||
RebalanceRequest,
|
||||
RebalanceResponse,
|
||||
StrategyCreate,
|
||||
StrategyResponse
|
||||
)
|
||||
from ..blockchain.contract_interactions import ContractInteractionService
|
||||
from ..marketdata.price_service import PriceService
|
||||
from ..risk.risk_calculator import RiskCalculator
|
||||
from ..ml.strategy_optimizer import StrategyOptimizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentPortfolioManager:
|
||||
"""Advanced portfolio management for autonomous agents"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
contract_service: ContractInteractionService,
|
||||
price_service: PriceService,
|
||||
risk_calculator: RiskCalculator,
|
||||
strategy_optimizer: StrategyOptimizer
|
||||
) -> None:
|
||||
self.session = session
|
||||
self.contract_service = contract_service
|
||||
self.price_service = price_service
|
||||
self.risk_calculator = risk_calculator
|
||||
self.strategy_optimizer = strategy_optimizer
|
||||
|
||||
async def create_portfolio(
|
||||
self,
|
||||
portfolio_data: PortfolioCreate,
|
||||
agent_address: str
|
||||
) -> PortfolioResponse:
|
||||
"""Create a new portfolio for an autonomous agent"""
|
||||
|
||||
try:
|
||||
# Validate agent address
|
||||
if not self._is_valid_address(agent_address):
|
||||
raise HTTPException(status_code=400, detail="Invalid agent address")
|
||||
|
||||
# Check if portfolio already exists
|
||||
existing_portfolio = self.session.exec(
|
||||
select(AgentPortfolio).where(
|
||||
AgentPortfolio.agent_address == agent_address
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_portfolio:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Portfolio already exists for this agent"
|
||||
)
|
||||
|
||||
# Get strategy
|
||||
strategy = self.session.get(PortfolioStrategy, portfolio_data.strategy_id)
|
||||
if not strategy or not strategy.is_active:
|
||||
raise HTTPException(status_code=404, detail="Strategy not found")
|
||||
|
||||
# Create portfolio
|
||||
portfolio = AgentPortfolio(
|
||||
agent_address=agent_address,
|
||||
strategy_id=portfolio_data.strategy_id,
|
||||
initial_capital=portfolio_data.initial_capital,
|
||||
risk_tolerance=portfolio_data.risk_tolerance,
|
||||
is_active=True,
|
||||
created_at=datetime.utcnow(),
|
||||
last_rebalance=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(portfolio)
|
||||
self.session.commit()
|
||||
self.session.refresh(portfolio)
|
||||
|
||||
# Initialize portfolio assets based on strategy
|
||||
await self._initialize_portfolio_assets(portfolio, strategy)
|
||||
|
||||
# Deploy smart contract portfolio
|
||||
contract_portfolio_id = await self._deploy_contract_portfolio(
|
||||
portfolio, agent_address, strategy
|
||||
)
|
||||
|
||||
portfolio.contract_portfolio_id = contract_portfolio_id
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"Created portfolio {portfolio.id} for agent {agent_address}")
|
||||
|
||||
return PortfolioResponse.from_orm(portfolio)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating portfolio: {str(e)}")
|
||||
self.session.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def execute_trade(
|
||||
self,
|
||||
trade_request: TradeRequest,
|
||||
agent_address: str
|
||||
) -> TradeResponse:
|
||||
"""Execute a trade within the agent's portfolio"""
|
||||
|
||||
try:
|
||||
# Get portfolio
|
||||
portfolio = self._get_agent_portfolio(agent_address)
|
||||
|
||||
# Validate trade request
|
||||
validation_result = await self._validate_trade_request(
|
||||
portfolio, trade_request
|
||||
)
|
||||
if not validation_result.is_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=validation_result.error_message
|
||||
)
|
||||
|
||||
# Get current prices
|
||||
sell_price = await self.price_service.get_price(trade_request.sell_token)
|
||||
buy_price = await self.price_service.get_price(trade_request.buy_token)
|
||||
|
||||
# Calculate expected buy amount
|
||||
expected_buy_amount = self._calculate_buy_amount(
|
||||
trade_request.sell_amount, sell_price, buy_price
|
||||
)
|
||||
|
||||
# Check slippage
|
||||
if expected_buy_amount < trade_request.min_buy_amount:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Insufficient buy amount (slippage protection)"
|
||||
)
|
||||
|
||||
# Execute trade on blockchain
|
||||
trade_result = await self.contract_service.execute_portfolio_trade(
|
||||
portfolio.contract_portfolio_id,
|
||||
trade_request.sell_token,
|
||||
trade_request.buy_token,
|
||||
trade_request.sell_amount,
|
||||
trade_request.min_buy_amount
|
||||
)
|
||||
|
||||
# Record trade in database
|
||||
trade = PortfolioTrade(
|
||||
portfolio_id=portfolio.id,
|
||||
sell_token=trade_request.sell_token,
|
||||
buy_token=trade_request.buy_token,
|
||||
sell_amount=trade_request.sell_amount,
|
||||
buy_amount=trade_result.buy_amount,
|
||||
price=trade_result.price,
|
||||
status=TradeStatus.EXECUTED,
|
||||
transaction_hash=trade_result.transaction_hash,
|
||||
executed_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(trade)
|
||||
|
||||
# Update portfolio assets
|
||||
await self._update_portfolio_assets(portfolio, trade)
|
||||
|
||||
# Update portfolio value and risk
|
||||
await self._update_portfolio_metrics(portfolio)
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(trade)
|
||||
|
||||
logger.info(f"Executed trade {trade.id} for portfolio {portfolio.id}")
|
||||
|
||||
return TradeResponse.from_orm(trade)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing trade: {str(e)}")
|
||||
self.session.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def execute_rebalancing(
|
||||
self,
|
||||
rebalance_request: RebalanceRequest,
|
||||
agent_address: str
|
||||
) -> RebalanceResponse:
|
||||
"""Automated portfolio rebalancing based on market conditions"""
|
||||
|
||||
try:
|
||||
# Get portfolio
|
||||
portfolio = self._get_agent_portfolio(agent_address)
|
||||
|
||||
# Check if rebalancing is needed
|
||||
if not await self._needs_rebalancing(portfolio):
|
||||
return RebalanceResponse(
|
||||
success=False,
|
||||
message="Rebalancing not needed at this time"
|
||||
)
|
||||
|
||||
# Get current market conditions
|
||||
market_conditions = await self.price_service.get_market_conditions()
|
||||
|
||||
# Calculate optimal allocations
|
||||
optimal_allocations = await self.strategy_optimizer.calculate_optimal_allocations(
|
||||
portfolio, market_conditions
|
||||
)
|
||||
|
||||
# Generate rebalancing trades
|
||||
rebalance_trades = await self._generate_rebalance_trades(
|
||||
portfolio, optimal_allocations
|
||||
)
|
||||
|
||||
if not rebalance_trades:
|
||||
return RebalanceResponse(
|
||||
success=False,
|
||||
message="No rebalancing trades required"
|
||||
)
|
||||
|
||||
# Execute rebalancing trades
|
||||
executed_trades = []
|
||||
for trade in rebalance_trades:
|
||||
try:
|
||||
trade_response = await self.execute_trade(trade, agent_address)
|
||||
executed_trades.append(trade_response)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to execute rebalancing trade: {str(e)}")
|
||||
continue
|
||||
|
||||
# Update portfolio rebalance timestamp
|
||||
portfolio.last_rebalance = datetime.utcnow()
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"Rebalanced portfolio {portfolio.id} with {len(executed_trades)} trades")
|
||||
|
||||
return RebalanceResponse(
|
||||
success=True,
|
||||
message=f"Rebalanced with {len(executed_trades)} trades",
|
||||
trades_executed=len(executed_trades)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing rebalancing: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def risk_assessment(self, agent_address: str) -> RiskAssessmentResponse:
|
||||
"""Real-time risk assessment and position sizing"""
|
||||
|
||||
try:
|
||||
# Get portfolio
|
||||
portfolio = self._get_agent_portfolio(agent_address)
|
||||
|
||||
# Get current portfolio value
|
||||
portfolio_value = await self._calculate_portfolio_value(portfolio)
|
||||
|
||||
# Calculate risk metrics
|
||||
risk_metrics = await self.risk_calculator.calculate_portfolio_risk(
|
||||
portfolio, portfolio_value
|
||||
)
|
||||
|
||||
# Update risk metrics in database
|
||||
existing_metrics = self.session.exec(
|
||||
select(RiskMetrics).where(RiskMetrics.portfolio_id == portfolio.id)
|
||||
).first()
|
||||
|
||||
if existing_metrics:
|
||||
existing_metrics.volatility = risk_metrics.volatility
|
||||
existing_metrics.max_drawdown = risk_metrics.max_drawdown
|
||||
existing_metrics.sharpe_ratio = risk_metrics.sharpe_ratio
|
||||
existing_metrics.var_95 = risk_metrics.var_95
|
||||
existing_metrics.risk_level = risk_metrics.risk_level
|
||||
existing_metrics.updated_at = datetime.utcnow()
|
||||
else:
|
||||
risk_metrics.portfolio_id = portfolio.id
|
||||
risk_metrics.updated_at = datetime.utcnow()
|
||||
self.session.add(risk_metrics)
|
||||
|
||||
# Update portfolio risk score
|
||||
portfolio.risk_score = risk_metrics.overall_risk_score
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"Risk assessment completed for portfolio {portfolio.id}")
|
||||
|
||||
return RiskAssessmentResponse.from_orm(risk_metrics)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in risk assessment: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def get_portfolio_performance(
|
||||
self,
|
||||
agent_address: str,
|
||||
period: str = "30d"
|
||||
) -> Dict:
|
||||
"""Get portfolio performance metrics"""
|
||||
|
||||
try:
|
||||
# Get portfolio
|
||||
portfolio = self._get_agent_portfolio(agent_address)
|
||||
|
||||
# Calculate performance metrics
|
||||
performance_data = await self._calculate_performance_metrics(
|
||||
portfolio, period
|
||||
)
|
||||
|
||||
return performance_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting portfolio performance: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def create_portfolio_strategy(
|
||||
self,
|
||||
strategy_data: StrategyCreate
|
||||
) -> StrategyResponse:
|
||||
"""Create a new portfolio strategy"""
|
||||
|
||||
try:
|
||||
# Validate strategy allocations
|
||||
total_allocation = sum(strategy_data.target_allocations.values())
|
||||
if abs(total_allocation - 100.0) > 0.01: # Allow small rounding errors
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Target allocations must sum to 100%"
|
||||
)
|
||||
|
||||
# Create strategy
|
||||
strategy = PortfolioStrategy(
|
||||
name=strategy_data.name,
|
||||
strategy_type=strategy_data.strategy_type,
|
||||
target_allocations=strategy_data.target_allocations,
|
||||
max_drawdown=strategy_data.max_drawdown,
|
||||
rebalance_frequency=strategy_data.rebalance_frequency,
|
||||
is_active=True,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(strategy)
|
||||
self.session.commit()
|
||||
self.session.refresh(strategy)
|
||||
|
||||
logger.info(f"Created strategy {strategy.id}: {strategy.name}")
|
||||
|
||||
return StrategyResponse.from_orm(strategy)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating strategy: {str(e)}")
|
||||
self.session.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Private helper methods
|
||||
|
||||
def _get_agent_portfolio(self, agent_address: str) -> AgentPortfolio:
|
||||
"""Get portfolio for agent address"""
|
||||
portfolio = self.session.exec(
|
||||
select(AgentPortfolio).where(
|
||||
AgentPortfolio.agent_address == agent_address
|
||||
)
|
||||
).first()
|
||||
|
||||
if not portfolio:
|
||||
raise HTTPException(status_code=404, detail="Portfolio not found")
|
||||
|
||||
return portfolio
|
||||
|
||||
def _is_valid_address(self, address: str) -> bool:
|
||||
"""Validate Ethereum address"""
|
||||
return (
|
||||
address.startswith("0x") and
|
||||
len(address) == 42 and
|
||||
all(c in "0123456789abcdefABCDEF" for c in address[2:])
|
||||
)
|
||||
|
||||
async def _initialize_portfolio_assets(
|
||||
self,
|
||||
portfolio: AgentPortfolio,
|
||||
strategy: PortfolioStrategy
|
||||
) -> None:
|
||||
"""Initialize portfolio assets based on strategy allocations"""
|
||||
|
||||
for token_symbol, allocation in strategy.target_allocations.items():
|
||||
if allocation > 0:
|
||||
asset = PortfolioAsset(
|
||||
portfolio_id=portfolio.id,
|
||||
token_symbol=token_symbol,
|
||||
target_allocation=allocation,
|
||||
current_allocation=0.0,
|
||||
balance=0,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
self.session.add(asset)
|
||||
|
||||
async def _deploy_contract_portfolio(
|
||||
self,
|
||||
portfolio: AgentPortfolio,
|
||||
agent_address: str,
|
||||
strategy: PortfolioStrategy
|
||||
) -> str:
|
||||
"""Deploy smart contract portfolio"""
|
||||
|
||||
try:
|
||||
# Convert strategy allocations to contract format
|
||||
contract_allocations = {
|
||||
token: int(allocation * 100) # Convert to basis points
|
||||
for token, allocation in strategy.target_allocations.items()
|
||||
}
|
||||
|
||||
# Create portfolio on blockchain
|
||||
portfolio_id = await self.contract_service.create_portfolio(
|
||||
agent_address,
|
||||
strategy.strategy_type.value,
|
||||
contract_allocations
|
||||
)
|
||||
|
||||
return str(portfolio_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deploying contract portfolio: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _validate_trade_request(
|
||||
self,
|
||||
portfolio: AgentPortfolio,
|
||||
trade_request: TradeRequest
|
||||
) -> ValidationResult:
|
||||
"""Validate trade request"""
|
||||
|
||||
# Check if sell token exists in portfolio
|
||||
sell_asset = self.session.exec(
|
||||
select(PortfolioAsset).where(
|
||||
PortfolioAsset.portfolio_id == portfolio.id,
|
||||
PortfolioAsset.token_symbol == trade_request.sell_token
|
||||
)
|
||||
).first()
|
||||
|
||||
if not sell_asset:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Sell token not found in portfolio"
|
||||
)
|
||||
|
||||
# Check sufficient balance
|
||||
if sell_asset.balance < trade_request.sell_amount:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Insufficient balance"
|
||||
)
|
||||
|
||||
# Check risk limits
|
||||
current_risk = await self.risk_calculator.calculate_trade_risk(
|
||||
portfolio, trade_request
|
||||
)
|
||||
|
||||
if current_risk > portfolio.risk_tolerance:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Trade exceeds risk tolerance"
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
def _calculate_buy_amount(
|
||||
self,
|
||||
sell_amount: float,
|
||||
sell_price: float,
|
||||
buy_price: float
|
||||
) -> float:
|
||||
"""Calculate expected buy amount"""
|
||||
sell_value = sell_amount * sell_price
|
||||
return sell_value / buy_price
|
||||
|
||||
async def _update_portfolio_assets(
|
||||
self,
|
||||
portfolio: AgentPortfolio,
|
||||
trade: PortfolioTrade
|
||||
) -> None:
|
||||
"""Update portfolio assets after trade"""
|
||||
|
||||
# Update sell asset
|
||||
sell_asset = self.session.exec(
|
||||
select(PortfolioAsset).where(
|
||||
PortfolioAsset.portfolio_id == portfolio.id,
|
||||
PortfolioAsset.token_symbol == trade.sell_token
|
||||
)
|
||||
).first()
|
||||
|
||||
if sell_asset:
|
||||
sell_asset.balance -= trade.sell_amount
|
||||
sell_asset.updated_at = datetime.utcnow()
|
||||
|
||||
# Update buy asset
|
||||
buy_asset = self.session.exec(
|
||||
select(PortfolioAsset).where(
|
||||
PortfolioAsset.portfolio_id == portfolio.id,
|
||||
PortfolioAsset.token_symbol == trade.buy_token
|
||||
)
|
||||
).first()
|
||||
|
||||
if buy_asset:
|
||||
buy_asset.balance += trade.buy_amount
|
||||
buy_asset.updated_at = datetime.utcnow()
|
||||
else:
|
||||
# Create new asset if it doesn't exist
|
||||
new_asset = PortfolioAsset(
|
||||
portfolio_id=portfolio.id,
|
||||
token_symbol=trade.buy_token,
|
||||
target_allocation=0.0,
|
||||
current_allocation=0.0,
|
||||
balance=trade.buy_amount,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
self.session.add(new_asset)
|
||||
|
||||
async def _update_portfolio_metrics(self, portfolio: AgentPortfolio) -> None:
|
||||
"""Update portfolio value and allocations"""
|
||||
|
||||
portfolio_value = await self._calculate_portfolio_value(portfolio)
|
||||
|
||||
# Update current allocations
|
||||
assets = self.session.exec(
|
||||
select(PortfolioAsset).where(
|
||||
PortfolioAsset.portfolio_id == portfolio.id
|
||||
)
|
||||
).all()
|
||||
|
||||
for asset in assets:
|
||||
if asset.balance > 0:
|
||||
price = await self.price_service.get_price(asset.token_symbol)
|
||||
asset_value = asset.balance * price
|
||||
asset.current_allocation = (asset_value / portfolio_value) * 100
|
||||
asset.updated_at = datetime.utcnow()
|
||||
|
||||
portfolio.total_value = portfolio_value
|
||||
portfolio.updated_at = datetime.utcnow()
|
||||
|
||||
async def _calculate_portfolio_value(self, portfolio: AgentPortfolio) -> float:
|
||||
"""Calculate total portfolio value"""
|
||||
|
||||
assets = self.session.exec(
|
||||
select(PortfolioAsset).where(
|
||||
PortfolioAsset.portfolio_id == portfolio.id
|
||||
)
|
||||
).all()
|
||||
|
||||
total_value = 0.0
|
||||
for asset in assets:
|
||||
if asset.balance > 0:
|
||||
price = await self.price_service.get_price(asset.token_symbol)
|
||||
total_value += asset.balance * price
|
||||
|
||||
return total_value
|
||||
|
||||
async def _needs_rebalancing(self, portfolio: AgentPortfolio) -> bool:
|
||||
"""Check if portfolio needs rebalancing"""
|
||||
|
||||
# Check time-based rebalancing
|
||||
strategy = self.session.get(PortfolioStrategy, portfolio.strategy_id)
|
||||
if not strategy:
|
||||
return False
|
||||
|
||||
time_since_rebalance = datetime.utcnow() - portfolio.last_rebalance
|
||||
if time_since_rebalance > timedelta(seconds=strategy.rebalance_frequency):
|
||||
return True
|
||||
|
||||
# Check threshold-based rebalancing
|
||||
assets = self.session.exec(
|
||||
select(PortfolioAsset).where(
|
||||
PortfolioAsset.portfolio_id == portfolio.id
|
||||
)
|
||||
).all()
|
||||
|
||||
for asset in assets:
|
||||
if asset.balance > 0:
|
||||
deviation = abs(asset.current_allocation - asset.target_allocation)
|
||||
if deviation > 5.0: # 5% deviation threshold
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _generate_rebalance_trades(
|
||||
self,
|
||||
portfolio: AgentPortfolio,
|
||||
optimal_allocations: Dict[str, float]
|
||||
) -> List[TradeRequest]:
|
||||
"""Generate rebalancing trades"""
|
||||
|
||||
trades = []
|
||||
assets = self.session.exec(
|
||||
select(PortfolioAsset).where(
|
||||
PortfolioAsset.portfolio_id == portfolio.id
|
||||
)
|
||||
).all()
|
||||
|
||||
# Calculate current vs target allocations
|
||||
for asset in assets:
|
||||
target_allocation = optimal_allocations.get(asset.token_symbol, 0.0)
|
||||
current_allocation = asset.current_allocation
|
||||
|
||||
if abs(current_allocation - target_allocation) > 1.0: # 1% minimum deviation
|
||||
if current_allocation > target_allocation:
|
||||
# Sell excess
|
||||
excess_percentage = current_allocation - target_allocation
|
||||
sell_amount = (asset.balance * excess_percentage) / 100
|
||||
|
||||
# Find asset to buy
|
||||
for other_asset in assets:
|
||||
other_target = optimal_allocations.get(other_asset.token_symbol, 0.0)
|
||||
other_current = other_asset.current_allocation
|
||||
|
||||
if other_current < other_target:
|
||||
trade = TradeRequest(
|
||||
sell_token=asset.token_symbol,
|
||||
buy_token=other_asset.token_symbol,
|
||||
sell_amount=sell_amount,
|
||||
min_buy_amount=0 # Will be calculated during execution
|
||||
)
|
||||
trades.append(trade)
|
||||
break
|
||||
|
||||
return trades
|
||||
|
||||
async def _calculate_performance_metrics(
|
||||
self,
|
||||
portfolio: AgentPortfolio,
|
||||
period: str
|
||||
) -> Dict:
|
||||
"""Calculate portfolio performance metrics"""
|
||||
|
||||
# Get historical trades
|
||||
trades = self.session.exec(
|
||||
select(PortfolioTrade)
|
||||
.where(PortfolioTrade.portfolio_id == portfolio.id)
|
||||
.order_by(PortfolioTrade.executed_at.desc())
|
||||
).all()
|
||||
|
||||
# Calculate returns, volatility, etc.
|
||||
# This is a simplified implementation
|
||||
current_value = await self._calculate_portfolio_value(portfolio)
|
||||
initial_value = portfolio.initial_capital
|
||||
|
||||
total_return = ((current_value - initial_value) / initial_value) * 100
|
||||
|
||||
return {
|
||||
"total_return": total_return,
|
||||
"current_value": current_value,
|
||||
"initial_value": initial_value,
|
||||
"total_trades": len(trades),
|
||||
"last_updated": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
class ValidationResult:
|
||||
"""Validation result for trade requests"""
|
||||
|
||||
def __init__(self, is_valid: bool, error_message: str = ""):
|
||||
self.is_valid = is_valid
|
||||
self.error_message = error_message
|
||||
771
apps/coordinator-api/src/app/services/amm_service.py
Normal file
771
apps/coordinator-api/src/app/services/amm_service.py
Normal file
@@ -0,0 +1,771 @@
|
||||
"""
|
||||
AMM Service
|
||||
|
||||
Automated market making for AI service tokens in the AITBC ecosystem.
|
||||
Provides liquidity pool management, token swapping, and dynamic fee adjustment.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlmodel import Session
|
||||
|
||||
from ..domain.amm import (
|
||||
LiquidityPool,
|
||||
LiquidityPosition,
|
||||
SwapTransaction,
|
||||
PoolMetrics,
|
||||
FeeStructure,
|
||||
IncentiveProgram
|
||||
)
|
||||
from ..schemas.amm import (
|
||||
PoolCreate,
|
||||
PoolResponse,
|
||||
LiquidityAddRequest,
|
||||
LiquidityAddResponse,
|
||||
LiquidityRemoveRequest,
|
||||
LiquidityRemoveResponse,
|
||||
SwapRequest,
|
||||
SwapResponse,
|
||||
PoolMetricsResponse,
|
||||
FeeAdjustmentRequest,
|
||||
IncentiveCreateRequest
|
||||
)
|
||||
from ..blockchain.contract_interactions import ContractInteractionService
|
||||
from ..marketdata.price_service import PriceService
|
||||
from ..risk.volatility_calculator import VolatilityCalculator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AMMService:
|
||||
"""Automated market making for AI service tokens"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
contract_service: ContractInteractionService,
|
||||
price_service: PriceService,
|
||||
volatility_calculator: VolatilityCalculator
|
||||
) -> None:
|
||||
self.session = session
|
||||
self.contract_service = contract_service
|
||||
self.price_service = price_service
|
||||
self.volatility_calculator = volatility_calculator
|
||||
|
||||
# Default configuration
|
||||
self.default_fee_percentage = 0.3 # 0.3% default fee
|
||||
self.min_liquidity_threshold = 1000 # Minimum liquidity in USD
|
||||
self.max_slippage_percentage = 5.0 # Maximum 5% slippage
|
||||
self.incentive_duration_days = 30 # Default incentive duration
|
||||
|
||||
async def create_service_pool(
|
||||
self,
|
||||
pool_data: PoolCreate,
|
||||
creator_address: str
|
||||
) -> PoolResponse:
|
||||
"""Create liquidity pool for AI service trading"""
|
||||
|
||||
try:
|
||||
# Validate pool creation request
|
||||
validation_result = await self._validate_pool_creation(pool_data, creator_address)
|
||||
if not validation_result.is_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=validation_result.error_message
|
||||
)
|
||||
|
||||
# Check if pool already exists for this token pair
|
||||
existing_pool = await self._get_existing_pool(pool_data.token_a, pool_data.token_b)
|
||||
if existing_pool:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Pool already exists for this token pair"
|
||||
)
|
||||
|
||||
# Create pool on blockchain
|
||||
contract_pool_id = await self.contract_service.create_amm_pool(
|
||||
pool_data.token_a,
|
||||
pool_data.token_b,
|
||||
int(pool_data.fee_percentage * 100) # Convert to basis points
|
||||
)
|
||||
|
||||
# Create pool record in database
|
||||
pool = LiquidityPool(
|
||||
contract_pool_id=str(contract_pool_id),
|
||||
token_a=pool_data.token_a,
|
||||
token_b=pool_data.token_b,
|
||||
fee_percentage=pool_data.fee_percentage,
|
||||
total_liquidity=0.0,
|
||||
reserve_a=0.0,
|
||||
reserve_b=0.0,
|
||||
is_active=True,
|
||||
created_at=datetime.utcnow(),
|
||||
created_by=creator_address
|
||||
)
|
||||
|
||||
self.session.add(pool)
|
||||
self.session.commit()
|
||||
self.session.refresh(pool)
|
||||
|
||||
# Initialize pool metrics
|
||||
await self._initialize_pool_metrics(pool)
|
||||
|
||||
logger.info(f"Created AMM pool {pool.id} for {pool_data.token_a}/{pool_data.token_b}")
|
||||
|
||||
return PoolResponse.from_orm(pool)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating service pool: {str(e)}")
|
||||
self.session.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def add_liquidity(
|
||||
self,
|
||||
liquidity_request: LiquidityAddRequest,
|
||||
provider_address: str
|
||||
) -> LiquidityAddResponse:
|
||||
"""Add liquidity to a pool"""
|
||||
|
||||
try:
|
||||
# Get pool
|
||||
pool = await self._get_pool_by_id(liquidity_request.pool_id)
|
||||
|
||||
# Validate liquidity request
|
||||
validation_result = await self._validate_liquidity_addition(
|
||||
pool, liquidity_request, provider_address
|
||||
)
|
||||
if not validation_result.is_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=validation_result.error_message
|
||||
)
|
||||
|
||||
# Calculate optimal amounts
|
||||
optimal_amount_b = await self._calculate_optimal_amount_b(
|
||||
pool, liquidity_request.amount_a
|
||||
)
|
||||
|
||||
if liquidity_request.amount_b < optimal_amount_b:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Insufficient token B amount. Minimum required: {optimal_amount_b}"
|
||||
)
|
||||
|
||||
# Add liquidity on blockchain
|
||||
liquidity_result = await self.contract_service.add_liquidity(
|
||||
pool.contract_pool_id,
|
||||
liquidity_request.amount_a,
|
||||
liquidity_request.amount_b,
|
||||
liquidity_request.min_amount_a,
|
||||
liquidity_request.min_amount_b
|
||||
)
|
||||
|
||||
# Update pool reserves
|
||||
pool.reserve_a += liquidity_request.amount_a
|
||||
pool.reserve_b += liquidity_request.amount_b
|
||||
pool.total_liquidity += liquidity_result.liquidity_received
|
||||
pool.updated_at = datetime.utcnow()
|
||||
|
||||
# Update or create liquidity position
|
||||
position = self.session.exec(
|
||||
select(LiquidityPosition).where(
|
||||
LiquidityPosition.pool_id == pool.id,
|
||||
LiquidityPosition.provider_address == provider_address
|
||||
)
|
||||
).first()
|
||||
|
||||
if position:
|
||||
position.liquidity_amount += liquidity_result.liquidity_received
|
||||
position.shares_owned = (position.liquidity_amount / pool.total_liquidity) * 100
|
||||
position.last_deposit = datetime.utcnow()
|
||||
else:
|
||||
position = LiquidityPosition(
|
||||
pool_id=pool.id,
|
||||
provider_address=provider_address,
|
||||
liquidity_amount=liquidity_result.liquidity_received,
|
||||
shares_owned=(liquidity_result.liquidity_received / pool.total_liquidity) * 100,
|
||||
last_deposit=datetime.utcnow(),
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
self.session.add(position)
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(position)
|
||||
|
||||
# Update pool metrics
|
||||
await self._update_pool_metrics(pool)
|
||||
|
||||
logger.info(f"Added liquidity to pool {pool.id} by {provider_address}")
|
||||
|
||||
return LiquidityAddResponse.from_orm(position)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding liquidity: {str(e)}")
|
||||
self.session.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def remove_liquidity(
|
||||
self,
|
||||
liquidity_request: LiquidityRemoveRequest,
|
||||
provider_address: str
|
||||
) -> LiquidityRemoveResponse:
|
||||
"""Remove liquidity from a pool"""
|
||||
|
||||
try:
|
||||
# Get pool
|
||||
pool = await self._get_pool_by_id(liquidity_request.pool_id)
|
||||
|
||||
# Get liquidity position
|
||||
position = self.session.exec(
|
||||
select(LiquidityPosition).where(
|
||||
LiquidityPosition.pool_id == pool.id,
|
||||
LiquidityPosition.provider_address == provider_address
|
||||
)
|
||||
).first()
|
||||
|
||||
if not position:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Liquidity position not found"
|
||||
)
|
||||
|
||||
if position.liquidity_amount < liquidity_request.liquidity_amount:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Insufficient liquidity amount"
|
||||
)
|
||||
|
||||
# Remove liquidity on blockchain
|
||||
removal_result = await self.contract_service.remove_liquidity(
|
||||
pool.contract_pool_id,
|
||||
liquidity_request.liquidity_amount,
|
||||
liquidity_request.min_amount_a,
|
||||
liquidity_request.min_amount_b
|
||||
)
|
||||
|
||||
# Update pool reserves
|
||||
pool.reserve_a -= removal_result.amount_a
|
||||
pool.reserve_b -= removal_result.amount_b
|
||||
pool.total_liquidity -= liquidity_request.liquidity_amount
|
||||
pool.updated_at = datetime.utcnow()
|
||||
|
||||
# Update liquidity position
|
||||
position.liquidity_amount -= liquidity_request.liquidity_amount
|
||||
position.shares_owned = (position.liquidity_amount / pool.total_liquidity) * 100 if pool.total_liquidity > 0 else 0
|
||||
position.last_withdrawal = datetime.utcnow()
|
||||
|
||||
# Remove position if empty
|
||||
if position.liquidity_amount == 0:
|
||||
self.session.delete(position)
|
||||
|
||||
self.session.commit()
|
||||
|
||||
# Update pool metrics
|
||||
await self._update_pool_metrics(pool)
|
||||
|
||||
logger.info(f"Removed liquidity from pool {pool.id} by {provider_address}")
|
||||
|
||||
return LiquidityRemoveResponse(
|
||||
pool_id=pool.id,
|
||||
amount_a=removal_result.amount_a,
|
||||
amount_b=removal_result.amount_b,
|
||||
liquidity_removed=liquidity_request.liquidity_amount,
|
||||
remaining_liquidity=position.liquidity_amount if position.liquidity_amount > 0 else 0
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing liquidity: {str(e)}")
|
||||
self.session.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def execute_swap(
|
||||
self,
|
||||
swap_request: SwapRequest,
|
||||
user_address: str
|
||||
) -> SwapResponse:
|
||||
"""Execute token swap"""
|
||||
|
||||
try:
|
||||
# Get pool
|
||||
pool = await self._get_pool_by_id(swap_request.pool_id)
|
||||
|
||||
# Validate swap request
|
||||
validation_result = await self._validate_swap_request(
|
||||
pool, swap_request, user_address
|
||||
)
|
||||
if not validation_result.is_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=validation_result.error_message
|
||||
)
|
||||
|
||||
# Calculate expected output amount
|
||||
expected_output = await self._calculate_swap_output(
|
||||
pool, swap_request.amount_in, swap_request.token_in
|
||||
)
|
||||
|
||||
# Check slippage
|
||||
slippage_percentage = ((expected_output - swap_request.min_amount_out) / expected_output) * 100
|
||||
if slippage_percentage > self.max_slippage_percentage:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slippage too high: {slippage_percentage:.2f}%"
|
||||
)
|
||||
|
||||
# Execute swap on blockchain
|
||||
swap_result = await self.contract_service.execute_swap(
|
||||
pool.contract_pool_id,
|
||||
swap_request.token_in,
|
||||
swap_request.token_out,
|
||||
swap_request.amount_in,
|
||||
swap_request.min_amount_out,
|
||||
user_address,
|
||||
swap_request.deadline
|
||||
)
|
||||
|
||||
# Update pool reserves
|
||||
if swap_request.token_in == pool.token_a:
|
||||
pool.reserve_a += swap_request.amount_in
|
||||
pool.reserve_b -= swap_result.amount_out
|
||||
else:
|
||||
pool.reserve_b += swap_request.amount_in
|
||||
pool.reserve_a -= swap_result.amount_out
|
||||
|
||||
pool.updated_at = datetime.utcnow()
|
||||
|
||||
# Record swap transaction
|
||||
swap_transaction = SwapTransaction(
|
||||
pool_id=pool.id,
|
||||
user_address=user_address,
|
||||
token_in=swap_request.token_in,
|
||||
token_out=swap_request.token_out,
|
||||
amount_in=swap_request.amount_in,
|
||||
amount_out=swap_result.amount_out,
|
||||
price=swap_result.price,
|
||||
fee_amount=swap_result.fee_amount,
|
||||
transaction_hash=swap_result.transaction_hash,
|
||||
executed_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(swap_transaction)
|
||||
self.session.commit()
|
||||
self.session.refresh(swap_transaction)
|
||||
|
||||
# Update pool metrics
|
||||
await self._update_pool_metrics(pool)
|
||||
|
||||
logger.info(f"Executed swap {swap_transaction.id} in pool {pool.id}")
|
||||
|
||||
return SwapResponse.from_orm(swap_transaction)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing swap: {str(e)}")
|
||||
self.session.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def dynamic_fee_adjustment(
|
||||
self,
|
||||
pool_id: int,
|
||||
volatility: float
|
||||
) -> FeeStructure:
|
||||
"""Adjust trading fees based on market volatility"""
|
||||
|
||||
try:
|
||||
# Get pool
|
||||
pool = await self._get_pool_by_id(pool_id)
|
||||
|
||||
# Calculate optimal fee based on volatility
|
||||
base_fee = self.default_fee_percentage
|
||||
volatility_multiplier = 1.0 + (volatility / 100.0) # Increase fee with volatility
|
||||
|
||||
# Apply fee caps
|
||||
new_fee = min(base_fee * volatility_multiplier, 1.0) # Max 1% fee
|
||||
new_fee = max(new_fee, 0.05) # Min 0.05% fee
|
||||
|
||||
# Update pool fee on blockchain
|
||||
await self.contract_service.update_pool_fee(
|
||||
pool.contract_pool_id,
|
||||
int(new_fee * 100) # Convert to basis points
|
||||
)
|
||||
|
||||
# Update pool in database
|
||||
pool.fee_percentage = new_fee
|
||||
pool.updated_at = datetime.utcnow()
|
||||
self.session.commit()
|
||||
|
||||
# Create fee structure response
|
||||
fee_structure = FeeStructure(
|
||||
pool_id=pool_id,
|
||||
base_fee_percentage=base_fee,
|
||||
current_fee_percentage=new_fee,
|
||||
volatility_adjustment=volatility_multiplier - 1.0,
|
||||
adjusted_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
logger.info(f"Adjusted fee for pool {pool_id} to {new_fee:.3f}%")
|
||||
|
||||
return fee_structure
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adjusting fees: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def liquidity_incentives(
|
||||
self,
|
||||
pool_id: int
|
||||
) -> IncentiveProgram:
|
||||
"""Implement liquidity provider rewards"""
|
||||
|
||||
try:
|
||||
# Get pool
|
||||
pool = await self._get_pool_by_id(pool_id)
|
||||
|
||||
# Calculate incentive parameters based on pool metrics
|
||||
pool_metrics = await self._get_pool_metrics(pool)
|
||||
|
||||
# Higher incentives for lower liquidity pools
|
||||
liquidity_ratio = pool_metrics.total_value_locked / 1000000 # Normalize to 1M USD
|
||||
incentive_multiplier = max(1.0, 2.0 - liquidity_ratio) # 2x for small pools, 1x for large
|
||||
|
||||
# Calculate daily reward amount
|
||||
daily_reward = 100 * incentive_multiplier # Base $100 per day, adjusted by multiplier
|
||||
|
||||
# Create or update incentive program
|
||||
existing_program = self.session.exec(
|
||||
select(IncentiveProgram).where(IncentiveProgram.pool_id == pool_id)
|
||||
).first()
|
||||
|
||||
if existing_program:
|
||||
existing_program.daily_reward_amount = daily_reward
|
||||
existing_program.incentive_multiplier = incentive_multiplier
|
||||
existing_program.updated_at = datetime.utcnow()
|
||||
program = existing_program
|
||||
else:
|
||||
program = IncentiveProgram(
|
||||
pool_id=pool_id,
|
||||
daily_reward_amount=daily_reward,
|
||||
incentive_multiplier=incentive_multiplier,
|
||||
duration_days=self.incentive_duration_days,
|
||||
is_active=True,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
self.session.add(program)
|
||||
|
||||
self.session.commit()
|
||||
self.session.refresh(program)
|
||||
|
||||
logger.info(f"Created incentive program for pool {pool_id} with daily reward ${daily_reward:.2f}")
|
||||
|
||||
return program
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating incentive program: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def get_pool_metrics(self, pool_id: int) -> PoolMetricsResponse:
|
||||
"""Get comprehensive pool metrics"""
|
||||
|
||||
try:
|
||||
# Get pool
|
||||
pool = await self._get_pool_by_id(pool_id)
|
||||
|
||||
# Get detailed metrics
|
||||
metrics = await self._get_pool_metrics(pool)
|
||||
|
||||
return PoolMetricsResponse.from_orm(metrics)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pool metrics: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def get_user_positions(self, user_address: str) -> List[LiquidityPosition]:
|
||||
"""Get all liquidity positions for a user"""
|
||||
|
||||
try:
|
||||
positions = self.session.exec(
|
||||
select(LiquidityPosition).where(
|
||||
LiquidityPosition.provider_address == user_address
|
||||
)
|
||||
).all()
|
||||
|
||||
return positions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user positions: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Private helper methods
|
||||
|
||||
async def _get_pool_by_id(self, pool_id: int) -> LiquidityPool:
|
||||
"""Get pool by ID"""
|
||||
pool = self.session.get(LiquidityPool, pool_id)
|
||||
if not pool or not pool.is_active:
|
||||
raise HTTPException(status_code=404, detail="Pool not found")
|
||||
return pool
|
||||
|
||||
async def _get_existing_pool(self, token_a: str, token_b: str) -> Optional[LiquidityPool]:
|
||||
"""Check if pool exists for token pair"""
|
||||
pool = self.session.exec(
|
||||
select(LiquidityPool).where(
|
||||
(
|
||||
(LiquidityPool.token_a == token_a) &
|
||||
(LiquidityPool.token_b == token_b)
|
||||
) | (
|
||||
(LiquidityPool.token_a == token_b) &
|
||||
(LiquidityPool.token_b == token_a)
|
||||
)
|
||||
)
|
||||
).first()
|
||||
return pool
|
||||
|
||||
async def _validate_pool_creation(
|
||||
self,
|
||||
pool_data: PoolCreate,
|
||||
creator_address: str
|
||||
) -> ValidationResult:
|
||||
"""Validate pool creation request"""
|
||||
|
||||
# Check token addresses
|
||||
if pool_data.token_a == pool_data.token_b:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Token addresses must be different"
|
||||
)
|
||||
|
||||
# Validate fee percentage
|
||||
if not (0.05 <= pool_data.fee_percentage <= 1.0):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Fee percentage must be between 0.05% and 1.0%"
|
||||
)
|
||||
|
||||
# Check if tokens are supported
|
||||
# This would integrate with a token registry service
|
||||
# For now, we'll assume all tokens are supported
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
async def _validate_liquidity_addition(
|
||||
self,
|
||||
pool: LiquidityPool,
|
||||
liquidity_request: LiquidityAddRequest,
|
||||
provider_address: str
|
||||
) -> ValidationResult:
|
||||
"""Validate liquidity addition request"""
|
||||
|
||||
# Check minimum amounts
|
||||
if liquidity_request.amount_a <= 0 or liquidity_request.amount_b <= 0:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Amounts must be greater than 0"
|
||||
)
|
||||
|
||||
# Check if this is first liquidity (no ratio constraints)
|
||||
if pool.total_liquidity == 0:
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
# Calculate optimal ratio
|
||||
optimal_amount_b = await self._calculate_optimal_amount_b(
|
||||
pool, liquidity_request.amount_a
|
||||
)
|
||||
|
||||
# Allow 1% deviation
|
||||
min_required = optimal_amount_b * 0.99
|
||||
if liquidity_request.amount_b < min_required:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Insufficient token B amount. Minimum: {min_required}"
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
async def _validate_swap_request(
|
||||
self,
|
||||
pool: LiquidityPool,
|
||||
swap_request: SwapRequest,
|
||||
user_address: str
|
||||
) -> ValidationResult:
|
||||
"""Validate swap request"""
|
||||
|
||||
# Check if pool has sufficient liquidity
|
||||
if swap_request.token_in == pool.token_a:
|
||||
if pool.reserve_b < swap_request.min_amount_out:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Insufficient liquidity in pool"
|
||||
)
|
||||
else:
|
||||
if pool.reserve_a < swap_request.min_amount_out:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Insufficient liquidity in pool"
|
||||
)
|
||||
|
||||
# Check deadline
|
||||
if datetime.utcnow() > swap_request.deadline:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Transaction deadline expired"
|
||||
)
|
||||
|
||||
# Check minimum amount
|
||||
if swap_request.amount_in <= 0:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Amount must be greater than 0"
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
async def _calculate_optimal_amount_b(
|
||||
self,
|
||||
pool: LiquidityPool,
|
||||
amount_a: float
|
||||
) -> float:
|
||||
"""Calculate optimal amount of token B for adding liquidity"""
|
||||
|
||||
if pool.reserve_a == 0:
|
||||
return 0.0
|
||||
|
||||
return (amount_a * pool.reserve_b) / pool.reserve_a
|
||||
|
||||
async def _calculate_swap_output(
|
||||
self,
|
||||
pool: LiquidityPool,
|
||||
amount_in: float,
|
||||
token_in: str
|
||||
) -> float:
|
||||
"""Calculate output amount for swap using constant product formula"""
|
||||
|
||||
# Determine reserves
|
||||
if token_in == pool.token_a:
|
||||
reserve_in = pool.reserve_a
|
||||
reserve_out = pool.reserve_b
|
||||
else:
|
||||
reserve_in = pool.reserve_b
|
||||
reserve_out = pool.reserve_a
|
||||
|
||||
# Apply fee
|
||||
fee_amount = (amount_in * pool.fee_percentage) / 100
|
||||
amount_in_after_fee = amount_in - fee_amount
|
||||
|
||||
# Calculate output using constant product formula
|
||||
# x * y = k
|
||||
# (x + amount_in) * (y - amount_out) = k
|
||||
# amount_out = (amount_in_after_fee * y) / (x + amount_in_after_fee)
|
||||
|
||||
amount_out = (amount_in_after_fee * reserve_out) / (reserve_in + amount_in_after_fee)
|
||||
|
||||
return amount_out
|
||||
|
||||
async def _initialize_pool_metrics(self, pool: LiquidityPool) -> None:
|
||||
"""Initialize pool metrics"""
|
||||
|
||||
metrics = PoolMetrics(
|
||||
pool_id=pool.id,
|
||||
total_volume_24h=0.0,
|
||||
total_fees_24h=0.0,
|
||||
total_value_locked=0.0,
|
||||
apr=0.0,
|
||||
utilization_rate=0.0,
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(metrics)
|
||||
self.session.commit()
|
||||
|
||||
async def _update_pool_metrics(self, pool: LiquidityPool) -> None:
|
||||
"""Update pool metrics"""
|
||||
|
||||
# Get existing metrics
|
||||
metrics = self.session.exec(
|
||||
select(PoolMetrics).where(PoolMetrics.pool_id == pool.id)
|
||||
).first()
|
||||
|
||||
if not metrics:
|
||||
await self._initialize_pool_metrics(pool)
|
||||
metrics = self.session.exec(
|
||||
select(PoolMetrics).where(PoolMetrics.pool_id == pool.id)
|
||||
).first()
|
||||
|
||||
# Calculate TVL (simplified - would use actual token prices)
|
||||
token_a_price = await self.price_service.get_price(pool.token_a)
|
||||
token_b_price = await self.price_service.get_price(pool.token_b)
|
||||
|
||||
tvl = (pool.reserve_a * token_a_price) + (pool.reserve_b * token_b_price)
|
||||
|
||||
# Calculate APR (simplified)
|
||||
apr = 0.0
|
||||
if tvl > 0 and pool.total_liquidity > 0:
|
||||
daily_fees = metrics.total_fees_24h
|
||||
annual_fees = daily_fees * 365
|
||||
apr = (annual_fees / tvl) * 100
|
||||
|
||||
# Calculate utilization rate
|
||||
utilization_rate = 0.0
|
||||
if pool.total_liquidity > 0:
|
||||
# Simplified utilization calculation
|
||||
utilization_rate = (tvl / pool.total_liquidity) * 100
|
||||
|
||||
# Update metrics
|
||||
metrics.total_value_locked = tvl
|
||||
metrics.apr = apr
|
||||
metrics.utilization_rate = utilization_rate
|
||||
metrics.updated_at = datetime.utcnow()
|
||||
|
||||
self.session.commit()
|
||||
|
||||
async def _get_pool_metrics(self, pool: LiquidityPool) -> PoolMetrics:
|
||||
"""Get comprehensive pool metrics"""
|
||||
|
||||
metrics = self.session.exec(
|
||||
select(PoolMetrics).where(PoolMetrics.pool_id == pool.id)
|
||||
).first()
|
||||
|
||||
if not metrics:
|
||||
await self._initialize_pool_metrics(pool)
|
||||
metrics = self.session.exec(
|
||||
select(PoolMetrics).where(PoolMetrics.pool_id == pool.id)
|
||||
).first()
|
||||
|
||||
# Calculate 24h volume and fees
|
||||
twenty_four_hours_ago = datetime.utcnow() - timedelta(hours=24)
|
||||
|
||||
recent_swaps = self.session.exec(
|
||||
select(SwapTransaction).where(
|
||||
SwapTransaction.pool_id == pool.id,
|
||||
SwapTransaction.executed_at >= twenty_four_hours_ago
|
||||
)
|
||||
).all()
|
||||
|
||||
total_volume = sum(swap.amount_in for swap in recent_swaps)
|
||||
total_fees = sum(swap.fee_amount for swap in recent_swaps)
|
||||
|
||||
metrics.total_volume_24h = total_volume
|
||||
metrics.total_fees_24h = total_fees
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
class ValidationResult:
|
||||
"""Validation result for requests"""
|
||||
|
||||
def __init__(self, is_valid: bool, error_message: str = ""):
|
||||
self.is_valid = is_valid
|
||||
self.error_message = error_message
|
||||
803
apps/coordinator-api/src/app/services/cross_chain_bridge.py
Normal file
803
apps/coordinator-api/src/app/services/cross_chain_bridge.py
Normal file
@@ -0,0 +1,803 @@
|
||||
"""
|
||||
Cross-Chain Bridge Service
|
||||
|
||||
Secure cross-chain asset transfer protocol with ZK proof validation.
|
||||
Enables bridging of assets between different blockchain networks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlmodel import Session
|
||||
|
||||
from ..domain.cross_chain_bridge import (
|
||||
BridgeRequest,
|
||||
BridgeRequestStatus,
|
||||
SupportedToken,
|
||||
ChainConfig,
|
||||
Validator,
|
||||
BridgeTransaction,
|
||||
MerkleProof
|
||||
)
|
||||
from ..schemas.cross_chain_bridge import (
|
||||
BridgeCreateRequest,
|
||||
BridgeResponse,
|
||||
BridgeConfirmRequest,
|
||||
BridgeCompleteRequest,
|
||||
BridgeStatusResponse,
|
||||
TokenSupportRequest,
|
||||
ChainSupportRequest,
|
||||
ValidatorAddRequest
|
||||
)
|
||||
from ..blockchain.contract_interactions import ContractInteractionService
|
||||
from ..crypto.zk_proofs import ZKProofService
|
||||
from ..crypto.merkle_tree import MerkleTreeService
|
||||
from ..monitoring.bridge_monitor import BridgeMonitor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CrossChainBridgeService:
|
||||
"""Secure cross-chain asset transfer protocol"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
contract_service: ContractInteractionService,
|
||||
zk_proof_service: ZKProofService,
|
||||
merkle_tree_service: MerkleTreeService,
|
||||
bridge_monitor: BridgeMonitor
|
||||
) -> None:
|
||||
self.session = session
|
||||
self.contract_service = contract_service
|
||||
self.zk_proof_service = zk_proof_service
|
||||
self.merkle_tree_service = merkle_tree_service
|
||||
self.bridge_monitor = bridge_monitor
|
||||
|
||||
# Configuration
|
||||
self.bridge_fee_percentage = 0.5 # 0.5% bridge fee
|
||||
self.max_bridge_amount = 1000000 # Max 1M tokens per bridge
|
||||
self.min_confirmations = 3
|
||||
self.bridge_timeout = 24 * 60 * 60 # 24 hours
|
||||
self.validator_threshold = 0.67 # 67% of validators required
|
||||
|
||||
async def initiate_transfer(
|
||||
self,
|
||||
transfer_request: BridgeCreateRequest,
|
||||
sender_address: str
|
||||
) -> BridgeResponse:
|
||||
"""Initiate cross-chain asset transfer with ZK proof validation"""
|
||||
|
||||
try:
|
||||
# Validate transfer request
|
||||
validation_result = await self._validate_transfer_request(
|
||||
transfer_request, sender_address
|
||||
)
|
||||
if not validation_result.is_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=validation_result.error_message
|
||||
)
|
||||
|
||||
# Get supported token configuration
|
||||
token_config = await self._get_supported_token(transfer_request.source_token)
|
||||
if not token_config or not token_config.is_active:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Source token not supported for bridging"
|
||||
)
|
||||
|
||||
# Get chain configuration
|
||||
source_chain = await self._get_chain_config(transfer_request.source_chain_id)
|
||||
target_chain = await self._get_chain_config(transfer_request.target_chain_id)
|
||||
|
||||
if not source_chain or not target_chain:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Unsupported blockchain network"
|
||||
)
|
||||
|
||||
# Calculate bridge fee
|
||||
bridge_fee = (transfer_request.amount * self.bridge_fee_percentage) / 100
|
||||
total_amount = transfer_request.amount + bridge_fee
|
||||
|
||||
# Check bridge limits
|
||||
if transfer_request.amount > token_config.bridge_limit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Amount exceeds bridge limit of {token_config.bridge_limit}"
|
||||
)
|
||||
|
||||
# Generate ZK proof for transfer
|
||||
zk_proof = await self._generate_transfer_zk_proof(
|
||||
transfer_request, sender_address
|
||||
)
|
||||
|
||||
# Create bridge request on blockchain
|
||||
contract_request_id = await self.contract_service.initiate_bridge(
|
||||
transfer_request.source_token,
|
||||
transfer_request.target_token,
|
||||
transfer_request.amount,
|
||||
transfer_request.target_chain_id,
|
||||
transfer_request.recipient_address
|
||||
)
|
||||
|
||||
# Create bridge request record
|
||||
bridge_request = BridgeRequest(
|
||||
contract_request_id=str(contract_request_id),
|
||||
sender_address=sender_address,
|
||||
recipient_address=transfer_request.recipient_address,
|
||||
source_token=transfer_request.source_token,
|
||||
target_token=transfer_request.target_token,
|
||||
source_chain_id=transfer_request.source_chain_id,
|
||||
target_chain_id=transfer_request.target_chain_id,
|
||||
amount=transfer_request.amount,
|
||||
bridge_fee=bridge_fee,
|
||||
total_amount=total_amount,
|
||||
status=BridgeRequestStatus.PENDING,
|
||||
zk_proof=zk_proof.proof,
|
||||
created_at=datetime.utcnow(),
|
||||
expires_at=datetime.utcnow() + timedelta(seconds=self.bridge_timeout)
|
||||
)
|
||||
|
||||
self.session.add(bridge_request)
|
||||
self.session.commit()
|
||||
self.session.refresh(bridge_request)
|
||||
|
||||
# Start monitoring the bridge request
|
||||
await self.bridge_monitor.start_monitoring(bridge_request.id)
|
||||
|
||||
logger.info(f"Initiated bridge transfer {bridge_request.id} from {sender_address}")
|
||||
|
||||
return BridgeResponse.from_orm(bridge_request)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error initiating bridge transfer: {str(e)}")
|
||||
self.session.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def monitor_bridge_status(self, request_id: int) -> BridgeStatusResponse:
|
||||
"""Real-time bridge status monitoring across multiple chains"""
|
||||
|
||||
try:
|
||||
# Get bridge request
|
||||
bridge_request = self.session.get(BridgeRequest, request_id)
|
||||
if not bridge_request:
|
||||
raise HTTPException(status_code=404, detail="Bridge request not found")
|
||||
|
||||
# Get current status from blockchain
|
||||
contract_status = await self.contract_service.get_bridge_status(
|
||||
bridge_request.contract_request_id
|
||||
)
|
||||
|
||||
# Update local status if different
|
||||
if contract_status.status != bridge_request.status.value:
|
||||
bridge_request.status = BridgeRequestStatus(contract_status.status)
|
||||
bridge_request.updated_at = datetime.utcnow()
|
||||
self.session.commit()
|
||||
|
||||
# Get confirmation details
|
||||
confirmations = await self._get_bridge_confirmations(request_id)
|
||||
|
||||
# Get transaction details
|
||||
transactions = await self._get_bridge_transactions(request_id)
|
||||
|
||||
# Calculate estimated completion time
|
||||
estimated_completion = await self._calculate_estimated_completion(bridge_request)
|
||||
|
||||
status_response = BridgeStatusResponse(
|
||||
request_id=request_id,
|
||||
status=bridge_request.status,
|
||||
source_chain_id=bridge_request.source_chain_id,
|
||||
target_chain_id=bridge_request.target_chain_id,
|
||||
amount=bridge_request.amount,
|
||||
created_at=bridge_request.created_at,
|
||||
updated_at=bridge_request.updated_at,
|
||||
confirmations=confirmations,
|
||||
transactions=transactions,
|
||||
estimated_completion=estimated_completion
|
||||
)
|
||||
|
||||
return status_response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring bridge status: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def dispute_resolution(self, dispute_data: Dict) -> Dict:
|
||||
"""Automated dispute resolution for failed transfers"""
|
||||
|
||||
try:
|
||||
request_id = dispute_data.get('request_id')
|
||||
dispute_reason = dispute_data.get('reason')
|
||||
|
||||
# Get bridge request
|
||||
bridge_request = self.session.get(BridgeRequest, request_id)
|
||||
if not bridge_request:
|
||||
raise HTTPException(status_code=404, detail="Bridge request not found")
|
||||
|
||||
# Check if dispute is valid
|
||||
if bridge_request.status != BridgeRequestStatus.FAILED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Dispute only available for failed transfers"
|
||||
)
|
||||
|
||||
# Analyze failure reason
|
||||
failure_analysis = await self._analyze_bridge_failure(bridge_request)
|
||||
|
||||
# Determine resolution action
|
||||
resolution_action = await self._determine_resolution_action(
|
||||
bridge_request, failure_analysis
|
||||
)
|
||||
|
||||
# Execute resolution
|
||||
resolution_result = await self._execute_resolution(
|
||||
bridge_request, resolution_action
|
||||
)
|
||||
|
||||
# Record dispute resolution
|
||||
bridge_request.dispute_reason = dispute_reason
|
||||
bridge_request.resolution_action = resolution_action.action_type
|
||||
bridge_request.resolved_at = datetime.utcnow()
|
||||
bridge_request.status = BridgeRequestStatus.RESOLVED
|
||||
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"Resolved dispute for bridge request {request_id}")
|
||||
|
||||
return resolution_result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error resolving dispute: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def confirm_bridge_transfer(
|
||||
self,
|
||||
confirm_request: BridgeConfirmRequest,
|
||||
validator_address: str
|
||||
) -> Dict:
|
||||
"""Confirm bridge transfer by validator"""
|
||||
|
||||
try:
|
||||
# Validate validator
|
||||
validator = await self._get_validator(validator_address)
|
||||
if not validator or not validator.is_active:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Not an active validator"
|
||||
)
|
||||
|
||||
# Get bridge request
|
||||
bridge_request = self.session.get(BridgeRequest, confirm_request.request_id)
|
||||
if not bridge_request:
|
||||
raise HTTPException(status_code=404, detail="Bridge request not found")
|
||||
|
||||
if bridge_request.status != BridgeRequestStatus.PENDING:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Bridge request not in pending status"
|
||||
)
|
||||
|
||||
# Verify validator signature
|
||||
signature_valid = await self._verify_validator_signature(
|
||||
confirm_request, validator_address
|
||||
)
|
||||
if not signature_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid validator signature"
|
||||
)
|
||||
|
||||
# Check if already confirmed by this validator
|
||||
existing_confirmation = self.session.exec(
|
||||
select(BridgeTransaction).where(
|
||||
BridgeTransaction.bridge_request_id == bridge_request.id,
|
||||
BridgeTransaction.validator_address == validator_address,
|
||||
BridgeTransaction.transaction_type == "confirmation"
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_confirmation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Already confirmed by this validator"
|
||||
)
|
||||
|
||||
# Record confirmation
|
||||
confirmation = BridgeTransaction(
|
||||
bridge_request_id=bridge_request.id,
|
||||
validator_address=validator_address,
|
||||
transaction_type="confirmation",
|
||||
transaction_hash=confirm_request.lock_tx_hash,
|
||||
signature=confirm_request.signature,
|
||||
confirmed_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(confirmation)
|
||||
|
||||
# Check if we have enough confirmations
|
||||
total_confirmations = await self._count_confirmations(bridge_request.id)
|
||||
required_confirmations = await self._get_required_confirmations(
|
||||
bridge_request.source_chain_id
|
||||
)
|
||||
|
||||
if total_confirmations >= required_confirmations:
|
||||
# Update bridge request status
|
||||
bridge_request.status = BridgeRequestStatus.CONFIRMED
|
||||
bridge_request.confirmed_at = datetime.utcnow()
|
||||
|
||||
# Generate Merkle proof for completion
|
||||
merkle_proof = await self._generate_merkle_proof(bridge_request)
|
||||
bridge_request.merkle_proof = merkle_proof.proof_hash
|
||||
|
||||
logger.info(f"Bridge request {bridge_request.id} confirmed by validators")
|
||||
|
||||
self.session.commit()
|
||||
|
||||
return {
|
||||
"request_id": bridge_request.id,
|
||||
"confirmations": total_confirmations,
|
||||
"required": required_confirmations,
|
||||
"status": bridge_request.status.value
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error confirming bridge transfer: {str(e)}")
|
||||
self.session.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def complete_bridge_transfer(
|
||||
self,
|
||||
complete_request: BridgeCompleteRequest,
|
||||
executor_address: str
|
||||
) -> Dict:
|
||||
"""Complete bridge transfer on target chain"""
|
||||
|
||||
try:
|
||||
# Get bridge request
|
||||
bridge_request = self.session.get(BridgeRequest, complete_request.request_id)
|
||||
if not bridge_request:
|
||||
raise HTTPException(status_code=404, detail="Bridge request not found")
|
||||
|
||||
if bridge_request.status != BridgeRequestStatus.CONFIRMED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Bridge request not confirmed"
|
||||
)
|
||||
|
||||
# Verify Merkle proof
|
||||
proof_valid = await self._verify_merkle_proof(
|
||||
complete_request.merkle_proof, bridge_request
|
||||
)
|
||||
if not proof_valid:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid Merkle proof"
|
||||
)
|
||||
|
||||
# Complete bridge on blockchain
|
||||
completion_result = await self.contract_service.complete_bridge(
|
||||
bridge_request.contract_request_id,
|
||||
complete_request.unlock_tx_hash,
|
||||
complete_request.merkle_proof
|
||||
)
|
||||
|
||||
# Record completion transaction
|
||||
completion = BridgeTransaction(
|
||||
bridge_request_id=bridge_request.id,
|
||||
validator_address=executor_address,
|
||||
transaction_type="completion",
|
||||
transaction_hash=complete_request.unlock_tx_hash,
|
||||
merkle_proof=complete_request.merkle_proof,
|
||||
completed_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(completion)
|
||||
|
||||
# Update bridge request status
|
||||
bridge_request.status = BridgeRequestStatus.COMPLETED
|
||||
bridge_request.completed_at = datetime.utcnow()
|
||||
bridge_request.unlock_tx_hash = complete_request.unlock_tx_hash
|
||||
|
||||
self.session.commit()
|
||||
|
||||
# Stop monitoring
|
||||
await self.bridge_monitor.stop_monitoring(bridge_request.id)
|
||||
|
||||
logger.info(f"Completed bridge transfer {bridge_request.id}")
|
||||
|
||||
return {
|
||||
"request_id": bridge_request.id,
|
||||
"status": "completed",
|
||||
"unlock_tx_hash": complete_request.unlock_tx_hash,
|
||||
"completed_at": bridge_request.completed_at
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error completing bridge transfer: {str(e)}")
|
||||
self.session.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def add_supported_token(self, token_request: TokenSupportRequest) -> Dict:
|
||||
"""Add support for new token"""
|
||||
|
||||
try:
|
||||
# Check if token already supported
|
||||
existing_token = await self._get_supported_token(token_request.token_address)
|
||||
if existing_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Token already supported"
|
||||
)
|
||||
|
||||
# Create supported token record
|
||||
supported_token = SupportedToken(
|
||||
token_address=token_request.token_address,
|
||||
token_symbol=token_request.token_symbol,
|
||||
bridge_limit=token_request.bridge_limit,
|
||||
fee_percentage=token_request.fee_percentage,
|
||||
requires_whitelist=token_request.requires_whitelist,
|
||||
is_active=True,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(supported_token)
|
||||
self.session.commit()
|
||||
self.session.refresh(supported_token)
|
||||
|
||||
logger.info(f"Added supported token {token_request.token_symbol}")
|
||||
|
||||
return {"token_id": supported_token.id, "status": "supported"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding supported token: {str(e)}")
|
||||
self.session.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def add_supported_chain(self, chain_request: ChainSupportRequest) -> Dict:
|
||||
"""Add support for new blockchain"""
|
||||
|
||||
try:
|
||||
# Check if chain already supported
|
||||
existing_chain = await self._get_chain_config(chain_request.chain_id)
|
||||
if existing_chain:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Chain already supported"
|
||||
)
|
||||
|
||||
# Create chain configuration
|
||||
chain_config = ChainConfig(
|
||||
chain_id=chain_request.chain_id,
|
||||
chain_name=chain_request.chain_name,
|
||||
chain_type=chain_request.chain_type,
|
||||
bridge_contract_address=chain_request.bridge_contract_address,
|
||||
min_confirmations=chain_request.min_confirmations,
|
||||
avg_block_time=chain_request.avg_block_time,
|
||||
is_active=True,
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.session.add(chain_config)
|
||||
self.session.commit()
|
||||
self.session.refresh(chain_config)
|
||||
|
||||
logger.info(f"Added supported chain {chain_request.chain_name}")
|
||||
|
||||
return {"chain_id": chain_config.id, "status": "supported"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding supported chain: {str(e)}")
|
||||
self.session.rollback()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Private helper methods
|
||||
|
||||
async def _validate_transfer_request(
|
||||
self,
|
||||
transfer_request: BridgeCreateRequest,
|
||||
sender_address: str
|
||||
) -> ValidationResult:
|
||||
"""Validate bridge transfer request"""
|
||||
|
||||
# Check addresses
|
||||
if not self._is_valid_address(sender_address):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Invalid sender address"
|
||||
)
|
||||
|
||||
if not self._is_valid_address(transfer_request.recipient_address):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Invalid recipient address"
|
||||
)
|
||||
|
||||
# Check amount
|
||||
if transfer_request.amount <= 0:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Amount must be greater than 0"
|
||||
)
|
||||
|
||||
if transfer_request.amount > self.max_bridge_amount:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Amount exceeds maximum bridge limit of {self.max_bridge_amount}"
|
||||
)
|
||||
|
||||
# Check chains
|
||||
if transfer_request.source_chain_id == transfer_request.target_chain_id:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Source and target chains must be different"
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
def _is_valid_address(self, address: str) -> bool:
|
||||
"""Validate blockchain address"""
|
||||
return (
|
||||
address.startswith("0x") and
|
||||
len(address) == 42 and
|
||||
all(c in "0123456789abcdefABCDEF" for c in address[2:])
|
||||
)
|
||||
|
||||
async def _get_supported_token(self, token_address: str) -> Optional[SupportedToken]:
|
||||
"""Get supported token configuration"""
|
||||
return self.session.exec(
|
||||
select(SupportedToken).where(
|
||||
SupportedToken.token_address == token_address
|
||||
)
|
||||
).first()
|
||||
|
||||
async def _get_chain_config(self, chain_id: int) -> Optional[ChainConfig]:
|
||||
"""Get chain configuration"""
|
||||
return self.session.exec(
|
||||
select(ChainConfig).where(
|
||||
ChainConfig.chain_id == chain_id
|
||||
)
|
||||
).first()
|
||||
|
||||
async def _generate_transfer_zk_proof(
|
||||
self,
|
||||
transfer_request: BridgeCreateRequest,
|
||||
sender_address: str
|
||||
) -> Dict:
|
||||
"""Generate ZK proof for transfer"""
|
||||
|
||||
# Create proof inputs
|
||||
proof_inputs = {
|
||||
"sender": sender_address,
|
||||
"recipient": transfer_request.recipient_address,
|
||||
"amount": transfer_request.amount,
|
||||
"source_chain": transfer_request.source_chain_id,
|
||||
"target_chain": transfer_request.target_chain_id,
|
||||
"timestamp": int(datetime.utcnow().timestamp())
|
||||
}
|
||||
|
||||
# Generate ZK proof
|
||||
zk_proof = await self.zk_proof_service.generate_proof(
|
||||
"bridge_transfer",
|
||||
proof_inputs
|
||||
)
|
||||
|
||||
return zk_proof
|
||||
|
||||
async def _get_bridge_confirmations(self, request_id: int) -> List[Dict]:
|
||||
"""Get bridge confirmations"""
|
||||
|
||||
confirmations = self.session.exec(
|
||||
select(BridgeTransaction).where(
|
||||
BridgeTransaction.bridge_request_id == request_id,
|
||||
BridgeTransaction.transaction_type == "confirmation"
|
||||
)
|
||||
).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"validator_address": conf.validator_address,
|
||||
"transaction_hash": conf.transaction_hash,
|
||||
"confirmed_at": conf.confirmed_at
|
||||
}
|
||||
for conf in confirmations
|
||||
]
|
||||
|
||||
async def _get_bridge_transactions(self, request_id: int) -> List[Dict]:
|
||||
"""Get all bridge transactions"""
|
||||
|
||||
transactions = self.session.exec(
|
||||
select(BridgeTransaction).where(
|
||||
BridgeTransaction.bridge_request_id == request_id
|
||||
)
|
||||
).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"transaction_type": tx.transaction_type,
|
||||
"validator_address": tx.validator_address,
|
||||
"transaction_hash": tx.transaction_hash,
|
||||
"created_at": tx.created_at
|
||||
}
|
||||
for tx in transactions
|
||||
]
|
||||
|
||||
async def _calculate_estimated_completion(
|
||||
self,
|
||||
bridge_request: BridgeRequest
|
||||
) -> Optional[datetime]:
|
||||
"""Calculate estimated completion time"""
|
||||
|
||||
if bridge_request.status in [BridgeRequestStatus.COMPLETED, BridgeRequestStatus.FAILED]:
|
||||
return None
|
||||
|
||||
# Get chain configuration
|
||||
source_chain = await self._get_chain_config(bridge_request.source_chain_id)
|
||||
target_chain = await self._get_chain_config(bridge_request.target_chain_id)
|
||||
|
||||
if not source_chain or not target_chain:
|
||||
return None
|
||||
|
||||
# Estimate based on block times and confirmations
|
||||
source_confirmation_time = source_chain.avg_block_time * source_chain.min_confirmations
|
||||
target_confirmation_time = target_chain.avg_block_time * target_chain.min_confirmations
|
||||
|
||||
total_estimated_time = source_confirmation_time + target_confirmation_time + 300 # 5 min buffer
|
||||
|
||||
return bridge_request.created_at + timedelta(seconds=total_estimated_time)
|
||||
|
||||
async def _analyze_bridge_failure(self, bridge_request: BridgeRequest) -> Dict:
|
||||
"""Analyze bridge failure reason"""
|
||||
|
||||
# This would integrate with monitoring and analytics
|
||||
# For now, return basic analysis
|
||||
return {
|
||||
"failure_type": "timeout",
|
||||
"failure_reason": "Bridge request expired",
|
||||
"recoverable": True
|
||||
}
|
||||
|
||||
async def _determine_resolution_action(
|
||||
self,
|
||||
bridge_request: BridgeRequest,
|
||||
failure_analysis: Dict
|
||||
) -> Dict:
|
||||
"""Determine resolution action for failed bridge"""
|
||||
|
||||
if failure_analysis.get("recoverable", False):
|
||||
return {
|
||||
"action_type": "refund",
|
||||
"refund_amount": bridge_request.total_amount,
|
||||
"refund_to": bridge_request.sender_address
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"action_type": "manual_review",
|
||||
"escalate_to": "support_team"
|
||||
}
|
||||
|
||||
async def _execute_resolution(
|
||||
self,
|
||||
bridge_request: BridgeRequest,
|
||||
resolution_action: Dict
|
||||
) -> Dict:
|
||||
"""Execute resolution action"""
|
||||
|
||||
if resolution_action["action_type"] == "refund":
|
||||
# Process refund on blockchain
|
||||
refund_result = await self.contract_service.process_bridge_refund(
|
||||
bridge_request.contract_request_id,
|
||||
resolution_action["refund_amount"],
|
||||
resolution_action["refund_to"]
|
||||
)
|
||||
|
||||
return {
|
||||
"resolution_type": "refund_processed",
|
||||
"refund_tx_hash": refund_result.transaction_hash,
|
||||
"refund_amount": resolution_action["refund_amount"]
|
||||
}
|
||||
|
||||
return {"resolution_type": "escalated"}
|
||||
|
||||
async def _get_validator(self, validator_address: str) -> Optional[Validator]:
|
||||
"""Get validator information"""
|
||||
return self.session.exec(
|
||||
select(Validator).where(
|
||||
Validator.validator_address == validator_address
|
||||
)
|
||||
).first()
|
||||
|
||||
async def _verify_validator_signature(
|
||||
self,
|
||||
confirm_request: BridgeConfirmRequest,
|
||||
validator_address: str
|
||||
) -> bool:
|
||||
"""Verify validator signature"""
|
||||
|
||||
# This would implement proper signature verification
|
||||
# For now, return True for demonstration
|
||||
return True
|
||||
|
||||
async def _count_confirmations(self, request_id: int) -> int:
|
||||
"""Count confirmations for bridge request"""
|
||||
|
||||
confirmations = self.session.exec(
|
||||
select(BridgeTransaction).where(
|
||||
BridgeTransaction.bridge_request_id == request_id,
|
||||
BridgeTransaction.transaction_type == "confirmation"
|
||||
)
|
||||
).all()
|
||||
|
||||
return len(confirmations)
|
||||
|
||||
async def _get_required_confirmations(self, chain_id: int) -> int:
|
||||
"""Get required confirmations for chain"""
|
||||
|
||||
chain_config = await self._get_chain_config(chain_id)
|
||||
return chain_config.min_confirmations if chain_config else self.min_confirmations
|
||||
|
||||
async def _generate_merkle_proof(self, bridge_request: BridgeRequest) -> MerkleProof:
|
||||
"""Generate Merkle proof for bridge completion"""
|
||||
|
||||
# Create leaf data
|
||||
leaf_data = {
|
||||
"request_id": bridge_request.id,
|
||||
"sender": bridge_request.sender_address,
|
||||
"recipient": bridge_request.recipient_address,
|
||||
"amount": bridge_request.amount,
|
||||
"target_chain": bridge_request.target_chain_id
|
||||
}
|
||||
|
||||
# Generate Merkle proof
|
||||
merkle_proof = await self.merkle_tree_service.generate_proof(leaf_data)
|
||||
|
||||
return merkle_proof
|
||||
|
||||
async def _verify_merkle_proof(
|
||||
self,
|
||||
merkle_proof: List[str],
|
||||
bridge_request: BridgeRequest
|
||||
) -> bool:
|
||||
"""Verify Merkle proof"""
|
||||
|
||||
# Recreate leaf data
|
||||
leaf_data = {
|
||||
"request_id": bridge_request.id,
|
||||
"sender": bridge_request.sender_address,
|
||||
"recipient": bridge_request.recipient_address,
|
||||
"amount": bridge_request.amount,
|
||||
"target_chain": bridge_request.target_chain_id
|
||||
}
|
||||
|
||||
# Verify proof
|
||||
return await self.merkle_tree_service.verify_proof(leaf_data, merkle_proof)
|
||||
|
||||
|
||||
class ValidationResult:
|
||||
"""Validation result for requests"""
|
||||
|
||||
def __init__(self, is_valid: bool, error_message: str = ""):
|
||||
self.is_valid = is_valid
|
||||
self.error_message = error_message
|
||||
872
apps/coordinator-api/src/app/services/dynamic_pricing_engine.py
Normal file
872
apps/coordinator-api/src/app/services/dynamic_pricing_engine.py
Normal file
@@ -0,0 +1,872 @@
|
||||
"""
|
||||
Dynamic Pricing Engine for AITBC Marketplace
|
||||
Implements sophisticated pricing algorithms based on real-time market conditions
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import json
|
||||
from aitbc.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PricingStrategy(str, Enum):
|
||||
"""Dynamic pricing strategy types"""
|
||||
AGGRESSIVE_GROWTH = "aggressive_growth"
|
||||
PROFIT_MAXIMIZATION = "profit_maximization"
|
||||
MARKET_BALANCE = "market_balance"
|
||||
COMPETITIVE_RESPONSE = "competitive_response"
|
||||
DEMAND_ELASTICITY = "demand_elasticity"
|
||||
|
||||
|
||||
class ResourceType(str, Enum):
|
||||
"""Resource types for pricing"""
|
||||
GPU = "gpu"
|
||||
SERVICE = "service"
|
||||
STORAGE = "storage"
|
||||
|
||||
|
||||
class PriceTrend(str, Enum):
|
||||
"""Price trend indicators"""
|
||||
INCREASING = "increasing"
|
||||
DECREASING = "decreasing"
|
||||
STABLE = "stable"
|
||||
VOLATILE = "volatile"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PricingFactors:
|
||||
"""Factors that influence dynamic pricing"""
|
||||
base_price: float
|
||||
demand_multiplier: float = 1.0 # 0.5 - 3.0
|
||||
supply_multiplier: float = 1.0 # 0.8 - 2.5
|
||||
time_multiplier: float = 1.0 # 0.7 - 1.5
|
||||
performance_multiplier: float = 1.0 # 0.9 - 1.3
|
||||
competition_multiplier: float = 1.0 # 0.8 - 1.4
|
||||
sentiment_multiplier: float = 1.0 # 0.9 - 1.2
|
||||
regional_multiplier: float = 1.0 # 0.8 - 1.3
|
||||
|
||||
# Confidence and risk factors
|
||||
confidence_score: float = 0.8
|
||||
risk_adjustment: float = 0.0
|
||||
|
||||
# Market conditions
|
||||
demand_level: float = 0.5
|
||||
supply_level: float = 0.5
|
||||
market_volatility: float = 0.1
|
||||
|
||||
# Provider-specific factors
|
||||
provider_reputation: float = 1.0
|
||||
utilization_rate: float = 0.5
|
||||
historical_performance: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriceConstraints:
|
||||
"""Constraints for pricing calculations"""
|
||||
min_price: Optional[float] = None
|
||||
max_price: Optional[float] = None
|
||||
max_change_percent: float = 0.5 # Maximum 50% change per update
|
||||
min_change_interval: int = 300 # Minimum 5 minutes between changes
|
||||
strategy_lock_period: int = 3600 # 1 hour strategy lock
|
||||
|
||||
|
||||
@dataclass
|
||||
class PricePoint:
|
||||
"""Single price point in time series"""
|
||||
timestamp: datetime
|
||||
price: float
|
||||
demand_level: float
|
||||
supply_level: float
|
||||
confidence: float
|
||||
strategy_used: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketConditions:
|
||||
"""Current market conditions snapshot"""
|
||||
region: str
|
||||
resource_type: ResourceType
|
||||
demand_level: float
|
||||
supply_level: float
|
||||
average_price: float
|
||||
price_volatility: float
|
||||
utilization_rate: float
|
||||
competitor_prices: List[float] = field(default_factory=list)
|
||||
market_sentiment: float = 0.0 # -1 to 1
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PricingResult:
|
||||
"""Result of dynamic pricing calculation"""
|
||||
resource_id: str
|
||||
resource_type: ResourceType
|
||||
current_price: float
|
||||
recommended_price: float
|
||||
price_trend: PriceTrend
|
||||
confidence_score: float
|
||||
factors_exposed: Dict[str, float]
|
||||
reasoning: List[str]
|
||||
next_update: datetime
|
||||
strategy_used: PricingStrategy
|
||||
|
||||
|
||||
class DynamicPricingEngine:
|
||||
"""Core dynamic pricing engine with advanced algorithms"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.pricing_history: Dict[str, List[PricePoint]] = {}
|
||||
self.market_conditions_cache: Dict[str, MarketConditions] = {}
|
||||
self.provider_strategies: Dict[str, PricingStrategy] = {}
|
||||
self.price_constraints: Dict[str, PriceConstraints] = {}
|
||||
|
||||
# Strategy configuration
|
||||
self.strategy_configs = {
|
||||
PricingStrategy.AGGRESSIVE_GROWTH: {
|
||||
"base_multiplier": 0.85,
|
||||
"demand_sensitivity": 0.3,
|
||||
"competition_weight": 0.4,
|
||||
"growth_priority": 0.8
|
||||
},
|
||||
PricingStrategy.PROFIT_MAXIMIZATION: {
|
||||
"base_multiplier": 1.25,
|
||||
"demand_sensitivity": 0.7,
|
||||
"competition_weight": 0.2,
|
||||
"growth_priority": 0.2
|
||||
},
|
||||
PricingStrategy.MARKET_BALANCE: {
|
||||
"base_multiplier": 1.0,
|
||||
"demand_sensitivity": 0.5,
|
||||
"competition_weight": 0.3,
|
||||
"growth_priority": 0.5
|
||||
},
|
||||
PricingStrategy.COMPETITIVE_RESPONSE: {
|
||||
"base_multiplier": 0.95,
|
||||
"demand_sensitivity": 0.4,
|
||||
"competition_weight": 0.6,
|
||||
"growth_priority": 0.4
|
||||
},
|
||||
PricingStrategy.DEMAND_ELASTICITY: {
|
||||
"base_multiplier": 1.0,
|
||||
"demand_sensitivity": 0.8,
|
||||
"competition_weight": 0.3,
|
||||
"growth_priority": 0.6
|
||||
}
|
||||
}
|
||||
|
||||
# Pricing parameters
|
||||
self.min_price = config.get("min_price", 0.001)
|
||||
self.max_price = config.get("max_price", 1000.0)
|
||||
self.update_interval = config.get("update_interval", 300) # 5 minutes
|
||||
self.forecast_horizon = config.get("forecast_horizon", 72) # 72 hours
|
||||
|
||||
# Risk management
|
||||
self.max_volatility_threshold = config.get("max_volatility_threshold", 0.3)
|
||||
self.circuit_breaker_threshold = config.get("circuit_breaker_threshold", 0.5)
|
||||
self.circuit_breakers: Dict[str, bool] = {}
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the dynamic pricing engine"""
|
||||
logger.info("Initializing Dynamic Pricing Engine")
|
||||
|
||||
# Load historical pricing data
|
||||
await self._load_pricing_history()
|
||||
|
||||
# Load provider strategies
|
||||
await self._load_provider_strategies()
|
||||
|
||||
# Start background tasks
|
||||
asyncio.create_task(self._update_market_conditions())
|
||||
asyncio.create_task(self._monitor_price_volatility())
|
||||
asyncio.create_task(self._optimize_strategies())
|
||||
|
||||
logger.info("Dynamic Pricing Engine initialized")
|
||||
|
||||
async def calculate_dynamic_price(
|
||||
self,
|
||||
resource_id: str,
|
||||
resource_type: ResourceType,
|
||||
base_price: float,
|
||||
strategy: Optional[PricingStrategy] = None,
|
||||
constraints: Optional[PriceConstraints] = None,
|
||||
region: str = "global"
|
||||
) -> PricingResult:
|
||||
"""Calculate dynamic price for a resource"""
|
||||
|
||||
try:
|
||||
# Get or determine strategy
|
||||
if strategy is None:
|
||||
strategy = self.provider_strategies.get(resource_id, PricingStrategy.MARKET_BALANCE)
|
||||
|
||||
# Get current market conditions
|
||||
market_conditions = await self._get_market_conditions(resource_type, region)
|
||||
|
||||
# Calculate pricing factors
|
||||
factors = await self._calculate_pricing_factors(
|
||||
resource_id, resource_type, base_price, strategy, market_conditions
|
||||
)
|
||||
|
||||
# Apply strategy-specific calculations
|
||||
strategy_price = await self._apply_strategy_pricing(
|
||||
base_price, factors, strategy, market_conditions
|
||||
)
|
||||
|
||||
# Apply constraints and risk management
|
||||
final_price = await self._apply_constraints_and_risk(
|
||||
resource_id, strategy_price, constraints, factors
|
||||
)
|
||||
|
||||
# Determine price trend
|
||||
price_trend = await self._determine_price_trend(resource_id, final_price)
|
||||
|
||||
# Generate reasoning
|
||||
reasoning = await self._generate_pricing_reasoning(
|
||||
factors, strategy, market_conditions, price_trend
|
||||
)
|
||||
|
||||
# Calculate confidence score
|
||||
confidence = await self._calculate_confidence_score(factors, market_conditions)
|
||||
|
||||
# Schedule next update
|
||||
next_update = datetime.utcnow() + timedelta(seconds=self.update_interval)
|
||||
|
||||
# Store price point
|
||||
await self._store_price_point(resource_id, final_price, factors, strategy)
|
||||
|
||||
# Create result
|
||||
result = PricingResult(
|
||||
resource_id=resource_id,
|
||||
resource_type=resource_type,
|
||||
current_price=base_price,
|
||||
recommended_price=final_price,
|
||||
price_trend=price_trend,
|
||||
confidence_score=confidence,
|
||||
factors_exposed=asdict(factors),
|
||||
reasoning=reasoning,
|
||||
next_update=next_update,
|
||||
strategy_used=strategy
|
||||
)
|
||||
|
||||
logger.info(f"Calculated dynamic price for {resource_id}: {final_price:.6f} (was {base_price:.6f})")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to calculate dynamic price for {resource_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_price_forecast(
|
||||
self,
|
||||
resource_id: str,
|
||||
hours_ahead: int = 24
|
||||
) -> List[PricePoint]:
|
||||
"""Generate price forecast for the specified horizon"""
|
||||
|
||||
try:
|
||||
if resource_id not in self.pricing_history:
|
||||
return []
|
||||
|
||||
historical_data = self.pricing_history[resource_id]
|
||||
if len(historical_data) < 24: # Need at least 24 data points
|
||||
return []
|
||||
|
||||
# Extract price series
|
||||
prices = [point.price for point in historical_data[-48:]] # Last 48 points
|
||||
demand_levels = [point.demand_level for point in historical_data[-48:]]
|
||||
supply_levels = [point.supply_level for point in historical_data[-48:]]
|
||||
|
||||
# Generate forecast using time series analysis
|
||||
forecast_points = []
|
||||
|
||||
for hour in range(1, hours_ahead + 1):
|
||||
# Simple linear trend with seasonal adjustment
|
||||
price_trend = self._calculate_price_trend(prices[-12:]) # Last 12 points
|
||||
seasonal_factor = self._calculate_seasonal_factor(hour)
|
||||
demand_forecast = self._forecast_demand_level(demand_levels, hour)
|
||||
supply_forecast = self._forecast_supply_level(supply_levels, hour)
|
||||
|
||||
# Calculate forecasted price
|
||||
base_forecast = prices[-1] + (price_trend * hour)
|
||||
seasonal_adjusted = base_forecast * seasonal_factor
|
||||
demand_adjusted = seasonal_adjusted * (1 + (demand_forecast - 0.5) * 0.3)
|
||||
supply_adjusted = demand_adjusted * (1 + (0.5 - supply_forecast) * 0.2)
|
||||
|
||||
forecast_price = max(self.min_price, min(supply_adjusted, self.max_price))
|
||||
|
||||
# Calculate confidence (decreases with time)
|
||||
confidence = max(0.3, 0.9 - (hour / hours_ahead) * 0.6)
|
||||
|
||||
forecast_point = PricePoint(
|
||||
timestamp=datetime.utcnow() + timedelta(hours=hour),
|
||||
price=forecast_price,
|
||||
demand_level=demand_forecast,
|
||||
supply_level=supply_forecast,
|
||||
confidence=confidence,
|
||||
strategy_used="forecast"
|
||||
)
|
||||
|
||||
forecast_points.append(forecast_point)
|
||||
|
||||
return forecast_points
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate price forecast for {resource_id}: {e}")
|
||||
return []
|
||||
|
||||
async def set_provider_strategy(
|
||||
self,
|
||||
provider_id: str,
|
||||
strategy: PricingStrategy,
|
||||
constraints: Optional[PriceConstraints] = None
|
||||
) -> bool:
|
||||
"""Set pricing strategy for a provider"""
|
||||
|
||||
try:
|
||||
self.provider_strategies[provider_id] = strategy
|
||||
if constraints:
|
||||
self.price_constraints[provider_id] = constraints
|
||||
|
||||
logger.info(f"Set strategy {strategy.value} for provider {provider_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set strategy for provider {provider_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _calculate_pricing_factors(
|
||||
self,
|
||||
resource_id: str,
|
||||
resource_type: ResourceType,
|
||||
base_price: float,
|
||||
strategy: PricingStrategy,
|
||||
market_conditions: MarketConditions
|
||||
) -> PricingFactors:
|
||||
"""Calculate all pricing factors"""
|
||||
|
||||
factors = PricingFactors(base_price=base_price)
|
||||
|
||||
# Demand multiplier based on market conditions
|
||||
factors.demand_multiplier = self._calculate_demand_multiplier(
|
||||
market_conditions.demand_level, strategy
|
||||
)
|
||||
|
||||
# Supply multiplier based on availability
|
||||
factors.supply_multiplier = self._calculate_supply_multiplier(
|
||||
market_conditions.supply_level, strategy
|
||||
)
|
||||
|
||||
# Time-based multiplier (peak/off-peak)
|
||||
factors.time_multiplier = self._calculate_time_multiplier()
|
||||
|
||||
# Performance multiplier based on provider history
|
||||
factors.performance_multiplier = await self._calculate_performance_multiplier(resource_id)
|
||||
|
||||
# Competition multiplier based on competitor prices
|
||||
factors.competition_multiplier = self._calculate_competition_multiplier(
|
||||
base_price, market_conditions.competitor_prices, strategy
|
||||
)
|
||||
|
||||
# Market sentiment multiplier
|
||||
factors.sentiment_multiplier = self._calculate_sentiment_multiplier(
|
||||
market_conditions.market_sentiment
|
||||
)
|
||||
|
||||
# Regional multiplier
|
||||
factors.regional_multiplier = self._calculate_regional_multiplier(
|
||||
market_conditions.region, resource_type
|
||||
)
|
||||
|
||||
# Update market condition fields
|
||||
factors.demand_level = market_conditions.demand_level
|
||||
factors.supply_level = market_conditions.supply_level
|
||||
factors.market_volatility = market_conditions.price_volatility
|
||||
|
||||
return factors
|
||||
|
||||
async def _apply_strategy_pricing(
|
||||
self,
|
||||
base_price: float,
|
||||
factors: PricingFactors,
|
||||
strategy: PricingStrategy,
|
||||
market_conditions: MarketConditions
|
||||
) -> float:
|
||||
"""Apply strategy-specific pricing logic"""
|
||||
|
||||
config = self.strategy_configs[strategy]
|
||||
price = base_price
|
||||
|
||||
# Apply base strategy multiplier
|
||||
price *= config["base_multiplier"]
|
||||
|
||||
# Apply demand sensitivity
|
||||
demand_adjustment = (factors.demand_level - 0.5) * config["demand_sensitivity"]
|
||||
price *= (1 + demand_adjustment)
|
||||
|
||||
# Apply competition adjustment
|
||||
if market_conditions.competitor_prices:
|
||||
avg_competitor_price = np.mean(market_conditions.competitor_prices)
|
||||
competition_ratio = avg_competitor_price / base_price
|
||||
competition_adjustment = (competition_ratio - 1) * config["competition_weight"]
|
||||
price *= (1 + competition_adjustment)
|
||||
|
||||
# Apply individual multipliers
|
||||
price *= factors.time_multiplier
|
||||
price *= factors.performance_multiplier
|
||||
price *= factors.sentiment_multiplier
|
||||
price *= factors.regional_multiplier
|
||||
|
||||
# Apply growth priority adjustment
|
||||
if config["growth_priority"] > 0.5:
|
||||
price *= (1 - (config["growth_priority"] - 0.5) * 0.2) # Discount for growth
|
||||
|
||||
return max(price, self.min_price)
|
||||
|
||||
async def _apply_constraints_and_risk(
|
||||
self,
|
||||
resource_id: str,
|
||||
price: float,
|
||||
constraints: Optional[PriceConstraints],
|
||||
factors: PricingFactors
|
||||
) -> float:
|
||||
"""Apply pricing constraints and risk management"""
|
||||
|
||||
# Check if circuit breaker is active
|
||||
if self.circuit_breakers.get(resource_id, False):
|
||||
logger.warning(f"Circuit breaker active for {resource_id}, using last price")
|
||||
if resource_id in self.pricing_history and self.pricing_history[resource_id]:
|
||||
return self.pricing_history[resource_id][-1].price
|
||||
|
||||
# Apply provider-specific constraints
|
||||
if constraints:
|
||||
if constraints.min_price:
|
||||
price = max(price, constraints.min_price)
|
||||
if constraints.max_price:
|
||||
price = min(price, constraints.max_price)
|
||||
|
||||
# Apply global constraints
|
||||
price = max(price, self.min_price)
|
||||
price = min(price, self.max_price)
|
||||
|
||||
# Apply maximum change constraint
|
||||
if resource_id in self.pricing_history and self.pricing_history[resource_id]:
|
||||
last_price = self.pricing_history[resource_id][-1].price
|
||||
max_change = last_price * 0.5 # 50% max change
|
||||
if abs(price - last_price) > max_change:
|
||||
price = last_price + (max_change if price > last_price else -max_change)
|
||||
logger.info(f"Applied max change constraint for {resource_id}")
|
||||
|
||||
# Check for high volatility and trigger circuit breaker if needed
|
||||
if factors.market_volatility > self.circuit_breaker_threshold:
|
||||
self.circuit_breakers[resource_id] = True
|
||||
logger.warning(f"Triggered circuit breaker for {resource_id} due to high volatility")
|
||||
# Schedule circuit breaker reset
|
||||
asyncio.create_task(self._reset_circuit_breaker(resource_id, 3600)) # 1 hour
|
||||
|
||||
return price
|
||||
|
||||
def _calculate_demand_multiplier(self, demand_level: float, strategy: PricingStrategy) -> float:
|
||||
"""Calculate demand-based price multiplier"""
|
||||
|
||||
# Base demand curve
|
||||
if demand_level > 0.8:
|
||||
base_multiplier = 1.0 + (demand_level - 0.8) * 2.5 # High demand
|
||||
elif demand_level > 0.5:
|
||||
base_multiplier = 1.0 + (demand_level - 0.5) * 0.5 # Normal demand
|
||||
else:
|
||||
base_multiplier = 0.8 + (demand_level * 0.4) # Low demand
|
||||
|
||||
# Strategy adjustment
|
||||
if strategy == PricingStrategy.AGGRESSIVE_GROWTH:
|
||||
return base_multiplier * 0.9 # Discount for growth
|
||||
elif strategy == PricingStrategy.PROFIT_MAXIMIZATION:
|
||||
return base_multiplier * 1.3 # Premium for profit
|
||||
else:
|
||||
return base_multiplier
|
||||
|
||||
def _calculate_supply_multiplier(self, supply_level: float, strategy: PricingStrategy) -> float:
|
||||
"""Calculate supply-based price multiplier"""
|
||||
|
||||
# Inverse supply curve (low supply = higher prices)
|
||||
if supply_level < 0.3:
|
||||
base_multiplier = 1.0 + (0.3 - supply_level) * 1.5 # Low supply
|
||||
elif supply_level < 0.7:
|
||||
base_multiplier = 1.0 - (supply_level - 0.3) * 0.3 # Normal supply
|
||||
else:
|
||||
base_multiplier = 0.9 - (supply_level - 0.7) * 0.3 # High supply
|
||||
|
||||
return max(0.5, min(2.0, base_multiplier))
|
||||
|
||||
def _calculate_time_multiplier(self) -> float:
|
||||
"""Calculate time-based price multiplier"""
|
||||
|
||||
hour = datetime.utcnow().hour
|
||||
day_of_week = datetime.utcnow().weekday()
|
||||
|
||||
# Business hours premium (8 AM - 8 PM, Monday-Friday)
|
||||
if 8 <= hour <= 20 and day_of_week < 5:
|
||||
return 1.2
|
||||
# Evening premium (8 PM - 12 AM)
|
||||
elif 20 <= hour <= 24 or 0 <= hour <= 2:
|
||||
return 1.1
|
||||
# Late night discount (2 AM - 6 AM)
|
||||
elif 2 <= hour <= 6:
|
||||
return 0.8
|
||||
# Weekend premium
|
||||
elif day_of_week >= 5:
|
||||
return 1.15
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
async def _calculate_performance_multiplier(self, resource_id: str) -> float:
|
||||
"""Calculate performance-based multiplier"""
|
||||
|
||||
# In a real implementation, this would fetch from performance metrics
|
||||
# For now, return a default based on historical data
|
||||
if resource_id in self.pricing_history and len(self.pricing_history[resource_id]) > 10:
|
||||
# Simple performance calculation based on consistency
|
||||
recent_prices = [p.price for p in self.pricing_history[resource_id][-10:]]
|
||||
price_variance = np.var(recent_prices)
|
||||
avg_price = np.mean(recent_prices)
|
||||
|
||||
# Lower variance = higher performance multiplier
|
||||
if price_variance < (avg_price * 0.01):
|
||||
return 1.1 # High consistency
|
||||
elif price_variance < (avg_price * 0.05):
|
||||
return 1.05 # Good consistency
|
||||
else:
|
||||
return 0.95 # Low consistency
|
||||
else:
|
||||
return 1.0 # Default for new resources
|
||||
|
||||
def _calculate_competition_multiplier(
|
||||
self,
|
||||
base_price: float,
|
||||
competitor_prices: List[float],
|
||||
strategy: PricingStrategy
|
||||
) -> float:
|
||||
"""Calculate competition-based multiplier"""
|
||||
|
||||
if not competitor_prices:
|
||||
return 1.0
|
||||
|
||||
avg_competitor_price = np.mean(competitor_prices)
|
||||
price_ratio = base_price / avg_competitor_price
|
||||
|
||||
# Strategy-specific competition response
|
||||
if strategy == PricingStrategy.COMPETITIVE_RESPONSE:
|
||||
if price_ratio > 1.1: # We're more expensive
|
||||
return 0.9 # Discount to compete
|
||||
elif price_ratio < 0.9: # We're cheaper
|
||||
return 1.05 # Slight premium
|
||||
else:
|
||||
return 1.0
|
||||
elif strategy == PricingStrategy.PROFIT_MAXIMIZATION:
|
||||
return 1.0 + (price_ratio - 1) * 0.3 # Less sensitive to competition
|
||||
else:
|
||||
return 1.0 + (price_ratio - 1) * 0.5 # Moderate competition sensitivity
|
||||
|
||||
def _calculate_sentiment_multiplier(self, sentiment: float) -> float:
|
||||
"""Calculate market sentiment multiplier"""
|
||||
|
||||
# Sentiment ranges from -1 (negative) to 1 (positive)
|
||||
if sentiment > 0.3:
|
||||
return 1.1 # Positive sentiment premium
|
||||
elif sentiment < -0.3:
|
||||
return 0.9 # Negative sentiment discount
|
||||
else:
|
||||
return 1.0 # Neutral sentiment
|
||||
|
||||
def _calculate_regional_multiplier(self, region: str, resource_type: ResourceType) -> float:
|
||||
"""Calculate regional price multiplier"""
|
||||
|
||||
# Regional pricing adjustments
|
||||
regional_adjustments = {
|
||||
"us_west": {"gpu": 1.1, "service": 1.05, "storage": 1.0},
|
||||
"us_east": {"gpu": 1.2, "service": 1.1, "storage": 1.05},
|
||||
"europe": {"gpu": 1.15, "service": 1.08, "storage": 1.02},
|
||||
"asia": {"gpu": 0.9, "service": 0.95, "storage": 0.9},
|
||||
"global": {"gpu": 1.0, "service": 1.0, "storage": 1.0}
|
||||
}
|
||||
|
||||
return regional_adjustments.get(region, {}).get(resource_type.value, 1.0)
|
||||
|
||||
async def _determine_price_trend(self, resource_id: str, current_price: float) -> PriceTrend:
|
||||
"""Determine price trend based on historical data"""
|
||||
|
||||
if resource_id not in self.pricing_history or len(self.pricing_history[resource_id]) < 5:
|
||||
return PriceTrend.STABLE
|
||||
|
||||
recent_prices = [p.price for p in self.pricing_history[resource_id][-10:]]
|
||||
|
||||
# Calculate trend
|
||||
if len(recent_prices) >= 3:
|
||||
recent_avg = np.mean(recent_prices[-3:])
|
||||
older_avg = np.mean(recent_prices[-6:-3]) if len(recent_prices) >= 6 else np.mean(recent_prices[:-3])
|
||||
|
||||
change = (recent_avg - older_avg) / older_avg if older_avg > 0 else 0
|
||||
|
||||
# Calculate volatility
|
||||
volatility = np.std(recent_prices) / np.mean(recent_prices) if np.mean(recent_prices) > 0 else 0
|
||||
|
||||
if volatility > 0.2:
|
||||
return PriceTrend.VOLATILE
|
||||
elif change > 0.05:
|
||||
return PriceTrend.INCREASING
|
||||
elif change < -0.05:
|
||||
return PriceTrend.DECREASING
|
||||
else:
|
||||
return PriceTrend.STABLE
|
||||
else:
|
||||
return PriceTrend.STABLE
|
||||
|
||||
async def _generate_pricing_reasoning(
|
||||
self,
|
||||
factors: PricingFactors,
|
||||
strategy: PricingStrategy,
|
||||
market_conditions: MarketConditions,
|
||||
trend: PriceTrend
|
||||
) -> List[str]:
|
||||
"""Generate reasoning for pricing decisions"""
|
||||
|
||||
reasoning = []
|
||||
|
||||
# Strategy reasoning
|
||||
reasoning.append(f"Strategy: {strategy.value} applied")
|
||||
|
||||
# Market conditions
|
||||
if factors.demand_level > 0.8:
|
||||
reasoning.append("High demand increases prices")
|
||||
elif factors.demand_level < 0.3:
|
||||
reasoning.append("Low demand allows competitive pricing")
|
||||
|
||||
if factors.supply_level < 0.3:
|
||||
reasoning.append("Limited supply justifies premium pricing")
|
||||
elif factors.supply_level > 0.8:
|
||||
reasoning.append("High supply enables competitive pricing")
|
||||
|
||||
# Time-based reasoning
|
||||
hour = datetime.utcnow().hour
|
||||
if 8 <= hour <= 20:
|
||||
reasoning.append("Business hours premium applied")
|
||||
elif 2 <= hour <= 6:
|
||||
reasoning.append("Late night discount applied")
|
||||
|
||||
# Performance reasoning
|
||||
if factors.performance_multiplier > 1.05:
|
||||
reasoning.append("High performance justifies premium")
|
||||
elif factors.performance_multiplier < 0.95:
|
||||
reasoning.append("Performance issues require discount")
|
||||
|
||||
# Competition reasoning
|
||||
if factors.competition_multiplier != 1.0:
|
||||
if factors.competition_multiplier < 1.0:
|
||||
reasoning.append("Competitive pricing applied")
|
||||
else:
|
||||
reasoning.append("Premium pricing over competitors")
|
||||
|
||||
# Trend reasoning
|
||||
reasoning.append(f"Price trend: {trend.value}")
|
||||
|
||||
return reasoning
|
||||
|
||||
async def _calculate_confidence_score(
|
||||
self,
|
||||
factors: PricingFactors,
|
||||
market_conditions: MarketConditions
|
||||
) -> float:
|
||||
"""Calculate confidence score for pricing decision"""
|
||||
|
||||
confidence = 0.8 # Base confidence
|
||||
|
||||
# Market stability factor
|
||||
stability_factor = 1.0 - market_conditions.price_volatility
|
||||
confidence *= stability_factor
|
||||
|
||||
# Data availability factor
|
||||
data_factor = min(1.0, len(market_conditions.competitor_prices) / 5)
|
||||
confidence = confidence * 0.7 + data_factor * 0.3
|
||||
|
||||
# Factor consistency
|
||||
if abs(factors.demand_multiplier - 1.0) > 1.5:
|
||||
confidence *= 0.9 # Extreme demand adjustments reduce confidence
|
||||
|
||||
if abs(factors.supply_multiplier - 1.0) > 1.0:
|
||||
confidence *= 0.9 # Extreme supply adjustments reduce confidence
|
||||
|
||||
return max(0.3, min(0.95, confidence))
|
||||
|
||||
async def _store_price_point(
|
||||
self,
|
||||
resource_id: str,
|
||||
price: float,
|
||||
factors: PricingFactors,
|
||||
strategy: PricingStrategy
|
||||
):
|
||||
"""Store price point in history"""
|
||||
|
||||
if resource_id not in self.pricing_history:
|
||||
self.pricing_history[resource_id] = []
|
||||
|
||||
price_point = PricePoint(
|
||||
timestamp=datetime.utcnow(),
|
||||
price=price,
|
||||
demand_level=factors.demand_level,
|
||||
supply_level=factors.supply_level,
|
||||
confidence=factors.confidence_score,
|
||||
strategy_used=strategy.value
|
||||
)
|
||||
|
||||
self.pricing_history[resource_id].append(price_point)
|
||||
|
||||
# Keep only last 1000 points
|
||||
if len(self.pricing_history[resource_id]) > 1000:
|
||||
self.pricing_history[resource_id] = self.pricing_history[resource_id][-1000:]
|
||||
|
||||
async def _get_market_conditions(
|
||||
self,
|
||||
resource_type: ResourceType,
|
||||
region: str
|
||||
) -> MarketConditions:
|
||||
"""Get current market conditions"""
|
||||
|
||||
cache_key = f"{region}_{resource_type.value}"
|
||||
|
||||
if cache_key in self.market_conditions_cache:
|
||||
cached = self.market_conditions_cache[cache_key]
|
||||
# Use cached data if less than 5 minutes old
|
||||
if (datetime.utcnow() - cached.timestamp).total_seconds() < 300:
|
||||
return cached
|
||||
|
||||
# In a real implementation, this would fetch from market data sources
|
||||
# For now, return simulated data
|
||||
conditions = MarketConditions(
|
||||
region=region,
|
||||
resource_type=resource_type,
|
||||
demand_level=0.6 + np.random.normal(0, 0.1),
|
||||
supply_level=0.7 + np.random.normal(0, 0.1),
|
||||
average_price=0.05 + np.random.normal(0, 0.01),
|
||||
price_volatility=0.1 + np.random.normal(0, 0.05),
|
||||
utilization_rate=0.65 + np.random.normal(0, 0.1),
|
||||
competitor_prices=[0.045, 0.055, 0.048, 0.052], # Simulated competitor prices
|
||||
market_sentiment=np.random.normal(0.1, 0.2)
|
||||
)
|
||||
|
||||
# Cache the conditions
|
||||
self.market_conditions_cache[cache_key] = conditions
|
||||
|
||||
return conditions
|
||||
|
||||
async def _load_pricing_history(self):
|
||||
"""Load historical pricing data"""
|
||||
# In a real implementation, this would load from database
|
||||
pass
|
||||
|
||||
async def _load_provider_strategies(self):
|
||||
"""Load provider strategies from storage"""
|
||||
# In a real implementation, this would load from database
|
||||
pass
|
||||
|
||||
async def _update_market_conditions(self):
|
||||
"""Background task to update market conditions"""
|
||||
while True:
|
||||
try:
|
||||
# Clear cache to force refresh
|
||||
self.market_conditions_cache.clear()
|
||||
await asyncio.sleep(300) # Update every 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating market conditions: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _monitor_price_volatility(self):
|
||||
"""Background task to monitor price volatility"""
|
||||
while True:
|
||||
try:
|
||||
for resource_id, history in self.pricing_history.items():
|
||||
if len(history) >= 10:
|
||||
recent_prices = [p.price for p in history[-10:]]
|
||||
volatility = np.std(recent_prices) / np.mean(recent_prices) if np.mean(recent_prices) > 0 else 0
|
||||
|
||||
if volatility > self.max_volatility_threshold:
|
||||
logger.warning(f"High volatility detected for {resource_id}: {volatility:.3f}")
|
||||
|
||||
await asyncio.sleep(600) # Check every 10 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring volatility: {e}")
|
||||
await asyncio.sleep(120)
|
||||
|
||||
async def _optimize_strategies(self):
|
||||
"""Background task to optimize pricing strategies"""
|
||||
while True:
|
||||
try:
|
||||
# Analyze strategy performance and recommend optimizations
|
||||
await asyncio.sleep(3600) # Optimize every hour
|
||||
except Exception as e:
|
||||
logger.error(f"Error optimizing strategies: {e}")
|
||||
await asyncio.sleep(300)
|
||||
|
||||
async def _reset_circuit_breaker(self, resource_id: str, delay: int):
|
||||
"""Reset circuit breaker after delay"""
|
||||
await asyncio.sleep(delay)
|
||||
self.circuit_breakers[resource_id] = False
|
||||
logger.info(f"Reset circuit breaker for {resource_id}")
|
||||
|
||||
def _calculate_price_trend(self, prices: List[float]) -> float:
|
||||
"""Calculate simple price trend"""
|
||||
if len(prices) < 2:
|
||||
return 0.0
|
||||
|
||||
# Simple linear regression
|
||||
x = np.arange(len(prices))
|
||||
y = np.array(prices)
|
||||
|
||||
# Calculate slope
|
||||
slope = np.polyfit(x, y, 1)[0]
|
||||
return slope
|
||||
|
||||
def _calculate_seasonal_factor(self, hour: int) -> float:
|
||||
"""Calculate seasonal adjustment factor"""
|
||||
# Simple daily seasonality pattern
|
||||
if 6 <= hour <= 10: # Morning ramp
|
||||
return 1.05
|
||||
elif 10 <= hour <= 16: # Business peak
|
||||
return 1.1
|
||||
elif 16 <= hour <= 20: # Evening ramp
|
||||
return 1.05
|
||||
elif 20 <= hour <= 24: # Night
|
||||
return 0.95
|
||||
else: # Late night
|
||||
return 0.9
|
||||
|
||||
def _forecast_demand_level(self, historical: List[float], hour_ahead: int) -> float:
|
||||
"""Simple demand level forecasting"""
|
||||
if not historical:
|
||||
return 0.5
|
||||
|
||||
# Use recent average with some noise
|
||||
recent_avg = np.mean(historical[-6:]) if len(historical) >= 6 else np.mean(historical)
|
||||
|
||||
# Add some prediction uncertainty
|
||||
noise = np.random.normal(0, 0.05)
|
||||
forecast = max(0.0, min(1.0, recent_avg + noise))
|
||||
|
||||
return forecast
|
||||
|
||||
def _forecast_supply_level(self, historical: List[float], hour_ahead: int) -> float:
|
||||
"""Simple supply level forecasting"""
|
||||
if not historical:
|
||||
return 0.5
|
||||
|
||||
# Supply is usually more stable than demand
|
||||
recent_avg = np.mean(historical[-12:]) if len(historical) >= 12 else np.mean(historical)
|
||||
|
||||
# Add small prediction uncertainty
|
||||
noise = np.random.normal(0, 0.02)
|
||||
forecast = max(0.0, min(1.0, recent_avg + noise))
|
||||
|
||||
return forecast
|
||||
744
apps/coordinator-api/src/app/services/market_data_collector.py
Normal file
744
apps/coordinator-api/src/app/services/market_data_collector.py
Normal file
@@ -0,0 +1,744 @@
|
||||
"""
|
||||
Market Data Collector for Dynamic Pricing Engine
|
||||
Collects real-time market data from various sources for pricing calculations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import websockets
|
||||
from aitbc.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataSource(str, Enum):
|
||||
"""Market data source types"""
|
||||
GPU_METRICS = "gpu_metrics"
|
||||
BOOKING_DATA = "booking_data"
|
||||
REGIONAL_DEMAND = "regional_demand"
|
||||
COMPETITOR_PRICES = "competitor_prices"
|
||||
PERFORMANCE_DATA = "performance_data"
|
||||
MARKET_SENTIMENT = "market_sentiment"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketDataPoint:
|
||||
"""Single market data point"""
|
||||
source: DataSource
|
||||
resource_id: str
|
||||
resource_type: str
|
||||
region: str
|
||||
timestamp: datetime
|
||||
value: float
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatedMarketData:
|
||||
"""Aggregated market data for a resource type and region"""
|
||||
resource_type: str
|
||||
region: str
|
||||
timestamp: datetime
|
||||
demand_level: float
|
||||
supply_level: float
|
||||
average_price: float
|
||||
price_volatility: float
|
||||
utilization_rate: float
|
||||
competitor_prices: List[float]
|
||||
market_sentiment: float
|
||||
data_sources: List[DataSource] = field(default_factory=list)
|
||||
confidence_score: float = 0.8
|
||||
|
||||
|
||||
class MarketDataCollector:
|
||||
"""Collects and processes market data from multiple sources"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
self.data_callbacks: Dict[DataSource, List[Callable]] = {}
|
||||
self.raw_data: List[MarketDataPoint] = []
|
||||
self.aggregated_data: Dict[str, AggregatedMarketData] = {}
|
||||
self.websocket_connections: Dict[str, websockets.WebSocketServerProtocol] = {}
|
||||
|
||||
# Data collection intervals (seconds)
|
||||
self.collection_intervals = {
|
||||
DataSource.GPU_METRICS: 60, # 1 minute
|
||||
DataSource.BOOKING_DATA: 30, # 30 seconds
|
||||
DataSource.REGIONAL_DEMAND: 300, # 5 minutes
|
||||
DataSource.COMPETITOR_PRICES: 600, # 10 minutes
|
||||
DataSource.PERFORMANCE_DATA: 120, # 2 minutes
|
||||
DataSource.MARKET_SENTIMENT: 180 # 3 minutes
|
||||
}
|
||||
|
||||
# Data retention
|
||||
self.max_data_age = timedelta(hours=48)
|
||||
self.max_raw_data_points = 10000
|
||||
|
||||
# WebSocket server
|
||||
self.websocket_port = config.get("websocket_port", 8765)
|
||||
self.websocket_server = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the market data collector"""
|
||||
logger.info("Initializing Market Data Collector")
|
||||
|
||||
# Start data collection tasks
|
||||
for source in DataSource:
|
||||
asyncio.create_task(self._collect_data_source(source))
|
||||
|
||||
# Start data aggregation task
|
||||
asyncio.create_task(self._aggregate_market_data())
|
||||
|
||||
# Start data cleanup task
|
||||
asyncio.create_task(self._cleanup_old_data())
|
||||
|
||||
# Start WebSocket server for real-time updates
|
||||
await self._start_websocket_server()
|
||||
|
||||
logger.info("Market Data Collector initialized")
|
||||
|
||||
def register_callback(self, source: DataSource, callback: Callable):
|
||||
"""Register callback for data updates"""
|
||||
if source not in self.data_callbacks:
|
||||
self.data_callbacks[source] = []
|
||||
self.data_callbacks[source].append(callback)
|
||||
logger.info(f"Registered callback for {source.value}")
|
||||
|
||||
async def get_aggregated_data(
|
||||
self,
|
||||
resource_type: str,
|
||||
region: str = "global"
|
||||
) -> Optional[AggregatedMarketData]:
|
||||
"""Get aggregated market data for a resource type and region"""
|
||||
|
||||
key = f"{resource_type}_{region}"
|
||||
return self.aggregated_data.get(key)
|
||||
|
||||
async def get_recent_data(
|
||||
self,
|
||||
source: DataSource,
|
||||
minutes: int = 60
|
||||
) -> List[MarketDataPoint]:
|
||||
"""Get recent data from a specific source"""
|
||||
|
||||
cutoff_time = datetime.utcnow() - timedelta(minutes=minutes)
|
||||
|
||||
return [
|
||||
point for point in self.raw_data
|
||||
if point.source == source and point.timestamp >= cutoff_time
|
||||
]
|
||||
|
||||
async def _collect_data_source(self, source: DataSource):
|
||||
"""Collect data from a specific source"""
|
||||
|
||||
interval = self.collection_intervals[source]
|
||||
|
||||
while True:
|
||||
try:
|
||||
await self._collect_from_source(source)
|
||||
await asyncio.sleep(interval)
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting data from {source.value}: {e}")
|
||||
await asyncio.sleep(60) # Wait 1 minute on error
|
||||
|
||||
async def _collect_from_source(self, source: DataSource):
|
||||
"""Collect data from a specific source"""
|
||||
|
||||
if source == DataSource.GPU_METRICS:
|
||||
await self._collect_gpu_metrics()
|
||||
elif source == DataSource.BOOKING_DATA:
|
||||
await self._collect_booking_data()
|
||||
elif source == DataSource.REGIONAL_DEMAND:
|
||||
await self._collect_regional_demand()
|
||||
elif source == DataSource.COMPETITOR_PRICES:
|
||||
await self._collect_competitor_prices()
|
||||
elif source == DataSource.PERFORMANCE_DATA:
|
||||
await self._collect_performance_data()
|
||||
elif source == DataSource.MARKET_SENTIMENT:
|
||||
await self._collect_market_sentiment()
|
||||
|
||||
async def _collect_gpu_metrics(self):
|
||||
"""Collect GPU utilization and performance metrics"""
|
||||
|
||||
try:
|
||||
# In a real implementation, this would query GPU monitoring systems
|
||||
# For now, simulate data collection
|
||||
|
||||
regions = ["us_west", "us_east", "europe", "asia"]
|
||||
|
||||
for region in regions:
|
||||
# Simulate GPU metrics
|
||||
utilization = 0.6 + (hash(region + str(datetime.utcnow().minute)) % 100) / 200
|
||||
available_gpus = 100 + (hash(region + str(datetime.utcnow().hour)) % 50)
|
||||
total_gpus = 150
|
||||
|
||||
supply_level = available_gpus / total_gpus
|
||||
|
||||
# Create data points
|
||||
data_point = MarketDataPoint(
|
||||
source=DataSource.GPU_METRICS,
|
||||
resource_id=f"gpu_{region}",
|
||||
resource_type="gpu",
|
||||
region=region,
|
||||
timestamp=datetime.utcnow(),
|
||||
value=utilization,
|
||||
metadata={
|
||||
"available_gpus": available_gpus,
|
||||
"total_gpus": total_gpus,
|
||||
"supply_level": supply_level
|
||||
}
|
||||
)
|
||||
|
||||
await self._add_data_point(data_point)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting GPU metrics: {e}")
|
||||
|
||||
async def _collect_booking_data(self):
|
||||
"""Collect booking and transaction data"""
|
||||
|
||||
try:
|
||||
# Simulate booking data collection
|
||||
regions = ["us_west", "us_east", "europe", "asia"]
|
||||
|
||||
for region in regions:
|
||||
# Simulate recent bookings
|
||||
recent_bookings = (hash(region + str(datetime.utcnow().minute)) % 20)
|
||||
total_capacity = 100
|
||||
booking_rate = recent_bookings / total_capacity
|
||||
|
||||
# Calculate demand level from booking rate
|
||||
demand_level = min(1.0, booking_rate * 2)
|
||||
|
||||
data_point = MarketDataPoint(
|
||||
source=DataSource.BOOKING_DATA,
|
||||
resource_id=f"bookings_{region}",
|
||||
resource_type="gpu",
|
||||
region=region,
|
||||
timestamp=datetime.utcnow(),
|
||||
value=booking_rate,
|
||||
metadata={
|
||||
"recent_bookings": recent_bookings,
|
||||
"total_capacity": total_capacity,
|
||||
"demand_level": demand_level
|
||||
}
|
||||
)
|
||||
|
||||
await self._add_data_point(data_point)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting booking data: {e}")
|
||||
|
||||
async def _collect_regional_demand(self):
|
||||
"""Collect regional demand patterns"""
|
||||
|
||||
try:
|
||||
# Simulate regional demand analysis
|
||||
regions = ["us_west", "us_east", "europe", "asia"]
|
||||
|
||||
for region in regions:
|
||||
# Simulate demand based on time of day and region
|
||||
hour = datetime.utcnow().hour
|
||||
|
||||
# Different regions have different peak times
|
||||
if region == "asia":
|
||||
peak_hours = [9, 10, 11, 14, 15, 16] # Business hours Asia
|
||||
elif region == "europe":
|
||||
peak_hours = [8, 9, 10, 11, 14, 15, 16] # Business hours Europe
|
||||
elif region == "us_east":
|
||||
peak_hours = [9, 10, 11, 14, 15, 16, 17] # Business hours US East
|
||||
else: # us_west
|
||||
peak_hours = [10, 11, 12, 14, 15, 16, 17] # Business hours US West
|
||||
|
||||
base_demand = 0.4
|
||||
if hour in peak_hours:
|
||||
demand_multiplier = 1.5
|
||||
elif hour in [h + 1 for h in peak_hours] or hour in [h - 1 for h in peak_hours]:
|
||||
demand_multiplier = 1.2
|
||||
else:
|
||||
demand_multiplier = 0.8
|
||||
|
||||
demand_level = min(1.0, base_demand * demand_multiplier)
|
||||
|
||||
data_point = MarketDataPoint(
|
||||
source=DataSource.REGIONAL_DEMAND,
|
||||
resource_id=f"demand_{region}",
|
||||
resource_type="gpu",
|
||||
region=region,
|
||||
timestamp=datetime.utcnow(),
|
||||
value=demand_level,
|
||||
metadata={
|
||||
"hour": hour,
|
||||
"peak_hours": peak_hours,
|
||||
"demand_multiplier": demand_multiplier
|
||||
}
|
||||
)
|
||||
|
||||
await self._add_data_point(data_point)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting regional demand: {e}")
|
||||
|
||||
async def _collect_competitor_prices(self):
|
||||
"""Collect competitor pricing data"""
|
||||
|
||||
try:
|
||||
# Simulate competitor price monitoring
|
||||
regions = ["us_west", "us_east", "europe", "asia"]
|
||||
|
||||
for region in regions:
|
||||
# Simulate competitor prices
|
||||
base_price = 0.05
|
||||
competitor_prices = [
|
||||
base_price * (1 + (hash(f"comp1_{region}") % 20 - 10) / 100),
|
||||
base_price * (1 + (hash(f"comp2_{region}") % 20 - 10) / 100),
|
||||
base_price * (1 + (hash(f"comp3_{region}") % 20 - 10) / 100),
|
||||
base_price * (1 + (hash(f"comp4_{region}") % 20 - 10) / 100)
|
||||
]
|
||||
|
||||
avg_competitor_price = sum(competitor_prices) / len(competitor_prices)
|
||||
|
||||
data_point = MarketDataPoint(
|
||||
source=DataSource.COMPETITOR_PRICES,
|
||||
resource_id=f"competitors_{region}",
|
||||
resource_type="gpu",
|
||||
region=region,
|
||||
timestamp=datetime.utcnow(),
|
||||
value=avg_competitor_price,
|
||||
metadata={
|
||||
"competitor_prices": competitor_prices,
|
||||
"price_count": len(competitor_prices)
|
||||
}
|
||||
)
|
||||
|
||||
await self._add_data_point(data_point)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting competitor prices: {e}")
|
||||
|
||||
async def _collect_performance_data(self):
|
||||
"""Collect provider performance metrics"""
|
||||
|
||||
try:
|
||||
# Simulate performance data collection
|
||||
regions = ["us_west", "us_east", "europe", "asia"]
|
||||
|
||||
for region in regions:
|
||||
# Simulate performance metrics
|
||||
completion_rate = 0.85 + (hash(f"perf_{region}") % 20) / 200
|
||||
average_response_time = 120 + (hash(f"resp_{region}") % 60) # seconds
|
||||
error_rate = 0.02 + (hash(f"error_{region}") % 10) / 1000
|
||||
|
||||
performance_score = completion_rate * (1 - error_rate)
|
||||
|
||||
data_point = MarketDataPoint(
|
||||
source=DataSource.PERFORMANCE_DATA,
|
||||
resource_id=f"performance_{region}",
|
||||
resource_type="gpu",
|
||||
region=region,
|
||||
timestamp=datetime.utcnow(),
|
||||
value=performance_score,
|
||||
metadata={
|
||||
"completion_rate": completion_rate,
|
||||
"average_response_time": average_response_time,
|
||||
"error_rate": error_rate
|
||||
}
|
||||
)
|
||||
|
||||
await self._add_data_point(data_point)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting performance data: {e}")
|
||||
|
||||
async def _collect_market_sentiment(self):
|
||||
"""Collect market sentiment data"""
|
||||
|
||||
try:
|
||||
# Simulate sentiment analysis
|
||||
regions = ["us_west", "us_east", "europe", "asia"]
|
||||
|
||||
for region in regions:
|
||||
# Simulate sentiment based on recent market activity
|
||||
recent_activity = (hash(f"activity_{region}") % 100) / 100
|
||||
price_trend = (hash(f"trend_{region}") % 21 - 10) / 100 # -0.1 to 0.1
|
||||
volume_change = (hash(f"volume_{region}") % 31 - 15) / 100 # -0.15 to 0.15
|
||||
|
||||
# Calculate sentiment score (-1 to 1)
|
||||
sentiment = (recent_activity * 0.4 + price_trend * 0.3 + volume_change * 0.3)
|
||||
sentiment = max(-1.0, min(1.0, sentiment))
|
||||
|
||||
data_point = MarketDataPoint(
|
||||
source=DataSource.MARKET_SENTIMENT,
|
||||
resource_id=f"sentiment_{region}",
|
||||
resource_type="gpu",
|
||||
region=region,
|
||||
timestamp=datetime.utcnow(),
|
||||
value=sentiment,
|
||||
metadata={
|
||||
"recent_activity": recent_activity,
|
||||
"price_trend": price_trend,
|
||||
"volume_change": volume_change
|
||||
}
|
||||
)
|
||||
|
||||
await self._add_data_point(data_point)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting market sentiment: {e}")
|
||||
|
||||
async def _add_data_point(self, data_point: MarketDataPoint):
|
||||
"""Add a data point and notify callbacks"""
|
||||
|
||||
# Add to raw data
|
||||
self.raw_data.append(data_point)
|
||||
|
||||
# Maintain data size limits
|
||||
if len(self.raw_data) > self.max_raw_data_points:
|
||||
self.raw_data = self.raw_data[-self.max_raw_data_points:]
|
||||
|
||||
# Notify callbacks
|
||||
if data_point.source in self.data_callbacks:
|
||||
for callback in self.data_callbacks[data_point.source]:
|
||||
try:
|
||||
await callback(data_point)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in data callback: {e}")
|
||||
|
||||
# Broadcast via WebSocket
|
||||
await self._broadcast_data_point(data_point)
|
||||
|
||||
async def _aggregate_market_data(self):
|
||||
"""Aggregate raw market data into useful metrics"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
await self._perform_aggregation()
|
||||
await asyncio.sleep(60) # Aggregate every minute
|
||||
except Exception as e:
|
||||
logger.error(f"Error aggregating market data: {e}")
|
||||
await asyncio.sleep(30)
|
||||
|
||||
async def _perform_aggregation(self):
|
||||
"""Perform the actual data aggregation"""
|
||||
|
||||
regions = ["us_west", "us_east", "europe", "asia", "global"]
|
||||
resource_types = ["gpu", "service", "storage"]
|
||||
|
||||
for resource_type in resource_types:
|
||||
for region in regions:
|
||||
aggregated = await self._aggregate_for_resource_region(resource_type, region)
|
||||
if aggregated:
|
||||
key = f"{resource_type}_{region}"
|
||||
self.aggregated_data[key] = aggregated
|
||||
|
||||
async def _aggregate_for_resource_region(
|
||||
self,
|
||||
resource_type: str,
|
||||
region: str
|
||||
) -> Optional[AggregatedMarketData]:
|
||||
"""Aggregate data for a specific resource type and region"""
|
||||
|
||||
try:
|
||||
# Get recent data for this resource type and region
|
||||
cutoff_time = datetime.utcnow() - timedelta(minutes=30)
|
||||
relevant_data = [
|
||||
point for point in self.raw_data
|
||||
if (point.resource_type == resource_type and
|
||||
point.region == region and
|
||||
point.timestamp >= cutoff_time)
|
||||
]
|
||||
|
||||
if not relevant_data:
|
||||
return None
|
||||
|
||||
# Aggregate metrics by source
|
||||
source_data = {}
|
||||
data_sources = []
|
||||
|
||||
for point in relevant_data:
|
||||
if point.source not in source_data:
|
||||
source_data[point.source] = []
|
||||
source_data[point.source].append(point)
|
||||
if point.source not in data_sources:
|
||||
data_sources.append(point.source)
|
||||
|
||||
# Calculate aggregated metrics
|
||||
demand_level = self._calculate_aggregated_demand(source_data)
|
||||
supply_level = self._calculate_aggregated_supply(source_data)
|
||||
average_price = self._calculate_aggregated_price(source_data)
|
||||
price_volatility = self._calculate_price_volatility(source_data)
|
||||
utilization_rate = self._calculate_aggregated_utilization(source_data)
|
||||
competitor_prices = self._get_competitor_prices(source_data)
|
||||
market_sentiment = self._calculate_aggregated_sentiment(source_data)
|
||||
|
||||
# Calculate confidence score based on data freshness and completeness
|
||||
confidence = self._calculate_aggregation_confidence(source_data, data_sources)
|
||||
|
||||
return AggregatedMarketData(
|
||||
resource_type=resource_type,
|
||||
region=region,
|
||||
timestamp=datetime.utcnow(),
|
||||
demand_level=demand_level,
|
||||
supply_level=supply_level,
|
||||
average_price=average_price,
|
||||
price_volatility=price_volatility,
|
||||
utilization_rate=utilization_rate,
|
||||
competitor_prices=competitor_prices,
|
||||
market_sentiment=market_sentiment,
|
||||
data_sources=data_sources,
|
||||
confidence_score=confidence
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error aggregating data for {resource_type}_{region}: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_aggregated_demand(self, source_data: Dict[DataSource, List[MarketDataPoint]]) -> float:
|
||||
"""Calculate aggregated demand level"""
|
||||
|
||||
demand_values = []
|
||||
|
||||
# Get demand from booking data
|
||||
if DataSource.BOOKING_DATA in source_data:
|
||||
for point in source_data[DataSource.BOOKING_DATA]:
|
||||
if "demand_level" in point.metadata:
|
||||
demand_values.append(point.metadata["demand_level"])
|
||||
|
||||
# Get demand from regional demand data
|
||||
if DataSource.REGIONAL_DEMAND in source_data:
|
||||
for point in source_data[DataSource.REGIONAL_DEMAND]:
|
||||
demand_values.append(point.value)
|
||||
|
||||
if demand_values:
|
||||
return sum(demand_values) / len(demand_values)
|
||||
else:
|
||||
return 0.5 # Default
|
||||
|
||||
def _calculate_aggregated_supply(self, source_data: Dict[DataSource, List[MarketDataPoint]]) -> float:
|
||||
"""Calculate aggregated supply level"""
|
||||
|
||||
supply_values = []
|
||||
|
||||
# Get supply from GPU metrics
|
||||
if DataSource.GPU_METRICS in source_data:
|
||||
for point in source_data[DataSource.GPU_METRICS]:
|
||||
if "supply_level" in point.metadata:
|
||||
supply_values.append(point.metadata["supply_level"])
|
||||
|
||||
if supply_values:
|
||||
return sum(supply_values) / len(supply_values)
|
||||
else:
|
||||
return 0.5 # Default
|
||||
|
||||
def _calculate_aggregated_price(self, source_data: Dict[DataSource, List[MarketDataPoint]]) -> float:
|
||||
"""Calculate aggregated average price"""
|
||||
|
||||
price_values = []
|
||||
|
||||
# Get prices from competitor data
|
||||
if DataSource.COMPETITOR_PRICES in source_data:
|
||||
for point in source_data[DataSource.COMPETITOR_PRICES]:
|
||||
price_values.append(point.value)
|
||||
|
||||
if price_values:
|
||||
return sum(price_values) / len(price_values)
|
||||
else:
|
||||
return 0.05 # Default price
|
||||
|
||||
def _calculate_price_volatility(self, source_data: Dict[DataSource, List[MarketDataPoint]]) -> float:
|
||||
"""Calculate price volatility"""
|
||||
|
||||
price_values = []
|
||||
|
||||
# Get historical prices from competitor data
|
||||
if DataSource.COMPETITOR_PRICES in source_data:
|
||||
for point in source_data[DataSource.COMPETITOR_PRICES]:
|
||||
if "competitor_prices" in point.metadata:
|
||||
price_values.extend(point.metadata["competitor_prices"])
|
||||
|
||||
if len(price_values) >= 2:
|
||||
import numpy as np
|
||||
mean_price = sum(price_values) / len(price_values)
|
||||
variance = sum((p - mean_price) ** 2 for p in price_values) / len(price_values)
|
||||
volatility = (variance ** 0.5) / mean_price if mean_price > 0 else 0
|
||||
return min(1.0, volatility)
|
||||
else:
|
||||
return 0.1 # Default volatility
|
||||
|
||||
def _calculate_aggregated_utilization(self, source_data: Dict[DataSource, List[MarketDataPoint]]) -> float:
|
||||
"""Calculate aggregated utilization rate"""
|
||||
|
||||
utilization_values = []
|
||||
|
||||
# Get utilization from GPU metrics
|
||||
if DataSource.GPU_METRICS in source_data:
|
||||
for point in source_data[DataSource.GPU_METRICS]:
|
||||
utilization_values.append(point.value)
|
||||
|
||||
if utilization_values:
|
||||
return sum(utilization_values) / len(utilization_values)
|
||||
else:
|
||||
return 0.6 # Default utilization
|
||||
|
||||
def _get_competitor_prices(self, source_data: Dict[DataSource, List[MarketDataPoint]]) -> List[float]:
|
||||
"""Get competitor prices"""
|
||||
|
||||
competitor_prices = []
|
||||
|
||||
if DataSource.COMPETITOR_PRICES in source_data:
|
||||
for point in source_data[DataSource.COMPETITOR_PRICES]:
|
||||
if "competitor_prices" in point.metadata:
|
||||
competitor_prices.extend(point.metadata["competitor_prices"])
|
||||
|
||||
return competitor_prices[:10] # Limit to 10 most recent prices
|
||||
|
||||
def _calculate_aggregated_sentiment(self, source_data: Dict[DataSource, List[MarketDataPoint]]) -> float:
|
||||
"""Calculate aggregated market sentiment"""
|
||||
|
||||
sentiment_values = []
|
||||
|
||||
# Get sentiment from market sentiment data
|
||||
if DataSource.MARKET_SENTIMENT in source_data:
|
||||
for point in source_data[DataSource.MARKET_SENTIMENT]:
|
||||
sentiment_values.append(point.value)
|
||||
|
||||
if sentiment_values:
|
||||
return sum(sentiment_values) / len(sentiment_values)
|
||||
else:
|
||||
return 0.0 # Neutral sentiment
|
||||
|
||||
def _calculate_aggregation_confidence(
|
||||
self,
|
||||
source_data: Dict[DataSource, List[MarketDataPoint]],
|
||||
data_sources: List[DataSource]
|
||||
) -> float:
|
||||
"""Calculate confidence score for aggregated data"""
|
||||
|
||||
# Base confidence from number of data sources
|
||||
source_confidence = min(1.0, len(data_sources) / 4.0) # 4 sources available
|
||||
|
||||
# Data freshness confidence
|
||||
now = datetime.utcnow()
|
||||
freshness_scores = []
|
||||
|
||||
for source, points in source_data.items():
|
||||
if points:
|
||||
latest_time = max(point.timestamp for point in points)
|
||||
age_minutes = (now - latest_time).total_seconds() / 60
|
||||
freshness_score = max(0.0, 1.0 - age_minutes / 60) # Decay over 1 hour
|
||||
freshness_scores.append(freshness_score)
|
||||
|
||||
freshness_confidence = sum(freshness_scores) / len(freshness_scores) if freshness_scores else 0.5
|
||||
|
||||
# Data volume confidence
|
||||
total_points = sum(len(points) for points in source_data.values())
|
||||
volume_confidence = min(1.0, total_points / 20.0) # 20 points = full confidence
|
||||
|
||||
# Combine confidences
|
||||
overall_confidence = (
|
||||
source_confidence * 0.4 +
|
||||
freshness_confidence * 0.4 +
|
||||
volume_confidence * 0.2
|
||||
)
|
||||
|
||||
return max(0.1, min(0.95, overall_confidence))
|
||||
|
||||
async def _cleanup_old_data(self):
|
||||
"""Clean up old data points"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
cutoff_time = datetime.utcnow() - self.max_data_age
|
||||
|
||||
# Remove old raw data
|
||||
self.raw_data = [
|
||||
point for point in self.raw_data
|
||||
if point.timestamp >= cutoff_time
|
||||
]
|
||||
|
||||
# Remove old aggregated data
|
||||
for key in list(self.aggregated_data.keys()):
|
||||
if self.aggregated_data[key].timestamp < cutoff_time:
|
||||
del self.aggregated_data[key]
|
||||
|
||||
await asyncio.sleep(3600) # Clean up every hour
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up old data: {e}")
|
||||
await asyncio.sleep(300)
|
||||
|
||||
async def _start_websocket_server(self):
|
||||
"""Start WebSocket server for real-time data streaming"""
|
||||
|
||||
async def handle_websocket(websocket, path):
|
||||
"""Handle WebSocket connections"""
|
||||
try:
|
||||
# Store connection
|
||||
connection_id = f"{websocket.remote_address}_{datetime.utcnow().timestamp()}"
|
||||
self.websocket_connections[connection_id] = websocket
|
||||
|
||||
logger.info(f"WebSocket client connected: {connection_id}")
|
||||
|
||||
# Keep connection alive
|
||||
try:
|
||||
async for message in websocket:
|
||||
# Handle client messages if needed
|
||||
pass
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
# Remove connection
|
||||
if connection_id in self.websocket_connections:
|
||||
del self.websocket_connections[connection_id]
|
||||
logger.info(f"WebSocket client disconnected: {connection_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling WebSocket connection: {e}")
|
||||
|
||||
try:
|
||||
self.websocket_server = await websockets.serve(
|
||||
handle_websocket,
|
||||
"localhost",
|
||||
self.websocket_port
|
||||
)
|
||||
logger.info(f"WebSocket server started on port {self.websocket_port}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start WebSocket server: {e}")
|
||||
|
||||
async def _broadcast_data_point(self, data_point: MarketDataPoint):
|
||||
"""Broadcast data point to all connected WebSocket clients"""
|
||||
|
||||
if not self.websocket_connections:
|
||||
return
|
||||
|
||||
message = {
|
||||
"type": "market_data",
|
||||
"source": data_point.source.value,
|
||||
"resource_id": data_point.resource_id,
|
||||
"resource_type": data_point.resource_type,
|
||||
"region": data_point.region,
|
||||
"timestamp": data_point.timestamp.isoformat(),
|
||||
"value": data_point.value,
|
||||
"metadata": data_point.metadata
|
||||
}
|
||||
|
||||
message_str = json.dumps(message)
|
||||
|
||||
# Send to all connected clients
|
||||
disconnected = []
|
||||
for connection_id, websocket in self.websocket_connections.items():
|
||||
try:
|
||||
await websocket.send(message_str)
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
disconnected.append(connection_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending WebSocket message: {e}")
|
||||
disconnected.append(connection_id)
|
||||
|
||||
# Remove disconnected clients
|
||||
for connection_id in disconnected:
|
||||
if connection_id in self.websocket_connections:
|
||||
del self.websocket_connections[connection_id]
|
||||
361
apps/coordinator-api/src/app/services/multi_language/README.md
Normal file
361
apps/coordinator-api/src/app/services/multi_language/README.md
Normal file
@@ -0,0 +1,361 @@
|
||||
# Multi-Language API Service
|
||||
|
||||
## Overview
|
||||
|
||||
The Multi-Language API service provides comprehensive translation, language detection, and localization capabilities for the AITBC platform. This service enables global agent interactions and marketplace listings with support for 50+ languages.
|
||||
|
||||
## Features
|
||||
|
||||
### Core Capabilities
|
||||
- **Multi-Provider Translation**: OpenAI GPT-4, Google Translate, DeepL, and local models
|
||||
- **Intelligent Fallback**: Automatic provider switching based on language pair and quality
|
||||
- **Language Detection**: Ensemble detection using langdetect, Polyglot, and FastText
|
||||
- **Quality Assurance**: BLEU scores, semantic similarity, and consistency checks
|
||||
- **Redis Caching**: High-performance caching with intelligent eviction
|
||||
- **Real-time Translation**: WebSocket support for live conversations
|
||||
|
||||
### Integration Points
|
||||
- **Agent Communication**: Automatic message translation between agents
|
||||
- **Marketplace Localization**: Multi-language listings and search
|
||||
- **User Preferences**: Per-user language settings and auto-translation
|
||||
- **Cultural Intelligence**: Regional communication style adaptation
|
||||
|
||||
## Architecture
|
||||
|
||||
### Service Components
|
||||
|
||||
```
|
||||
multi_language/
|
||||
├── __init__.py # Service initialization and dependency injection
|
||||
├── translation_engine.py # Core translation orchestration
|
||||
├── language_detector.py # Multi-method language detection
|
||||
├── translation_cache.py # Redis-based caching layer
|
||||
├── quality_assurance.py # Translation quality assessment
|
||||
├── agent_communication.py # Enhanced agent messaging
|
||||
├── marketplace_localization.py # Marketplace content localization
|
||||
├── api_endpoints.py # REST API endpoints
|
||||
├── config.py # Configuration management
|
||||
├── database_schema.sql # Database migrations
|
||||
├── test_multi_language.py # Comprehensive test suite
|
||||
└── requirements.txt # Dependencies
|
||||
```
|
||||
|
||||
### Data Flow
|
||||
|
||||
1. **Translation Request** → Language Detection → Provider Selection → Translation → Quality Check → Cache
|
||||
2. **Agent Message** → Language Detection → Auto-Translation (if needed) → Delivery
|
||||
3. **Marketplace Listing** → Batch Translation → Quality Assessment → Search Indexing
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Translation
|
||||
- `POST /api/v1/multi-language/translate` - Single text translation
|
||||
- `POST /api/v1/multi-language/translate/batch` - Batch translation
|
||||
- `GET /api/v1/multi-language/languages` - Supported languages
|
||||
|
||||
### Language Detection
|
||||
- `POST /api/v1/multi-language/detect-language` - Detect text language
|
||||
- `POST /api/v1/multi-language/detect-language/batch` - Batch detection
|
||||
|
||||
### Cache Management
|
||||
- `GET /api/v1/multi-language/cache/stats` - Cache statistics
|
||||
- `POST /api/v1/multi-language/cache/clear` - Clear cache entries
|
||||
- `POST /api/v1/multi-language/cache/optimize` - Optimize cache
|
||||
|
||||
### Health & Monitoring
|
||||
- `GET /api/v1/multi-language/health` - Service health check
|
||||
- `GET /api/v1/multi-language/cache/top-translations` - Popular translations
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Translation Providers
|
||||
OPENAI_API_KEY=your_openai_api_key
|
||||
GOOGLE_TRANSLATE_API_KEY=your_google_api_key
|
||||
DEEPL_API_KEY=your_deepl_api_key
|
||||
|
||||
# Cache Configuration
|
||||
REDIS_URL=redis://localhost:6379
|
||||
REDIS_PASSWORD=your_redis_password
|
||||
REDIS_DB=0
|
||||
|
||||
# Database
|
||||
DATABASE_URL=postgresql://user:pass@localhost/aitbc
|
||||
|
||||
# FastText Model
|
||||
FASTTEXT_MODEL_PATH=models/lid.176.bin
|
||||
|
||||
# Service Settings
|
||||
ENVIRONMENT=development
|
||||
LOG_LEVEL=INFO
|
||||
PORT=8000
|
||||
```
|
||||
|
||||
### Configuration Structure
|
||||
|
||||
```python
|
||||
{
|
||||
"translation": {
|
||||
"providers": {
|
||||
"openai": {"api_key": "...", "model": "gpt-4"},
|
||||
"google": {"api_key": "..."},
|
||||
"deepl": {"api_key": "..."}
|
||||
},
|
||||
"fallback_strategy": {
|
||||
"primary": "openai",
|
||||
"secondary": "google",
|
||||
"tertiary": "deepl"
|
||||
}
|
||||
},
|
||||
"cache": {
|
||||
"redis": {"url": "redis://localhost:6379"},
|
||||
"default_ttl": 86400,
|
||||
"max_cache_size": 100000
|
||||
},
|
||||
"quality": {
|
||||
"thresholds": {
|
||||
"overall": 0.7,
|
||||
"bleu": 0.3,
|
||||
"semantic_similarity": 0.6
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Database Schema
|
||||
|
||||
### Core Tables
|
||||
- `translation_cache` - Cached translation results
|
||||
- `supported_languages` - Language registry
|
||||
- `agent_message_translations` - Agent communication translations
|
||||
- `marketplace_listings_i18n` - Multi-language marketplace listings
|
||||
- `translation_quality_logs` - Quality assessment logs
|
||||
- `translation_statistics` - Usage analytics
|
||||
|
||||
### Key Relationships
|
||||
- Agents → Language Preferences
|
||||
- Listings → Localized Content
|
||||
- Messages → Translations
|
||||
- Users → Language Settings
|
||||
|
||||
## Performance Metrics
|
||||
|
||||
### Target Performance
|
||||
- **Single Translation**: <200ms
|
||||
- **Batch Translation (100 items)**: <2s
|
||||
- **Language Detection**: <50ms
|
||||
- **Cache Hit Ratio**: >85%
|
||||
- **API Response Time**: <100ms
|
||||
|
||||
### Scaling Considerations
|
||||
- **Horizontal Scaling**: Multiple service instances behind load balancer
|
||||
- **Cache Sharding**: Redis cluster for high-volume caching
|
||||
- **Provider Rate Limiting**: Intelligent request distribution
|
||||
- **Database Partitioning**: Time-based partitioning for logs
|
||||
|
||||
## Quality Assurance
|
||||
|
||||
### Translation Quality Metrics
|
||||
- **BLEU Score**: Reference-based quality assessment
|
||||
- **Semantic Similarity**: NLP-based meaning preservation
|
||||
- **Length Ratio**: Appropriate length preservation
|
||||
- **Consistency**: Internal translation consistency
|
||||
- **Confidence Scoring**: Provider confidence aggregation
|
||||
|
||||
### Quality Thresholds
|
||||
- **Minimum Confidence**: 0.6 for cache eligibility
|
||||
- **Quality Threshold**: 0.7 for user-facing translations
|
||||
- **Auto-Retry**: Below 0.4 confidence triggers retry
|
||||
|
||||
## Security & Privacy
|
||||
|
||||
### Data Protection
|
||||
- **Encryption**: All API communications encrypted
|
||||
- **Data Retention**: Minimal cache retention policies
|
||||
- **Privacy Options**: On-premise models for sensitive data
|
||||
- **Compliance**: GDPR and regional privacy law compliance
|
||||
|
||||
### Access Control
|
||||
- **API Authentication**: JWT-based authentication
|
||||
- **Rate Limiting**: Tiered rate limiting by user type
|
||||
- **Audit Logging**: Complete translation audit trail
|
||||
- **Role-Based Access**: Different access levels for different user types
|
||||
|
||||
## Monitoring & Observability
|
||||
|
||||
### Metrics Collection
|
||||
- **Translation Volume**: Requests per language pair
|
||||
- **Provider Performance**: Response times and error rates
|
||||
- **Cache Performance**: Hit ratios and eviction rates
|
||||
- **Quality Metrics**: Average quality scores by provider
|
||||
|
||||
### Health Checks
|
||||
- **Service Health**: Provider availability checks
|
||||
- **Cache Health**: Redis connectivity and performance
|
||||
- **Database Health**: Connection pool and query performance
|
||||
- **Quality Health**: Quality assessment system status
|
||||
|
||||
### Alerting
|
||||
- **Error Rate**: >5% error rate triggers alerts
|
||||
- **Response Time**: P95 >1s triggers alerts
|
||||
- **Cache Performance**: Hit ratio <70% triggers alerts
|
||||
- **Quality Score**: Average quality <60% triggers alerts
|
||||
|
||||
## Deployment
|
||||
|
||||
### Service Dependencies
|
||||
- **Redis**: For translation caching
|
||||
- **PostgreSQL**: For persistent storage and analytics
|
||||
- **External APIs**: OpenAI, Google Translate, DeepL
|
||||
- **NLP Models**: spaCy models for quality assessment
|
||||
|
||||
### Deployment Steps
|
||||
1. Install dependencies: `pip install -r requirements.txt`
|
||||
2. Configure environment variables
|
||||
3. Run database migrations: `psql -f database_schema.sql`
|
||||
4. Download NLP models: `python -m spacy download en_core_web_sm`
|
||||
5. Start service: `uvicorn main:app --host 0.0.0.0 --port 8000`
|
||||
|
||||
### Docker-Free Deployment
|
||||
```bash
|
||||
# Systemd service configuration
|
||||
sudo cp multi-language.service /etc/systemd/system/
|
||||
sudo systemctl enable multi-language
|
||||
sudo systemctl start multi-language
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
### Test Coverage
|
||||
- **Unit Tests**: Individual component testing
|
||||
- **Integration Tests**: Service interaction testing
|
||||
- **Performance Tests**: Load and stress testing
|
||||
- **Quality Tests**: Translation quality validation
|
||||
|
||||
### Running Tests
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest test_multi_language.py -v
|
||||
|
||||
# Run specific test categories
|
||||
pytest test_multi_language.py::TestTranslationEngine -v
|
||||
pytest test_multi_language.py::TestIntegration -v
|
||||
|
||||
# Run with coverage
|
||||
pytest test_multi_language.py --cov=. --cov-report=html
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Translation
|
||||
```python
|
||||
from app.services.multi_language import initialize_multi_language_service
|
||||
|
||||
# Initialize service
|
||||
service = await initialize_multi_language_service()
|
||||
|
||||
# Translate text
|
||||
result = await service.translation_engine.translate(
|
||||
TranslationRequest(
|
||||
text="Hello world",
|
||||
source_language="en",
|
||||
target_language="es"
|
||||
)
|
||||
)
|
||||
|
||||
print(result.translated_text) # "Hola mundo"
|
||||
```
|
||||
|
||||
### Agent Communication
|
||||
```python
|
||||
# Register agent language profile
|
||||
profile = AgentLanguageProfile(
|
||||
agent_id="agent1",
|
||||
preferred_language="es",
|
||||
supported_languages=["es", "en"],
|
||||
auto_translate_enabled=True
|
||||
)
|
||||
|
||||
await agent_comm.register_agent_language_profile(profile)
|
||||
|
||||
# Send message (auto-translated)
|
||||
message = AgentMessage(
|
||||
id="msg1",
|
||||
sender_id="agent2",
|
||||
receiver_id="agent1",
|
||||
message_type=MessageType.AGENT_TO_AGENT,
|
||||
content="Hello from agent2"
|
||||
)
|
||||
|
||||
translated_message = await agent_comm.send_message(message)
|
||||
print(translated_message.translated_content) # "Hola del agente2"
|
||||
```
|
||||
|
||||
### Marketplace Localization
|
||||
```python
|
||||
# Create localized listing
|
||||
listing = {
|
||||
"id": "service1",
|
||||
"type": "service",
|
||||
"title": "AI Translation Service",
|
||||
"description": "High-quality translation service",
|
||||
"keywords": ["translation", "AI"]
|
||||
}
|
||||
|
||||
localized = await marketplace_loc.create_localized_listing(listing, ["es", "fr"])
|
||||
|
||||
# Search in specific language
|
||||
results = await marketplace_loc.search_localized_listings(
|
||||
"traducción", "es"
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
1. **API Key Errors**: Verify environment variables are set correctly
|
||||
2. **Cache Connection Issues**: Check Redis connectivity and configuration
|
||||
3. **Model Loading Errors**: Ensure NLP models are downloaded
|
||||
4. **Performance Issues**: Monitor cache hit ratio and provider response times
|
||||
|
||||
### Debug Mode
|
||||
```bash
|
||||
# Enable debug logging
|
||||
export LOG_LEVEL=DEBUG
|
||||
export DEBUG=true
|
||||
|
||||
# Run with detailed logging
|
||||
uvicorn main:app --log-level debug
|
||||
```
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
### Short-term (3 months)
|
||||
- **Voice Translation**: Real-time audio translation
|
||||
- **Document Translation**: Bulk document processing
|
||||
- **Custom Models**: Domain-specific translation models
|
||||
- **Enhanced Quality**: Advanced quality assessment metrics
|
||||
|
||||
### Long-term (6+ months)
|
||||
- **Neural Machine Translation**: Custom NMT model training
|
||||
- **Cross-Modal Translation**: Image/video description translation
|
||||
- **Agent Language Learning**: Adaptive language learning
|
||||
- **Blockchain Integration**: Decentralized translation verification
|
||||
|
||||
## Support & Maintenance
|
||||
|
||||
### Regular Maintenance
|
||||
- **Cache Optimization**: Weekly cache cleanup and optimization
|
||||
- **Model Updates**: Monthly NLP model updates
|
||||
- **Performance Monitoring**: Continuous performance monitoring
|
||||
- **Quality Audits**: Regular translation quality audits
|
||||
|
||||
### Support Channels
|
||||
- **Documentation**: Comprehensive API documentation
|
||||
- **Monitoring**: Real-time service monitoring dashboard
|
||||
- **Alerts**: Automated alerting for critical issues
|
||||
- **Logs**: Structured logging for debugging
|
||||
|
||||
This Multi-Language API service provides a robust, scalable foundation for global AI agent interactions and marketplace localization within the AITBC ecosystem.
|
||||
261
apps/coordinator-api/src/app/services/multi_language/__init__.py
Normal file
261
apps/coordinator-api/src/app/services/multi_language/__init__.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
Multi-Language Service Initialization
|
||||
Main entry point for multi-language services
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from .translation_engine import TranslationEngine
|
||||
from .language_detector import LanguageDetector
|
||||
from .translation_cache import TranslationCache
|
||||
from .quality_assurance import TranslationQualityChecker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MultiLanguageService:
|
||||
"""Main service class for multi-language functionality"""
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
self.config = config or self._load_default_config()
|
||||
self.translation_engine: Optional[TranslationEngine] = None
|
||||
self.language_detector: Optional[LanguageDetector] = None
|
||||
self.translation_cache: Optional[TranslationCache] = None
|
||||
self.quality_checker: Optional[TranslationQualityChecker] = None
|
||||
self._initialized = False
|
||||
|
||||
def _load_default_config(self) -> Dict[str, Any]:
|
||||
"""Load default configuration"""
|
||||
return {
|
||||
"translation": {
|
||||
"openai": {
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"model": "gpt-4"
|
||||
},
|
||||
"google": {
|
||||
"api_key": os.getenv("GOOGLE_TRANSLATE_API_KEY")
|
||||
},
|
||||
"deepl": {
|
||||
"api_key": os.getenv("DEEPL_API_KEY")
|
||||
}
|
||||
},
|
||||
"cache": {
|
||||
"redis_url": os.getenv("REDIS_URL", "redis://localhost:6379"),
|
||||
"default_ttl": 86400, # 24 hours
|
||||
"max_cache_size": 100000
|
||||
},
|
||||
"detection": {
|
||||
"fasttext": {
|
||||
"model_path": os.getenv("FASTTEXT_MODEL_PATH", "lid.176.bin")
|
||||
}
|
||||
},
|
||||
"quality": {
|
||||
"thresholds": {
|
||||
"overall": 0.7,
|
||||
"bleu": 0.3,
|
||||
"semantic_similarity": 0.6,
|
||||
"length_ratio": 0.5,
|
||||
"confidence": 0.6
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize all multi-language services"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Initializing Multi-Language Service...")
|
||||
|
||||
# Initialize translation cache first
|
||||
await self._initialize_cache()
|
||||
|
||||
# Initialize translation engine
|
||||
await self._initialize_translation_engine()
|
||||
|
||||
# Initialize language detector
|
||||
await self._initialize_language_detector()
|
||||
|
||||
# Initialize quality checker
|
||||
await self._initialize_quality_checker()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("Multi-Language Service initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Multi-Language Service: {e}")
|
||||
raise
|
||||
|
||||
async def _initialize_cache(self):
|
||||
"""Initialize translation cache"""
|
||||
try:
|
||||
self.translation_cache = TranslationCache(
|
||||
redis_url=self.config["cache"]["redis_url"],
|
||||
config=self.config["cache"]
|
||||
)
|
||||
await self.translation_cache.initialize()
|
||||
logger.info("Translation cache initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize translation cache: {e}")
|
||||
self.translation_cache = None
|
||||
|
||||
async def _initialize_translation_engine(self):
|
||||
"""Initialize translation engine"""
|
||||
try:
|
||||
self.translation_engine = TranslationEngine(self.config["translation"])
|
||||
|
||||
# Inject cache dependency
|
||||
if self.translation_cache:
|
||||
self.translation_engine.cache = self.translation_cache
|
||||
|
||||
logger.info("Translation engine initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize translation engine: {e}")
|
||||
raise
|
||||
|
||||
async def _initialize_language_detector(self):
|
||||
"""Initialize language detector"""
|
||||
try:
|
||||
self.language_detector = LanguageDetector(self.config["detection"])
|
||||
logger.info("Language detector initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize language detector: {e}")
|
||||
raise
|
||||
|
||||
async def _initialize_quality_checker(self):
|
||||
"""Initialize quality checker"""
|
||||
try:
|
||||
self.quality_checker = TranslationQualityChecker(self.config["quality"])
|
||||
logger.info("Quality checker initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize quality checker: {e}")
|
||||
self.quality_checker = None
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown all services"""
|
||||
logger.info("Shutting down Multi-Language Service...")
|
||||
|
||||
if self.translation_cache:
|
||||
await self.translation_cache.close()
|
||||
|
||||
self._initialized = False
|
||||
logger.info("Multi-Language Service shutdown complete")
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Comprehensive health check"""
|
||||
if not self._initialized:
|
||||
return {"status": "not_initialized"}
|
||||
|
||||
health_status = {
|
||||
"overall": "healthy",
|
||||
"services": {}
|
||||
}
|
||||
|
||||
# Check translation engine
|
||||
if self.translation_engine:
|
||||
try:
|
||||
translation_health = await self.translation_engine.health_check()
|
||||
health_status["services"]["translation_engine"] = translation_health
|
||||
if not all(translation_health.values()):
|
||||
health_status["overall"] = "degraded"
|
||||
except Exception as e:
|
||||
health_status["services"]["translation_engine"] = {"error": str(e)}
|
||||
health_status["overall"] = "unhealthy"
|
||||
|
||||
# Check language detector
|
||||
if self.language_detector:
|
||||
try:
|
||||
detection_health = await self.language_detector.health_check()
|
||||
health_status["services"]["language_detector"] = detection_health
|
||||
if not all(detection_health.values()):
|
||||
health_status["overall"] = "degraded"
|
||||
except Exception as e:
|
||||
health_status["services"]["language_detector"] = {"error": str(e)}
|
||||
health_status["overall"] = "unhealthy"
|
||||
|
||||
# Check cache
|
||||
if self.translation_cache:
|
||||
try:
|
||||
cache_health = await self.translation_cache.health_check()
|
||||
health_status["services"]["translation_cache"] = cache_health
|
||||
if cache_health.get("status") != "healthy":
|
||||
health_status["overall"] = "degraded"
|
||||
except Exception as e:
|
||||
health_status["services"]["translation_cache"] = {"error": str(e)}
|
||||
health_status["overall"] = "degraded"
|
||||
|
||||
# Check quality checker
|
||||
if self.quality_checker:
|
||||
try:
|
||||
quality_health = await self.quality_checker.health_check()
|
||||
health_status["services"]["quality_checker"] = quality_health
|
||||
if not all(quality_health.values()):
|
||||
health_status["overall"] = "degraded"
|
||||
except Exception as e:
|
||||
health_status["services"]["quality_checker"] = {"error": str(e)}
|
||||
|
||||
return health_status
|
||||
|
||||
def get_service_status(self) -> Dict[str, bool]:
|
||||
"""Get basic service status"""
|
||||
return {
|
||||
"initialized": self._initialized,
|
||||
"translation_engine": self.translation_engine is not None,
|
||||
"language_detector": self.language_detector is not None,
|
||||
"translation_cache": self.translation_cache is not None,
|
||||
"quality_checker": self.quality_checker is not None
|
||||
}
|
||||
|
||||
# Global service instance
|
||||
multi_language_service = MultiLanguageService()
|
||||
|
||||
# Initialize function for app startup
|
||||
async def initialize_multi_language_service(config: Optional[Dict[str, Any]] = None):
|
||||
"""Initialize the multi-language service"""
|
||||
global multi_language_service
|
||||
|
||||
if config:
|
||||
multi_language_service.config.update(config)
|
||||
|
||||
await multi_language_service.initialize()
|
||||
return multi_language_service
|
||||
|
||||
# Dependency getters for FastAPI
|
||||
async def get_translation_engine():
|
||||
"""Get translation engine instance"""
|
||||
if not multi_language_service.translation_engine:
|
||||
await multi_language_service.initialize()
|
||||
return multi_language_service.translation_engine
|
||||
|
||||
async def get_language_detector():
|
||||
"""Get language detector instance"""
|
||||
if not multi_language_service.language_detector:
|
||||
await multi_language_service.initialize()
|
||||
return multi_language_service.language_detector
|
||||
|
||||
async def get_translation_cache():
|
||||
"""Get translation cache instance"""
|
||||
if not multi_language_service.translation_cache:
|
||||
await multi_language_service.initialize()
|
||||
return multi_language_service.translation_cache
|
||||
|
||||
async def get_quality_checker():
|
||||
"""Get quality checker instance"""
|
||||
if not multi_language_service.quality_checker:
|
||||
await multi_language_service.initialize()
|
||||
return multi_language_service.quality_checker
|
||||
|
||||
# Export main components
|
||||
__all__ = [
|
||||
"MultiLanguageService",
|
||||
"multi_language_service",
|
||||
"initialize_multi_language_service",
|
||||
"get_translation_engine",
|
||||
"get_language_detector",
|
||||
"get_translation_cache",
|
||||
"get_quality_checker"
|
||||
]
|
||||
@@ -0,0 +1,509 @@
|
||||
"""
|
||||
Multi-Language Agent Communication Integration
|
||||
Enhanced agent communication with translation support
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from .translation_engine import TranslationEngine, TranslationRequest, TranslationResponse
|
||||
from .language_detector import LanguageDetector, DetectionResult
|
||||
from .translation_cache import TranslationCache
|
||||
from .quality_assurance import TranslationQualityChecker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
AGENT_TO_AGENT = "agent_to_agent"
|
||||
AGENT_TO_USER = "agent_to_user"
|
||||
USER_TO_AGENT = "user_to_agent"
|
||||
SYSTEM = "system"
|
||||
|
||||
@dataclass
|
||||
class AgentMessage:
|
||||
"""Enhanced agent message with multi-language support"""
|
||||
id: str
|
||||
sender_id: str
|
||||
receiver_id: str
|
||||
message_type: MessageType
|
||||
content: str
|
||||
original_language: Optional[str] = None
|
||||
translated_content: Optional[str] = None
|
||||
target_language: Optional[str] = None
|
||||
translation_confidence: Optional[float] = None
|
||||
translation_provider: Optional[str] = None
|
||||
metadata: Dict[str, Any] = None
|
||||
created_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
@dataclass
|
||||
class AgentLanguageProfile:
|
||||
"""Agent language preferences and capabilities"""
|
||||
agent_id: str
|
||||
preferred_language: str
|
||||
supported_languages: List[str]
|
||||
auto_translate_enabled: bool
|
||||
translation_quality_threshold: float
|
||||
cultural_preferences: Dict[str, Any]
|
||||
created_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
if self.cultural_preferences is None:
|
||||
self.cultural_preferences = {}
|
||||
|
||||
class MultilingualAgentCommunication:
|
||||
"""Enhanced agent communication with multi-language support"""
|
||||
|
||||
def __init__(self, translation_engine: TranslationEngine,
|
||||
language_detector: LanguageDetector,
|
||||
translation_cache: Optional[TranslationCache] = None,
|
||||
quality_checker: Optional[TranslationQualityChecker] = None):
|
||||
self.translation_engine = translation_engine
|
||||
self.language_detector = language_detector
|
||||
self.translation_cache = translation_cache
|
||||
self.quality_checker = quality_checker
|
||||
self.agent_profiles: Dict[str, AgentLanguageProfile] = {}
|
||||
self.message_history: List[AgentMessage] = []
|
||||
self.translation_stats = {
|
||||
"total_translations": 0,
|
||||
"successful_translations": 0,
|
||||
"failed_translations": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0
|
||||
}
|
||||
|
||||
async def register_agent_language_profile(self, profile: AgentLanguageProfile) -> bool:
|
||||
"""Register agent language preferences"""
|
||||
try:
|
||||
self.agent_profiles[profile.agent_id] = profile
|
||||
logger.info(f"Registered language profile for agent {profile.agent_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register agent language profile: {e}")
|
||||
return False
|
||||
|
||||
async def get_agent_language_profile(self, agent_id: str) -> Optional[AgentLanguageProfile]:
|
||||
"""Get agent language profile"""
|
||||
return self.agent_profiles.get(agent_id)
|
||||
|
||||
async def send_message(self, message: AgentMessage) -> AgentMessage:
|
||||
"""Send message with automatic translation if needed"""
|
||||
try:
|
||||
# Detect source language if not provided
|
||||
if not message.original_language:
|
||||
detection_result = await self.language_detector.detect_language(message.content)
|
||||
message.original_language = detection_result.language
|
||||
|
||||
# Get receiver's language preference
|
||||
receiver_profile = await self.get_agent_language_profile(message.receiver_id)
|
||||
|
||||
if receiver_profile and receiver_profile.auto_translate_enabled:
|
||||
# Check if translation is needed
|
||||
if message.original_language != receiver_profile.preferred_language:
|
||||
message.target_language = receiver_profile.preferred_language
|
||||
|
||||
# Perform translation
|
||||
translation_result = await self._translate_message(
|
||||
message.content,
|
||||
message.original_language,
|
||||
receiver_profile.preferred_language,
|
||||
message.message_type
|
||||
)
|
||||
|
||||
if translation_result:
|
||||
message.translated_content = translation_result.translated_text
|
||||
message.translation_confidence = translation_result.confidence
|
||||
message.translation_provider = translation_result.provider.value
|
||||
|
||||
# Quality check if threshold is set
|
||||
if (receiver_profile.translation_quality_threshold > 0 and
|
||||
translation_result.confidence < receiver_profile.translation_quality_threshold):
|
||||
logger.warning(f"Translation confidence {translation_result.confidence} below threshold {receiver_profile.translation_quality_threshold}")
|
||||
|
||||
# Store message
|
||||
self.message_history.append(message)
|
||||
|
||||
return message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message: {e}")
|
||||
raise
|
||||
|
||||
async def _translate_message(self, content: str, source_lang: str, target_lang: str,
|
||||
message_type: MessageType) -> Optional[TranslationResponse]:
|
||||
"""Translate message content with context"""
|
||||
try:
|
||||
# Add context based on message type
|
||||
context = self._get_translation_context(message_type)
|
||||
domain = self._get_translation_domain(message_type)
|
||||
|
||||
# Check cache first
|
||||
cache_key = f"agent_message:{hashlib.md5(content.encode()).hexdigest()}:{source_lang}:{target_lang}"
|
||||
if self.translation_cache:
|
||||
cached_result = await self.translation_cache.get(content, source_lang, target_lang, context, domain)
|
||||
if cached_result:
|
||||
self.translation_stats["cache_hits"] += 1
|
||||
return cached_result
|
||||
self.translation_stats["cache_misses"] += 1
|
||||
|
||||
# Perform translation
|
||||
translation_request = TranslationRequest(
|
||||
text=content,
|
||||
source_language=source_lang,
|
||||
target_language=target_lang,
|
||||
context=context,
|
||||
domain=domain
|
||||
)
|
||||
|
||||
translation_result = await self.translation_engine.translate(translation_request)
|
||||
|
||||
# Cache the result
|
||||
if self.translation_cache and translation_result.confidence > 0.8:
|
||||
await self.translation_cache.set(content, source_lang, target_lang, translation_result, context=context, domain=domain)
|
||||
|
||||
self.translation_stats["total_translations"] += 1
|
||||
self.translation_stats["successful_translations"] += 1
|
||||
|
||||
return translation_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to translate message: {e}")
|
||||
self.translation_stats["failed_translations"] += 1
|
||||
return None
|
||||
|
||||
def _get_translation_context(self, message_type: MessageType) -> str:
|
||||
"""Get translation context based on message type"""
|
||||
contexts = {
|
||||
MessageType.TEXT: "General text communication between AI agents",
|
||||
MessageType.AGENT_TO_AGENT: "Technical communication between AI agents",
|
||||
MessageType.AGENT_TO_USER: "AI agent responding to human user",
|
||||
MessageType.USER_TO_AGENT: "Human user communicating with AI agent",
|
||||
MessageType.SYSTEM: "System notification or status message"
|
||||
}
|
||||
return contexts.get(message_type, "General communication")
|
||||
|
||||
def _get_translation_domain(self, message_type: MessageType) -> str:
|
||||
"""Get translation domain based on message type"""
|
||||
domains = {
|
||||
MessageType.TEXT: "general",
|
||||
MessageType.AGENT_TO_AGENT: "technical",
|
||||
MessageType.AGENT_TO_USER: "customer_service",
|
||||
MessageType.USER_TO_AGENT: "user_input",
|
||||
MessageType.SYSTEM: "system"
|
||||
}
|
||||
return domains.get(message_type, "general")
|
||||
|
||||
async def translate_message_history(self, agent_id: str, target_language: str) -> List[AgentMessage]:
|
||||
"""Translate agent's message history to target language"""
|
||||
try:
|
||||
agent_messages = [msg for msg in self.message_history if msg.receiver_id == agent_id or msg.sender_id == agent_id]
|
||||
translated_messages = []
|
||||
|
||||
for message in agent_messages:
|
||||
if message.original_language != target_language and not message.translated_content:
|
||||
translation_result = await self._translate_message(
|
||||
message.content,
|
||||
message.original_language,
|
||||
target_language,
|
||||
message.message_type
|
||||
)
|
||||
|
||||
if translation_result:
|
||||
message.translated_content = translation_result.translated_text
|
||||
message.translation_confidence = translation_result.confidence
|
||||
message.translation_provider = translation_result.provider.value
|
||||
message.target_language = target_language
|
||||
|
||||
translated_messages.append(message)
|
||||
|
||||
return translated_messages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to translate message history: {e}")
|
||||
return []
|
||||
|
||||
async def get_conversation_summary(self, agent_ids: List[str], language: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get conversation summary with optional translation"""
|
||||
try:
|
||||
# Filter messages by participants
|
||||
conversation_messages = [
|
||||
msg for msg in self.message_history
|
||||
if msg.sender_id in agent_ids and msg.receiver_id in agent_ids
|
||||
]
|
||||
|
||||
if not conversation_messages:
|
||||
return {"summary": "No conversation found", "message_count": 0}
|
||||
|
||||
# Sort by timestamp
|
||||
conversation_messages.sort(key=lambda x: x.created_at)
|
||||
|
||||
# Generate summary
|
||||
summary = {
|
||||
"participants": agent_ids,
|
||||
"message_count": len(conversation_messages),
|
||||
"languages_used": list(set([msg.original_language for msg in conversation_messages if msg.original_language])),
|
||||
"start_time": conversation_messages[0].created_at.isoformat(),
|
||||
"end_time": conversation_messages[-1].created_at.isoformat(),
|
||||
"messages": []
|
||||
}
|
||||
|
||||
# Add messages with optional translation
|
||||
for message in conversation_messages:
|
||||
message_data = {
|
||||
"id": message.id,
|
||||
"sender": message.sender_id,
|
||||
"receiver": message.receiver_id,
|
||||
"type": message.message_type.value,
|
||||
"timestamp": message.created_at.isoformat(),
|
||||
"original_language": message.original_language,
|
||||
"original_content": message.content
|
||||
}
|
||||
|
||||
# Add translated content if requested and available
|
||||
if language and message.translated_content and message.target_language == language:
|
||||
message_data["translated_content"] = message.translated_content
|
||||
message_data["translation_confidence"] = message.translation_confidence
|
||||
elif language and language != message.original_language and not message.translated_content:
|
||||
# Translate on-demand
|
||||
translation_result = await self._translate_message(
|
||||
message.content,
|
||||
message.original_language,
|
||||
language,
|
||||
message.message_type
|
||||
)
|
||||
|
||||
if translation_result:
|
||||
message_data["translated_content"] = translation_result.translated_text
|
||||
message_data["translation_confidence"] = translation_result.confidence
|
||||
|
||||
summary["messages"].append(message_data)
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get conversation summary: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def detect_language_conflicts(self, conversation: List[AgentMessage]) -> List[Dict[str, Any]]:
|
||||
"""Detect potential language conflicts in conversation"""
|
||||
try:
|
||||
conflicts = []
|
||||
language_changes = []
|
||||
|
||||
# Track language changes
|
||||
for i, message in enumerate(conversation):
|
||||
if i > 0:
|
||||
prev_message = conversation[i-1]
|
||||
if message.original_language != prev_message.original_language:
|
||||
language_changes.append({
|
||||
"message_id": message.id,
|
||||
"from_language": prev_message.original_language,
|
||||
"to_language": message.original_language,
|
||||
"timestamp": message.created_at.isoformat()
|
||||
})
|
||||
|
||||
# Check for translation quality issues
|
||||
for message in conversation:
|
||||
if (message.translation_confidence and
|
||||
message.translation_confidence < 0.6):
|
||||
conflicts.append({
|
||||
"type": "low_translation_confidence",
|
||||
"message_id": message.id,
|
||||
"confidence": message.translation_confidence,
|
||||
"recommendation": "Consider manual review or re-translation"
|
||||
})
|
||||
|
||||
# Check for unsupported languages
|
||||
supported_languages = set()
|
||||
for profile in self.agent_profiles.values():
|
||||
supported_languages.update(profile.supported_languages)
|
||||
|
||||
for message in conversation:
|
||||
if message.original_language not in supported_languages:
|
||||
conflicts.append({
|
||||
"type": "unsupported_language",
|
||||
"message_id": message.id,
|
||||
"language": message.original_language,
|
||||
"recommendation": "Add language support or use fallback translation"
|
||||
})
|
||||
|
||||
return conflicts
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to detect language conflicts: {e}")
|
||||
return []
|
||||
|
||||
async def optimize_agent_languages(self, agent_id: str) -> Dict[str, Any]:
|
||||
"""Optimize language settings for an agent based on communication patterns"""
|
||||
try:
|
||||
agent_messages = [
|
||||
msg for msg in self.message_history
|
||||
if msg.sender_id == agent_id or msg.receiver_id == agent_id
|
||||
]
|
||||
|
||||
if not agent_messages:
|
||||
return {"recommendation": "No communication data available"}
|
||||
|
||||
# Analyze language usage
|
||||
language_frequency = {}
|
||||
translation_frequency = {}
|
||||
|
||||
for message in agent_messages:
|
||||
# Count original languages
|
||||
lang = message.original_language
|
||||
language_frequency[lang] = language_frequency.get(lang, 0) + 1
|
||||
|
||||
# Count translations
|
||||
if message.translated_content:
|
||||
target_lang = message.target_language
|
||||
translation_frequency[target_lang] = translation_frequency.get(target_lang, 0) + 1
|
||||
|
||||
# Get current profile
|
||||
profile = await self.get_agent_language_profile(agent_id)
|
||||
if not profile:
|
||||
return {"error": "Agent profile not found"}
|
||||
|
||||
# Generate recommendations
|
||||
recommendations = []
|
||||
|
||||
# Most used languages
|
||||
if language_frequency:
|
||||
most_used = max(language_frequency, key=language_frequency.get)
|
||||
if most_used != profile.preferred_language:
|
||||
recommendations.append({
|
||||
"type": "preferred_language",
|
||||
"suggestion": most_used,
|
||||
"reason": f"Most frequently used language ({language_frequency[most_used]} messages)"
|
||||
})
|
||||
|
||||
# Add missing languages to supported list
|
||||
missing_languages = set(language_frequency.keys()) - set(profile.supported_languages)
|
||||
for lang in missing_languages:
|
||||
if language_frequency[lang] > 5: # Significant usage
|
||||
recommendations.append({
|
||||
"type": "add_supported_language",
|
||||
"suggestion": lang,
|
||||
"reason": f"Used in {language_frequency[lang]} messages"
|
||||
})
|
||||
|
||||
return {
|
||||
"current_profile": asdict(profile),
|
||||
"language_frequency": language_frequency,
|
||||
"translation_frequency": translation_frequency,
|
||||
"recommendations": recommendations
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to optimize agent languages: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def get_translation_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive translation statistics"""
|
||||
try:
|
||||
stats = self.translation_stats.copy()
|
||||
|
||||
# Calculate success rate
|
||||
total = stats["total_translations"]
|
||||
if total > 0:
|
||||
stats["success_rate"] = stats["successful_translations"] / total
|
||||
stats["failure_rate"] = stats["failed_translations"] / total
|
||||
else:
|
||||
stats["success_rate"] = 0.0
|
||||
stats["failure_rate"] = 0.0
|
||||
|
||||
# Calculate cache hit ratio
|
||||
cache_total = stats["cache_hits"] + stats["cache_misses"]
|
||||
if cache_total > 0:
|
||||
stats["cache_hit_ratio"] = stats["cache_hits"] / cache_total
|
||||
else:
|
||||
stats["cache_hit_ratio"] = 0.0
|
||||
|
||||
# Agent statistics
|
||||
agent_stats = {}
|
||||
for agent_id, profile in self.agent_profiles.items():
|
||||
agent_messages = [
|
||||
msg for msg in self.message_history
|
||||
if msg.sender_id == agent_id or msg.receiver_id == agent_id
|
||||
]
|
||||
|
||||
translated_count = len([msg for msg in agent_messages if msg.translated_content])
|
||||
|
||||
agent_stats[agent_id] = {
|
||||
"preferred_language": profile.preferred_language,
|
||||
"supported_languages": profile.supported_languages,
|
||||
"total_messages": len(agent_messages),
|
||||
"translated_messages": translated_count,
|
||||
"translation_rate": translated_count / len(agent_messages) if agent_messages else 0.0
|
||||
}
|
||||
|
||||
stats["agent_statistics"] = agent_stats
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get translation statistics: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Health check for multilingual agent communication"""
|
||||
try:
|
||||
health_status = {
|
||||
"overall": "healthy",
|
||||
"services": {},
|
||||
"statistics": {}
|
||||
}
|
||||
|
||||
# Check translation engine
|
||||
translation_health = await self.translation_engine.health_check()
|
||||
health_status["services"]["translation_engine"] = all(translation_health.values())
|
||||
|
||||
# Check language detector
|
||||
detection_health = await self.language_detector.health_check()
|
||||
health_status["services"]["language_detector"] = all(detection_health.values())
|
||||
|
||||
# Check cache
|
||||
if self.translation_cache:
|
||||
cache_health = await self.translation_cache.health_check()
|
||||
health_status["services"]["translation_cache"] = cache_health.get("status") == "healthy"
|
||||
else:
|
||||
health_status["services"]["translation_cache"] = False
|
||||
|
||||
# Check quality checker
|
||||
if self.quality_checker:
|
||||
quality_health = await self.quality_checker.health_check()
|
||||
health_status["services"]["quality_checker"] = all(quality_health.values())
|
||||
else:
|
||||
health_status["services"]["quality_checker"] = False
|
||||
|
||||
# Overall status
|
||||
all_healthy = all(health_status["services"].values())
|
||||
health_status["overall"] = "healthy" if all_healthy else "degraded" if any(health_status["services"].values()) else "unhealthy"
|
||||
|
||||
# Add statistics
|
||||
health_status["statistics"] = {
|
||||
"registered_agents": len(self.agent_profiles),
|
||||
"total_messages": len(self.message_history),
|
||||
"translation_stats": self.translation_stats
|
||||
}
|
||||
|
||||
return health_status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}")
|
||||
return {
|
||||
"overall": "unhealthy",
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -0,0 +1,522 @@
|
||||
"""
|
||||
Multi-Language API Endpoints
|
||||
REST API endpoints for translation and language detection services
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import List, Optional, Dict, Any
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from .translation_engine import TranslationEngine, TranslationRequest, TranslationResponse, TranslationProvider
|
||||
from .language_detector import LanguageDetector, DetectionMethod, DetectionResult
|
||||
from .translation_cache import TranslationCache
|
||||
from .quality_assurance import TranslationQualityChecker, QualityAssessment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pydantic models for API requests/responses
|
||||
class TranslationAPIRequest(BaseModel):
|
||||
text: str = Field(..., min_length=1, max_length=10000, description="Text to translate")
|
||||
source_language: str = Field(..., description="Source language code (e.g., 'en', 'zh')")
|
||||
target_language: str = Field(..., description="Target language code (e.g., 'es', 'fr')")
|
||||
context: Optional[str] = Field(None, description="Additional context for translation")
|
||||
domain: Optional[str] = Field(None, description="Domain-specific context (e.g., 'medical', 'legal')")
|
||||
use_cache: bool = Field(True, description="Whether to use cached translations")
|
||||
quality_check: bool = Field(False, description="Whether to perform quality assessment")
|
||||
|
||||
@validator('text')
|
||||
def validate_text(cls, v):
|
||||
if not v.strip():
|
||||
raise ValueError('Text cannot be empty')
|
||||
return v.strip()
|
||||
|
||||
class BatchTranslationRequest(BaseModel):
|
||||
translations: List[TranslationAPIRequest] = Field(..., max_items=100, description="List of translation requests")
|
||||
|
||||
@validator('translations')
|
||||
def validate_translations(cls, v):
|
||||
if len(v) == 0:
|
||||
raise ValueError('At least one translation request is required')
|
||||
return v
|
||||
|
||||
class LanguageDetectionRequest(BaseModel):
|
||||
text: str = Field(..., min_length=10, max_length=10000, description="Text for language detection")
|
||||
methods: Optional[List[str]] = Field(None, description="Detection methods to use")
|
||||
|
||||
@validator('methods')
|
||||
def validate_methods(cls, v):
|
||||
if v:
|
||||
valid_methods = [method.value for method in DetectionMethod]
|
||||
for method in v:
|
||||
if method not in valid_methods:
|
||||
raise ValueError(f'Invalid detection method: {method}')
|
||||
return v
|
||||
|
||||
class BatchDetectionRequest(BaseModel):
|
||||
texts: List[str] = Field(..., max_items=100, description="List of texts for language detection")
|
||||
methods: Optional[List[str]] = Field(None, description="Detection methods to use")
|
||||
|
||||
class TranslationAPIResponse(BaseModel):
|
||||
translated_text: str
|
||||
confidence: float
|
||||
provider: str
|
||||
processing_time_ms: int
|
||||
source_language: str
|
||||
target_language: str
|
||||
cached: bool = False
|
||||
quality_assessment: Optional[Dict[str, Any]] = None
|
||||
|
||||
class BatchTranslationResponse(BaseModel):
|
||||
translations: List[TranslationAPIResponse]
|
||||
total_processed: int
|
||||
failed_count: int
|
||||
processing_time_ms: int
|
||||
errors: List[str] = []
|
||||
|
||||
class LanguageDetectionResponse(BaseModel):
|
||||
language: str
|
||||
confidence: float
|
||||
method: str
|
||||
alternatives: List[Dict[str, float]]
|
||||
processing_time_ms: int
|
||||
|
||||
class BatchDetectionResponse(BaseModel):
|
||||
detections: List[LanguageDetectionResponse]
|
||||
total_processed: int
|
||||
processing_time_ms: int
|
||||
|
||||
class SupportedLanguagesResponse(BaseModel):
|
||||
languages: Dict[str, List[str]] # Provider -> List of languages
|
||||
total_languages: int
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
services: Dict[str, bool]
|
||||
timestamp: datetime
|
||||
|
||||
# Dependency injection
|
||||
async def get_translation_engine() -> TranslationEngine:
|
||||
"""Dependency injection for translation engine"""
|
||||
# This would be initialized in the main app
|
||||
from ..main import translation_engine
|
||||
return translation_engine
|
||||
|
||||
async def get_language_detector() -> LanguageDetector:
|
||||
"""Dependency injection for language detector"""
|
||||
from ..main import language_detector
|
||||
return language_detector
|
||||
|
||||
async def get_translation_cache() -> Optional[TranslationCache]:
|
||||
"""Dependency injection for translation cache"""
|
||||
from ..main import translation_cache
|
||||
return translation_cache
|
||||
|
||||
async def get_quality_checker() -> Optional[TranslationQualityChecker]:
|
||||
"""Dependency injection for quality checker"""
|
||||
from ..main import quality_checker
|
||||
return quality_checker
|
||||
|
||||
# Router setup
|
||||
router = APIRouter(prefix="/api/v1/multi-language", tags=["multi-language"])
|
||||
|
||||
@router.post("/translate", response_model=TranslationAPIResponse)
|
||||
async def translate_text(
|
||||
request: TranslationAPIRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
engine: TranslationEngine = Depends(get_translation_engine),
|
||||
cache: Optional[TranslationCache] = Depends(get_translation_cache),
|
||||
quality_checker: Optional[TranslationQualityChecker] = Depends(get_quality_checker)
|
||||
):
|
||||
"""
|
||||
Translate text between supported languages with caching and quality assessment
|
||||
"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
# Check cache first
|
||||
cached_result = None
|
||||
if request.use_cache and cache:
|
||||
cached_result = await cache.get(
|
||||
request.text,
|
||||
request.source_language,
|
||||
request.target_language,
|
||||
request.context,
|
||||
request.domain
|
||||
)
|
||||
|
||||
if cached_result:
|
||||
# Update cache access statistics in background
|
||||
background_tasks.add_task(
|
||||
cache.get, # This will update access count
|
||||
request.text,
|
||||
request.source_language,
|
||||
request.target_language,
|
||||
request.context,
|
||||
request.domain
|
||||
)
|
||||
|
||||
return TranslationAPIResponse(
|
||||
translated_text=cached_result.translated_text,
|
||||
confidence=cached_result.confidence,
|
||||
provider=cached_result.provider.value,
|
||||
processing_time_ms=cached_result.processing_time_ms,
|
||||
source_language=cached_result.source_language,
|
||||
target_language=cached_result.target_language,
|
||||
cached=True
|
||||
)
|
||||
|
||||
# Perform translation
|
||||
translation_request = TranslationRequest(
|
||||
text=request.text,
|
||||
source_language=request.source_language,
|
||||
target_language=request.target_language,
|
||||
context=request.context,
|
||||
domain=request.domain
|
||||
)
|
||||
|
||||
translation_result = await engine.translate(translation_request)
|
||||
|
||||
# Cache the result
|
||||
if cache and translation_result.confidence > 0.8:
|
||||
background_tasks.add_task(
|
||||
cache.set,
|
||||
request.text,
|
||||
request.source_language,
|
||||
request.target_language,
|
||||
translation_result,
|
||||
context=request.context,
|
||||
domain=request.domain
|
||||
)
|
||||
|
||||
# Quality assessment
|
||||
quality_assessment = None
|
||||
if request.quality_check and quality_checker:
|
||||
assessment = await quality_checker.evaluate_translation(
|
||||
request.text,
|
||||
translation_result.translated_text,
|
||||
request.source_language,
|
||||
request.target_language
|
||||
)
|
||||
quality_assessment = {
|
||||
"overall_score": assessment.overall_score,
|
||||
"passed_threshold": assessment.passed_threshold,
|
||||
"recommendations": assessment.recommendations
|
||||
}
|
||||
|
||||
return TranslationAPIResponse(
|
||||
translated_text=translation_result.translated_text,
|
||||
confidence=translation_result.confidence,
|
||||
provider=translation_result.provider.value,
|
||||
processing_time_ms=translation_result.processing_time_ms,
|
||||
source_language=translation_result.source_language,
|
||||
target_language=translation_result.target_language,
|
||||
cached=False,
|
||||
quality_assessment=quality_assessment
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Translation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/translate/batch", response_model=BatchTranslationResponse)
|
||||
async def translate_batch(
|
||||
request: BatchTranslationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
engine: TranslationEngine = Depends(get_translation_engine),
|
||||
cache: Optional[TranslationCache] = Depends(get_translation_cache)
|
||||
):
|
||||
"""
|
||||
Translate multiple texts in a single request
|
||||
"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
# Process translations in parallel
|
||||
tasks = []
|
||||
for translation_req in request.translations:
|
||||
task = translate_text(
|
||||
translation_req,
|
||||
background_tasks,
|
||||
engine,
|
||||
cache,
|
||||
None # Skip quality check for batch
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Process results
|
||||
translations = []
|
||||
errors = []
|
||||
failed_count = 0
|
||||
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, TranslationAPIResponse):
|
||||
translations.append(result)
|
||||
else:
|
||||
errors.append(f"Translation {i+1} failed: {str(result)}")
|
||||
failed_count += 1
|
||||
|
||||
processing_time = int((asyncio.get_event_loop().time() - start_time) * 1000)
|
||||
|
||||
return BatchTranslationResponse(
|
||||
translations=translations,
|
||||
total_processed=len(request.translations),
|
||||
failed_count=failed_count,
|
||||
processing_time_ms=processing_time,
|
||||
errors=errors
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch translation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/detect-language", response_model=LanguageDetectionResponse)
|
||||
async def detect_language(
|
||||
request: LanguageDetectionRequest,
|
||||
detector: LanguageDetector = Depends(get_language_detector)
|
||||
):
|
||||
"""
|
||||
Detect the language of given text
|
||||
"""
|
||||
try:
|
||||
# Convert method strings to enum
|
||||
methods = None
|
||||
if request.methods:
|
||||
methods = [DetectionMethod(method) for method in request.methods]
|
||||
|
||||
result = await detector.detect_language(request.text, methods)
|
||||
|
||||
return LanguageDetectionResponse(
|
||||
language=result.language,
|
||||
confidence=result.confidence,
|
||||
method=result.method.value,
|
||||
alternatives=[
|
||||
{"language": lang, "confidence": conf}
|
||||
for lang, conf in result.alternatives
|
||||
],
|
||||
processing_time_ms=result.processing_time_ms
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Language detection error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/detect-language/batch", response_model=BatchDetectionResponse)
|
||||
async def detect_language_batch(
|
||||
request: BatchDetectionRequest,
|
||||
detector: LanguageDetector = Depends(get_language_detector)
|
||||
):
|
||||
"""
|
||||
Detect languages for multiple texts in a single request
|
||||
"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
# Convert method strings to enum
|
||||
methods = None
|
||||
if request.methods:
|
||||
methods = [DetectionMethod(method) for method in request.methods]
|
||||
|
||||
results = await detector.batch_detect(request.texts)
|
||||
|
||||
detections = []
|
||||
for result in results:
|
||||
detections.append(LanguageDetectionResponse(
|
||||
language=result.language,
|
||||
confidence=result.confidence,
|
||||
method=result.method.value,
|
||||
alternatives=[
|
||||
{"language": lang, "confidence": conf}
|
||||
for lang, conf in result.alternatives
|
||||
],
|
||||
processing_time_ms=result.processing_time_ms
|
||||
))
|
||||
|
||||
processing_time = int((asyncio.get_event_loop().time() - start_time) * 1000)
|
||||
|
||||
return BatchDetectionResponse(
|
||||
detections=detections,
|
||||
total_processed=len(request.texts),
|
||||
processing_time_ms=processing_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch language detection error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/languages", response_model=SupportedLanguagesResponse)
|
||||
async def get_supported_languages(
|
||||
engine: TranslationEngine = Depends(get_translation_engine),
|
||||
detector: LanguageDetector = Depends(get_language_detector)
|
||||
):
|
||||
"""
|
||||
Get list of supported languages for translation and detection
|
||||
"""
|
||||
try:
|
||||
translation_languages = engine.get_supported_languages()
|
||||
detection_languages = detector.get_supported_languages()
|
||||
|
||||
# Combine all languages
|
||||
all_languages = set()
|
||||
for lang_list in translation_languages.values():
|
||||
all_languages.update(lang_list)
|
||||
all_languages.update(detection_languages)
|
||||
|
||||
return SupportedLanguagesResponse(
|
||||
languages=translation_languages,
|
||||
total_languages=len(all_languages)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Get supported languages error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/cache/stats")
|
||||
async def get_cache_stats(cache: Optional[TranslationCache] = Depends(get_translation_cache)):
|
||||
"""
|
||||
Get translation cache statistics
|
||||
"""
|
||||
if not cache:
|
||||
raise HTTPException(status_code=404, detail="Cache service not available")
|
||||
|
||||
try:
|
||||
stats = await cache.get_cache_stats()
|
||||
return JSONResponse(content=stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache stats error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/cache/clear")
|
||||
async def clear_cache(
|
||||
source_language: Optional[str] = None,
|
||||
target_language: Optional[str] = None,
|
||||
cache: Optional[TranslationCache] = Depends(get_translation_cache)
|
||||
):
|
||||
"""
|
||||
Clear translation cache (optionally by language pair)
|
||||
"""
|
||||
if not cache:
|
||||
raise HTTPException(status_code=404, detail="Cache service not available")
|
||||
|
||||
try:
|
||||
if source_language and target_language:
|
||||
cleared_count = await cache.clear_by_language_pair(source_language, target_language)
|
||||
return {"cleared_count": cleared_count, "scope": f"{source_language}->{target_language}"}
|
||||
else:
|
||||
# Clear entire cache
|
||||
# This would need to be implemented in the cache service
|
||||
return {"message": "Full cache clear not implemented yet"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache clear error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health_check(
|
||||
engine: TranslationEngine = Depends(get_translation_engine),
|
||||
detector: LanguageDetector = Depends(get_language_detector),
|
||||
cache: Optional[TranslationCache] = Depends(get_translation_cache),
|
||||
quality_checker: Optional[TranslationQualityChecker] = Depends(get_quality_checker)
|
||||
):
|
||||
"""
|
||||
Health check for all multi-language services
|
||||
"""
|
||||
try:
|
||||
services = {}
|
||||
|
||||
# Check translation engine
|
||||
translation_health = await engine.health_check()
|
||||
services["translation_engine"] = all(translation_health.values())
|
||||
|
||||
# Check language detector
|
||||
detection_health = await detector.health_check()
|
||||
services["language_detector"] = all(detection_health.values())
|
||||
|
||||
# Check cache
|
||||
if cache:
|
||||
cache_health = await cache.health_check()
|
||||
services["translation_cache"] = cache_health.get("status") == "healthy"
|
||||
else:
|
||||
services["translation_cache"] = False
|
||||
|
||||
# Check quality checker
|
||||
if quality_checker:
|
||||
quality_health = await quality_checker.health_check()
|
||||
services["quality_checker"] = all(quality_health.values())
|
||||
else:
|
||||
services["quality_checker"] = False
|
||||
|
||||
# Overall status
|
||||
all_healthy = all(services.values())
|
||||
status = "healthy" if all_healthy else "degraded" if any(services.values()) else "unhealthy"
|
||||
|
||||
return HealthResponse(
|
||||
status=status,
|
||||
services=services,
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health check error: {e}")
|
||||
return HealthResponse(
|
||||
status="unhealthy",
|
||||
services={"error": str(e)},
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
|
||||
@router.get("/cache/top-translations")
|
||||
async def get_top_translations(
|
||||
limit: int = 100,
|
||||
cache: Optional[TranslationCache] = Depends(get_translation_cache)
|
||||
):
|
||||
"""
|
||||
Get most accessed translations from cache
|
||||
"""
|
||||
if not cache:
|
||||
raise HTTPException(status_code=404, detail="Cache service not available")
|
||||
|
||||
try:
|
||||
top_translations = await cache.get_top_translations(limit)
|
||||
return JSONResponse(content={"translations": top_translations})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Get top translations error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/cache/optimize")
|
||||
async def optimize_cache(cache: Optional[TranslationCache] = Depends(get_translation_cache)):
|
||||
"""
|
||||
Optimize cache by removing low-access entries
|
||||
"""
|
||||
if not cache:
|
||||
raise HTTPException(status_code=404, detail="Cache service not available")
|
||||
|
||||
try:
|
||||
optimization_result = await cache.optimize_cache()
|
||||
return JSONResponse(content=optimization_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache optimization error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Error handlers
|
||||
@router.exception_handler(ValueError)
|
||||
async def value_error_handler(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "Validation error", "details": str(exc)}
|
||||
)
|
||||
|
||||
@router.exception_handler(Exception)
|
||||
async def general_exception_handler(request, exc):
|
||||
logger.error(f"Unhandled exception: {exc}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "Internal server error", "details": str(exc)}
|
||||
)
|
||||
393
apps/coordinator-api/src/app/services/multi_language/config.py
Normal file
393
apps/coordinator-api/src/app/services/multi_language/config.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
Multi-Language Configuration
|
||||
Configuration file for multi-language services
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
class MultiLanguageConfig:
|
||||
"""Configuration class for multi-language services"""
|
||||
|
||||
def __init__(self):
|
||||
self.translation = self._get_translation_config()
|
||||
self.cache = self._get_cache_config()
|
||||
self.detection = self._get_detection_config()
|
||||
self.quality = self._get_quality_config()
|
||||
self.api = self._get_api_config()
|
||||
self.localization = self._get_localization_config()
|
||||
|
||||
def _get_translation_config(self) -> Dict[str, Any]:
|
||||
"""Translation service configuration"""
|
||||
return {
|
||||
"providers": {
|
||||
"openai": {
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"model": "gpt-4",
|
||||
"max_tokens": 2000,
|
||||
"temperature": 0.3,
|
||||
"timeout": 30,
|
||||
"retry_attempts": 3,
|
||||
"rate_limit": {
|
||||
"requests_per_minute": 60,
|
||||
"tokens_per_minute": 40000
|
||||
}
|
||||
},
|
||||
"google": {
|
||||
"api_key": os.getenv("GOOGLE_TRANSLATE_API_KEY"),
|
||||
"project_id": os.getenv("GOOGLE_PROJECT_ID"),
|
||||
"timeout": 10,
|
||||
"retry_attempts": 3,
|
||||
"rate_limit": {
|
||||
"requests_per_minute": 100,
|
||||
"characters_per_minute": 100000
|
||||
}
|
||||
},
|
||||
"deepl": {
|
||||
"api_key": os.getenv("DEEPL_API_KEY"),
|
||||
"timeout": 15,
|
||||
"retry_attempts": 3,
|
||||
"rate_limit": {
|
||||
"requests_per_minute": 60,
|
||||
"characters_per_minute": 50000
|
||||
}
|
||||
},
|
||||
"local": {
|
||||
"model_path": os.getenv("LOCAL_MODEL_PATH", "models/translation"),
|
||||
"timeout": 5,
|
||||
"max_text_length": 5000
|
||||
}
|
||||
},
|
||||
"fallback_strategy": {
|
||||
"primary": "openai",
|
||||
"secondary": "google",
|
||||
"tertiary": "deepl",
|
||||
"local": "local"
|
||||
},
|
||||
"quality_thresholds": {
|
||||
"minimum_confidence": 0.6,
|
||||
"cache_eligibility": 0.8,
|
||||
"auto_retry": 0.4
|
||||
}
|
||||
}
|
||||
|
||||
def _get_cache_config(self) -> Dict[str, Any]:
|
||||
"""Cache service configuration"""
|
||||
return {
|
||||
"redis": {
|
||||
"url": os.getenv("REDIS_URL", "redis://localhost:6379"),
|
||||
"password": os.getenv("REDIS_PASSWORD"),
|
||||
"db": int(os.getenv("REDIS_DB", 0)),
|
||||
"max_connections": 20,
|
||||
"retry_on_timeout": True,
|
||||
"socket_timeout": 5,
|
||||
"socket_connect_timeout": 5
|
||||
},
|
||||
"cache_settings": {
|
||||
"default_ttl": 86400, # 24 hours
|
||||
"max_ttl": 604800, # 7 days
|
||||
"min_ttl": 300, # 5 minutes
|
||||
"max_cache_size": 100000,
|
||||
"cleanup_interval": 3600, # 1 hour
|
||||
"compression_threshold": 1000 # Compress entries larger than 1KB
|
||||
},
|
||||
"optimization": {
|
||||
"enable_auto_optimize": True,
|
||||
"optimization_threshold": 0.8, # Optimize when 80% full
|
||||
"eviction_policy": "least_accessed",
|
||||
"batch_size": 100
|
||||
}
|
||||
}
|
||||
|
||||
def _get_detection_config(self) -> Dict[str, Any]:
|
||||
"""Language detection configuration"""
|
||||
return {
|
||||
"methods": {
|
||||
"langdetect": {
|
||||
"enabled": True,
|
||||
"priority": 1,
|
||||
"min_text_length": 10,
|
||||
"max_text_length": 10000
|
||||
},
|
||||
"polyglot": {
|
||||
"enabled": True,
|
||||
"priority": 2,
|
||||
"min_text_length": 5,
|
||||
"max_text_length": 5000
|
||||
},
|
||||
"fasttext": {
|
||||
"enabled": True,
|
||||
"priority": 3,
|
||||
"model_path": os.getenv("FASTTEXT_MODEL_PATH", "models/lid.176.bin"),
|
||||
"min_text_length": 1,
|
||||
"max_text_length": 100000
|
||||
}
|
||||
},
|
||||
"ensemble": {
|
||||
"enabled": True,
|
||||
"voting_method": "weighted",
|
||||
"min_confidence": 0.5,
|
||||
"max_alternatives": 5
|
||||
},
|
||||
"fallback": {
|
||||
"default_language": "en",
|
||||
"confidence_threshold": 0.3
|
||||
}
|
||||
}
|
||||
|
||||
def _get_quality_config(self) -> Dict[str, Any]:
|
||||
"""Quality assessment configuration"""
|
||||
return {
|
||||
"thresholds": {
|
||||
"overall": 0.7,
|
||||
"bleu": 0.3,
|
||||
"semantic_similarity": 0.6,
|
||||
"length_ratio": 0.5,
|
||||
"confidence": 0.6,
|
||||
"consistency": 0.4
|
||||
},
|
||||
"weights": {
|
||||
"confidence": 0.3,
|
||||
"length_ratio": 0.2,
|
||||
"semantic_similarity": 0.3,
|
||||
"bleu": 0.2,
|
||||
"consistency": 0.1
|
||||
},
|
||||
"models": {
|
||||
"spacy_models": {
|
||||
"en": "en_core_web_sm",
|
||||
"zh": "zh_core_web_sm",
|
||||
"es": "es_core_news_sm",
|
||||
"fr": "fr_core_news_sm",
|
||||
"de": "de_core_news_sm",
|
||||
"ja": "ja_core_news_sm",
|
||||
"ko": "ko_core_news_sm",
|
||||
"ru": "ru_core_news_sm"
|
||||
},
|
||||
"download_missing": True,
|
||||
"fallback_model": "en_core_web_sm"
|
||||
},
|
||||
"features": {
|
||||
"enable_bleu": True,
|
||||
"enable_semantic": True,
|
||||
"enable_consistency": True,
|
||||
"enable_length_check": True
|
||||
}
|
||||
}
|
||||
|
||||
def _get_api_config(self) -> Dict[str, Any]:
|
||||
"""API configuration"""
|
||||
return {
|
||||
"rate_limiting": {
|
||||
"enabled": True,
|
||||
"requests_per_minute": {
|
||||
"default": 100,
|
||||
"premium": 1000,
|
||||
"enterprise": 10000
|
||||
},
|
||||
"burst_size": 10,
|
||||
"strategy": "fixed_window"
|
||||
},
|
||||
"request_limits": {
|
||||
"max_text_length": 10000,
|
||||
"max_batch_size": 100,
|
||||
"max_concurrent_requests": 50
|
||||
},
|
||||
"response_format": {
|
||||
"include_confidence": True,
|
||||
"include_provider": True,
|
||||
"include_processing_time": True,
|
||||
"include_cache_info": True
|
||||
},
|
||||
"security": {
|
||||
"enable_api_key_auth": True,
|
||||
"enable_jwt_auth": True,
|
||||
"cors_origins": ["*"],
|
||||
"max_request_size": "10MB"
|
||||
}
|
||||
}
|
||||
|
||||
def _get_localization_config(self) -> Dict[str, Any]:
|
||||
"""Localization configuration"""
|
||||
return {
|
||||
"default_language": "en",
|
||||
"supported_languages": [
|
||||
"en", "zh", "zh-cn", "zh-tw", "es", "fr", "de", "ja", "ko",
|
||||
"ru", "ar", "hi", "pt", "it", "nl", "sv", "da", "no", "fi",
|
||||
"pl", "tr", "th", "vi", "id", "ms", "tl", "sw", "zu", "xh"
|
||||
],
|
||||
"auto_detect": True,
|
||||
"fallback_language": "en",
|
||||
"template_cache": {
|
||||
"enabled": True,
|
||||
"ttl": 3600, # 1 hour
|
||||
"max_size": 10000
|
||||
},
|
||||
"ui_settings": {
|
||||
"show_language_selector": True,
|
||||
"show_original_text": False,
|
||||
"auto_translate": True,
|
||||
"quality_indicator": True
|
||||
}
|
||||
}
|
||||
|
||||
def get_database_config(self) -> Dict[str, Any]:
|
||||
"""Database configuration"""
|
||||
return {
|
||||
"connection_string": os.getenv("DATABASE_URL"),
|
||||
"pool_size": int(os.getenv("DB_POOL_SIZE", 10)),
|
||||
"max_overflow": int(os.getenv("DB_MAX_OVERFLOW", 20)),
|
||||
"pool_timeout": int(os.getenv("DB_POOL_TIMEOUT", 30)),
|
||||
"pool_recycle": int(os.getenv("DB_POOL_RECYCLE", 3600)),
|
||||
"echo": os.getenv("DB_ECHO", "false").lower() == "true"
|
||||
}
|
||||
|
||||
def get_monitoring_config(self) -> Dict[str, Any]:
|
||||
"""Monitoring and logging configuration"""
|
||||
return {
|
||||
"logging": {
|
||||
"level": os.getenv("LOG_LEVEL", "INFO"),
|
||||
"format": "json",
|
||||
"enable_performance_logs": True,
|
||||
"enable_error_logs": True,
|
||||
"enable_access_logs": True
|
||||
},
|
||||
"metrics": {
|
||||
"enabled": True,
|
||||
"endpoint": "/metrics",
|
||||
"include_cache_metrics": True,
|
||||
"include_translation_metrics": True,
|
||||
"include_quality_metrics": True
|
||||
},
|
||||
"health_checks": {
|
||||
"enabled": True,
|
||||
"endpoint": "/health",
|
||||
"interval": 30, # seconds
|
||||
"timeout": 10
|
||||
},
|
||||
"alerts": {
|
||||
"enabled": True,
|
||||
"thresholds": {
|
||||
"error_rate": 0.05, # 5%
|
||||
"response_time_p95": 1000, # 1 second
|
||||
"cache_hit_ratio": 0.7, # 70%
|
||||
"quality_score_avg": 0.6 # 60%
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def get_deployment_config(self) -> Dict[str, Any]:
|
||||
"""Deployment configuration"""
|
||||
return {
|
||||
"environment": os.getenv("ENVIRONMENT", "development"),
|
||||
"debug": os.getenv("DEBUG", "false").lower() == "true",
|
||||
"workers": int(os.getenv("WORKERS", 4)),
|
||||
"host": os.getenv("HOST", "0.0.0.0"),
|
||||
"port": int(os.getenv("PORT", 8000)),
|
||||
"ssl": {
|
||||
"enabled": os.getenv("SSL_ENABLED", "false").lower() == "true",
|
||||
"cert_path": os.getenv("SSL_CERT_PATH"),
|
||||
"key_path": os.getenv("SSL_KEY_PATH")
|
||||
},
|
||||
"scaling": {
|
||||
"auto_scaling": os.getenv("AUTO_SCALING", "false").lower() == "true",
|
||||
"min_instances": int(os.getenv("MIN_INSTANCES", 1)),
|
||||
"max_instances": int(os.getenv("MAX_INSTANCES", 10)),
|
||||
"target_cpu": 70,
|
||||
"target_memory": 80
|
||||
}
|
||||
}
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""Validate configuration and return list of issues"""
|
||||
issues = []
|
||||
|
||||
# Check required API keys
|
||||
if not self.translation["providers"]["openai"]["api_key"]:
|
||||
issues.append("OpenAI API key not configured")
|
||||
|
||||
if not self.translation["providers"]["google"]["api_key"]:
|
||||
issues.append("Google Translate API key not configured")
|
||||
|
||||
if not self.translation["providers"]["deepl"]["api_key"]:
|
||||
issues.append("DeepL API key not configured")
|
||||
|
||||
# Check Redis configuration
|
||||
if not self.cache["redis"]["url"]:
|
||||
issues.append("Redis URL not configured")
|
||||
|
||||
# Check database configuration
|
||||
if not self.get_database_config()["connection_string"]:
|
||||
issues.append("Database connection string not configured")
|
||||
|
||||
# Check FastText model
|
||||
if self.detection["methods"]["fasttext"]["enabled"]:
|
||||
model_path = self.detection["methods"]["fasttext"]["model_path"]
|
||||
if not os.path.exists(model_path):
|
||||
issues.append(f"FastText model not found at {model_path}")
|
||||
|
||||
# Validate thresholds
|
||||
quality_thresholds = self.quality["thresholds"]
|
||||
for metric, threshold in quality_thresholds.items():
|
||||
if not 0 <= threshold <= 1:
|
||||
issues.append(f"Invalid threshold for {metric}: {threshold}")
|
||||
|
||||
return issues
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert configuration to dictionary"""
|
||||
return {
|
||||
"translation": self.translation,
|
||||
"cache": self.cache,
|
||||
"detection": self.detection,
|
||||
"quality": self.quality,
|
||||
"api": self.api,
|
||||
"localization": self.localization,
|
||||
"database": self.get_database_config(),
|
||||
"monitoring": self.get_monitoring_config(),
|
||||
"deployment": self.get_deployment_config()
|
||||
}
|
||||
|
||||
# Environment-specific configurations
|
||||
class DevelopmentConfig(MultiLanguageConfig):
|
||||
"""Development environment configuration"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cache["redis"]["url"] = "redis://localhost:6379/1"
|
||||
self.monitoring["logging"]["level"] = "DEBUG"
|
||||
self.deployment["debug"] = True
|
||||
|
||||
class ProductionConfig(MultiLanguageConfig):
|
||||
"""Production environment configuration"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.monitoring["logging"]["level"] = "INFO"
|
||||
self.deployment["debug"] = False
|
||||
self.api["rate_limiting"]["enabled"] = True
|
||||
self.cache["cache_settings"]["default_ttl"] = 86400 # 24 hours
|
||||
|
||||
class TestingConfig(MultiLanguageConfig):
|
||||
"""Testing environment configuration"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cache["redis"]["url"] = "redis://localhost:6379/15"
|
||||
self.translation["providers"]["local"]["model_path"] = "tests/fixtures/models"
|
||||
self.quality["features"]["enable_bleu"] = False # Disable for faster tests
|
||||
|
||||
# Configuration factory
|
||||
def get_config() -> MultiLanguageConfig:
|
||||
"""Get configuration based on environment"""
|
||||
environment = os.getenv("ENVIRONMENT", "development").lower()
|
||||
|
||||
if environment == "production":
|
||||
return ProductionConfig()
|
||||
elif environment == "testing":
|
||||
return TestingConfig()
|
||||
else:
|
||||
return DevelopmentConfig()
|
||||
|
||||
# Export configuration
|
||||
config = get_config()
|
||||
@@ -0,0 +1,436 @@
|
||||
-- Multi-Language Support Database Schema
|
||||
-- Migration script for adding multi-language support to AITBC platform
|
||||
|
||||
-- 1. Translation cache table
|
||||
CREATE TABLE IF NOT EXISTS translation_cache (
|
||||
id SERIAL PRIMARY KEY,
|
||||
cache_key VARCHAR(255) UNIQUE NOT NULL,
|
||||
source_text TEXT NOT NULL,
|
||||
source_language VARCHAR(10) NOT NULL,
|
||||
target_language VARCHAR(10) NOT NULL,
|
||||
translated_text TEXT NOT NULL,
|
||||
provider VARCHAR(50) NOT NULL,
|
||||
confidence FLOAT NOT NULL,
|
||||
processing_time_ms INTEGER NOT NULL,
|
||||
context TEXT,
|
||||
domain VARCHAR(50),
|
||||
access_count INTEGER DEFAULT 1,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
last_accessed TIMESTAMP DEFAULT NOW(),
|
||||
expires_at TIMESTAMP,
|
||||
|
||||
-- Indexes for performance
|
||||
INDEX idx_cache_key (cache_key),
|
||||
INDEX idx_source_target (source_language, target_language),
|
||||
INDEX idx_provider (provider),
|
||||
INDEX idx_created_at (created_at),
|
||||
INDEX idx_expires_at (expires_at)
|
||||
);
|
||||
|
||||
-- 2. Supported languages registry
|
||||
CREATE TABLE IF NOT EXISTS supported_languages (
|
||||
id VARCHAR(10) PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
native_name VARCHAR(100) NOT NULL,
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
translation_engine VARCHAR(50),
|
||||
detection_supported BOOLEAN DEFAULT TRUE,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- 3. Agent language preferences
|
||||
ALTER TABLE agents ADD COLUMN IF NOT EXISTS preferred_language VARCHAR(10) DEFAULT 'en';
|
||||
ALTER TABLE agents ADD COLUMN IF NOT EXISTS supported_languages TEXT[] DEFAULT ARRAY['en'];
|
||||
ALTER TABLE agents ADD COLUMN IF NOT EXISTS auto_translate_enabled BOOLEAN DEFAULT TRUE;
|
||||
ALTER TABLE agents ADD COLUMN IF NOT EXISTS translation_quality_threshold FLOAT DEFAULT 0.7;
|
||||
|
||||
-- 4. Multi-language marketplace listings
|
||||
CREATE TABLE IF NOT EXISTS marketplace_listings_i18n (
|
||||
id SERIAL PRIMARY KEY,
|
||||
listing_id INTEGER NOT NULL REFERENCES marketplace_listings(id) ON DELETE CASCADE,
|
||||
language VARCHAR(10) NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
keywords TEXT[],
|
||||
features TEXT[],
|
||||
requirements TEXT[],
|
||||
translated_at TIMESTAMP DEFAULT NOW(),
|
||||
translation_confidence FLOAT,
|
||||
translator_provider VARCHAR(50),
|
||||
|
||||
-- Unique constraint per listing and language
|
||||
UNIQUE(listing_id, language),
|
||||
|
||||
-- Indexes
|
||||
INDEX idx_listing_language (listing_id, language),
|
||||
INDEX idx_language (language),
|
||||
INDEX idx_keywords USING GIN (keywords),
|
||||
INDEX idx_translated_at (translated_at)
|
||||
);
|
||||
|
||||
-- 5. Agent communication translations
|
||||
CREATE TABLE IF NOT EXISTS agent_message_translations (
|
||||
id SERIAL PRIMARY KEY,
|
||||
message_id INTEGER NOT NULL REFERENCES agent_messages(id) ON DELETE CASCADE,
|
||||
source_language VARCHAR(10) NOT NULL,
|
||||
target_language VARCHAR(10) NOT NULL,
|
||||
original_text TEXT NOT NULL,
|
||||
translated_text TEXT NOT NULL,
|
||||
provider VARCHAR(50) NOT NULL,
|
||||
confidence FLOAT NOT NULL,
|
||||
translation_time_ms INTEGER NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
|
||||
-- Indexes
|
||||
INDEX idx_message_id (message_id),
|
||||
INDEX idx_source_target (source_language, target_language),
|
||||
INDEX idx_created_at (created_at)
|
||||
);
|
||||
|
||||
-- 6. Translation quality logs
|
||||
CREATE TABLE IF NOT EXISTS translation_quality_logs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
source_text TEXT NOT NULL,
|
||||
translated_text TEXT NOT NULL,
|
||||
source_language VARCHAR(10) NOT NULL,
|
||||
target_language VARCHAR(10) NOT NULL,
|
||||
provider VARCHAR(50) NOT NULL,
|
||||
overall_score FLOAT NOT NULL,
|
||||
bleu_score FLOAT,
|
||||
semantic_similarity FLOAT,
|
||||
length_ratio FLOAT,
|
||||
confidence_score FLOAT,
|
||||
consistency_score FLOAT,
|
||||
passed_threshold BOOLEAN NOT NULL,
|
||||
recommendations TEXT[],
|
||||
processing_time_ms INTEGER NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
|
||||
-- Indexes
|
||||
INDEX idx_provider_date (provider, created_at),
|
||||
INDEX idx_score (overall_score),
|
||||
INDEX idx_threshold (passed_threshold),
|
||||
INDEX idx_created_at (created_at)
|
||||
);
|
||||
|
||||
-- 7. User language preferences
|
||||
CREATE TABLE IF NOT EXISTS user_language_preferences (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
language VARCHAR(10) NOT NULL,
|
||||
is_primary BOOLEAN DEFAULT FALSE,
|
||||
auto_translate BOOLEAN DEFAULT TRUE,
|
||||
show_original BOOLEAN DEFAULT FALSE,
|
||||
quality_threshold FLOAT DEFAULT 0.7,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW(),
|
||||
|
||||
-- Unique constraint per user and language
|
||||
UNIQUE(user_id, language),
|
||||
|
||||
-- Indexes
|
||||
INDEX idx_user_id (user_id),
|
||||
INDEX idx_language (language),
|
||||
INDEX idx_primary (is_primary)
|
||||
);
|
||||
|
||||
-- 8. Translation statistics
|
||||
CREATE TABLE IF NOT EXISTS translation_statistics (
|
||||
id SERIAL PRIMARY KEY,
|
||||
date DATE NOT NULL,
|
||||
source_language VARCHAR(10) NOT NULL,
|
||||
target_language VARCHAR(10) NOT NULL,
|
||||
provider VARCHAR(50) NOT NULL,
|
||||
total_translations INTEGER DEFAULT 0,
|
||||
successful_translations INTEGER DEFAULT 0,
|
||||
failed_translations INTEGER DEFAULT 0,
|
||||
cache_hits INTEGER DEFAULT 0,
|
||||
cache_misses INTEGER DEFAULT 0,
|
||||
avg_confidence FLOAT DEFAULT 0,
|
||||
avg_processing_time_ms INTEGER DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW(),
|
||||
|
||||
-- Unique constraint per date and language pair
|
||||
UNIQUE(date, source_language, target_language, provider),
|
||||
|
||||
-- Indexes
|
||||
INDEX idx_date (date),
|
||||
INDEX idx_language_pair (source_language, target_language),
|
||||
INDEX idx_provider (provider)
|
||||
);
|
||||
|
||||
-- 9. Content localization templates
|
||||
CREATE TABLE IF NOT EXISTS localization_templates (
|
||||
id SERIAL PRIMARY KEY,
|
||||
template_key VARCHAR(255) NOT NULL,
|
||||
language VARCHAR(10) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
variables TEXT[],
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW(),
|
||||
|
||||
-- Unique constraint per template key and language
|
||||
UNIQUE(template_key, language),
|
||||
|
||||
-- Indexes
|
||||
INDEX idx_template_key (template_key),
|
||||
INDEX idx_language (language)
|
||||
);
|
||||
|
||||
-- 10. Translation API usage logs
|
||||
CREATE TABLE IF NOT EXISTS translation_api_logs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
endpoint VARCHAR(255) NOT NULL,
|
||||
method VARCHAR(10) NOT NULL,
|
||||
source_language VARCHAR(10),
|
||||
target_language VARCHAR(10),
|
||||
text_length INTEGER,
|
||||
processing_time_ms INTEGER NOT NULL,
|
||||
status_code INTEGER NOT NULL,
|
||||
error_message TEXT,
|
||||
user_id INTEGER REFERENCES users(id),
|
||||
ip_address INET,
|
||||
user_agent TEXT,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
|
||||
-- Indexes
|
||||
INDEX idx_endpoint (endpoint),
|
||||
INDEX idx_created_at (created_at),
|
||||
INDEX idx_status_code (status_code),
|
||||
INDEX idx_user_id (user_id)
|
||||
);
|
||||
|
||||
-- Insert supported languages
|
||||
INSERT INTO supported_languages (id, name, native_name, is_active, translation_engine, detection_supported) VALUES
|
||||
('en', 'English', 'English', TRUE, 'openai', TRUE),
|
||||
('zh', 'Chinese', '中文', TRUE, 'openai', TRUE),
|
||||
('zh-cn', 'Chinese (Simplified)', '简体中文', TRUE, 'openai', TRUE),
|
||||
('zh-tw', 'Chinese (Traditional)', '繁體中文', TRUE, 'openai', TRUE),
|
||||
('es', 'Spanish', 'Español', TRUE, 'openai', TRUE),
|
||||
('fr', 'French', 'Français', TRUE, 'deepl', TRUE),
|
||||
('de', 'German', 'Deutsch', TRUE, 'deepl', TRUE),
|
||||
('ja', 'Japanese', '日本語', TRUE, 'openai', TRUE),
|
||||
('ko', 'Korean', '한국어', TRUE, 'openai', TRUE),
|
||||
('ru', 'Russian', 'Русский', TRUE, 'openai', TRUE),
|
||||
('ar', 'Arabic', 'العربية', TRUE, 'openai', TRUE),
|
||||
('hi', 'Hindi', 'हिन्दी', TRUE, 'openai', TRUE),
|
||||
('pt', 'Portuguese', 'Português', TRUE, 'openai', TRUE),
|
||||
('it', 'Italian', 'Italiano', TRUE, 'deepl', TRUE),
|
||||
('nl', 'Dutch', 'Nederlands', TRUE, 'google', TRUE),
|
||||
('sv', 'Swedish', 'Svenska', TRUE, 'google', TRUE),
|
||||
('da', 'Danish', 'Dansk', TRUE, 'google', TRUE),
|
||||
('no', 'Norwegian', 'Norsk', TRUE, 'google', TRUE),
|
||||
('fi', 'Finnish', 'Suomi', TRUE, 'google', TRUE),
|
||||
('pl', 'Polish', 'Polski', TRUE, 'google', TRUE),
|
||||
('tr', 'Turkish', 'Türkçe', TRUE, 'google', TRUE),
|
||||
('th', 'Thai', 'ไทย', TRUE, 'openai', TRUE),
|
||||
('vi', 'Vietnamese', 'Tiếng Việt', TRUE, 'openai', TRUE),
|
||||
('id', 'Indonesian', 'Bahasa Indonesia', TRUE, 'google', TRUE),
|
||||
('ms', 'Malay', 'Bahasa Melayu', TRUE, 'google', TRUE),
|
||||
('tl', 'Filipino', 'Filipino', TRUE, 'google', TRUE),
|
||||
('sw', 'Swahili', 'Kiswahili', TRUE, 'google', TRUE),
|
||||
('zu', 'Zulu', 'IsiZulu', TRUE, 'google', TRUE),
|
||||
('xh', 'Xhosa', 'isiXhosa', TRUE, 'google', TRUE),
|
||||
('af', 'Afrikaans', 'Afrikaans', TRUE, 'google', TRUE),
|
||||
('is', 'Icelandic', 'Íslenska', TRUE, 'google', TRUE),
|
||||
('mt', 'Maltese', 'Malti', TRUE, 'google', TRUE),
|
||||
('cy', 'Welsh', 'Cymraeg', TRUE, 'google', TRUE),
|
||||
('ga', 'Irish', 'Gaeilge', TRUE, 'google', TRUE),
|
||||
('gd', 'Scottish Gaelic', 'Gàidhlig', TRUE, 'google', TRUE),
|
||||
('eu', 'Basque', 'Euskara', TRUE, 'google', TRUE),
|
||||
('ca', 'Catalan', 'Català', TRUE, 'google', TRUE),
|
||||
('gl', 'Galician', 'Galego', TRUE, 'google', TRUE),
|
||||
('ast', 'Asturian', 'Asturianu', TRUE, 'google', TRUE),
|
||||
('lb', 'Luxembourgish', 'Lëtzebuergesch', TRUE, 'google', TRUE),
|
||||
('rm', 'Romansh', 'Rumantsch', TRUE, 'google', TRUE),
|
||||
('fur', 'Friulian', 'Furlan', TRUE, 'google', TRUE),
|
||||
('lld', 'Ladin', 'Ladin', TRUE, 'google', TRUE),
|
||||
('lij', 'Ligurian', 'Ligure', TRUE, 'google', TRUE),
|
||||
('lmo', 'Lombard', 'Lombard', TRUE, 'google', TRUE),
|
||||
('vec', 'Venetian', 'Vèneto', TRUE, 'google', TRUE),
|
||||
('scn', 'Sicilian', 'Sicilianu', TRUE, 'google', TRUE),
|
||||
('ro', 'Romanian', 'Română', TRUE, 'google', TRUE),
|
||||
('mo', 'Moldovan', 'Moldovenească', TRUE, 'google', TRUE),
|
||||
('hr', 'Croatian', 'Hrvatski', TRUE, 'google', TRUE),
|
||||
('sr', 'Serbian', 'Српски', TRUE, 'google', TRUE),
|
||||
('sl', 'Slovenian', 'Slovenščina', TRUE, 'google', TRUE),
|
||||
('sk', 'Slovak', 'Slovenčina', TRUE, 'google', TRUE),
|
||||
('cs', 'Czech', 'Čeština', TRUE, 'google', TRUE),
|
||||
('bg', 'Bulgarian', 'Български', TRUE, 'google', TRUE),
|
||||
('mk', 'Macedonian', 'Македонски', TRUE, 'google', TRUE),
|
||||
('sq', 'Albanian', 'Shqip', TRUE, 'google', TRUE),
|
||||
('hy', 'Armenian', 'Հայերեն', TRUE, 'google', TRUE),
|
||||
('ka', 'Georgian', 'ქართული', TRUE, 'google', TRUE),
|
||||
('he', 'Hebrew', 'עברית', TRUE, 'openai', TRUE),
|
||||
('yi', 'Yiddish', 'ייִדיש', TRUE, 'google', TRUE),
|
||||
('fa', 'Persian', 'فارسی', TRUE, 'openai', TRUE),
|
||||
('ps', 'Pashto', 'پښتو', TRUE, 'google', TRUE),
|
||||
('ur', 'Urdu', 'اردو', TRUE, 'openai', TRUE),
|
||||
('bn', 'Bengali', 'বাংলা', TRUE, 'openai', TRUE),
|
||||
('as', 'Assamese', 'অসমীয়া', TRUE, 'google', TRUE),
|
||||
('or', 'Odia', 'ଓଡ଼ିଆ', TRUE, 'google', TRUE),
|
||||
('pa', 'Punjabi', 'ਪੰਜਾਬੀ', TRUE, 'google', TRUE),
|
||||
('gu', 'Gujarati', 'ગુજરાતી', TRUE, 'google', TRUE),
|
||||
('mr', 'Marathi', 'मराठी', TRUE, 'google', TRUE),
|
||||
('ne', 'Nepali', 'नेपाली', TRUE, 'google', TRUE),
|
||||
('si', 'Sinhala', 'සිංහල', TRUE, 'google', TRUE),
|
||||
('ta', 'Tamil', 'தமிழ்', TRUE, 'openai', TRUE),
|
||||
('te', 'Telugu', 'తెలుగు', TRUE, 'google', TRUE),
|
||||
('ml', 'Malayalam', 'മലയാളം', TRUE, 'google', TRUE),
|
||||
('kn', 'Kannada', 'ಕನ್ನಡ', TRUE, 'google', TRUE),
|
||||
('my', 'Myanmar', 'မြန်မာ', TRUE, 'google', TRUE),
|
||||
('km', 'Khmer', 'ខ្មែរ', TRUE, 'google', TRUE),
|
||||
('lo', 'Lao', 'ລາວ', TRUE, 'google', TRUE)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
|
||||
-- Insert common localization templates
|
||||
INSERT INTO localization_templates (template_key, language, content, variables) VALUES
|
||||
('welcome_message', 'en', 'Welcome to AITBC!', []),
|
||||
('welcome_message', 'zh', '欢迎使用AITBC!', []),
|
||||
('welcome_message', 'es', '¡Bienvenido a AITBC!', []),
|
||||
('welcome_message', 'fr', 'Bienvenue sur AITBC!', []),
|
||||
('welcome_message', 'de', 'Willkommen bei AITBC!', []),
|
||||
('welcome_message', 'ja', 'AITBCへようこそ!', []),
|
||||
('welcome_message', 'ko', 'AITBC에 오신 것을 환영합니다!', []),
|
||||
('welcome_message', 'ru', 'Добро пожаловать в AITBC!', []),
|
||||
('welcome_message', 'ar', 'مرحبا بك في AITBC!', []),
|
||||
('welcome_message', 'hi', 'AITBC में आपका स्वागत है!', []),
|
||||
|
||||
('marketplace_title', 'en', 'AI Power Marketplace', []),
|
||||
('marketplace_title', 'zh', 'AI算力市场', []),
|
||||
('marketplace_title', 'es', 'Mercado de Poder de IA', []),
|
||||
('marketplace_title', 'fr', 'Marché de la Puissance IA', []),
|
||||
('marketplace_title', 'de', 'KI-Leistungsmarktplatz', []),
|
||||
('marketplace_title', 'ja', 'AIパワーマーケット', []),
|
||||
('marketplace_title', 'ko', 'AI 파워 마켓플레이스', []),
|
||||
('marketplace_title', 'ru', 'Рынок мощностей ИИ', []),
|
||||
('marketplace_title', 'ar', 'سوق قوة الذكاء الاصطناعي', []),
|
||||
('marketplace_title', 'hi', 'AI पावर मार्केटप्लेस', []),
|
||||
|
||||
('agent_status_online', 'en', 'Agent is online and ready', []),
|
||||
('agent_status_online', 'zh', '智能体在线并准备就绪', []),
|
||||
('agent_status_online', 'es', 'El agente está en línea y listo', []),
|
||||
('agent_status_online', 'fr', ''L'agent est en ligne et prêt', []),
|
||||
('agent_status_online', 'de', 'Agent ist online und bereit', []),
|
||||
('agent_status_online', 'ja', 'エージェントがオンラインで準備完了', []),
|
||||
('agent_status_online', 'ko', '에이전트가 온라인 상태이며 준비됨', []),
|
||||
('agent_status_online', 'ru', 'Агент в сети и готов', []),
|
||||
('agent_status_online', 'ar', 'العميل متصل وجاهز', []),
|
||||
('agent_status_online', 'hi', 'एजेंट ऑनलाइन और तैयार है', []),
|
||||
|
||||
('transaction_success', 'en', 'Transaction completed successfully', []),
|
||||
('transaction_success', 'zh', '交易成功完成', []),
|
||||
('transaction_success', 'es', 'Transacción completada exitosamente', []),
|
||||
('transaction_success', 'fr', 'Transaction terminée avec succès', []),
|
||||
('transaction_success', 'de', 'Transaktion erfolgreich abgeschlossen', []),
|
||||
('transaction_success', 'ja', '取引が正常に完了しました', []),
|
||||
('transaction_success', 'ko', '거래가 성공적으로 완료되었습니다', []),
|
||||
('transaction_success', 'ru', 'Транзакция успешно завершена', []),
|
||||
('transaction_success', 'ar', 'تمت المعاملة بنجاح', []),
|
||||
('transaction_success', 'hi', 'लेन-देन सफलतापूर्वक पूर्ण हुई', [])
|
||||
ON CONFLICT (template_key, language) DO NOTHING;
|
||||
|
||||
-- Create indexes for better performance
|
||||
CREATE INDEX IF NOT EXISTS idx_translation_cache_expires ON translation_cache(expires_at) WHERE expires_at IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_agent_messages_created_at ON agent_messages(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_marketplace_listings_created_at ON marketplace_listings(created_at);
|
||||
|
||||
-- Create function to update translation statistics
|
||||
CREATE OR REPLACE FUNCTION update_translation_stats()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
INSERT INTO translation_statistics (
|
||||
date, source_language, target_language, provider,
|
||||
total_translations, successful_translations, failed_translations,
|
||||
avg_confidence, avg_processing_time_ms
|
||||
) VALUES (
|
||||
CURRENT_DATE,
|
||||
COALESCE(NEW.source_language, 'unknown'),
|
||||
COALESCE(NEW.target_language, 'unknown'),
|
||||
COALESCE(NEW.provider, 'unknown'),
|
||||
1, 1, 0,
|
||||
COALESCE(NEW.confidence, 0),
|
||||
COALESCE(NEW.processing_time_ms, 0)
|
||||
)
|
||||
ON CONFLICT (date, source_language, target_language, provider)
|
||||
DO UPDATE SET
|
||||
total_translations = translation_statistics.total_translations + 1,
|
||||
successful_translations = translation_statistics.successful_translations + 1,
|
||||
avg_confidence = (translation_statistics.avg_confidence * translation_statistics.successful_translations + COALESCE(NEW.confidence, 0)) / (translation_statistics.successful_translations + 1),
|
||||
avg_processing_time_ms = (translation_statistics.avg_processing_time_ms * translation_statistics.successful_translations + COALESCE(NEW.processing_time_ms, 0)) / (translation_statistics.successful_translations + 1),
|
||||
updated_at = NOW();
|
||||
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- Create trigger for automatic statistics updates
|
||||
DROP TRIGGER IF EXISTS trigger_update_translation_stats ON translation_cache;
|
||||
CREATE TRIGGER trigger_update_translation_stats
|
||||
AFTER INSERT ON translation_cache
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_translation_stats();
|
||||
|
||||
-- Create function to clean up expired cache entries
|
||||
CREATE OR REPLACE FUNCTION cleanup_expired_cache()
|
||||
RETURNS INTEGER AS $$
|
||||
DECLARE
|
||||
deleted_count INTEGER;
|
||||
BEGIN
|
||||
DELETE FROM translation_cache
|
||||
WHERE expires_at IS NOT NULL AND expires_at < NOW();
|
||||
|
||||
GET DIAGNOSTICS deleted_count = ROW_COUNT;
|
||||
|
||||
RETURN deleted_count;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- Create view for translation analytics
|
||||
CREATE OR REPLACE VIEW translation_analytics AS
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
source_language,
|
||||
target_language,
|
||||
provider,
|
||||
COUNT(*) as total_translations,
|
||||
AVG(confidence) as avg_confidence,
|
||||
AVG(processing_time_ms) as avg_processing_time_ms,
|
||||
COUNT(CASE WHEN confidence > 0.8 THEN 1 END) as high_confidence_count,
|
||||
COUNT(CASE WHEN confidence < 0.5 THEN 1 END) as low_confidence_count
|
||||
FROM translation_cache
|
||||
GROUP BY DATE(created_at), source_language, target_language, provider
|
||||
ORDER BY date DESC;
|
||||
|
||||
-- Create view for cache performance metrics
|
||||
CREATE OR REPLACE VIEW cache_performance_metrics AS
|
||||
SELECT
|
||||
(SELECT COUNT(*) FROM translation_cache) as total_entries,
|
||||
(SELECT COUNT(*) FROM translation_cache WHERE created_at > NOW() - INTERVAL '24 hours') as entries_last_24h,
|
||||
(SELECT AVG(access_count) FROM translation_cache) as avg_access_count,
|
||||
(SELECT COUNT(*) FROM translation_cache WHERE access_count > 10) as popular_entries,
|
||||
(SELECT COUNT(*) FROM translation_cache WHERE expires_at < NOW()) as expired_entries,
|
||||
(SELECT AVG(confidence) FROM translation_cache) as avg_confidence,
|
||||
(SELECT AVG(processing_time_ms) FROM translation_cache) as avg_processing_time;
|
||||
|
||||
-- Grant permissions (adjust as needed for your setup)
|
||||
-- GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public TO aitbc_app;
|
||||
-- GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO aitbc_app;
|
||||
|
||||
-- Add comments for documentation
|
||||
COMMENT ON TABLE translation_cache IS 'Cache for translation results to improve performance';
|
||||
COMMENT ON TABLE supported_languages IS 'Registry of supported languages for translation and detection';
|
||||
COMMENT ON TABLE marketplace_listings_i18n IS 'Multi-language versions of marketplace listings';
|
||||
COMMENT ON TABLE agent_message_translations IS 'Translations of agent communications';
|
||||
COMMENT ON TABLE translation_quality_logs IS 'Quality assessment logs for translations';
|
||||
COMMENT ON TABLE user_language_preferences IS 'User language preferences and settings';
|
||||
COMMENT ON TABLE translation_statistics IS 'Daily translation usage statistics';
|
||||
COMMENT ON TABLE localization_templates IS 'Template strings for UI localization';
|
||||
COMMENT ON TABLE translation_api_logs IS 'API usage logs for monitoring and analytics';
|
||||
|
||||
-- Create partition for large tables (optional for high-volume deployments)
|
||||
-- This would be implemented based on actual usage patterns
|
||||
-- CREATE TABLE translation_cache_y2024m01 PARTITION OF translation_cache
|
||||
-- FOR VALUES FROM ('2024-01-01') TO ('2024-02-01');
|
||||
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
Language Detection Service
|
||||
Automatic language detection for multi-language support
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import langdetect
|
||||
from langdetect.lang_detect_exception import LangDetectException
|
||||
import polyglot
|
||||
from polyglot.detect import Detector
|
||||
import fasttext
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DetectionMethod(Enum):
|
||||
LANGDETECT = "langdetect"
|
||||
POLYGLOT = "polyglot"
|
||||
FASTTEXT = "fasttext"
|
||||
ENSEMBLE = "ensemble"
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
language: str
|
||||
confidence: float
|
||||
method: DetectionMethod
|
||||
alternatives: List[Tuple[str, float]]
|
||||
processing_time_ms: int
|
||||
|
||||
class LanguageDetector:
|
||||
"""Advanced language detection with multiple methods and ensemble voting"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.fasttext_model = None
|
||||
self._initialize_fasttext()
|
||||
|
||||
def _initialize_fasttext(self):
|
||||
"""Initialize FastText language detection model"""
|
||||
try:
|
||||
# Download lid.176.bin model if not present
|
||||
model_path = self.config.get("fasttext", {}).get("model_path", "lid.176.bin")
|
||||
self.fasttext_model = fasttext.load_model(model_path)
|
||||
logger.info("FastText model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"FastText model initialization failed: {e}")
|
||||
self.fasttext_model = None
|
||||
|
||||
async def detect_language(self, text: str, methods: Optional[List[DetectionMethod]] = None) -> DetectionResult:
|
||||
"""Detect language with specified methods or ensemble"""
|
||||
|
||||
if not methods:
|
||||
methods = [DetectionMethod.ENSEMBLE]
|
||||
|
||||
if DetectionMethod.ENSEMBLE in methods:
|
||||
return await self._ensemble_detection(text)
|
||||
|
||||
# Use single specified method
|
||||
method = methods[0]
|
||||
return await self._detect_with_method(text, method)
|
||||
|
||||
async def _detect_with_method(self, text: str, method: DetectionMethod) -> DetectionResult:
|
||||
"""Detect language using specific method"""
|
||||
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
if method == DetectionMethod.LANGDETECT:
|
||||
return await self._langdetect_method(text, start_time)
|
||||
elif method == DetectionMethod.POLYGLOT:
|
||||
return await self._polyglot_method(text, start_time)
|
||||
elif method == DetectionMethod.FASTTEXT:
|
||||
return await self._fasttext_method(text, start_time)
|
||||
else:
|
||||
raise ValueError(f"Unsupported detection method: {method}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Language detection failed with {method.value}: {e}")
|
||||
# Fallback to langdetect
|
||||
return await self._langdetect_method(text, start_time)
|
||||
|
||||
async def _langdetect_method(self, text: str, start_time: float) -> DetectionResult:
|
||||
"""Language detection using langdetect library"""
|
||||
|
||||
def detect():
|
||||
try:
|
||||
langs = langdetect.detect_langs(text)
|
||||
return langs
|
||||
except LangDetectException:
|
||||
# Fallback to basic detection
|
||||
return [langdetect.DetectLanguage("en", 1.0)]
|
||||
|
||||
langs = await asyncio.get_event_loop().run_in_executor(None, detect)
|
||||
|
||||
primary_lang = langs[0].lang
|
||||
confidence = langs[0].prob
|
||||
alternatives = [(lang.lang, lang.prob) for lang in langs[1:]]
|
||||
processing_time = int((asyncio.get_event_loop().time() - start_time) * 1000)
|
||||
|
||||
return DetectionResult(
|
||||
language=primary_lang,
|
||||
confidence=confidence,
|
||||
method=DetectionMethod.LANGDETECT,
|
||||
alternatives=alternatives,
|
||||
processing_time_ms=processing_time
|
||||
)
|
||||
|
||||
async def _polyglot_method(self, text: str, start_time: float) -> DetectionResult:
|
||||
"""Language detection using Polyglot library"""
|
||||
|
||||
def detect():
|
||||
try:
|
||||
detector = Detector(text)
|
||||
return detector
|
||||
except Exception as e:
|
||||
logger.warning(f"Polyglot detection failed: {e}")
|
||||
# Fallback
|
||||
class FallbackDetector:
|
||||
def __init__(self):
|
||||
self.language = "en"
|
||||
self.confidence = 0.5
|
||||
return FallbackDetector()
|
||||
|
||||
detector = await asyncio.get_event_loop().run_in_executor(None, detect)
|
||||
|
||||
primary_lang = detector.language
|
||||
confidence = getattr(detector, 'confidence', 0.8)
|
||||
alternatives = [] # Polyglot doesn't provide alternatives easily
|
||||
processing_time = int((asyncio.get_event_loop().time() - start_time) * 1000)
|
||||
|
||||
return DetectionResult(
|
||||
language=primary_lang,
|
||||
confidence=confidence,
|
||||
method=DetectionMethod.POLYGLOT,
|
||||
alternatives=alternatives,
|
||||
processing_time_ms=processing_time
|
||||
)
|
||||
|
||||
async def _fasttext_method(self, text: str, start_time: float) -> DetectionResult:
|
||||
"""Language detection using FastText model"""
|
||||
|
||||
if not self.fasttext_model:
|
||||
raise Exception("FastText model not available")
|
||||
|
||||
def detect():
|
||||
# FastText requires preprocessing
|
||||
processed_text = text.replace("\n", " ").strip()
|
||||
if len(processed_text) < 10:
|
||||
processed_text += " " * (10 - len(processed_text))
|
||||
|
||||
labels, probabilities = self.fasttext_model.predict(processed_text, k=5)
|
||||
|
||||
results = []
|
||||
for label, prob in zip(labels, probabilities):
|
||||
# Remove __label__ prefix
|
||||
lang = label.replace("__label__", "")
|
||||
results.append((lang, float(prob)))
|
||||
|
||||
return results
|
||||
|
||||
results = await asyncio.get_event_loop().run_in_executor(None, detect)
|
||||
|
||||
if not results:
|
||||
raise Exception("FastText detection failed")
|
||||
|
||||
primary_lang, confidence = results[0]
|
||||
alternatives = results[1:]
|
||||
processing_time = int((asyncio.get_event_loop().time() - start_time) * 1000)
|
||||
|
||||
return DetectionResult(
|
||||
language=primary_lang,
|
||||
confidence=confidence,
|
||||
method=DetectionMethod.FASTTEXT,
|
||||
alternatives=alternatives,
|
||||
processing_time_ms=processing_time
|
||||
)
|
||||
|
||||
async def _ensemble_detection(self, text: str) -> DetectionResult:
|
||||
"""Ensemble detection combining multiple methods"""
|
||||
|
||||
methods = [DetectionMethod.LANGDETECT, DetectionMethod.POLYGLOT]
|
||||
if self.fasttext_model:
|
||||
methods.append(DetectionMethod.FASTTEXT)
|
||||
|
||||
# Run detections in parallel
|
||||
tasks = [self._detect_with_method(text, method) for method in methods]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Filter successful results
|
||||
valid_results = []
|
||||
for result in results:
|
||||
if isinstance(result, DetectionResult):
|
||||
valid_results.append(result)
|
||||
else:
|
||||
logger.warning(f"Detection method failed: {result}")
|
||||
|
||||
if not valid_results:
|
||||
# Ultimate fallback
|
||||
return DetectionResult(
|
||||
language="en",
|
||||
confidence=0.5,
|
||||
method=DetectionMethod.LANGDETECT,
|
||||
alternatives=[],
|
||||
processing_time_ms=0
|
||||
)
|
||||
|
||||
# Ensemble voting
|
||||
return self._ensemble_voting(valid_results)
|
||||
|
||||
def _ensemble_voting(self, results: List[DetectionResult]) -> DetectionResult:
|
||||
"""Combine multiple detection results using weighted voting"""
|
||||
|
||||
# Weight by method reliability
|
||||
method_weights = {
|
||||
DetectionMethod.LANGDETECT: 0.3,
|
||||
DetectionMethod.POLYGLOT: 0.2,
|
||||
DetectionMethod.FASTTEXT: 0.5
|
||||
}
|
||||
|
||||
# Collect votes
|
||||
votes = {}
|
||||
total_confidence = 0
|
||||
total_processing_time = 0
|
||||
|
||||
for result in results:
|
||||
weight = method_weights.get(result.method, 0.1)
|
||||
weighted_confidence = result.confidence * weight
|
||||
|
||||
if result.language not in votes:
|
||||
votes[result.language] = 0
|
||||
votes[result.language] += weighted_confidence
|
||||
|
||||
total_confidence += weighted_confidence
|
||||
total_processing_time += result.processing_time_ms
|
||||
|
||||
# Find winner
|
||||
if not votes:
|
||||
# Fallback to first result
|
||||
return results[0]
|
||||
|
||||
winner_language = max(votes.keys(), key=lambda x: votes[x])
|
||||
winner_confidence = votes[winner_language] / total_confidence if total_confidence > 0 else 0.5
|
||||
|
||||
# Collect alternatives
|
||||
alternatives = []
|
||||
for lang, score in sorted(votes.items(), key=lambda x: x[1], reverse=True):
|
||||
if lang != winner_language:
|
||||
alternatives.append((lang, score / total_confidence))
|
||||
|
||||
return DetectionResult(
|
||||
language=winner_language,
|
||||
confidence=winner_confidence,
|
||||
method=DetectionMethod.ENSEMBLE,
|
||||
alternatives=alternatives[:5], # Top 5 alternatives
|
||||
processing_time_ms=int(total_processing_time / len(results))
|
||||
)
|
||||
|
||||
def get_supported_languages(self) -> List[str]:
|
||||
"""Get list of supported languages for detection"""
|
||||
return [
|
||||
"en", "zh", "zh-cn", "zh-tw", "es", "fr", "de", "ja", "ko", "ru", "ar",
|
||||
"hi", "pt", "it", "nl", "sv", "da", "no", "fi", "pl", "tr", "th", "vi",
|
||||
"id", "ms", "tl", "sw", "af", "is", "mt", "cy", "ga", "gd", "eu", "ca",
|
||||
"gl", "ast", "lb", "rm", "fur", "lld", "lij", "lmo", "vec", "scn",
|
||||
"ro", "mo", "hr", "sr", "sl", "sk", "cs", "pl", "uk", "be", "bg",
|
||||
"mk", "sq", "hy", "ka", "he", "yi", "fa", "ps", "ur", "bn", "as",
|
||||
"or", "pa", "gu", "mr", "ne", "si", "ta", "te", "ml", "kn", "my",
|
||||
"km", "lo", "th", "vi", "id", "ms", "jv", "su", "tl", "sw", "zu",
|
||||
"xh", "af", "is", "mt", "cy", "ga", "gd", "eu", "ca", "gl", "ast",
|
||||
"lb", "rm", "fur", "lld", "lij", "lmo", "vec", "scn"
|
||||
]
|
||||
|
||||
async def batch_detect(self, texts: List[str]) -> List[DetectionResult]:
|
||||
"""Detect languages for multiple texts in parallel"""
|
||||
|
||||
tasks = [self.detect_language(text) for text in texts]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle exceptions
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, DetectionResult):
|
||||
processed_results.append(result)
|
||||
else:
|
||||
logger.error(f"Batch detection failed for text {i}: {result}")
|
||||
# Add fallback result
|
||||
processed_results.append(DetectionResult(
|
||||
language="en",
|
||||
confidence=0.5,
|
||||
method=DetectionMethod.LANGDETECT,
|
||||
alternatives=[],
|
||||
processing_time_ms=0
|
||||
))
|
||||
|
||||
return processed_results
|
||||
|
||||
def validate_language_code(self, language_code: str) -> bool:
|
||||
"""Validate if language code is supported"""
|
||||
supported = self.get_supported_languages()
|
||||
return language_code.lower() in supported
|
||||
|
||||
def normalize_language_code(self, language_code: str) -> str:
|
||||
"""Normalize language code to standard format"""
|
||||
|
||||
# Common mappings
|
||||
mappings = {
|
||||
"zh": "zh-cn",
|
||||
"zh-cn": "zh-cn",
|
||||
"zh_tw": "zh-tw",
|
||||
"zh_tw": "zh-tw",
|
||||
"en_us": "en",
|
||||
"en-us": "en",
|
||||
"en_gb": "en",
|
||||
"en-gb": "en"
|
||||
}
|
||||
|
||||
normalized = language_code.lower().replace("_", "-")
|
||||
return mappings.get(normalized, normalized)
|
||||
|
||||
async def health_check(self) -> Dict[str, bool]:
|
||||
"""Health check for all detection methods"""
|
||||
|
||||
health_status = {}
|
||||
test_text = "Hello, how are you today?"
|
||||
|
||||
# Test each method
|
||||
methods_to_test = [DetectionMethod.LANGDETECT, DetectionMethod.POLYGLOT]
|
||||
if self.fasttext_model:
|
||||
methods_to_test.append(DetectionMethod.FASTTEXT)
|
||||
|
||||
for method in methods_to_test:
|
||||
try:
|
||||
result = await self._detect_with_method(test_text, method)
|
||||
health_status[method.value] = result.confidence > 0.5
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed for {method.value}: {e}")
|
||||
health_status[method.value] = False
|
||||
|
||||
# Test ensemble
|
||||
try:
|
||||
result = await self._ensemble_detection(test_text)
|
||||
health_status["ensemble"] = result.confidence > 0.5
|
||||
except Exception as e:
|
||||
logger.error(f"Ensemble health check failed: {e}")
|
||||
health_status["ensemble"] = False
|
||||
|
||||
return health_status
|
||||
@@ -0,0 +1,557 @@
|
||||
"""
|
||||
Marketplace Localization Support
|
||||
Multi-language support for marketplace listings and content
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from .translation_engine import TranslationEngine, TranslationRequest, TranslationResponse
|
||||
from .language_detector import LanguageDetector, DetectionResult
|
||||
from .translation_cache import TranslationCache
|
||||
from .quality_assurance import TranslationQualityChecker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ListingType(Enum):
|
||||
SERVICE = "service"
|
||||
AGENT = "agent"
|
||||
RESOURCE = "resource"
|
||||
DATASET = "dataset"
|
||||
|
||||
@dataclass
|
||||
class LocalizedListing:
|
||||
"""Multi-language marketplace listing"""
|
||||
id: str
|
||||
original_id: str
|
||||
listing_type: ListingType
|
||||
language: str
|
||||
title: str
|
||||
description: str
|
||||
keywords: List[str]
|
||||
features: List[str]
|
||||
requirements: List[str]
|
||||
pricing_info: Dict[str, Any]
|
||||
translation_confidence: Optional[float] = None
|
||||
translation_provider: Optional[str] = None
|
||||
translated_at: Optional[datetime] = None
|
||||
reviewed: bool = False
|
||||
reviewer_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.translated_at is None:
|
||||
self.translated_at = datetime.utcnow()
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
@dataclass
|
||||
class LocalizationRequest:
|
||||
"""Request for listing localization"""
|
||||
listing_id: str
|
||||
target_languages: List[str]
|
||||
translate_title: bool = True
|
||||
translate_description: bool = True
|
||||
translate_keywords: bool = True
|
||||
translate_features: bool = True
|
||||
translate_requirements: bool = True
|
||||
quality_threshold: float = 0.7
|
||||
priority: str = "normal" # low, normal, high
|
||||
|
||||
class MarketplaceLocalization:
|
||||
"""Marketplace localization service"""
|
||||
|
||||
def __init__(self, translation_engine: TranslationEngine,
|
||||
language_detector: LanguageDetector,
|
||||
translation_cache: Optional[TranslationCache] = None,
|
||||
quality_checker: Optional[TranslationQualityChecker] = None):
|
||||
self.translation_engine = translation_engine
|
||||
self.language_detector = language_detector
|
||||
self.translation_cache = translation_cache
|
||||
self.quality_checker = quality_checker
|
||||
self.localized_listings: Dict[str, List[LocalizedListing]] = {} # listing_id -> [LocalizedListing]
|
||||
self.localization_queue: List[LocalizationRequest] = []
|
||||
self.localization_stats = {
|
||||
"total_localizations": 0,
|
||||
"successful_localizations": 0,
|
||||
"failed_localizations": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"quality_checks": 0
|
||||
}
|
||||
|
||||
async def create_localized_listing(self, original_listing: Dict[str, Any],
|
||||
target_languages: List[str]) -> List[LocalizedListing]:
|
||||
"""Create localized versions of a marketplace listing"""
|
||||
try:
|
||||
localized_listings = []
|
||||
|
||||
# Detect original language if not specified
|
||||
original_language = original_listing.get("language", "en")
|
||||
if not original_language:
|
||||
# Detect from title and description
|
||||
text_to_detect = f"{original_listing.get('title', '')} {original_listing.get('description', '')}"
|
||||
detection_result = await self.language_detector.detect_language(text_to_detect)
|
||||
original_language = detection_result.language
|
||||
|
||||
# Create localized versions for each target language
|
||||
for target_lang in target_languages:
|
||||
if target_lang == original_language:
|
||||
continue # Skip same language
|
||||
|
||||
localized_listing = await self._translate_listing(
|
||||
original_listing, original_language, target_lang
|
||||
)
|
||||
|
||||
if localized_listing:
|
||||
localized_listings.append(localized_listing)
|
||||
|
||||
# Store localized listings
|
||||
listing_id = original_listing.get("id")
|
||||
if listing_id not in self.localized_listings:
|
||||
self.localized_listings[listing_id] = []
|
||||
self.localized_listings[listing_id].extend(localized_listings)
|
||||
|
||||
return localized_listings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create localized listings: {e}")
|
||||
return []
|
||||
|
||||
async def _translate_listing(self, original_listing: Dict[str, Any],
|
||||
source_lang: str, target_lang: str) -> Optional[LocalizedListing]:
|
||||
"""Translate a single listing to target language"""
|
||||
try:
|
||||
translations = {}
|
||||
confidence_scores = []
|
||||
|
||||
# Translate title
|
||||
title = original_listing.get("title", "")
|
||||
if title:
|
||||
title_result = await self._translate_text(
|
||||
title, source_lang, target_lang, "marketplace_title"
|
||||
)
|
||||
if title_result:
|
||||
translations["title"] = title_result.translated_text
|
||||
confidence_scores.append(title_result.confidence)
|
||||
|
||||
# Translate description
|
||||
description = original_listing.get("description", "")
|
||||
if description:
|
||||
desc_result = await self._translate_text(
|
||||
description, source_lang, target_lang, "marketplace_description"
|
||||
)
|
||||
if desc_result:
|
||||
translations["description"] = desc_result.translated_text
|
||||
confidence_scores.append(desc_result.confidence)
|
||||
|
||||
# Translate keywords
|
||||
keywords = original_listing.get("keywords", [])
|
||||
translated_keywords = []
|
||||
for keyword in keywords:
|
||||
keyword_result = await self._translate_text(
|
||||
keyword, source_lang, target_lang, "marketplace_keyword"
|
||||
)
|
||||
if keyword_result:
|
||||
translated_keywords.append(keyword_result.translated_text)
|
||||
confidence_scores.append(keyword_result.confidence)
|
||||
translations["keywords"] = translated_keywords
|
||||
|
||||
# Translate features
|
||||
features = original_listing.get("features", [])
|
||||
translated_features = []
|
||||
for feature in features:
|
||||
feature_result = await self._translate_text(
|
||||
feature, source_lang, target_lang, "marketplace_feature"
|
||||
)
|
||||
if feature_result:
|
||||
translated_features.append(feature_result.translated_text)
|
||||
confidence_scores.append(feature_result.confidence)
|
||||
translations["features"] = translated_features
|
||||
|
||||
# Translate requirements
|
||||
requirements = original_listing.get("requirements", [])
|
||||
translated_requirements = []
|
||||
for requirement in requirements:
|
||||
req_result = await self._translate_text(
|
||||
requirement, source_lang, target_lang, "marketplace_requirement"
|
||||
)
|
||||
if req_result:
|
||||
translated_requirements.append(req_result.translated_text)
|
||||
confidence_scores.append(req_result.confidence)
|
||||
translations["requirements"] = translated_requirements
|
||||
|
||||
# Calculate overall confidence
|
||||
overall_confidence = sum(confidence_scores) / len(confidence_scores) if confidence_scores else 0.0
|
||||
|
||||
# Create localized listing
|
||||
localized_listing = LocalizedListing(
|
||||
id=f"{original_listing.get('id')}_{target_lang}",
|
||||
original_id=original_listing.get("id"),
|
||||
listing_type=ListingType(original_listing.get("type", "service")),
|
||||
language=target_lang,
|
||||
title=translations.get("title", ""),
|
||||
description=translations.get("description", ""),
|
||||
keywords=translations.get("keywords", []),
|
||||
features=translations.get("features", []),
|
||||
requirements=translations.get("requirements", []),
|
||||
pricing_info=original_listing.get("pricing_info", {}),
|
||||
translation_confidence=overall_confidence,
|
||||
translation_provider="mixed", # Could be enhanced to track actual providers
|
||||
translated_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
# Quality check
|
||||
if self.quality_checker and overall_confidence > 0.5:
|
||||
await self._perform_quality_check(localized_listing, original_listing)
|
||||
|
||||
self.localization_stats["total_localizations"] += 1
|
||||
self.localization_stats["successful_localizations"] += 1
|
||||
|
||||
return localized_listing
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to translate listing: {e}")
|
||||
self.localization_stats["failed_localizations"] += 1
|
||||
return None
|
||||
|
||||
async def _translate_text(self, text: str, source_lang: str, target_lang: str,
|
||||
context: str) -> Optional[TranslationResponse]:
|
||||
"""Translate text with caching and context"""
|
||||
try:
|
||||
# Check cache first
|
||||
if self.translation_cache:
|
||||
cached_result = await self.translation_cache.get(text, source_lang, target_lang, context)
|
||||
if cached_result:
|
||||
self.localization_stats["cache_hits"] += 1
|
||||
return cached_result
|
||||
self.localization_stats["cache_misses"] += 1
|
||||
|
||||
# Perform translation
|
||||
translation_request = TranslationRequest(
|
||||
text=text,
|
||||
source_language=source_lang,
|
||||
target_language=target_lang,
|
||||
context=context,
|
||||
domain="marketplace"
|
||||
)
|
||||
|
||||
translation_result = await self.translation_engine.translate(translation_request)
|
||||
|
||||
# Cache the result
|
||||
if self.translation_cache and translation_result.confidence > 0.8:
|
||||
await self.translation_cache.set(text, source_lang, target_lang, translation_result, context=context)
|
||||
|
||||
return translation_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to translate text: {e}")
|
||||
return None
|
||||
|
||||
async def _perform_quality_check(self, localized_listing: LocalizedListing,
|
||||
original_listing: Dict[str, Any]):
|
||||
"""Perform quality assessment on localized listing"""
|
||||
try:
|
||||
if not self.quality_checker:
|
||||
return
|
||||
|
||||
# Quality check title
|
||||
if localized_listing.title and original_listing.get("title"):
|
||||
title_assessment = await self.quality_checker.evaluate_translation(
|
||||
original_listing["title"],
|
||||
localized_listing.title,
|
||||
"en", # Assuming original is English for now
|
||||
localized_listing.language
|
||||
)
|
||||
|
||||
# Update confidence based on quality check
|
||||
if title_assessment.overall_score < localized_listing.translation_confidence:
|
||||
localized_listing.translation_confidence = title_assessment.overall_score
|
||||
|
||||
# Quality check description
|
||||
if localized_listing.description and original_listing.get("description"):
|
||||
desc_assessment = await self.quality_checker.evaluate_translation(
|
||||
original_listing["description"],
|
||||
localized_listing.description,
|
||||
"en",
|
||||
localized_listing.language
|
||||
)
|
||||
|
||||
# Update confidence
|
||||
if desc_assessment.overall_score < localized_listing.translation_confidence:
|
||||
localized_listing.translation_confidence = desc_assessment.overall_score
|
||||
|
||||
self.localization_stats["quality_checks"] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to perform quality check: {e}")
|
||||
|
||||
async def get_localized_listing(self, listing_id: str, language: str) -> Optional[LocalizedListing]:
|
||||
"""Get localized listing for specific language"""
|
||||
try:
|
||||
if listing_id in self.localized_listings:
|
||||
for listing in self.localized_listings[listing_id]:
|
||||
if listing.language == language:
|
||||
return listing
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get localized listing: {e}")
|
||||
return None
|
||||
|
||||
async def search_localized_listings(self, query: str, language: str,
|
||||
filters: Optional[Dict[str, Any]] = None) -> List[LocalizedListing]:
|
||||
"""Search localized listings with multi-language support"""
|
||||
try:
|
||||
results = []
|
||||
|
||||
# Detect query language if needed
|
||||
query_language = language
|
||||
if language != "en": # Assume English as default
|
||||
detection_result = await self.language_detector.detect_language(query)
|
||||
query_language = detection_result.language
|
||||
|
||||
# Search in all localized listings
|
||||
for listing_id, listings in self.localized_listings.items():
|
||||
for listing in listings:
|
||||
if listing.language != language:
|
||||
continue
|
||||
|
||||
# Simple text matching (could be enhanced with proper search)
|
||||
if self._matches_query(listing, query, query_language):
|
||||
# Apply filters if provided
|
||||
if filters and not self._matches_filters(listing, filters):
|
||||
continue
|
||||
|
||||
results.append(listing)
|
||||
|
||||
# Sort by relevance (could be enhanced with proper ranking)
|
||||
results.sort(key=lambda x: x.translation_confidence or 0, reverse=True)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search localized listings: {e}")
|
||||
return []
|
||||
|
||||
def _matches_query(self, listing: LocalizedListing, query: str, query_language: str) -> bool:
|
||||
"""Check if listing matches search query"""
|
||||
query_lower = query.lower()
|
||||
|
||||
# Search in title
|
||||
if query_lower in listing.title.lower():
|
||||
return True
|
||||
|
||||
# Search in description
|
||||
if query_lower in listing.description.lower():
|
||||
return True
|
||||
|
||||
# Search in keywords
|
||||
for keyword in listing.keywords:
|
||||
if query_lower in keyword.lower():
|
||||
return True
|
||||
|
||||
# Search in features
|
||||
for feature in listing.features:
|
||||
if query_lower in feature.lower():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _matches_filters(self, listing: LocalizedListing, filters: Dict[str, Any]) -> bool:
|
||||
"""Check if listing matches provided filters"""
|
||||
# Filter by listing type
|
||||
if "listing_type" in filters:
|
||||
if listing.listing_type.value != filters["listing_type"]:
|
||||
return False
|
||||
|
||||
# Filter by minimum confidence
|
||||
if "min_confidence" in filters:
|
||||
if (listing.translation_confidence or 0) < filters["min_confidence"]:
|
||||
return False
|
||||
|
||||
# Filter by reviewed status
|
||||
if "reviewed_only" in filters and filters["reviewed_only"]:
|
||||
if not listing.reviewed:
|
||||
return False
|
||||
|
||||
# Filter by price range
|
||||
if "price_range" in filters:
|
||||
price_info = listing.pricing_info
|
||||
if "min_price" in price_info and "max_price" in price_info:
|
||||
price_min = filters["price_range"].get("min", 0)
|
||||
price_max = filters["price_range"].get("max", float("inf"))
|
||||
if price_info["min_price"] > price_max or price_info["max_price"] < price_min:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def batch_localize_listings(self, listings: List[Dict[str, Any]],
|
||||
target_languages: List[str]) -> Dict[str, List[LocalizedListing]]:
|
||||
"""Localize multiple listings in batch"""
|
||||
try:
|
||||
results = {}
|
||||
|
||||
# Process listings in parallel
|
||||
tasks = []
|
||||
for listing in listings:
|
||||
task = self.create_localized_listing(listing, target_languages)
|
||||
tasks.append(task)
|
||||
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Process results
|
||||
for i, result in enumerate(batch_results):
|
||||
listing_id = listings[i].get("id", f"unknown_{i}")
|
||||
if isinstance(result, list):
|
||||
results[listing_id] = result
|
||||
else:
|
||||
logger.error(f"Failed to localize listing {listing_id}: {result}")
|
||||
results[listing_id] = []
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to batch localize listings: {e}")
|
||||
return {}
|
||||
|
||||
async def update_localized_listing(self, localized_listing: LocalizedListing) -> bool:
|
||||
"""Update an existing localized listing"""
|
||||
try:
|
||||
listing_id = localized_listing.original_id
|
||||
|
||||
if listing_id not in self.localized_listings:
|
||||
self.localized_listings[listing_id] = []
|
||||
|
||||
# Find and update existing listing
|
||||
for i, existing in enumerate(self.localized_listings[listing_id]):
|
||||
if existing.id == localized_listing.id:
|
||||
self.localized_listings[listing_id][i] = localized_listing
|
||||
return True
|
||||
|
||||
# Add new listing if not found
|
||||
self.localized_listings[listing_id].append(localized_listing)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update localized listing: {e}")
|
||||
return False
|
||||
|
||||
async def get_localization_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive localization statistics"""
|
||||
try:
|
||||
stats = self.localization_stats.copy()
|
||||
|
||||
# Calculate success rate
|
||||
total = stats["total_localizations"]
|
||||
if total > 0:
|
||||
stats["success_rate"] = stats["successful_localizations"] / total
|
||||
stats["failure_rate"] = stats["failed_localizations"] / total
|
||||
else:
|
||||
stats["success_rate"] = 0.0
|
||||
stats["failure_rate"] = 0.0
|
||||
|
||||
# Calculate cache hit ratio
|
||||
cache_total = stats["cache_hits"] + stats["cache_misses"]
|
||||
if cache_total > 0:
|
||||
stats["cache_hit_ratio"] = stats["cache_hits"] / cache_total
|
||||
else:
|
||||
stats["cache_hit_ratio"] = 0.0
|
||||
|
||||
# Language statistics
|
||||
language_stats = {}
|
||||
total_listings = 0
|
||||
|
||||
for listing_id, listings in self.localized_listings.items():
|
||||
for listing in listings:
|
||||
lang = listing.language
|
||||
if lang not in language_stats:
|
||||
language_stats[lang] = 0
|
||||
language_stats[lang] += 1
|
||||
total_listings += 1
|
||||
|
||||
stats["language_distribution"] = language_stats
|
||||
stats["total_localized_listings"] = total_listings
|
||||
|
||||
# Quality statistics
|
||||
quality_stats = {
|
||||
"high_quality": 0, # > 0.8
|
||||
"medium_quality": 0, # 0.6-0.8
|
||||
"low_quality": 0, # < 0.6
|
||||
"reviewed": 0
|
||||
}
|
||||
|
||||
for listings in self.localized_listings.values():
|
||||
for listing in listings:
|
||||
confidence = listing.translation_confidence or 0
|
||||
if confidence > 0.8:
|
||||
quality_stats["high_quality"] += 1
|
||||
elif confidence > 0.6:
|
||||
quality_stats["medium_quality"] += 1
|
||||
else:
|
||||
quality_stats["low_quality"] += 1
|
||||
|
||||
if listing.reviewed:
|
||||
quality_stats["reviewed"] += 1
|
||||
|
||||
stats["quality_statistics"] = quality_stats
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get localization statistics: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Health check for marketplace localization"""
|
||||
try:
|
||||
health_status = {
|
||||
"overall": "healthy",
|
||||
"services": {},
|
||||
"statistics": {}
|
||||
}
|
||||
|
||||
# Check translation engine
|
||||
translation_health = await self.translation_engine.health_check()
|
||||
health_status["services"]["translation_engine"] = all(translation_health.values())
|
||||
|
||||
# Check language detector
|
||||
detection_health = await self.language_detector.health_check()
|
||||
health_status["services"]["language_detector"] = all(detection_health.values())
|
||||
|
||||
# Check cache
|
||||
if self.translation_cache:
|
||||
cache_health = await self.translation_cache.health_check()
|
||||
health_status["services"]["translation_cache"] = cache_health.get("status") == "healthy"
|
||||
else:
|
||||
health_status["services"]["translation_cache"] = False
|
||||
|
||||
# Check quality checker
|
||||
if self.quality_checker:
|
||||
quality_health = await self.quality_checker.health_check()
|
||||
health_status["services"]["quality_checker"] = all(quality_health.values())
|
||||
else:
|
||||
health_status["services"]["quality_checker"] = False
|
||||
|
||||
# Overall status
|
||||
all_healthy = all(health_status["services"].values())
|
||||
health_status["overall"] = "healthy" if all_healthy else "degraded" if any(health_status["services"].values()) else "unhealthy"
|
||||
|
||||
# Add statistics
|
||||
health_status["statistics"] = {
|
||||
"total_listings": len(self.localized_listings),
|
||||
"localization_stats": self.localization_stats
|
||||
}
|
||||
|
||||
return health_status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}")
|
||||
return {
|
||||
"overall": "unhealthy",
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
Translation Quality Assurance Module
|
||||
Quality assessment and validation for translation results
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import nltk
|
||||
from nltk.tokenize import word_tokenize, sent_tokenize
|
||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
import spacy
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
import difflib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class QualityMetric(Enum):
|
||||
BLEU = "bleu"
|
||||
SEMANTIC_SIMILARITY = "semantic_similarity"
|
||||
LENGTH_RATIO = "length_ratio"
|
||||
CONFIDENCE = "confidence"
|
||||
CONSISTENCY = "consistency"
|
||||
|
||||
@dataclass
|
||||
class QualityScore:
|
||||
metric: QualityMetric
|
||||
score: float
|
||||
weight: float
|
||||
description: str
|
||||
|
||||
@dataclass
|
||||
class QualityAssessment:
|
||||
overall_score: float
|
||||
individual_scores: List[QualityScore]
|
||||
passed_threshold: bool
|
||||
recommendations: List[str]
|
||||
processing_time_ms: int
|
||||
|
||||
class TranslationQualityChecker:
|
||||
"""Advanced quality assessment for translation results"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.nlp_models = {}
|
||||
self.thresholds = config.get("thresholds", {
|
||||
"overall": 0.7,
|
||||
"bleu": 0.3,
|
||||
"semantic_similarity": 0.6,
|
||||
"length_ratio": 0.5,
|
||||
"confidence": 0.6
|
||||
})
|
||||
self._initialize_models()
|
||||
|
||||
def _initialize_models(self):
|
||||
"""Initialize NLP models for quality assessment"""
|
||||
try:
|
||||
# Load spaCy models for different languages
|
||||
languages = ["en", "zh", "es", "fr", "de", "ja", "ko", "ru"]
|
||||
for lang in languages:
|
||||
try:
|
||||
model_name = f"{lang}_core_web_sm"
|
||||
self.nlp_models[lang] = spacy.load(model_name)
|
||||
except OSError:
|
||||
logger.warning(f"Spacy model for {lang} not found, using fallback")
|
||||
# Fallback to English model for basic processing
|
||||
if "en" not in self.nlp_models:
|
||||
self.nlp_models["en"] = spacy.load("en_core_web_sm")
|
||||
self.nlp_models[lang] = self.nlp_models["en"]
|
||||
|
||||
# Download NLTK data if needed
|
||||
try:
|
||||
nltk.data.find('tokenizers/punkt')
|
||||
except LookupError:
|
||||
nltk.download('punkt')
|
||||
|
||||
logger.info("Quality checker models initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize quality checker models: {e}")
|
||||
|
||||
async def evaluate_translation(self, source_text: str, translated_text: str,
|
||||
source_lang: str, target_lang: str,
|
||||
reference_translation: Optional[str] = None) -> QualityAssessment:
|
||||
"""Comprehensive quality assessment of translation"""
|
||||
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
scores = []
|
||||
|
||||
# 1. Confidence-based scoring
|
||||
confidence_score = await self._evaluate_confidence(translated_text, source_lang, target_lang)
|
||||
scores.append(confidence_score)
|
||||
|
||||
# 2. Length ratio assessment
|
||||
length_score = await self._evaluate_length_ratio(source_text, translated_text, source_lang, target_lang)
|
||||
scores.append(length_score)
|
||||
|
||||
# 3. Semantic similarity (if models available)
|
||||
semantic_score = await self._evaluate_semantic_similarity(source_text, translated_text, source_lang, target_lang)
|
||||
scores.append(semantic_score)
|
||||
|
||||
# 4. BLEU score (if reference available)
|
||||
if reference_translation:
|
||||
bleu_score = await self._evaluate_bleu_score(translated_text, reference_translation)
|
||||
scores.append(bleu_score)
|
||||
|
||||
# 5. Consistency check
|
||||
consistency_score = await self._evaluate_consistency(source_text, translated_text)
|
||||
scores.append(consistency_score)
|
||||
|
||||
# Calculate overall score
|
||||
overall_score = self._calculate_overall_score(scores)
|
||||
|
||||
# Generate recommendations
|
||||
recommendations = self._generate_recommendations(scores, overall_score)
|
||||
|
||||
processing_time = int((asyncio.get_event_loop().time() - start_time) * 1000)
|
||||
|
||||
return QualityAssessment(
|
||||
overall_score=overall_score,
|
||||
individual_scores=scores,
|
||||
passed_threshold=overall_score >= self.thresholds["overall"],
|
||||
recommendations=recommendations,
|
||||
processing_time_ms=processing_time
|
||||
)
|
||||
|
||||
async def _evaluate_confidence(self, translated_text: str, source_lang: str, target_lang: str) -> QualityScore:
|
||||
"""Evaluate translation confidence based on various factors"""
|
||||
|
||||
confidence_factors = []
|
||||
|
||||
# Text completeness
|
||||
if translated_text.strip():
|
||||
confidence_factors.append(0.8)
|
||||
else:
|
||||
confidence_factors.append(0.1)
|
||||
|
||||
# Language detection consistency
|
||||
try:
|
||||
# Basic language detection (simplified)
|
||||
if self._is_valid_language(translated_text, target_lang):
|
||||
confidence_factors.append(0.7)
|
||||
else:
|
||||
confidence_factors.append(0.3)
|
||||
except:
|
||||
confidence_factors.append(0.5)
|
||||
|
||||
# Text structure preservation
|
||||
source_sentences = sent_tokenize(source_text)
|
||||
translated_sentences = sent_tokenize(translated_text)
|
||||
|
||||
if len(source_sentences) > 0:
|
||||
sentence_ratio = len(translated_sentences) / len(source_sentences)
|
||||
if 0.5 <= sentence_ratio <= 2.0:
|
||||
confidence_factors.append(0.6)
|
||||
else:
|
||||
confidence_factors.append(0.3)
|
||||
else:
|
||||
confidence_factors.append(0.5)
|
||||
|
||||
# Average confidence
|
||||
avg_confidence = np.mean(confidence_factors)
|
||||
|
||||
return QualityScore(
|
||||
metric=QualityMetric.CONFIDENCE,
|
||||
score=avg_confidence,
|
||||
weight=0.3,
|
||||
description=f"Confidence based on text completeness, language detection, and structure preservation"
|
||||
)
|
||||
|
||||
async def _evaluate_length_ratio(self, source_text: str, translated_text: str,
|
||||
source_lang: str, target_lang: str) -> QualityScore:
|
||||
"""Evaluate appropriate length ratio between source and target"""
|
||||
|
||||
source_length = len(source_text.strip())
|
||||
translated_length = len(translated_text.strip())
|
||||
|
||||
if source_length == 0:
|
||||
return QualityScore(
|
||||
metric=QualityMetric.LENGTH_RATIO,
|
||||
score=0.0,
|
||||
weight=0.2,
|
||||
description="Empty source text"
|
||||
)
|
||||
|
||||
ratio = translated_length / source_length
|
||||
|
||||
# Expected length ratios by language pair (simplified)
|
||||
expected_ratios = {
|
||||
("en", "zh"): 0.8, # Chinese typically shorter
|
||||
("en", "ja"): 0.9,
|
||||
("en", "ko"): 0.9,
|
||||
("zh", "en"): 1.2, # English typically longer
|
||||
("ja", "en"): 1.1,
|
||||
("ko", "en"): 1.1,
|
||||
}
|
||||
|
||||
expected_ratio = expected_ratios.get((source_lang, target_lang), 1.0)
|
||||
|
||||
# Calculate score based on deviation from expected ratio
|
||||
deviation = abs(ratio - expected_ratio)
|
||||
score = max(0.0, 1.0 - deviation)
|
||||
|
||||
return QualityScore(
|
||||
metric=QualityMetric.LENGTH_RATIO,
|
||||
score=score,
|
||||
weight=0.2,
|
||||
description=f"Length ratio: {ratio:.2f} (expected: {expected_ratio:.2f})"
|
||||
)
|
||||
|
||||
async def _evaluate_semantic_similarity(self, source_text: str, translated_text: str,
|
||||
source_lang: str, target_lang: str) -> QualityScore:
|
||||
"""Evaluate semantic similarity using NLP models"""
|
||||
|
||||
try:
|
||||
# Get appropriate NLP models
|
||||
source_nlp = self.nlp_models.get(source_lang, self.nlp_models.get("en"))
|
||||
target_nlp = self.nlp_models.get(target_lang, self.nlp_models.get("en"))
|
||||
|
||||
# Process texts
|
||||
source_doc = source_nlp(source_text)
|
||||
target_doc = target_nlp(translated_text)
|
||||
|
||||
# Extract key features
|
||||
source_features = self._extract_text_features(source_doc)
|
||||
target_features = self._extract_text_features(target_doc)
|
||||
|
||||
# Calculate similarity
|
||||
similarity = self._calculate_feature_similarity(source_features, target_features)
|
||||
|
||||
return QualityScore(
|
||||
metric=QualityMetric.SEMANTIC_SIMILARITY,
|
||||
score=similarity,
|
||||
weight=0.3,
|
||||
description=f"Semantic similarity based on NLP features"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Semantic similarity evaluation failed: {e}")
|
||||
# Fallback to basic similarity
|
||||
return QualityScore(
|
||||
metric=QualityMetric.SEMANTIC_SIMILARITY,
|
||||
score=0.5,
|
||||
weight=0.3,
|
||||
description="Fallback similarity score"
|
||||
)
|
||||
|
||||
async def _evaluate_bleu_score(self, translated_text: str, reference_text: str) -> QualityScore:
|
||||
"""Calculate BLEU score against reference translation"""
|
||||
|
||||
try:
|
||||
# Tokenize texts
|
||||
reference_tokens = word_tokenize(reference_text.lower())
|
||||
candidate_tokens = word_tokenize(translated_text.lower())
|
||||
|
||||
# Calculate BLEU score with smoothing
|
||||
smoothing = SmoothingFunction().method1
|
||||
bleu_score = sentence_bleu([reference_tokens], candidate_tokens, smoothing_function=smoothing)
|
||||
|
||||
return QualityScore(
|
||||
metric=QualityMetric.BLEU,
|
||||
score=bleu_score,
|
||||
weight=0.2,
|
||||
description=f"BLEU score against reference translation"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"BLEU score calculation failed: {e}")
|
||||
return QualityScore(
|
||||
metric=QualityMetric.BLEU,
|
||||
score=0.0,
|
||||
weight=0.2,
|
||||
description="BLEU score calculation failed"
|
||||
)
|
||||
|
||||
async def _evaluate_consistency(self, source_text: str, translated_text: str) -> QualityScore:
|
||||
"""Evaluate internal consistency of translation"""
|
||||
|
||||
consistency_factors = []
|
||||
|
||||
# Check for repeated patterns
|
||||
source_words = word_tokenize(source_text.lower())
|
||||
translated_words = word_tokenize(translated_text.lower())
|
||||
|
||||
source_word_freq = Counter(source_words)
|
||||
translated_word_freq = Counter(translated_words)
|
||||
|
||||
# Check if high-frequency words are preserved
|
||||
common_words = [word for word, freq in source_word_freq.most_common(5) if freq > 1]
|
||||
|
||||
if common_words:
|
||||
preserved_count = 0
|
||||
for word in common_words:
|
||||
# Simplified check - in reality, this would be more complex
|
||||
if len(translated_words) >= len(source_words) * 0.8:
|
||||
preserved_count += 1
|
||||
|
||||
consistency_score = preserved_count / len(common_words)
|
||||
consistency_factors.append(consistency_score)
|
||||
else:
|
||||
consistency_factors.append(0.8) # No repetition issues
|
||||
|
||||
# Check for formatting consistency
|
||||
source_punctuation = re.findall(r'[.!?;:,]', source_text)
|
||||
translated_punctuation = re.findall(r'[.!?;:,]', translated_text)
|
||||
|
||||
if len(source_punctuation) > 0:
|
||||
punctuation_ratio = len(translated_punctuation) / len(source_punctuation)
|
||||
if 0.5 <= punctuation_ratio <= 2.0:
|
||||
consistency_factors.append(0.7)
|
||||
else:
|
||||
consistency_factors.append(0.4)
|
||||
else:
|
||||
consistency_factors.append(0.8)
|
||||
|
||||
avg_consistency = np.mean(consistency_factors)
|
||||
|
||||
return QualityScore(
|
||||
metric=QualityMetric.CONSISTENCY,
|
||||
score=avg_consistency,
|
||||
weight=0.1,
|
||||
description="Internal consistency of translation"
|
||||
)
|
||||
|
||||
def _extract_text_features(self, doc) -> Dict[str, Any]:
|
||||
"""Extract linguistic features from spaCy document"""
|
||||
features = {
|
||||
"pos_tags": [token.pos_ for token in doc],
|
||||
"entities": [(ent.text, ent.label_) for ent in doc.ents],
|
||||
"noun_chunks": [chunk.text for chunk in doc.noun_chunks],
|
||||
"verbs": [token.lemma_ for token in doc if token.pos_ == "VERB"],
|
||||
"sentence_count": len(list(doc.sents)),
|
||||
"token_count": len(doc),
|
||||
}
|
||||
return features
|
||||
|
||||
def _calculate_feature_similarity(self, source_features: Dict, target_features: Dict) -> float:
|
||||
"""Calculate similarity between text features"""
|
||||
|
||||
similarities = []
|
||||
|
||||
# POS tag similarity
|
||||
source_pos = Counter(source_features["pos_tags"])
|
||||
target_pos = Counter(target_features["pos_tags"])
|
||||
|
||||
if source_pos and target_pos:
|
||||
pos_similarity = self._calculate_counter_similarity(source_pos, target_pos)
|
||||
similarities.append(pos_similarity)
|
||||
|
||||
# Entity similarity
|
||||
source_entities = set([ent[0].lower() for ent in source_features["entities"]])
|
||||
target_entities = set([ent[0].lower() for ent in target_features["entities"]])
|
||||
|
||||
if source_entities and target_entities:
|
||||
entity_similarity = len(source_entities & target_entities) / len(source_entities | target_entities)
|
||||
similarities.append(entity_similarity)
|
||||
|
||||
# Length similarity
|
||||
source_len = source_features["token_count"]
|
||||
target_len = target_features["token_count"]
|
||||
|
||||
if source_len > 0 and target_len > 0:
|
||||
length_similarity = min(source_len, target_len) / max(source_len, target_len)
|
||||
similarities.append(length_similarity)
|
||||
|
||||
return np.mean(similarities) if similarities else 0.5
|
||||
|
||||
def _calculate_counter_similarity(self, counter1: Counter, counter2: Counter) -> float:
|
||||
"""Calculate similarity between two Counters"""
|
||||
all_items = set(counter1.keys()) | set(counter2.keys())
|
||||
|
||||
if not all_items:
|
||||
return 1.0
|
||||
|
||||
dot_product = sum(counter1[item] * counter2[item] for item in all_items)
|
||||
magnitude1 = sum(counter1[item] ** 2 for item in all_items) ** 0.5
|
||||
magnitude2 = sum(counter2[item] ** 2 for item in all_items) ** 0.5
|
||||
|
||||
if magnitude1 == 0 or magnitude2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (magnitude1 * magnitude2)
|
||||
|
||||
def _is_valid_language(self, text: str, expected_lang: str) -> bool:
|
||||
"""Basic language validation (simplified)"""
|
||||
# This is a placeholder - in reality, you'd use a proper language detector
|
||||
lang_patterns = {
|
||||
"zh": r"[\u4e00-\u9fff]",
|
||||
"ja": r"[\u3040-\u309f\u30a0-\u30ff]",
|
||||
"ko": r"[\uac00-\ud7af]",
|
||||
"ar": r"[\u0600-\u06ff]",
|
||||
"ru": r"[\u0400-\u04ff]",
|
||||
}
|
||||
|
||||
pattern = lang_patterns.get(expected_lang, r"[a-zA-Z]")
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
return len(matches) > len(text) * 0.1 # At least 10% of characters should match
|
||||
|
||||
def _calculate_overall_score(self, scores: List[QualityScore]) -> float:
|
||||
"""Calculate weighted overall quality score"""
|
||||
|
||||
if not scores:
|
||||
return 0.0
|
||||
|
||||
weighted_sum = sum(score.score * score.weight for score in scores)
|
||||
total_weight = sum(score.weight for score in scores)
|
||||
|
||||
return weighted_sum / total_weight if total_weight > 0 else 0.0
|
||||
|
||||
def _generate_recommendations(self, scores: List[QualityScore], overall_score: float) -> List[str]:
|
||||
"""Generate improvement recommendations based on quality assessment"""
|
||||
|
||||
recommendations = []
|
||||
|
||||
if overall_score < self.thresholds["overall"]:
|
||||
recommendations.append("Translation quality below threshold - consider manual review")
|
||||
|
||||
for score in scores:
|
||||
if score.score < self.thresholds.get(score.metric.value, 0.5):
|
||||
if score.metric == QualityMetric.LENGTH_RATIO:
|
||||
recommendations.append("Translation length seems inappropriate - check for truncation or expansion")
|
||||
elif score.metric == QualityMetric.SEMANTIC_SIMILARITY:
|
||||
recommendations.append("Semantic meaning may be lost - verify key concepts are preserved")
|
||||
elif score.metric == QualityMetric.CONSISTENCY:
|
||||
recommendations.append("Translation lacks consistency - check for repeated patterns and formatting")
|
||||
elif score.metric == QualityMetric.CONFIDENCE:
|
||||
recommendations.append("Low confidence detected - verify translation accuracy")
|
||||
|
||||
return recommendations
|
||||
|
||||
async def batch_evaluate(self, translations: List[Tuple[str, str, str, str, Optional[str]]]) -> List[QualityAssessment]:
|
||||
"""Evaluate multiple translations in parallel"""
|
||||
|
||||
tasks = []
|
||||
for source_text, translated_text, source_lang, target_lang, reference in translations:
|
||||
task = self.evaluate_translation(source_text, translated_text, source_lang, target_lang, reference)
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle exceptions
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, QualityAssessment):
|
||||
processed_results.append(result)
|
||||
else:
|
||||
logger.error(f"Quality assessment failed for translation {i}: {result}")
|
||||
# Add fallback assessment
|
||||
processed_results.append(QualityAssessment(
|
||||
overall_score=0.5,
|
||||
individual_scores=[],
|
||||
passed_threshold=False,
|
||||
recommendations=["Quality assessment failed"],
|
||||
processing_time_ms=0
|
||||
))
|
||||
|
||||
return processed_results
|
||||
|
||||
async def health_check(self) -> Dict[str, bool]:
|
||||
"""Health check for quality checker"""
|
||||
|
||||
health_status = {}
|
||||
|
||||
# Test with sample translation
|
||||
try:
|
||||
sample_assessment = await self.evaluate_translation(
|
||||
"Hello world", "Hola mundo", "en", "es"
|
||||
)
|
||||
health_status["basic_assessment"] = sample_assessment.overall_score > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Quality checker health check failed: {e}")
|
||||
health_status["basic_assessment"] = False
|
||||
|
||||
# Check model availability
|
||||
health_status["nlp_models_loaded"] = len(self.nlp_models) > 0
|
||||
|
||||
return health_status
|
||||
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Multi-Language Service Requirements
|
||||
Dependencies and requirements for multi-language support
|
||||
"""
|
||||
|
||||
# Core dependencies
|
||||
fastapi>=0.104.0
|
||||
uvicorn[standard]>=0.24.0
|
||||
pydantic>=2.5.0
|
||||
python-multipart>=0.0.6
|
||||
|
||||
# Translation providers
|
||||
openai>=1.3.0
|
||||
google-cloud-translate>=3.11.0
|
||||
deepl>=1.16.0
|
||||
|
||||
# Language detection
|
||||
langdetect>=1.0.9
|
||||
polyglot>=16.10.0
|
||||
fasttext>=0.9.2
|
||||
|
||||
# Quality assessment
|
||||
nltk>=3.8.1
|
||||
spacy>=3.7.0
|
||||
numpy>=1.24.0
|
||||
|
||||
# Caching
|
||||
redis[hiredis]>=5.0.0
|
||||
aioredis>=2.0.1
|
||||
|
||||
# Database
|
||||
asyncpg>=0.29.0
|
||||
sqlalchemy[asyncio]>=2.0.0
|
||||
alembic>=1.13.0
|
||||
|
||||
# Testing
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.21.0
|
||||
pytest-mock>=3.12.0
|
||||
httpx>=0.25.0
|
||||
|
||||
# Monitoring and logging
|
||||
structlog>=23.2.0
|
||||
prometheus-client>=0.19.0
|
||||
|
||||
# Utilities
|
||||
python-dotenv>=1.0.0
|
||||
click>=8.1.0
|
||||
rich>=13.7.0
|
||||
tqdm>=4.66.0
|
||||
|
||||
# Security
|
||||
cryptography>=41.0.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
passlib[bcrypt]>=1.7.4
|
||||
|
||||
# Performance
|
||||
orjson>=3.9.0
|
||||
lz4>=4.3.0
|
||||
@@ -0,0 +1,641 @@
|
||||
"""
|
||||
Multi-Language Service Tests
|
||||
Comprehensive test suite for multi-language functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
# Import all modules to test
|
||||
from .translation_engine import TranslationEngine, TranslationRequest, TranslationResponse, TranslationProvider
|
||||
from .language_detector import LanguageDetector, DetectionMethod, DetectionResult
|
||||
from .translation_cache import TranslationCache
|
||||
from .quality_assurance import TranslationQualityChecker, QualityAssessment
|
||||
from .agent_communication import MultilingualAgentCommunication, AgentMessage, MessageType, AgentLanguageProfile
|
||||
from .marketplace_localization import MarketplaceLocalization, LocalizedListing, ListingType
|
||||
from .config import MultiLanguageConfig
|
||||
|
||||
class TestTranslationEngine:
|
||||
"""Test suite for TranslationEngine"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
return {
|
||||
"openai": {"api_key": "test-key"},
|
||||
"google": {"api_key": "test-key"},
|
||||
"deepl": {"api_key": "test-key"}
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def translation_engine(self, mock_config):
|
||||
return TranslationEngine(mock_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_translate_with_openai(self, translation_engine):
|
||||
"""Test translation using OpenAI provider"""
|
||||
request = TranslationRequest(
|
||||
text="Hello world",
|
||||
source_language="en",
|
||||
target_language="es"
|
||||
)
|
||||
|
||||
# Mock OpenAI response
|
||||
with patch.object(translation_engine.translators[TranslationProvider.OPENAI], 'translate') as mock_translate:
|
||||
mock_translate.return_value = TranslationResponse(
|
||||
translated_text="Hola mundo",
|
||||
confidence=0.95,
|
||||
provider=TranslationProvider.OPENAI,
|
||||
processing_time_ms=120,
|
||||
source_language="en",
|
||||
target_language="es"
|
||||
)
|
||||
|
||||
result = await translation_engine.translate(request)
|
||||
|
||||
assert result.translated_text == "Hola mundo"
|
||||
assert result.confidence == 0.95
|
||||
assert result.provider == TranslationProvider.OPENAI
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_translate_fallback_strategy(self, translation_engine):
|
||||
"""Test fallback strategy when primary provider fails"""
|
||||
request = TranslationRequest(
|
||||
text="Hello world",
|
||||
source_language="en",
|
||||
target_language="es"
|
||||
)
|
||||
|
||||
# Mock primary provider failure
|
||||
with patch.object(translation_engine.translators[TranslationProvider.OPENAI], 'translate') as mock_openai:
|
||||
mock_openai.side_effect = Exception("OpenAI failed")
|
||||
|
||||
# Mock secondary provider success
|
||||
with patch.object(translation_engine.translators[TranslationProvider.GOOGLE], 'translate') as mock_google:
|
||||
mock_google.return_value = TranslationResponse(
|
||||
translated_text="Hola mundo",
|
||||
confidence=0.85,
|
||||
provider=TranslationProvider.GOOGLE,
|
||||
processing_time_ms=100,
|
||||
source_language="en",
|
||||
target_language="es"
|
||||
)
|
||||
|
||||
result = await translation_engine.translate(request)
|
||||
|
||||
assert result.translated_text == "Hola mundo"
|
||||
assert result.provider == TranslationProvider.GOOGLE
|
||||
|
||||
def test_get_preferred_providers(self, translation_engine):
|
||||
"""Test provider preference logic"""
|
||||
request = TranslationRequest(
|
||||
text="Hello world",
|
||||
source_language="en",
|
||||
target_language="de"
|
||||
)
|
||||
|
||||
providers = translation_engine._get_preferred_providers(request)
|
||||
|
||||
# Should prefer DeepL for European languages
|
||||
assert TranslationProvider.DEEPL in providers
|
||||
assert providers[0] == TranslationProvider.DEEPL
|
||||
|
||||
class TestLanguageDetector:
|
||||
"""Test suite for LanguageDetector"""
|
||||
|
||||
@pytest.fixture
|
||||
def detector(self):
|
||||
config = {"fasttext": {"model_path": "test-model.bin"}}
|
||||
return LanguageDetector(config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_language_ensemble(self, detector):
|
||||
"""Test ensemble language detection"""
|
||||
text = "Bonjour le monde"
|
||||
|
||||
# Mock individual methods
|
||||
with patch.object(detector, '_detect_with_method') as mock_detect:
|
||||
mock_detect.side_effect = [
|
||||
DetectionResult("fr", 0.9, DetectionMethod.LANGDETECT, [], 50),
|
||||
DetectionResult("fr", 0.85, DetectionMethod.POLYGLOT, [], 60),
|
||||
DetectionResult("fr", 0.95, DetectionMethod.FASTTEXT, [], 40)
|
||||
]
|
||||
|
||||
result = await detector.detect_language(text)
|
||||
|
||||
assert result.language == "fr"
|
||||
assert result.method == DetectionMethod.ENSEMBLE
|
||||
assert result.confidence > 0.8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_detection(self, detector):
|
||||
"""Test batch language detection"""
|
||||
texts = ["Hello world", "Bonjour le monde", "Hola mundo"]
|
||||
|
||||
with patch.object(detector, 'detect_language') as mock_detect:
|
||||
mock_detect.side_effect = [
|
||||
DetectionResult("en", 0.95, DetectionMethod.LANGDETECT, [], 50),
|
||||
DetectionResult("fr", 0.90, DetectionMethod.LANGDETECT, [], 60),
|
||||
DetectionResult("es", 0.92, DetectionMethod.LANGDETECT, [], 55)
|
||||
]
|
||||
|
||||
results = await detector.batch_detect(texts)
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0].language == "en"
|
||||
assert results[1].language == "fr"
|
||||
assert results[2].language == "es"
|
||||
|
||||
class TestTranslationCache:
|
||||
"""Test suite for TranslationCache"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
redis_mock = AsyncMock()
|
||||
redis_mock.ping.return_value = True
|
||||
return redis_mock
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self, mock_redis):
|
||||
cache = TranslationCache("redis://localhost:6379")
|
||||
cache.redis = mock_redis
|
||||
return cache
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit(self, cache, mock_redis):
|
||||
"""Test cache hit scenario"""
|
||||
# Mock cache hit
|
||||
mock_response = Mock()
|
||||
mock_response.translated_text = "Hola mundo"
|
||||
mock_response.confidence = 0.95
|
||||
mock_response.provider = TranslationProvider.OPENAI
|
||||
mock_response.processing_time_ms = 120
|
||||
mock_response.source_language = "en"
|
||||
mock_response.target_language = "es"
|
||||
|
||||
with patch('pickle.loads', return_value=mock_response):
|
||||
mock_redis.get.return_value = b"serialized_data"
|
||||
|
||||
result = await cache.get("Hello world", "en", "es")
|
||||
|
||||
assert result.translated_text == "Hola mundo"
|
||||
assert result.confidence == 0.95
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss(self, cache, mock_redis):
|
||||
"""Test cache miss scenario"""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = await cache.get("Hello world", "en", "es")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set(self, cache, mock_redis):
|
||||
"""Test cache set operation"""
|
||||
response = TranslationResponse(
|
||||
translated_text="Hola mundo",
|
||||
confidence=0.95,
|
||||
provider=TranslationProvider.OPENAI,
|
||||
processing_time_ms=120,
|
||||
source_language="en",
|
||||
target_language="es"
|
||||
)
|
||||
|
||||
with patch('pickle.dumps', return_value=b"serialized_data"):
|
||||
result = await cache.set("Hello world", "en", "es", response)
|
||||
|
||||
assert result is True
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_cache_stats(self, cache, mock_redis):
|
||||
"""Test cache statistics"""
|
||||
mock_redis.info.return_value = {
|
||||
"used_memory": 1000000,
|
||||
"db_size": 1000
|
||||
}
|
||||
mock_redis.dbsize.return_value = 1000
|
||||
|
||||
stats = await cache.get_cache_stats()
|
||||
|
||||
assert "hits" in stats
|
||||
assert "misses" in stats
|
||||
assert "cache_size" in stats
|
||||
assert "memory_used" in stats
|
||||
|
||||
class TestTranslationQualityChecker:
|
||||
"""Test suite for TranslationQualityChecker"""
|
||||
|
||||
@pytest.fixture
|
||||
def quality_checker(self):
|
||||
config = {
|
||||
"thresholds": {
|
||||
"overall": 0.7,
|
||||
"bleu": 0.3,
|
||||
"semantic_similarity": 0.6,
|
||||
"length_ratio": 0.5,
|
||||
"confidence": 0.6
|
||||
}
|
||||
}
|
||||
return TranslationQualityChecker(config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_translation(self, quality_checker):
|
||||
"""Test translation quality evaluation"""
|
||||
with patch.object(quality_checker, '_evaluate_confidence') as mock_confidence, \
|
||||
patch.object(quality_checker, '_evaluate_length_ratio') as mock_length, \
|
||||
patch.object(quality_checker, '_evaluate_semantic_similarity') as mock_semantic, \
|
||||
patch.object(quality_checker, '_evaluate_consistency') as mock_consistency:
|
||||
|
||||
# Mock individual evaluations
|
||||
from .quality_assurance import QualityScore, QualityMetric
|
||||
mock_confidence.return_value = QualityScore(
|
||||
metric=QualityMetric.CONFIDENCE,
|
||||
score=0.8,
|
||||
weight=0.3,
|
||||
description="Test"
|
||||
)
|
||||
mock_length.return_value = QualityScore(
|
||||
metric=QualityMetric.LENGTH_RATIO,
|
||||
score=0.7,
|
||||
weight=0.2,
|
||||
description="Test"
|
||||
)
|
||||
mock_semantic.return_value = QualityScore(
|
||||
metric=QualityMetric.SEMANTIC_SIMILARITY,
|
||||
score=0.75,
|
||||
weight=0.3,
|
||||
description="Test"
|
||||
)
|
||||
mock_consistency.return_value = QualityScore(
|
||||
metric=QualityMetric.CONSISTENCY,
|
||||
score=0.9,
|
||||
weight=0.1,
|
||||
description="Test"
|
||||
)
|
||||
|
||||
assessment = await quality_checker.evaluate_translation(
|
||||
"Hello world", "Hola mundo", "en", "es"
|
||||
)
|
||||
|
||||
assert isinstance(assessment, QualityAssessment)
|
||||
assert assessment.overall_score > 0.7
|
||||
assert len(assessment.individual_scores) == 4
|
||||
|
||||
class TestMultilingualAgentCommunication:
|
||||
"""Test suite for MultilingualAgentCommunication"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_services(self):
|
||||
translation_engine = Mock()
|
||||
language_detector = Mock()
|
||||
translation_cache = Mock()
|
||||
quality_checker = Mock()
|
||||
|
||||
return {
|
||||
"translation_engine": translation_engine,
|
||||
"language_detector": language_detector,
|
||||
"translation_cache": translation_cache,
|
||||
"quality_checker": quality_checker
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def agent_comm(self, mock_services):
|
||||
return MultilingualAgentCommunication(
|
||||
mock_services["translation_engine"],
|
||||
mock_services["language_detector"],
|
||||
mock_services["translation_cache"],
|
||||
mock_services["quality_checker"]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_agent_language_profile(self, agent_comm):
|
||||
"""Test agent language profile registration"""
|
||||
profile = AgentLanguageProfile(
|
||||
agent_id="agent1",
|
||||
preferred_language="es",
|
||||
supported_languages=["es", "en"],
|
||||
auto_translate_enabled=True,
|
||||
translation_quality_threshold=0.7,
|
||||
cultural_preferences={}
|
||||
)
|
||||
|
||||
result = await agent_comm.register_agent_language_profile(profile)
|
||||
|
||||
assert result is True
|
||||
assert "agent1" in agent_comm.agent_profiles
|
||||
assert agent_comm.agent_profiles["agent1"].preferred_language == "es"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_with_translation(self, agent_comm, mock_services):
|
||||
"""Test sending message with automatic translation"""
|
||||
# Setup agent profile
|
||||
profile = AgentLanguageProfile(
|
||||
agent_id="agent2",
|
||||
preferred_language="es",
|
||||
supported_languages=["es", "en"],
|
||||
auto_translate_enabled=True,
|
||||
translation_quality_threshold=0.7,
|
||||
cultural_preferences={}
|
||||
)
|
||||
await agent_comm.register_agent_language_profile(profile)
|
||||
|
||||
# Mock language detection
|
||||
mock_services["language_detector"].detect_language.return_value = DetectionResult(
|
||||
"en", 0.95, DetectionMethod.LANGDETECT, [], 50
|
||||
)
|
||||
|
||||
# Mock translation
|
||||
mock_services["translation_engine"].translate.return_value = TranslationResponse(
|
||||
translated_text="Hola mundo",
|
||||
confidence=0.9,
|
||||
provider=TranslationProvider.OPENAI,
|
||||
processing_time_ms=120,
|
||||
source_language="en",
|
||||
target_language="es"
|
||||
)
|
||||
|
||||
message = AgentMessage(
|
||||
id="msg1",
|
||||
sender_id="agent1",
|
||||
receiver_id="agent2",
|
||||
message_type=MessageType.AGENT_TO_AGENT,
|
||||
content="Hello world"
|
||||
)
|
||||
|
||||
result = await agent_comm.send_message(message)
|
||||
|
||||
assert result.translated_content == "Hola mundo"
|
||||
assert result.translation_confidence == 0.9
|
||||
assert result.target_language == "es"
|
||||
|
||||
class TestMarketplaceLocalization:
|
||||
"""Test suite for MarketplaceLocalization"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_services(self):
|
||||
translation_engine = Mock()
|
||||
language_detector = Mock()
|
||||
translation_cache = Mock()
|
||||
quality_checker = Mock()
|
||||
|
||||
return {
|
||||
"translation_engine": translation_engine,
|
||||
"language_detector": language_detector,
|
||||
"translation_cache": translation_cache,
|
||||
"quality_checker": quality_checker
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def marketplace_loc(self, mock_services):
|
||||
return MarketplaceLocalization(
|
||||
mock_services["translation_engine"],
|
||||
mock_services["language_detector"],
|
||||
mock_services["translation_cache"],
|
||||
mock_services["quality_checker"]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_localized_listing(self, marketplace_loc, mock_services):
|
||||
"""Test creating localized listings"""
|
||||
original_listing = {
|
||||
"id": "listing1",
|
||||
"type": "service",
|
||||
"title": "AI Translation Service",
|
||||
"description": "High-quality translation service",
|
||||
"keywords": ["translation", "AI", "service"],
|
||||
"features": ["Fast translation", "High accuracy"],
|
||||
"requirements": ["API key", "Internet connection"],
|
||||
"pricing_info": {"price": 0.01, "unit": "character"}
|
||||
}
|
||||
|
||||
# Mock translation
|
||||
mock_services["translation_engine"].translate.return_value = TranslationResponse(
|
||||
translated_text="Servicio de Traducción IA",
|
||||
confidence=0.9,
|
||||
provider=TranslationProvider.OPENAI,
|
||||
processing_time_ms=150,
|
||||
source_language="en",
|
||||
target_language="es"
|
||||
)
|
||||
|
||||
result = await marketplace_loc.create_localized_listing(original_listing, ["es"])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].language == "es"
|
||||
assert result[0].title == "Servicio de Traducción IA"
|
||||
assert result[0].original_id == "listing1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_localized_listings(self, marketplace_loc):
|
||||
"""Test searching localized listings"""
|
||||
# Setup test data
|
||||
localized_listing = LocalizedListing(
|
||||
id="listing1_es",
|
||||
original_id="listing1",
|
||||
listing_type=ListingType.SERVICE,
|
||||
language="es",
|
||||
title="Servicio de Traducción",
|
||||
description="Servicio de alta calidad",
|
||||
keywords=["traducción", "servicio"],
|
||||
features=["Rápido", "Preciso"],
|
||||
requirements=["API", "Internet"],
|
||||
pricing_info={"price": 0.01}
|
||||
)
|
||||
|
||||
marketplace_loc.localized_listings["listing1"] = [localized_listing]
|
||||
|
||||
results = await marketplace_loc.search_localized_listings("traducción", "es")
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].language == "es"
|
||||
assert "traducción" in results[0].title.lower()
|
||||
|
||||
class TestMultiLanguageConfig:
|
||||
"""Test suite for MultiLanguageConfig"""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration"""
|
||||
config = MultiLanguageConfig()
|
||||
|
||||
assert "openai" in config.translation["providers"]
|
||||
assert "google" in config.translation["providers"]
|
||||
assert "deepl" in config.translation["providers"]
|
||||
assert config.cache["redis"]["url"] is not None
|
||||
assert config.quality["thresholds"]["overall"] == 0.7
|
||||
|
||||
def test_config_validation(self):
|
||||
"""Test configuration validation"""
|
||||
config = MultiLanguageConfig()
|
||||
|
||||
# Should have issues with missing API keys in test environment
|
||||
issues = config.validate()
|
||||
assert len(issues) > 0
|
||||
assert any("API key" in issue for issue in issues)
|
||||
|
||||
def test_environment_specific_configs(self):
|
||||
"""Test environment-specific configurations"""
|
||||
from .config import DevelopmentConfig, ProductionConfig, TestingConfig
|
||||
|
||||
dev_config = DevelopmentConfig()
|
||||
prod_config = ProductionConfig()
|
||||
test_config = TestingConfig()
|
||||
|
||||
assert dev_config.deployment["debug"] is True
|
||||
assert prod_config.deployment["debug"] is False
|
||||
assert test_config.cache["redis"]["url"] == "redis://localhost:6379/15"
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for multi-language services"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_translation_workflow(self):
|
||||
"""Test complete translation workflow"""
|
||||
# This would be a comprehensive integration test
|
||||
# mocking all external dependencies
|
||||
|
||||
# Setup mock services
|
||||
with patch('app.services.multi_language.translation_engine.openai') as mock_openai, \
|
||||
patch('app.services.multi_language.language_detector.langdetect') as mock_langdetect, \
|
||||
patch('redis.asyncio.from_url') as mock_redis:
|
||||
|
||||
# Configure mocks
|
||||
mock_openai.AsyncOpenAI.return_value.chat.completions.create.return_value = Mock(
|
||||
choices=[Mock(message=Mock(content="Hola mundo"))]
|
||||
)
|
||||
|
||||
mock_langdetect.detect.return_value = Mock(lang="en", prob=0.95)
|
||||
mock_redis.return_value.ping.return_value = True
|
||||
mock_redis.return_value.get.return_value = None # Cache miss
|
||||
|
||||
# Initialize services
|
||||
config = MultiLanguageConfig()
|
||||
translation_engine = TranslationEngine(config.translation)
|
||||
language_detector = LanguageDetector(config.detection)
|
||||
translation_cache = TranslationCache(config.cache["redis"]["url"])
|
||||
|
||||
await translation_cache.initialize()
|
||||
|
||||
# Test translation
|
||||
request = TranslationRequest(
|
||||
text="Hello world",
|
||||
source_language="en",
|
||||
target_language="es"
|
||||
)
|
||||
|
||||
result = await translation_engine.translate(request)
|
||||
|
||||
assert result.translated_text == "Hola mundo"
|
||||
assert result.provider == TranslationProvider.OPENAI
|
||||
|
||||
await translation_cache.close()
|
||||
|
||||
# Performance tests
|
||||
class TestPerformance:
|
||||
"""Performance tests for multi-language services"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_translation_performance(self):
|
||||
"""Test translation performance under load"""
|
||||
# This would test performance with concurrent requests
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_performance(self):
|
||||
"""Test cache performance under load"""
|
||||
# This would test cache performance with many concurrent operations
|
||||
pass
|
||||
|
||||
# Error handling tests
|
||||
class TestErrorHandling:
|
||||
"""Test error handling and edge cases"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_translation_engine_failure(self):
|
||||
"""Test translation engine failure handling"""
|
||||
config = {"openai": {"api_key": "invalid"}}
|
||||
engine = TranslationEngine(config)
|
||||
|
||||
request = TranslationRequest(
|
||||
text="Hello world",
|
||||
source_language="en",
|
||||
target_language="es"
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await engine.translate(request)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_text_handling(self):
|
||||
"""Test handling of empty or invalid text"""
|
||||
detector = LanguageDetector({})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await detector.detect_language("")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_language_handling(self):
|
||||
"""Test handling of unsupported languages"""
|
||||
config = MultiLanguageConfig()
|
||||
engine = TranslationEngine(config.translation)
|
||||
|
||||
request = TranslationRequest(
|
||||
text="Hello world",
|
||||
source_language="invalid_lang",
|
||||
target_language="es"
|
||||
)
|
||||
|
||||
# Should handle gracefully or raise appropriate error
|
||||
try:
|
||||
result = await engine.translate(request)
|
||||
# If successful, should have fallback behavior
|
||||
assert result is not None
|
||||
except Exception:
|
||||
# If failed, should be appropriate error
|
||||
pass
|
||||
|
||||
# Test utilities
|
||||
class TestUtils:
|
||||
"""Test utilities and helpers"""
|
||||
|
||||
def create_sample_translation_request(self):
|
||||
"""Create sample translation request for testing"""
|
||||
return TranslationRequest(
|
||||
text="Hello world, this is a test message",
|
||||
source_language="en",
|
||||
target_language="es",
|
||||
context="General communication",
|
||||
domain="general"
|
||||
)
|
||||
|
||||
def create_sample_agent_profile(self):
|
||||
"""Create sample agent profile for testing"""
|
||||
return AgentLanguageProfile(
|
||||
agent_id="test_agent",
|
||||
preferred_language="es",
|
||||
supported_languages=["es", "en", "fr"],
|
||||
auto_translate_enabled=True,
|
||||
translation_quality_threshold=0.7,
|
||||
cultural_preferences={"formality": "formal"}
|
||||
)
|
||||
|
||||
def create_sample_marketplace_listing(self):
|
||||
"""Create sample marketplace listing for testing"""
|
||||
return {
|
||||
"id": "test_listing",
|
||||
"type": "service",
|
||||
"title": "AI Translation Service",
|
||||
"description": "High-quality AI-powered translation service",
|
||||
"keywords": ["translation", "AI", "service"],
|
||||
"features": ["Fast", "Accurate", "Multi-language"],
|
||||
"requirements": ["API key", "Internet"],
|
||||
"pricing_info": {"price": 0.01, "unit": "character"}
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -0,0 +1,471 @@
|
||||
"""
|
||||
Translation Cache Service
|
||||
Redis-based caching for translation results to improve performance
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime, timedelta
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio import Redis
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
from .translation_engine import TranslationResponse, TranslationProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Cache entry for translation results"""
|
||||
translated_text: str
|
||||
confidence: float
|
||||
provider: str
|
||||
processing_time_ms: int
|
||||
source_language: str
|
||||
target_language: str
|
||||
created_at: float
|
||||
access_count: int = 0
|
||||
last_accessed: float = 0
|
||||
|
||||
class TranslationCache:
|
||||
"""Redis-based translation cache with intelligent eviction and statistics"""
|
||||
|
||||
def __init__(self, redis_url: str, config: Optional[Dict] = None):
|
||||
self.redis_url = redis_url
|
||||
self.config = config or {}
|
||||
self.redis: Optional[Redis] = None
|
||||
self.default_ttl = self.config.get("default_ttl", 86400) # 24 hours
|
||||
self.max_cache_size = self.config.get("max_cache_size", 100000)
|
||||
self.stats = {
|
||||
"hits": 0,
|
||||
"misses": 0,
|
||||
"sets": 0,
|
||||
"evictions": 0
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize Redis connection"""
|
||||
try:
|
||||
self.redis = redis.from_url(self.redis_url, decode_responses=False)
|
||||
# Test connection
|
||||
await self.redis.ping()
|
||||
logger.info("Translation cache Redis connection established")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""Close Redis connection"""
|
||||
if self.redis:
|
||||
await self.redis.close()
|
||||
|
||||
def _generate_cache_key(self, text: str, source_lang: str, target_lang: str,
|
||||
context: Optional[str] = None, domain: Optional[str] = None) -> str:
|
||||
"""Generate cache key for translation request"""
|
||||
|
||||
# Create a consistent key format
|
||||
key_parts = [
|
||||
"translate",
|
||||
source_lang.lower(),
|
||||
target_lang.lower(),
|
||||
hashlib.md5(text.encode()).hexdigest()
|
||||
]
|
||||
|
||||
if context:
|
||||
key_parts.append(hashlib.md5(context.encode()).hexdigest())
|
||||
|
||||
if domain:
|
||||
key_parts.append(domain.lower())
|
||||
|
||||
return ":".join(key_parts)
|
||||
|
||||
async def get(self, text: str, source_lang: str, target_lang: str,
|
||||
context: Optional[str] = None, domain: Optional[str] = None) -> Optional[TranslationResponse]:
|
||||
"""Get translation from cache"""
|
||||
|
||||
if not self.redis:
|
||||
return None
|
||||
|
||||
cache_key = self._generate_cache_key(text, source_lang, target_lang, context, domain)
|
||||
|
||||
try:
|
||||
cached_data = await self.redis.get(cache_key)
|
||||
|
||||
if cached_data:
|
||||
# Deserialize cache entry
|
||||
cache_entry = pickle.loads(cached_data)
|
||||
|
||||
# Update access statistics
|
||||
cache_entry.access_count += 1
|
||||
cache_entry.last_accessed = time.time()
|
||||
|
||||
# Update access count in Redis
|
||||
await self.redis.hset(f"{cache_key}:stats", "access_count", cache_entry.access_count)
|
||||
await self.redis.hset(f"{cache_key}:stats", "last_accessed", cache_entry.last_accessed)
|
||||
|
||||
self.stats["hits"] += 1
|
||||
|
||||
# Convert back to TranslationResponse
|
||||
return TranslationResponse(
|
||||
translated_text=cache_entry.translated_text,
|
||||
confidence=cache_entry.confidence,
|
||||
provider=TranslationProvider(cache_entry.provider),
|
||||
processing_time_ms=cache_entry.processing_time_ms,
|
||||
source_language=cache_entry.source_language,
|
||||
target_language=cache_entry.target_language
|
||||
)
|
||||
|
||||
self.stats["misses"] += 1
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache get error: {e}")
|
||||
self.stats["misses"] += 1
|
||||
return None
|
||||
|
||||
async def set(self, text: str, source_lang: str, target_lang: str,
|
||||
response: TranslationResponse, ttl: Optional[int] = None,
|
||||
context: Optional[str] = None, domain: Optional[str] = None) -> bool:
|
||||
"""Set translation in cache"""
|
||||
|
||||
if not self.redis:
|
||||
return False
|
||||
|
||||
cache_key = self._generate_cache_key(text, source_lang, target_lang, context, domain)
|
||||
ttl = ttl or self.default_ttl
|
||||
|
||||
try:
|
||||
# Create cache entry
|
||||
cache_entry = CacheEntry(
|
||||
translated_text=response.translated_text,
|
||||
confidence=response.confidence,
|
||||
provider=response.provider.value,
|
||||
processing_time_ms=response.processing_time_ms,
|
||||
source_language=response.source_language,
|
||||
target_language=response.target_language,
|
||||
created_at=time.time(),
|
||||
access_count=1,
|
||||
last_accessed=time.time()
|
||||
)
|
||||
|
||||
# Serialize and store
|
||||
serialized_entry = pickle.dumps(cache_entry)
|
||||
|
||||
# Use pipeline for atomic operations
|
||||
pipe = self.redis.pipeline()
|
||||
|
||||
# Set main cache entry
|
||||
pipe.setex(cache_key, ttl, serialized_entry)
|
||||
|
||||
# Set statistics
|
||||
stats_key = f"{cache_key}:stats"
|
||||
pipe.hset(stats_key, {
|
||||
"access_count": 1,
|
||||
"last_accessed": cache_entry.last_accessed,
|
||||
"created_at": cache_entry.created_at,
|
||||
"confidence": response.confidence,
|
||||
"provider": response.provider.value
|
||||
})
|
||||
pipe.expire(stats_key, ttl)
|
||||
|
||||
await pipe.execute()
|
||||
|
||||
self.stats["sets"] += 1
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache set error: {e}")
|
||||
return False
|
||||
|
||||
async def delete(self, text: str, source_lang: str, target_lang: str,
|
||||
context: Optional[str] = None, domain: Optional[str] = None) -> bool:
|
||||
"""Delete translation from cache"""
|
||||
|
||||
if not self.redis:
|
||||
return False
|
||||
|
||||
cache_key = self._generate_cache_key(text, source_lang, target_lang, context, domain)
|
||||
|
||||
try:
|
||||
pipe = self.redis.pipeline()
|
||||
pipe.delete(cache_key)
|
||||
pipe.delete(f"{cache_key}:stats")
|
||||
await pipe.execute()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Cache delete error: {e}")
|
||||
return False
|
||||
|
||||
async def clear_by_language_pair(self, source_lang: str, target_lang: str) -> int:
|
||||
"""Clear all cache entries for a specific language pair"""
|
||||
|
||||
if not self.redis:
|
||||
return 0
|
||||
|
||||
pattern = f"translate:{source_lang.lower()}:{target_lang.lower()}:*"
|
||||
|
||||
try:
|
||||
keys = await self.redis.keys(pattern)
|
||||
if keys:
|
||||
# Also delete stats keys
|
||||
stats_keys = [f"{key.decode()}:stats" for key in keys]
|
||||
all_keys = keys + stats_keys
|
||||
await self.redis.delete(*all_keys)
|
||||
return len(keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Cache clear by language pair error: {e}")
|
||||
return 0
|
||||
|
||||
async def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive cache statistics"""
|
||||
|
||||
if not self.redis:
|
||||
return {"error": "Redis not connected"}
|
||||
|
||||
try:
|
||||
# Get Redis info
|
||||
info = await self.redis.info()
|
||||
|
||||
# Calculate hit ratio
|
||||
total_requests = self.stats["hits"] + self.stats["misses"]
|
||||
hit_ratio = self.stats["hits"] / total_requests if total_requests > 0 else 0
|
||||
|
||||
# Get cache size
|
||||
cache_size = await self.redis.dbsize()
|
||||
|
||||
# Get memory usage
|
||||
memory_used = info.get("used_memory", 0)
|
||||
memory_human = self._format_bytes(memory_used)
|
||||
|
||||
return {
|
||||
"hits": self.stats["hits"],
|
||||
"misses": self.stats["misses"],
|
||||
"sets": self.stats["sets"],
|
||||
"evictions": self.stats["evictions"],
|
||||
"hit_ratio": hit_ratio,
|
||||
"cache_size": cache_size,
|
||||
"memory_used": memory_used,
|
||||
"memory_human": memory_human,
|
||||
"redis_connected": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache stats error: {e}")
|
||||
return {"error": str(e), "redis_connected": False}
|
||||
|
||||
async def get_top_translations(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get most accessed translations"""
|
||||
|
||||
if not self.redis:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Get all stats keys
|
||||
stats_keys = await self.redis.keys("translate:*:stats")
|
||||
|
||||
if not stats_keys:
|
||||
return []
|
||||
|
||||
# Get access counts for all entries
|
||||
pipe = self.redis.pipeline()
|
||||
for key in stats_keys:
|
||||
pipe.hget(key, "access_count")
|
||||
pipe.hget(key, "translated_text")
|
||||
pipe.hget(key, "source_language")
|
||||
pipe.hget(key, "target_language")
|
||||
pipe.hget(key, "confidence")
|
||||
|
||||
results = await pipe.execute()
|
||||
|
||||
# Process results
|
||||
translations = []
|
||||
for i in range(0, len(results), 5):
|
||||
access_count = results[i]
|
||||
translated_text = results[i+1]
|
||||
source_lang = results[i+2]
|
||||
target_lang = results[i+3]
|
||||
confidence = results[i+4]
|
||||
|
||||
if access_count and translated_text:
|
||||
translations.append({
|
||||
"access_count": int(access_count),
|
||||
"translated_text": translated_text.decode() if isinstance(translated_text, bytes) else translated_text,
|
||||
"source_language": source_lang.decode() if isinstance(source_lang, bytes) else source_lang,
|
||||
"target_language": target_lang.decode() if isinstance(target_lang, bytes) else target_lang,
|
||||
"confidence": float(confidence) if confidence else 0.0
|
||||
})
|
||||
|
||||
# Sort by access count and limit
|
||||
translations.sort(key=lambda x: x["access_count"], reverse=True)
|
||||
return translations[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Get top translations error: {e}")
|
||||
return []
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Clean up expired entries"""
|
||||
|
||||
if not self.redis:
|
||||
return 0
|
||||
|
||||
try:
|
||||
# Redis automatically handles TTL expiration
|
||||
# This method can be used for manual cleanup if needed
|
||||
# For now, just return cache size
|
||||
cache_size = await self.redis.dbsize()
|
||||
return cache_size
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup error: {e}")
|
||||
return 0
|
||||
|
||||
async def optimize_cache(self) -> Dict[str, Any]:
|
||||
"""Optimize cache by removing low-access entries"""
|
||||
|
||||
if not self.redis:
|
||||
return {"error": "Redis not connected"}
|
||||
|
||||
try:
|
||||
# Get current cache size
|
||||
current_size = await self.redis.dbsize()
|
||||
|
||||
if current_size <= self.max_cache_size:
|
||||
return {"status": "no_optimization_needed", "current_size": current_size}
|
||||
|
||||
# Get entries with lowest access counts
|
||||
stats_keys = await self.redis.keys("translate:*:stats")
|
||||
|
||||
if not stats_keys:
|
||||
return {"status": "no_stats_found", "current_size": current_size}
|
||||
|
||||
# Get access counts
|
||||
pipe = self.redis.pipeline()
|
||||
for key in stats_keys:
|
||||
pipe.hget(key, "access_count")
|
||||
|
||||
access_counts = await pipe.execute()
|
||||
|
||||
# Sort by access count
|
||||
entries_with_counts = []
|
||||
for i, key in enumerate(stats_keys):
|
||||
count = access_counts[i]
|
||||
if count:
|
||||
entries_with_counts.append((key, int(count)))
|
||||
|
||||
entries_with_counts.sort(key=lambda x: x[1])
|
||||
|
||||
# Remove entries with lowest access counts
|
||||
entries_to_remove = entries_with_counts[:len(entries_with_counts) // 4] # Remove bottom 25%
|
||||
|
||||
if entries_to_remove:
|
||||
keys_to_delete = []
|
||||
for key, _ in entries_to_remove:
|
||||
key_str = key.decode() if isinstance(key, bytes) else key
|
||||
keys_to_delete.append(key_str)
|
||||
keys_to_delete.append(key_str.replace(":stats", "")) # Also delete main entry
|
||||
|
||||
await self.redis.delete(*keys_to_delete)
|
||||
self.stats["evictions"] += len(entries_to_remove)
|
||||
|
||||
new_size = await self.redis.dbsize()
|
||||
|
||||
return {
|
||||
"status": "optimization_completed",
|
||||
"entries_removed": len(entries_to_remove),
|
||||
"previous_size": current_size,
|
||||
"new_size": new_size
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache optimization error: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _format_bytes(self, bytes_value: int) -> str:
|
||||
"""Format bytes in human readable format"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB']:
|
||||
if bytes_value < 1024.0:
|
||||
return f"{bytes_value:.2f} {unit}"
|
||||
bytes_value /= 1024.0
|
||||
return f"{bytes_value:.2f} TB"
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Health check for cache service"""
|
||||
|
||||
health_status = {
|
||||
"redis_connected": False,
|
||||
"cache_size": 0,
|
||||
"hit_ratio": 0.0,
|
||||
"memory_usage": 0,
|
||||
"status": "unhealthy"
|
||||
}
|
||||
|
||||
if not self.redis:
|
||||
return health_status
|
||||
|
||||
try:
|
||||
# Test Redis connection
|
||||
await self.redis.ping()
|
||||
health_status["redis_connected"] = True
|
||||
|
||||
# Get stats
|
||||
stats = await self.get_cache_stats()
|
||||
health_status.update(stats)
|
||||
|
||||
# Determine health status
|
||||
if stats.get("hit_ratio", 0) > 0.7 and stats.get("redis_connected", False):
|
||||
health_status["status"] = "healthy"
|
||||
elif stats.get("hit_ratio", 0) > 0.5:
|
||||
health_status["status"] = "degraded"
|
||||
|
||||
return health_status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache health check failed: {e}")
|
||||
health_status["error"] = str(e)
|
||||
return health_status
|
||||
|
||||
async def export_cache_data(self, output_file: str) -> bool:
|
||||
"""Export cache data for backup or analysis"""
|
||||
|
||||
if not self.redis:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Get all cache keys
|
||||
keys = await self.redis.keys("translate:*")
|
||||
|
||||
if not keys:
|
||||
return True
|
||||
|
||||
# Export data
|
||||
export_data = []
|
||||
|
||||
for key in keys:
|
||||
if b":stats" in key:
|
||||
continue # Skip stats keys
|
||||
|
||||
try:
|
||||
cached_data = await self.redis.get(key)
|
||||
if cached_data:
|
||||
cache_entry = pickle.loads(cached_data)
|
||||
export_data.append(asdict(cache_entry))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to export key {key}: {e}")
|
||||
continue
|
||||
|
||||
# Write to file
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(export_data, f, indent=2)
|
||||
|
||||
logger.info(f"Exported {len(export_data)} cache entries to {output_file}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache export failed: {e}")
|
||||
return False
|
||||
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
Multi-Language Translation Engine
|
||||
Core translation orchestration service for AITBC platform
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import openai
|
||||
import google.cloud.translate_v2 as translate
|
||||
import deepl
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TranslationProvider(Enum):
|
||||
OPENAI = "openai"
|
||||
GOOGLE = "google"
|
||||
DEEPL = "deepl"
|
||||
LOCAL = "local"
|
||||
|
||||
@dataclass
|
||||
class TranslationRequest:
|
||||
text: str
|
||||
source_language: str
|
||||
target_language: str
|
||||
context: Optional[str] = None
|
||||
domain: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class TranslationResponse:
|
||||
translated_text: str
|
||||
confidence: float
|
||||
provider: TranslationProvider
|
||||
processing_time_ms: int
|
||||
source_language: str
|
||||
target_language: str
|
||||
|
||||
class BaseTranslator(ABC):
|
||||
"""Base class for translation providers"""
|
||||
|
||||
@abstractmethod
|
||||
async def translate(self, request: TranslationRequest) -> TranslationResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_languages(self) -> List[str]:
|
||||
pass
|
||||
|
||||
class OpenAITranslator(BaseTranslator):
|
||||
"""OpenAI GPT-4 based translation"""
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
self.client = openai.AsyncOpenAI(api_key=api_key)
|
||||
|
||||
async def translate(self, request: TranslationRequest) -> TranslationResponse:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
prompt = self._build_prompt(request)
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a professional translator. Translate the given text accurately while preserving context and cultural nuances."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=2000
|
||||
)
|
||||
|
||||
translated_text = response.choices[0].message.content.strip()
|
||||
processing_time = int((asyncio.get_event_loop().time() - start_time) * 1000)
|
||||
|
||||
return TranslationResponse(
|
||||
translated_text=translated_text,
|
||||
confidence=0.95, # GPT-4 typically high confidence
|
||||
provider=TranslationProvider.OPENAI,
|
||||
processing_time_ms=processing_time,
|
||||
source_language=request.source_language,
|
||||
target_language=request.target_language
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI translation error: {e}")
|
||||
raise
|
||||
|
||||
def _build_prompt(self, request: TranslationRequest) -> str:
|
||||
prompt = f"Translate the following text from {request.source_language} to {request.target_language}:\n\n"
|
||||
prompt += f"Text: {request.text}\n\n"
|
||||
|
||||
if request.context:
|
||||
prompt += f"Context: {request.context}\n"
|
||||
|
||||
if request.domain:
|
||||
prompt += f"Domain: {request.domain}\n"
|
||||
|
||||
prompt += "Provide only the translation without additional commentary."
|
||||
return prompt
|
||||
|
||||
def get_supported_languages(self) -> List[str]:
|
||||
return ["en", "zh", "es", "fr", "de", "ja", "ko", "ru", "ar", "hi", "pt", "it", "nl", "sv", "da", "no", "fi"]
|
||||
|
||||
class GoogleTranslator(BaseTranslator):
|
||||
"""Google Translate API integration"""
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
self.client = translate.Client(api_key=api_key)
|
||||
|
||||
async def translate(self, request: TranslationRequest) -> TranslationResponse:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: self.client.translate(
|
||||
request.text,
|
||||
source_language=request.source_language,
|
||||
target_language=request.target_language
|
||||
)
|
||||
)
|
||||
|
||||
translated_text = result['translatedText']
|
||||
processing_time = int((asyncio.get_event_loop().time() - start_time) * 1000)
|
||||
|
||||
return TranslationResponse(
|
||||
translated_text=translated_text,
|
||||
confidence=0.85, # Google Translate moderate confidence
|
||||
provider=TranslationProvider.GOOGLE,
|
||||
processing_time_ms=processing_time,
|
||||
source_language=request.source_language,
|
||||
target_language=request.target_language
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Google translation error: {e}")
|
||||
raise
|
||||
|
||||
def get_supported_languages(self) -> List[str]:
|
||||
return ["en", "zh", "zh-cn", "zh-tw", "es", "fr", "de", "ja", "ko", "ru", "ar", "hi", "pt", "it", "nl", "sv", "da", "no", "fi", "th", "vi"]
|
||||
|
||||
class DeepLTranslator(BaseTranslator):
|
||||
"""DeepL API integration for European languages"""
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
self.translator = deepl.Translator(api_key)
|
||||
|
||||
async def translate(self, request: TranslationRequest) -> TranslationResponse:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
try:
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: self.translator.translate_text(
|
||||
request.text,
|
||||
source_lang=request.source_language.upper(),
|
||||
target_lang=request.target_language.upper()
|
||||
)
|
||||
)
|
||||
|
||||
translated_text = result.text
|
||||
processing_time = int((asyncio.get_event_loop().time() - start_time) * 1000)
|
||||
|
||||
return TranslationResponse(
|
||||
translated_text=translated_text,
|
||||
confidence=0.90, # DeepL high confidence for European languages
|
||||
provider=TranslationProvider.DEEPL,
|
||||
processing_time_ms=processing_time,
|
||||
source_language=request.source_language,
|
||||
target_language=request.target_language
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DeepL translation error: {e}")
|
||||
raise
|
||||
|
||||
def get_supported_languages(self) -> List[str]:
|
||||
return ["en", "de", "fr", "es", "pt", "it", "nl", "sv", "da", "fi", "pl", "ru", "ja", "zh"]
|
||||
|
||||
class LocalTranslator(BaseTranslator):
|
||||
"""Local MarianMT models for privacy-preserving translation"""
|
||||
|
||||
def __init__(self):
|
||||
# Placeholder for local model initialization
|
||||
# In production, this would load MarianMT models
|
||||
self.models = {}
|
||||
|
||||
async def translate(self, request: TranslationRequest) -> TranslationResponse:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Placeholder implementation
|
||||
# In production, this would use actual local models
|
||||
await asyncio.sleep(0.1) # Simulate processing time
|
||||
|
||||
translated_text = f"[LOCAL] {request.text}"
|
||||
processing_time = int((asyncio.get_event_loop().time() - start_time) * 1000)
|
||||
|
||||
return TranslationResponse(
|
||||
translated_text=translated_text,
|
||||
confidence=0.75, # Local models moderate confidence
|
||||
provider=TranslationProvider.LOCAL,
|
||||
processing_time_ms=processing_time,
|
||||
source_language=request.source_language,
|
||||
target_language=request.target_language
|
||||
)
|
||||
|
||||
def get_supported_languages(self) -> List[str]:
|
||||
return ["en", "de", "fr", "es"]
|
||||
|
||||
class TranslationEngine:
|
||||
"""Main translation orchestration engine"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.translators = self._initialize_translators()
|
||||
self.cache = None # Will be injected
|
||||
self.quality_checker = None # Will be injected
|
||||
|
||||
def _initialize_translators(self) -> Dict[TranslationProvider, BaseTranslator]:
|
||||
translators = {}
|
||||
|
||||
if self.config.get("openai", {}).get("api_key"):
|
||||
translators[TranslationProvider.OPENAI] = OpenAITranslator(
|
||||
self.config["openai"]["api_key"]
|
||||
)
|
||||
|
||||
if self.config.get("google", {}).get("api_key"):
|
||||
translators[TranslationProvider.GOOGLE] = GoogleTranslator(
|
||||
self.config["google"]["api_key"]
|
||||
)
|
||||
|
||||
if self.config.get("deepl", {}).get("api_key"):
|
||||
translators[TranslationProvider.DEEPL] = DeepLTranslator(
|
||||
self.config["deepl"]["api_key"]
|
||||
)
|
||||
|
||||
# Always include local translator as fallback
|
||||
translators[TranslationProvider.LOCAL] = LocalTranslator()
|
||||
|
||||
return translators
|
||||
|
||||
async def translate(self, request: TranslationRequest) -> TranslationResponse:
|
||||
"""Main translation method with fallback strategy"""
|
||||
|
||||
# Check cache first
|
||||
cache_key = self._generate_cache_key(request)
|
||||
if self.cache:
|
||||
cached_result = await self.cache.get(cache_key)
|
||||
if cached_result:
|
||||
logger.info(f"Cache hit for translation: {cache_key}")
|
||||
return cached_result
|
||||
|
||||
# Determine optimal translator for this request
|
||||
preferred_providers = self._get_preferred_providers(request)
|
||||
|
||||
last_error = None
|
||||
for provider in preferred_providers:
|
||||
if provider not in self.translators:
|
||||
continue
|
||||
|
||||
try:
|
||||
translator = self.translators[provider]
|
||||
result = await translator.translate(request)
|
||||
|
||||
# Quality check
|
||||
if self.quality_checker:
|
||||
quality_score = await self.quality_checker.evaluate_translation(
|
||||
request.text, result.translated_text,
|
||||
request.source_language, request.target_language
|
||||
)
|
||||
result.confidence = min(result.confidence, quality_score)
|
||||
|
||||
# Cache the result
|
||||
if self.cache and result.confidence > 0.8:
|
||||
await self.cache.set(cache_key, result, ttl=86400) # 24 hours
|
||||
|
||||
logger.info(f"Translation successful using {provider.value}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"Translation failed with {provider.value}: {e}")
|
||||
continue
|
||||
|
||||
# All providers failed
|
||||
logger.error(f"All translation providers failed. Last error: {last_error}")
|
||||
raise Exception("Translation failed with all providers")
|
||||
|
||||
def _get_preferred_providers(self, request: TranslationRequest) -> List[TranslationProvider]:
|
||||
"""Determine provider preference based on language pair and requirements"""
|
||||
|
||||
# Language-specific preferences
|
||||
european_languages = ["de", "fr", "es", "pt", "it", "nl", "sv", "da", "fi", "pl"]
|
||||
asian_languages = ["zh", "ja", "ko", "hi", "th", "vi"]
|
||||
|
||||
source_lang = request.source_language
|
||||
target_lang = request.target_language
|
||||
|
||||
# DeepL for European languages
|
||||
if (source_lang in european_languages or target_lang in european_languages) and TranslationProvider.DEEPL in self.translators:
|
||||
return [TranslationProvider.DEEPL, TranslationProvider.OPENAI, TranslationProvider.GOOGLE, TranslationProvider.LOCAL]
|
||||
|
||||
# OpenAI for complex translations with context
|
||||
if request.context or request.domain:
|
||||
return [TranslationProvider.OPENAI, TranslationProvider.GOOGLE, TranslationProvider.DEEPL, TranslationProvider.LOCAL]
|
||||
|
||||
# Google for speed and Asian languages
|
||||
if (source_lang in asian_languages or target_lang in asian_languages) and TranslationProvider.GOOGLE in self.translators:
|
||||
return [TranslationProvider.GOOGLE, TranslationProvider.OPENAI, TranslationProvider.DEEPL, TranslationProvider.LOCAL]
|
||||
|
||||
# Default preference
|
||||
return [TranslationProvider.OPENAI, TranslationProvider.GOOGLE, TranslationProvider.DEEPL, TranslationProvider.LOCAL]
|
||||
|
||||
def _generate_cache_key(self, request: TranslationRequest) -> str:
|
||||
"""Generate cache key for translation request"""
|
||||
content = f"{request.text}:{request.source_language}:{request.target_language}"
|
||||
if request.context:
|
||||
content += f":{request.context}"
|
||||
if request.domain:
|
||||
content += f":{request.domain}"
|
||||
|
||||
return hashlib.md5(content.encode()).hexdigest()
|
||||
|
||||
def get_supported_languages(self) -> Dict[str, List[str]]:
|
||||
"""Get all supported languages by provider"""
|
||||
supported = {}
|
||||
for provider, translator in self.translators.items():
|
||||
supported[provider.value] = translator.get_supported_languages()
|
||||
return supported
|
||||
|
||||
async def health_check(self) -> Dict[str, bool]:
|
||||
"""Check health of all translation providers"""
|
||||
health_status = {}
|
||||
|
||||
for provider, translator in self.translators.items():
|
||||
try:
|
||||
# Simple test translation
|
||||
test_request = TranslationRequest(
|
||||
text="Hello",
|
||||
source_language="en",
|
||||
target_language="es"
|
||||
)
|
||||
await translator.translate(test_request)
|
||||
health_status[provider.value] = True
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed for {provider.value}: {e}")
|
||||
health_status[provider.value] = False
|
||||
|
||||
return health_status
|
||||
Reference in New Issue
Block a user