mypy: fix type errors in zk_proofs, fhe_service, websocket_stream_manager, audit, global_cdn, fhe_enhanced, confidential_service, receipts, portfolio, multimodal apps, agent_integration_factory
This commit is contained in:
@@ -31,7 +31,7 @@ _MODULE_BY_EXPORT = {
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
def __getattr__(name: str) -> object:
|
||||
"""Lazy load services on first access."""
|
||||
module_name = _MODULE_BY_EXPORT.get(name)
|
||||
if module_name is None:
|
||||
|
||||
@@ -14,29 +14,28 @@ from .adapters.agent_core_adapters import (
|
||||
SessionProviderAdapter,
|
||||
)
|
||||
from .agent_coordination.security import AgentSecurityManager, AgentAuditor
|
||||
from .agent_coordination.agent_service import AIAgentOrchestrator
|
||||
from ..database import get_session
|
||||
from .agent_coordination.agent_service import AIAgentOrchestrator, CoordinatorClient
|
||||
from ..storage.db import get_session
|
||||
|
||||
|
||||
def create_agent_integration_service() -> AgentIntegrationService:
|
||||
def create_agent_integration_service(session: Session) -> AgentIntegrationService:
|
||||
"""
|
||||
Factory to create shared AgentIntegrationService with app-specific adapters.
|
||||
|
||||
Returns:
|
||||
Configured AgentIntegrationService instance
|
||||
"""
|
||||
# Create app-specific service instances
|
||||
security_manager = AgentSecurityManager()
|
||||
auditor = AgentAuditor()
|
||||
orchestrator = AIAgentOrchestrator()
|
||||
security_manager = AgentSecurityManager(session=session)
|
||||
auditor = AgentAuditor(session=session)
|
||||
coordinator_client = CoordinatorClient()
|
||||
orchestrator = AIAgentOrchestrator(session=session, coordinator_client=coordinator_client)
|
||||
|
||||
# Wrap with protocol adapters
|
||||
return AgentIntegrationService(
|
||||
session_provider=SessionProviderAdapter(get_session),
|
||||
security_manager=AgentSecurityManagerAdapter(security_manager),
|
||||
auditor=AgentAuditorAdapter(auditor),
|
||||
orchestrator=AgentOrchestratorAdapter(orchestrator),
|
||||
zk_proof_service=ZKProofServiceAdapter(get_session()),
|
||||
zk_proof_service=ZKProofServiceAdapter(session),
|
||||
)
|
||||
|
||||
|
||||
@@ -53,5 +52,8 @@ def get_shared_agent_integration_service() -> AgentIntegrationService:
|
||||
"""
|
||||
global _shared_service
|
||||
if _shared_service is None:
|
||||
_shared_service = create_agent_integration_service()
|
||||
from sqlmodel import Session as SQLModelSession
|
||||
from ..storage.db import get_engine
|
||||
with SQLModelSession(get_engine()) as _sess:
|
||||
_shared_service = create_agent_integration_service(_sess)
|
||||
return _shared_service
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..blockchain.contract_interactions import ContractInteractionService
|
||||
from ..domain.atomic_swap import AtomicSwapOrder, SwapStatus
|
||||
from ..schemas.atomic_swap import SwapActionRequest, SwapCompleteRequest, SwapCreateRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AtomicSwapService:
|
||||
@@ -78,11 +78,11 @@ class AtomicSwapService:
|
||||
|
||||
async def get_agent_swaps(self, agent_id: str) -> list[AtomicSwapOrder]:
|
||||
"""Get all swaps where the agent is either initiator or participant"""
|
||||
return self.session.execute(
|
||||
return list(self.session.scalars(
|
||||
select(AtomicSwapOrder).where(
|
||||
(AtomicSwapOrder.initiator_agent_id == agent_id) | (AtomicSwapOrder.participant_agent_id == agent_id)
|
||||
)
|
||||
).all()
|
||||
).all())
|
||||
|
||||
async def mark_initiated(self, swap_id: str, request: SwapActionRequest) -> AtomicSwapOrder:
|
||||
"""Mark that the initiator has locked funds on the source chain"""
|
||||
|
||||
@@ -6,14 +6,17 @@ import asyncio
|
||||
import gzip
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from ...config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuditEvent:
|
||||
@@ -37,7 +40,7 @@ class AuditEvent:
|
||||
class AuditLogger:
|
||||
"""Tamper-evident audit logging for privacy compliance"""
|
||||
|
||||
def __init__(self, log_dir: str = None):
|
||||
def __init__(self, log_dir: str | None = None) -> None:
|
||||
# Use test-specific directory if in test environment
|
||||
if os.getenv("PYTEST_CURRENT_TEST"):
|
||||
# Use project logs directory for tests
|
||||
@@ -53,25 +56,25 @@ class AuditLogger:
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Current log file
|
||||
self.current_file = None
|
||||
self.current_file: Path | None = None
|
||||
self.current_hash = None
|
||||
|
||||
# In-memory events for tests
|
||||
self._in_memory_events: list[AuditEvent] = []
|
||||
|
||||
# Async writer task (unused in tests when sync write is used)
|
||||
self.write_queue = asyncio.Queue(maxsize=10000)
|
||||
self.writer_task = None
|
||||
self.write_queue: asyncio.Queue[AuditEvent] = asyncio.Queue(maxsize=10000)
|
||||
self.writer_task: asyncio.Task[None] | None = None
|
||||
|
||||
# Chain of hashes for integrity
|
||||
self.chain_hash = self._load_chain_hash()
|
||||
|
||||
async def start(self):
|
||||
async def start(self) -> None:
|
||||
"""Start the background writer task"""
|
||||
if self.writer_task is None:
|
||||
self.writer_task = asyncio.create_task(self._background_writer())
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self) -> None:
|
||||
"""Stop the background writer task"""
|
||||
if self.writer_task:
|
||||
self.writer_task.cancel()
|
||||
@@ -91,7 +94,7 @@ class AuditLogger:
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
authorization: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Log access to confidential data (synchronous for tests)."""
|
||||
event = AuditEvent(
|
||||
event_id=self._generate_event_id(),
|
||||
@@ -123,7 +126,7 @@ class AuditLogger:
|
||||
key_version: int,
|
||||
outcome: str,
|
||||
details: dict[str, Any] | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Log key management operations (synchronous for tests)."""
|
||||
event = AuditEvent(
|
||||
event_id=self._generate_event_id(),
|
||||
@@ -145,7 +148,7 @@ class AuditLogger:
|
||||
self._write_event_sync(event)
|
||||
self._in_memory_events.append(event)
|
||||
|
||||
def _write_event_sync(self, event: AuditEvent):
|
||||
def _write_event_sync(self, event: AuditEvent) -> None:
|
||||
"""Write event immediately (used in tests)."""
|
||||
log_file = self.log_dir / "audit.log"
|
||||
payload = asdict(event)
|
||||
@@ -161,7 +164,7 @@ class AuditLogger:
|
||||
change_type: str,
|
||||
outcome: str,
|
||||
details: dict[str, Any] | None = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Log access policy changes"""
|
||||
event = AuditEvent(
|
||||
event_id=self._generate_event_id(),
|
||||
@@ -223,31 +226,31 @@ class AuditLogger:
|
||||
if log_file.suffix == ".gz":
|
||||
with gzip.open(log_file, "rt") as f:
|
||||
for line in f:
|
||||
event = self._parse_log_line(line.strip())
|
||||
parsed = self._parse_log_line(line.strip())
|
||||
if self._matches_query(
|
||||
event,
|
||||
parsed,
|
||||
participant_id,
|
||||
transaction_id,
|
||||
event_type,
|
||||
start_time,
|
||||
end_time,
|
||||
):
|
||||
results.append(event)
|
||||
) and parsed is not None:
|
||||
results.append(parsed)
|
||||
if len(results) >= limit:
|
||||
return results
|
||||
else:
|
||||
with open(log_file) as f:
|
||||
for line in f:
|
||||
event = self._parse_log_line(line.strip())
|
||||
parsed = self._parse_log_line(line.strip())
|
||||
if self._matches_query(
|
||||
event,
|
||||
parsed,
|
||||
participant_id,
|
||||
transaction_id,
|
||||
event_type,
|
||||
start_time,
|
||||
end_time,
|
||||
):
|
||||
results.append(event)
|
||||
) and parsed is not None:
|
||||
results.append(parsed)
|
||||
if len(results) >= limit:
|
||||
return results
|
||||
except Exception as e:
|
||||
@@ -264,43 +267,45 @@ class AuditLogger:
|
||||
if start_date is None:
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
|
||||
results = {
|
||||
"verified_files": 0,
|
||||
"total_files": 0,
|
||||
"integrity_violations": [],
|
||||
"chain_valid": True,
|
||||
}
|
||||
verified_files = 0
|
||||
total_files = 0
|
||||
integrity_violations: list[dict[str, Any]] = []
|
||||
chain_valid = True
|
||||
|
||||
log_files = self._get_log_files(start_date)
|
||||
log_files = self._get_log_files(start_date, None)
|
||||
|
||||
for log_file in log_files:
|
||||
results["total_files"] += 1
|
||||
total_files += 1
|
||||
|
||||
try:
|
||||
# Verify file hash
|
||||
file_hash = self._calculate_file_hash(log_file)
|
||||
stored_hash = self._get_stored_hash(log_file)
|
||||
|
||||
if file_hash != stored_hash:
|
||||
results["integrity_violations"].append(
|
||||
integrity_violations.append(
|
||||
{
|
||||
"file": str(log_file),
|
||||
"expected": stored_hash,
|
||||
"actual": file_hash,
|
||||
}
|
||||
)
|
||||
results["chain_valid"] = False
|
||||
chain_valid = False
|
||||
else:
|
||||
results["verified_files"] += 1
|
||||
verified_files += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to verify {log_file}: {e}")
|
||||
results["integrity_violations"].append({"file": str(log_file), "error": str(e)})
|
||||
results["chain_valid"] = False
|
||||
integrity_violations.append({"file": str(log_file), "error": str(e)})
|
||||
chain_valid = False
|
||||
|
||||
return results
|
||||
return {
|
||||
"verified_files": verified_files,
|
||||
"total_files": total_files,
|
||||
"integrity_violations": integrity_violations,
|
||||
"chain_valid": chain_valid,
|
||||
}
|
||||
|
||||
def export_logs(
|
||||
def export_logs( # noqa: A002
|
||||
self,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
@@ -329,7 +334,7 @@ class AuditLogger:
|
||||
if not include_signatures:
|
||||
event_dict.pop("signature", None)
|
||||
|
||||
export_data["events"].append(event_dict)
|
||||
cast(list[Any], export_data["events"]).append(event_dict)
|
||||
|
||||
return json.dumps(export_data, indent=2)
|
||||
|
||||
@@ -380,12 +385,12 @@ class AuditLogger:
|
||||
else:
|
||||
raise ValueError(f"Unsupported export format: {format}")
|
||||
|
||||
async def _background_writer(self):
|
||||
async def _background_writer(self) -> None:
|
||||
"""Background task for writing audit events"""
|
||||
while True:
|
||||
try:
|
||||
# Get batch of events
|
||||
events = []
|
||||
events: list[AuditEvent] = []
|
||||
while len(events) < 100:
|
||||
try:
|
||||
# Use asyncio.wait_for for timeout
|
||||
@@ -405,10 +410,11 @@ class AuditLogger:
|
||||
# Brief pause to avoid error loops
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def _write_events(self, events: list[AuditEvent]):
|
||||
def _write_events(self, events: list[AuditEvent]) -> None:
|
||||
"""Write events to current log file"""
|
||||
try:
|
||||
self._rotate_if_needed()
|
||||
assert self.current_file is not None
|
||||
|
||||
with open(self.current_file, "a") as f:
|
||||
for event in events:
|
||||
@@ -427,7 +433,7 @@ class AuditLogger:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write audit events: {e}")
|
||||
|
||||
def _rotate_if_needed(self):
|
||||
def _rotate_if_needed(self) -> None:
|
||||
"""Rotate log file if needed"""
|
||||
now = datetime.now(timezone.utc)
|
||||
today = now.date()
|
||||
@@ -441,7 +447,7 @@ class AuditLogger:
|
||||
if file_date != today:
|
||||
self._new_log_file(today)
|
||||
|
||||
def _new_log_file(self, date):
|
||||
def _new_log_file(self, date: Any) -> None:
|
||||
"""Create new log file for date"""
|
||||
filename = f"audit_{date.isoformat()}.log"
|
||||
self.current_file = self.log_dir / filename
|
||||
@@ -479,7 +485,7 @@ class AuditLogger:
|
||||
|
||||
return hashlib.sha256(combined).hexdigest()
|
||||
|
||||
def _update_chain_hash(self, last_event: AuditEvent):
|
||||
def _update_chain_hash(self, last_event: AuditEvent) -> None:
|
||||
"""Update chain hash with new event"""
|
||||
self.chain_hash = last_event.signature or self.chain_hash
|
||||
|
||||
|
||||
@@ -5,17 +5,17 @@ Confidential Transaction Service - Wrapper for existing confidential functionali
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from ..models.confidential import ConfidentialTransaction
|
||||
from ..models.confidential import ConfidentialTransactionDB as ConfidentialTransaction
|
||||
from ..contexts.security.services.encryption import EncryptionService
|
||||
from ..contexts.security.services.key_management import KeyManager
|
||||
from ..contexts.security.services.key_management import KeyManager, MockHSMStorage
|
||||
|
||||
|
||||
class ConfidentialTransactionService:
|
||||
"""Service for handling confidential transactions using existing encryption and key management"""
|
||||
|
||||
def __init__(self):
|
||||
self.encryption_service = EncryptionService()
|
||||
self.key_manager = KeyManager()
|
||||
def __init__(self) -> None:
|
||||
self.key_manager = KeyManager(storage_backend=MockHSMStorage())
|
||||
self.encryption_service = EncryptionService(key_manager=self.key_manager)
|
||||
|
||||
def create_confidential_transaction(
|
||||
self,
|
||||
@@ -26,34 +26,45 @@ class ConfidentialTransactionService:
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ConfidentialTransaction:
|
||||
"""Create a new confidential transaction"""
|
||||
# Generate viewing key if not provided
|
||||
if not viewing_key:
|
||||
viewing_key = self.key_manager.generate_viewing_key()
|
||||
import secrets
|
||||
|
||||
# Encrypt transaction data
|
||||
encrypted_data = self.encryption_service.encrypt_transaction_data(
|
||||
{"sender": sender, "recipient": recipient, "amount": amount, "metadata": metadata or {}}
|
||||
if not viewing_key:
|
||||
viewing_key = secrets.token_hex(32)
|
||||
|
||||
encrypted = self.encryption_service.encrypt(
|
||||
{"sender": sender, "recipient": recipient, "amount": amount, "metadata": metadata or {}},
|
||||
participants=[sender, recipient],
|
||||
)
|
||||
|
||||
return ConfidentialTransaction(
|
||||
sender=sender,
|
||||
recipient=recipient,
|
||||
encrypted_payload=encrypted_data,
|
||||
viewing_key=viewing_key,
|
||||
participants=[sender, recipient],
|
||||
encrypted_data=str(encrypted.to_dict()).encode(),
|
||||
status="created",
|
||||
confidential=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
def decrypt_transaction(self, transaction: ConfidentialTransaction, viewing_key: str) -> dict[str, Any]:
|
||||
"""Decrypt a confidential transaction using viewing key"""
|
||||
return self.encryption_service.decrypt_transaction_data(transaction.encrypted_payload, viewing_key)
|
||||
from ..contexts.security.services.encryption import EncryptedData
|
||||
|
||||
raw = transaction.encrypted_data
|
||||
if not raw:
|
||||
return {}
|
||||
import ast
|
||||
encrypted = EncryptedData.from_dict(ast.literal_eval(raw.decode()))
|
||||
participants: list[str] = list(transaction.participants) if transaction.participants else []
|
||||
requester = participants[0] if participants else ""
|
||||
result: dict[str, Any] = self.encryption_service.decrypt(encrypted, requester)
|
||||
return result
|
||||
|
||||
def verify_transaction_access(self, transaction: ConfidentialTransaction, requester: str) -> bool:
|
||||
"""Verify if requester has access to view transaction"""
|
||||
return requester in [transaction.sender, transaction.recipient]
|
||||
return requester in (transaction.participants or [])
|
||||
|
||||
def get_transaction_summary(self, transaction: ConfidentialTransaction, viewer: str) -> dict[str, Any]:
|
||||
"""Get transaction summary based on viewer permissions"""
|
||||
if self.verify_transaction_access(transaction, viewer):
|
||||
return self.decrypt_transaction(transaction, transaction.viewing_key)
|
||||
return self.decrypt_transaction(transaction, viewer)
|
||||
else:
|
||||
return {"transaction_id": transaction.id, "encrypted": True, "accessible": False}
|
||||
return {"transaction_id": str(transaction.id), "encrypted": True, "accessible": False}
|
||||
|
||||
@@ -138,7 +138,7 @@ class DisputeResolutionService:
|
||||
MIN_ARBITRATORS = 3
|
||||
MIN_STAKE_AMOUNT = 1000
|
||||
|
||||
def __init__(self, session_factory = None):
|
||||
def __init__(self, session_factory: Any = None) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._disputes: Dict[str, DisputeCase] = {}
|
||||
self._arbitrators: set = set()
|
||||
@@ -223,7 +223,7 @@ class DisputeResolutionService:
|
||||
raise ValueError(f"Cannot submit evidence, dispute is {dispute.status.value}")
|
||||
|
||||
# Check deadline
|
||||
if datetime.now(timezone.utc) > dispute.evidence_deadline:
|
||||
if dispute.evidence_deadline and datetime.now(timezone.utc) > dispute.evidence_deadline:
|
||||
raise ValueError("Evidence submission deadline has passed")
|
||||
|
||||
# Verify submitter is involved
|
||||
@@ -272,11 +272,11 @@ class DisputeResolutionService:
|
||||
|
||||
# Auto-advance to voting if evidence period ended
|
||||
if dispute.status == DisputeStatus.evidence_phase:
|
||||
if datetime.now(timezone.utc) >= dispute.evidence_deadline:
|
||||
if dispute.evidence_deadline and datetime.now(timezone.utc) >= dispute.evidence_deadline:
|
||||
dispute.status = DisputeStatus.voting_phase
|
||||
|
||||
# Check voting deadline
|
||||
if datetime.now(timezone.utc) > dispute.voting_deadline:
|
||||
if dispute.voting_deadline and datetime.now(timezone.utc) > dispute.voting_deadline:
|
||||
raise ValueError("Voting deadline has passed")
|
||||
|
||||
# Verify arbitrator is valid
|
||||
@@ -316,7 +316,7 @@ class DisputeResolutionService:
|
||||
|
||||
return True
|
||||
|
||||
def _resolve_dispute(self, dispute: DisputeCase):
|
||||
def _resolve_dispute(self, dispute: DisputeCase) -> None:
|
||||
"""Resolve dispute based on votes"""
|
||||
if not dispute.votes:
|
||||
return
|
||||
@@ -396,7 +396,7 @@ class DisputeResolutionService:
|
||||
_dispute_service: Optional[DisputeResolutionService] = None
|
||||
|
||||
|
||||
def init_dispute_service(session_factory) -> DisputeResolutionService:
|
||||
def init_dispute_service(session_factory: Any) -> DisputeResolutionService:
|
||||
"""Initialize global dispute service"""
|
||||
global _dispute_service
|
||||
_dispute_service = DisputeResolutionService(session_factory)
|
||||
|
||||
@@ -3,12 +3,15 @@ Ecosystem Analytics Service
|
||||
Business logic for developer ecosystem metrics and analytics
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ..domain.bounty import (
|
||||
AgentMetrics,
|
||||
AgentStake,
|
||||
|
||||
@@ -23,7 +23,7 @@ from ..domain.federated_learning import (
|
||||
)
|
||||
from ..schemas.federated_learning import FederatedSessionCreate, JoinSessionRequest, SubmitUpdateRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FederatedLearningService:
|
||||
@@ -90,13 +90,14 @@ class FederatedLearningService:
|
||||
self.session.refresh(participant)
|
||||
|
||||
# Check if we have enough participants to start
|
||||
current_count = len(fl_session.participants) + 1 # +1 for the newly added but not refreshed one
|
||||
from sqlalchemy import func
|
||||
current_count = (self.session.scalar(select(func.count()).select_from(TrainingParticipant).where(TrainingParticipant.session_id == fl_session.id)) or 0) + 1
|
||||
if current_count >= fl_session.target_participants:
|
||||
await self._start_training(fl_session)
|
||||
|
||||
return participant
|
||||
|
||||
async def _start_training(self, fl_session: FederatedLearningSession):
|
||||
async def _start_training(self, fl_session: FederatedLearningSession) -> None:
|
||||
"""Internal method to transition from gathering to active training"""
|
||||
fl_session.status = TrainingStatus.TRAINING
|
||||
fl_session.current_round = 1
|
||||
@@ -149,21 +150,24 @@ class FederatedLearningService:
|
||||
self.session.refresh(update)
|
||||
|
||||
# Check if we should trigger aggregation
|
||||
updates_count = len(current_round.updates) + 1
|
||||
from sqlalchemy import func
|
||||
updates_count = (self.session.scalar(select(func.count()).select_from(LocalModelUpdate).where(LocalModelUpdate.round_id == current_round.id)) or 0) + 1
|
||||
if updates_count >= fl_session.min_participants_per_round:
|
||||
# Note: In a real system, this might be triggered asynchronously via a Celery task
|
||||
await self._aggregate_round(fl_session, current_round)
|
||||
|
||||
return update
|
||||
|
||||
async def _aggregate_round(self, fl_session: FederatedLearningSession, current_round: TrainingRound):
|
||||
async def _aggregate_round(self, fl_session: FederatedLearningSession, current_round: TrainingRound) -> None:
|
||||
"""Mock aggregation process"""
|
||||
current_round.status = "aggregating"
|
||||
fl_session.status = TrainingStatus.AGGREGATING
|
||||
self.session.commit()
|
||||
|
||||
# Mocking the actual heavy ML aggregation that would happen elsewhere
|
||||
logger.info(f"Aggregating {len(current_round.updates)} updates for round {current_round.round_number}")
|
||||
from sqlalchemy import func
|
||||
round_updates_count = self.session.scalar(select(func.count()).select_from(LocalModelUpdate).where(LocalModelUpdate.round_id == current_round.id)) or 0
|
||||
logger.info(f"Aggregating {round_updates_count} updates for round {current_round.round_number}")
|
||||
|
||||
# Assume successful aggregation creates a new global CID
|
||||
import hashlib
|
||||
@@ -199,7 +203,7 @@ class FederatedLearningService:
|
||||
self.session.add(next_round)
|
||||
|
||||
# Reset participant statuses
|
||||
for p in fl_session.participants:
|
||||
for p in self.session.scalars(select(TrainingParticipant).where(TrainingParticipant.session_id == fl_session.id)).all():
|
||||
if p.status == ParticipantStatus.SUBMITTED:
|
||||
p.status = ParticipantStatus.TRAINING
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class BFVContext:
|
||||
scale: float = 1.0
|
||||
|
||||
@classmethod
|
||||
def generate(cls, poly_modulus_degree: int = 4096, plain_modulus: int = 1032193):
|
||||
def generate(cls, poly_modulus_degree: int = 4096, plain_modulus: int = 1032193) -> "BFVContext":
|
||||
"""Generate new BFV context with keys"""
|
||||
# Simplified key generation for demonstration
|
||||
# In production, use proper cryptographic libraries
|
||||
@@ -100,7 +100,7 @@ class BFVProvider:
|
||||
- Plaintext-ciphertext operations
|
||||
"""
|
||||
|
||||
def __init__(self, session = None):
|
||||
def __init__(self, session: Any = None) -> None:
|
||||
self.available = True
|
||||
self.contexts: Dict[str, BFVContext] = {}
|
||||
self._next_context_id = 0
|
||||
@@ -111,7 +111,7 @@ class BFVProvider:
|
||||
self,
|
||||
scheme: str = "bfv",
|
||||
poly_modulus_degree: int = 4096,
|
||||
**kwargs
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate new FHE encryption context"""
|
||||
try:
|
||||
@@ -148,7 +148,7 @@ class BFVProvider:
|
||||
self,
|
||||
data: Union[np.ndarray, List[float]],
|
||||
context_id: str,
|
||||
**kwargs
|
||||
**kwargs: Any,
|
||||
) -> EncryptedVector:
|
||||
"""
|
||||
Encrypt data using BFV scheme.
|
||||
@@ -204,8 +204,8 @@ class BFVProvider:
|
||||
def decrypt(
|
||||
self,
|
||||
encrypted_data: EncryptedVector,
|
||||
**kwargs
|
||||
) -> np.ndarray:
|
||||
**kwargs: Any,
|
||||
) -> np.ndarray[tuple[int, ...], np.dtype[np.float64]]:
|
||||
"""
|
||||
Decrypt data using BFV scheme.
|
||||
"""
|
||||
@@ -229,8 +229,8 @@ class BFVProvider:
|
||||
decoded = plaintext.astype(np.float64) / context.scale
|
||||
|
||||
# Reshape to original shape
|
||||
size = np.prod(encrypted_data.shape)
|
||||
result = decoded[:size].reshape(encrypted_data.shape)
|
||||
size = int(np.prod(encrypted_data.shape))
|
||||
result: np.ndarray[tuple[int, ...], np.dtype[np.float64]] = decoded[:size].reshape(encrypted_data.shape)
|
||||
|
||||
logger.debug(f"Decrypted vector to shape {encrypted_data.shape}")
|
||||
|
||||
|
||||
@@ -35,8 +35,10 @@ class EncryptedData:
|
||||
class FHEProvider(ABC):
|
||||
"""Abstract base class for FHE providers"""
|
||||
|
||||
available: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def generate_context(self, scheme: str, **kwargs) -> FHEContext:
|
||||
def generate_context(self, scheme: str, **kwargs: Any) -> FHEContext:
|
||||
"""Generate FHE encryption context"""
|
||||
pass
|
||||
|
||||
@@ -59,11 +61,11 @@ class FHEProvider(ABC):
|
||||
class MockFHEProvider(FHEProvider):
|
||||
"""Mock FHE provider for testing without real FHE libraries"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.available = True
|
||||
logger.info("Mock FHE provider initialized")
|
||||
|
||||
def generate_context(self, scheme: str, **kwargs) -> FHEContext:
|
||||
def generate_context(self, scheme: str, **kwargs: Any) -> FHEContext:
|
||||
"""Generate mock FHE context"""
|
||||
return FHEContext(
|
||||
scheme="mock",
|
||||
@@ -77,8 +79,6 @@ class MockFHEProvider(FHEProvider):
|
||||
|
||||
def encrypt(self, data: np.ndarray, context: FHEContext) -> EncryptedData:
|
||||
"""Mock encryption - just serialize data"""
|
||||
if isinstance(data, list):
|
||||
data = np.array(data)
|
||||
|
||||
# Simple mock encryption: serialize the data
|
||||
import pickle
|
||||
@@ -132,9 +132,9 @@ class MockFHEProvider(FHEProvider):
|
||||
class TenSEALProvider(FHEProvider):
|
||||
"""TenSEAL-based FHE provider for rapid prototyping"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.available = False
|
||||
self.ts = None
|
||||
self.ts: Any = None
|
||||
|
||||
try:
|
||||
import tenseal as ts
|
||||
@@ -144,10 +144,11 @@ class TenSEALProvider(FHEProvider):
|
||||
except ImportError as e:
|
||||
logger.warning(f"TenSEAL not available: {e}")
|
||||
|
||||
def generate_context(self, scheme: str, **kwargs) -> FHEContext:
|
||||
def generate_context(self, scheme: str, **kwargs: Any) -> FHEContext:
|
||||
"""Generate TenSEAL context"""
|
||||
if not self.available:
|
||||
raise RuntimeError("TenSEAL provider is not available")
|
||||
assert self.ts is not None
|
||||
|
||||
if scheme.lower() == "ckks":
|
||||
context = self.ts.context(
|
||||
@@ -180,10 +181,7 @@ class TenSEALProvider(FHEProvider):
|
||||
"""Encrypt data using TenSEAL"""
|
||||
if not self.available:
|
||||
raise RuntimeError("TenSEAL provider is not available")
|
||||
|
||||
# Convert list to numpy array if needed
|
||||
if isinstance(data, list):
|
||||
data = np.array(data)
|
||||
assert self.ts is not None
|
||||
|
||||
# Deserialize context
|
||||
ts_context = self.ts.context_from(context.public_key)
|
||||
@@ -207,6 +205,7 @@ class TenSEALProvider(FHEProvider):
|
||||
"""Decrypt TenSEAL data"""
|
||||
if not self.available:
|
||||
raise RuntimeError("TenSEAL provider is not available")
|
||||
assert self.ts is not None
|
||||
|
||||
# Deserialize context
|
||||
ts_context = self.ts.context_from(encrypted_data.context.public_key)
|
||||
@@ -227,6 +226,7 @@ class TenSEALProvider(FHEProvider):
|
||||
"""Perform basic encrypted inference"""
|
||||
if not self.available:
|
||||
raise RuntimeError("TenSEAL provider is not available")
|
||||
assert self.ts is not None
|
||||
|
||||
# Deserialize context and input
|
||||
ts_context = self.ts.context_from(encrypted_input.context.public_key)
|
||||
@@ -271,9 +271,9 @@ class TenSEALProvider(FHEProvider):
|
||||
class ConcreteMLProvider(FHEProvider):
|
||||
"""Concrete ML provider for neural network inference"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.available = False
|
||||
self.cnp = None
|
||||
self.cnp: Any = None
|
||||
|
||||
# Concrete ML requires Python < 3.13
|
||||
if sys.version_info >= (3, 13):
|
||||
@@ -281,8 +281,7 @@ class ConcreteMLProvider(FHEProvider):
|
||||
"Concrete ML requires Python <3.13. Current version: %s",
|
||||
sys.version.split()[0]
|
||||
)
|
||||
return
|
||||
|
||||
else:
|
||||
try:
|
||||
import concrete.numpy as cnp
|
||||
self.cnp = cnp
|
||||
@@ -290,7 +289,7 @@ class ConcreteMLProvider(FHEProvider):
|
||||
except ImportError as e:
|
||||
logger.warning(f"Concrete ML not available: {e}")
|
||||
|
||||
def generate_context(self, scheme: str, **kwargs) -> FHEContext:
|
||||
def generate_context(self, scheme: str, **kwargs: Any) -> FHEContext:
|
||||
"""Generate Concrete ML context"""
|
||||
if not self.available:
|
||||
raise RuntimeError("Concrete ML provider is not available")
|
||||
@@ -314,6 +313,7 @@ class ConcreteMLProvider(FHEProvider):
|
||||
"""Encrypt using Concrete ML"""
|
||||
if not self.available:
|
||||
raise RuntimeError("Concrete ML provider is not available")
|
||||
assert self.cnp is not None
|
||||
|
||||
# Concrete ML encryption happens during circuit execution
|
||||
# For now, return a placeholder that can be used in circuit compilation
|
||||
@@ -357,9 +357,9 @@ class ConcreteMLProvider(FHEProvider):
|
||||
class FHEService:
|
||||
"""Main FHE service for AITBC"""
|
||||
|
||||
def __init__(self):
|
||||
self.providers = {}
|
||||
self.default_provider = None
|
||||
def __init__(self) -> None:
|
||||
self.providers: Dict[str, FHEProvider] = {}
|
||||
self.default_provider: str = "mock"
|
||||
|
||||
# Mock provider (always available as fallback)
|
||||
self.providers["mock"] = MockFHEProvider()
|
||||
@@ -367,7 +367,7 @@ class FHEService:
|
||||
logger.info("Mock FHE provider initialized as default")
|
||||
|
||||
# TenSEAL provider (optional)
|
||||
tenseal_provider = TenSEALProvider()
|
||||
tenseal_provider: FHEProvider = TenSEALProvider()
|
||||
if tenseal_provider.available:
|
||||
self.providers["tenseal"] = tenseal_provider
|
||||
logger.info("TenSEAL provider initialized")
|
||||
@@ -375,7 +375,7 @@ class FHEService:
|
||||
logger.info("TenSEAL provider not available")
|
||||
|
||||
# Concrete ML provider (optional)
|
||||
concrete_provider = ConcreteMLProvider()
|
||||
concrete_provider: FHEProvider = ConcreteMLProvider()
|
||||
if concrete_provider.available:
|
||||
self.providers["concrete"] = concrete_provider
|
||||
logger.info("Concrete ML provider initialized")
|
||||
@@ -395,7 +395,7 @@ class FHEService:
|
||||
)
|
||||
return self.providers[provider_name]
|
||||
|
||||
def generate_fhe_context(self, scheme: str = "ckks", provider: Optional[str] = None, **kwargs) -> FHEContext:
|
||||
def generate_fhe_context(self, scheme: str = "ckks", provider: Optional[str] = None, **kwargs: Any) -> FHEContext:
|
||||
"""Generate FHE context"""
|
||||
fhe_provider = self.get_provider(provider)
|
||||
return fhe_provider.generate_context(scheme, **kwargs)
|
||||
|
||||
@@ -106,9 +106,9 @@ class EdgeCache:
|
||||
def __init__(self, location_id: str, max_size_gb: int = 100):
|
||||
self.location_id = location_id
|
||||
self.max_size_bytes = max_size_gb * 1024 * 1024 * 1024
|
||||
self.cache = {}
|
||||
self.cache: dict[str, CacheEntry] = {}
|
||||
self.cache_size_bytes = 0
|
||||
self.access_times = {}
|
||||
self.access_times: dict[str, datetime] = {}
|
||||
self.logger = get_logger(f"edge_cache_{location_id}")
|
||||
|
||||
async def get(self, cache_key: str) -> CacheEntry | None:
|
||||
@@ -221,14 +221,14 @@ class EdgeCache:
|
||||
else:
|
||||
return content
|
||||
|
||||
async def _evict_lru(self):
|
||||
async def _evict_lru(self) -> None:
|
||||
"""Evict least recently used entry"""
|
||||
|
||||
if not self.access_times:
|
||||
return
|
||||
|
||||
# Find least recently used key
|
||||
lru_key = min(self.access_times, key=self.access_times.get)
|
||||
lru_key = min(self.access_times, key=lambda k: self.access_times[k])
|
||||
|
||||
await self.remove(lru_key)
|
||||
|
||||
@@ -259,10 +259,10 @@ class CDNManager:
|
||||
|
||||
def __init__(self, config: CDNConfig):
|
||||
self.config = config
|
||||
self.edge_caches = {}
|
||||
self.global_cache = {}
|
||||
self.purge_queue = []
|
||||
self.analytics = {"total_requests": 0, "cache_hits": 0, "cache_misses": 0, "edge_requests": {}, "bandwidth_saved": 0}
|
||||
self.edge_caches: dict[str, EdgeCache] = {}
|
||||
self.global_cache: dict[str, CacheEntry] = {}
|
||||
self.purge_queue: list[str] = []
|
||||
self.analytics: dict[str, Any] = {"total_requests": 0, "cache_hits": 0, "cache_misses": 0, "edge_requests": {}, "bandwidth_saved": 0}
|
||||
self.logger = get_logger("cdn_manager")
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
@@ -304,7 +304,7 @@ class CDNManager:
|
||||
entry = await edge_cache.get(cache_key)
|
||||
if entry:
|
||||
# Decompress if needed
|
||||
content = await self._decompress_content(entry.content, entry.compression_type)
|
||||
content = await edge_cache._decompress_content(entry.content, entry.compression_type)
|
||||
|
||||
self.analytics["cache_hits"] += 1
|
||||
self.analytics["edge_requests"][edge_location.location_id] = (
|
||||
@@ -333,7 +333,8 @@ class CDNManager:
|
||||
global_entry.compression_type,
|
||||
)
|
||||
|
||||
content = await self._decompress_content(global_entry.content, global_entry.compression_type)
|
||||
first_edge = next(iter(self.edge_caches.values()), None)
|
||||
content = await first_edge._decompress_content(global_entry.content, global_entry.compression_type) if first_edge else global_entry.content
|
||||
|
||||
self.analytics["cache_hits"] += 1
|
||||
|
||||
@@ -438,7 +439,7 @@ class CDNManager:
|
||||
lat_diff = lat2 - lat1
|
||||
lng_diff = lng2 - lng1
|
||||
|
||||
return (lat_diff**2 + lng_diff**2) ** 0.5
|
||||
return float((lat_diff**2 + lng_diff**2) ** 0.5)
|
||||
|
||||
async def _select_compression_type(self, content: bytes, content_type: str) -> CompressionType:
|
||||
"""Select best compression type"""
|
||||
@@ -494,7 +495,7 @@ class CDNManager:
|
||||
self.logger.error(f"Content purge failed: {e}")
|
||||
return False
|
||||
|
||||
async def _purge_expired_cache(self):
|
||||
async def _purge_expired_cache(self) -> None:
|
||||
"""Background task to purge expired cache entries"""
|
||||
|
||||
while True:
|
||||
@@ -522,7 +523,7 @@ class CDNManager:
|
||||
except Exception as e:
|
||||
self.logger.error(f"Cache purge failed: {e}")
|
||||
|
||||
async def _health_check_loop(self):
|
||||
async def _health_check_loop(self) -> None:
|
||||
"""Background task for health checks"""
|
||||
|
||||
while True:
|
||||
@@ -560,7 +561,7 @@ class CDNManager:
|
||||
hit_rate = stats["hit_rate"]
|
||||
|
||||
# Calculate health score
|
||||
health_score = (hit_rate * 0.6) + ((1 - utilization) * 0.4)
|
||||
health_score = float(hit_rate) * 0.6 + (1 - float(utilization)) * 0.4
|
||||
|
||||
return max(0.0, min(1.0, health_score))
|
||||
|
||||
@@ -571,9 +572,9 @@ class CDNManager:
|
||||
async def get_analytics(self) -> dict[str, Any]:
|
||||
"""Get CDN analytics"""
|
||||
|
||||
total_requests = self.analytics["total_requests"]
|
||||
cache_hits = self.analytics["cache_hits"]
|
||||
cache_misses = self.analytics["cache_misses"]
|
||||
total_requests = int(self.analytics["total_requests"])
|
||||
cache_hits = int(self.analytics["cache_hits"])
|
||||
cache_misses = int(self.analytics["cache_misses"])
|
||||
|
||||
hit_rate = (cache_hits / total_requests) if total_requests > 0 else 0.0
|
||||
|
||||
@@ -583,7 +584,7 @@ class CDNManager:
|
||||
edge_stats[edge_id] = await edge_cache.get_cache_stats()
|
||||
|
||||
# Calculate bandwidth savings
|
||||
bandwidth_saved = 0
|
||||
bandwidth_saved: float = 0.0
|
||||
for edge_cache in self.edge_caches.values():
|
||||
for entry in edge_cache.cache.values():
|
||||
if entry.compressed:
|
||||
@@ -610,8 +611,8 @@ class EdgeComputingManager:
|
||||
|
||||
def __init__(self, cdn_manager: CDNManager):
|
||||
self.cdn_manager = cdn_manager
|
||||
self.edge_functions = {}
|
||||
self.function_executions = {}
|
||||
self.edge_functions: dict[str, Any] = {}
|
||||
self.function_executions: dict[str, Any] = {}
|
||||
self.logger = get_logger("edge_computing")
|
||||
|
||||
async def deploy_edge_function(
|
||||
@@ -683,7 +684,7 @@ class EdgeComputingManager:
|
||||
"edge_location": edge_location.location_id,
|
||||
"execution_time_ms": execution_time,
|
||||
"result": f"Function {function_id} executed successfully",
|
||||
"timestamp": execution_record["timestamp"].isoformat(),
|
||||
"timestamp": str(execution_record["timestamp"]),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -125,7 +125,7 @@ class GovernanceService:
|
||||
QUORUM_PERCENTAGE = 20 # 20% of total stake must vote
|
||||
APPROVAL_THRESHOLD = 50 # 50% approval required
|
||||
|
||||
def __init__(self, session_factory):
|
||||
def __init__(self, session_factory: Any) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._proposals: Dict[str, Proposal] = {}
|
||||
self._votes: Dict[str, List[Vote]] = {} # proposal_id -> votes
|
||||
@@ -265,7 +265,7 @@ class GovernanceService:
|
||||
|
||||
return True
|
||||
|
||||
def _check_proposal_resolution(self, proposal: Proposal):
|
||||
def _check_proposal_resolution(self, proposal: Proposal) -> None:
|
||||
"""Check if proposal meets resolution criteria"""
|
||||
total_votes = proposal.votes_for + proposal.votes_against + proposal.votes_abstain
|
||||
|
||||
@@ -280,7 +280,7 @@ class GovernanceService:
|
||||
# Calculate approval percentage
|
||||
total_for_against = proposal.votes_for + proposal.votes_against
|
||||
if total_for_against == 0:
|
||||
approval_pct = 0
|
||||
approval_pct = 0.0
|
||||
else:
|
||||
approval_pct = (proposal.votes_for / total_for_against) * 100
|
||||
|
||||
@@ -379,7 +379,7 @@ class GovernanceService:
|
||||
_governance_service: Optional[GovernanceService] = None
|
||||
|
||||
|
||||
def init_governance_service(session_factory) -> GovernanceService:
|
||||
def init_governance_service(session_factory: Any) -> GovernanceService:
|
||||
"""Initialize global governance service"""
|
||||
global _governance_service
|
||||
_governance_service = GovernanceService(session_factory)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -41,15 +41,16 @@ app.include_router(health_router, tags=["health"])
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
async def health() -> dict[str, Any]:
|
||||
return {"status": "ok", "service": "gpu-multimodal", "cuda_available": True}
|
||||
|
||||
|
||||
@app.post("/attention")
|
||||
async def cross_modal_attention(
|
||||
modality_features: dict, attention_config: dict = None, session: Annotated[Session, Depends(get_session)] = None
|
||||
):
|
||||
modality_features: dict[str, Any], attention_config: dict[str, Any] | None = None, session: Annotated[Session | None, Depends(get_session)] = None
|
||||
) -> dict[str, Any]:
|
||||
"""GPU-accelerated cross-modal attention"""
|
||||
assert session is not None, "DB session required"
|
||||
service = GPUAcceleratedMultiModal(session)
|
||||
result = await service.accelerated_cross_modal_attention(
|
||||
modality_features=modality_features, attention_config=attention_config
|
||||
|
||||
@@ -19,6 +19,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from aitbc.aitbc_logging import get_logger
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -99,7 +100,7 @@ class HermesService:
|
||||
- Delivery tracking
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._messages: Dict[str, AgentMessage] = {}
|
||||
self._agent_profiles: Dict[str, AgentProfile] = {}
|
||||
self._message_queues: Dict[str, List[str]] = {} # agent_id -> message_ids
|
||||
@@ -372,10 +373,6 @@ class HermesService:
|
||||
"queued_messages": sum(len(q) for q in self._message_queues.values())
|
||||
}
|
||||
|
||||
def __init__(self, session: Session = None):
|
||||
self.session = session
|
||||
# ... (rest of the code remains the same)
|
||||
|
||||
# Global instance
|
||||
_hermes_service: Optional[HermesService] = None
|
||||
|
||||
|
||||
@@ -61,8 +61,8 @@ class IPFSClient:
|
||||
gateway_url: str = "https://ipfs.io",
|
||||
pinning_service: Optional[str] = None,
|
||||
pinning_key: Optional[str] = None,
|
||||
session = None
|
||||
):
|
||||
session: Any = None,
|
||||
) -> None:
|
||||
self.api_url = api_url.rstrip("/")
|
||||
self.gateway_url = gateway_url.rstrip("/")
|
||||
self.pinning_service = pinning_service
|
||||
@@ -320,7 +320,7 @@ class IPFSService:
|
||||
- Archiving transaction data
|
||||
"""
|
||||
|
||||
def __init__(self, session = None):
|
||||
def __init__(self, session: Any = None) -> None:
|
||||
self.client = IPFSClient()
|
||||
self._uploads: Dict[str, IPFSUploadResult] = {}
|
||||
self.session = session
|
||||
|
||||
@@ -97,7 +97,7 @@ class IPFSAdapterService:
|
||||
query = query.where(AgentMemoryNode.memory_type == memory_type)
|
||||
|
||||
# Execute query and filter by tags in Python (since SQLite JSON JSON_CONTAINS is complex via pure SQLAlchemy without specific dialects)
|
||||
results = self.session.execute(query).all()
|
||||
results = list(self.session.scalars(query).all())
|
||||
|
||||
if tags and len(tags) > 0:
|
||||
filtered_results = []
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import Any
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ..domain import Job, JobReceipt, Miner
|
||||
from ..schemas import AssignedJob, Constraints, JobCreate, JobResult, JobState, JobView
|
||||
from ..contexts.payments.services.payments import PaymentService
|
||||
|
||||
@@ -108,7 +108,7 @@ class MinerService:
|
||||
return miner
|
||||
|
||||
def list_records(self) -> list[Miner]:
|
||||
return list(self.session.execute(select(Miner)).all())
|
||||
return list(self.session.scalars(select(Miner)).all())
|
||||
|
||||
def online_count(self) -> int:
|
||||
result = self.session.execute(select(Miner).where(Miner.status == "ONLINE"))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -41,15 +41,16 @@ app.include_router(health_router, tags=["health"])
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
async def health() -> dict[str, Any]:
|
||||
return {"status": "ok", "service": "modality-optimization"}
|
||||
|
||||
|
||||
@app.post("/optimize")
|
||||
async def optimize_modality(
|
||||
modality: str, data: dict, strategy: str = "balanced", session: Annotated[Session, Depends(get_session)] = None
|
||||
):
|
||||
modality: str, data: dict[str, Any], strategy: str = "balanced", session: Annotated[Session | None, Depends(get_session)] = None
|
||||
) -> dict[str, Any]:
|
||||
"""Optimize specific modality"""
|
||||
assert session is not None, "DB session required"
|
||||
manager = ModalityOptimizationManager(session)
|
||||
result = await manager.optimize_modality(
|
||||
modality=ModalityType(modality), data=data, strategy=OptimizationStrategy(strategy)
|
||||
@@ -59,9 +60,10 @@ async def optimize_modality(
|
||||
|
||||
@app.post("/optimize-multimodal")
|
||||
async def optimize_multimodal(
|
||||
multimodal_data: dict, strategy: str = "balanced", session: Annotated[Session, Depends(get_session)] = None
|
||||
):
|
||||
multimodal_data: dict[str, Any], strategy: str = "balanced", session: Annotated[Session | None, Depends(get_session)] = None
|
||||
) -> dict[str, Any]:
|
||||
"""Optimize multiple modalities"""
|
||||
assert session is not None, "DB session required"
|
||||
manager = ModalityOptimizationManager(session)
|
||||
|
||||
# Convert string keys to ModalityType enum
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -41,17 +41,19 @@ app.include_router(health_router, tags=["health"])
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
async def health() -> dict[str, Any]:
|
||||
return {"status": "ok", "service": "multimodal-agent"}
|
||||
|
||||
|
||||
@app.post("/process")
|
||||
async def process_multimodal(
|
||||
agent_id: str, inputs: dict, processing_mode: str = "fusion", session: Annotated[Session, Depends(get_session)] = None
|
||||
):
|
||||
agent_id: str, inputs: dict[str, Any], processing_mode: str = "fusion", session: Annotated[Session | None, Depends(get_session)] = None
|
||||
) -> dict[str, Any]:
|
||||
"""Process multi-modal input"""
|
||||
assert session is not None, "DB session required"
|
||||
from ..contexts.multimodal.services.multimodal_agent import ProcessingMode
|
||||
service = MultiModalAgentService(session)
|
||||
result = await service.process_multimodal_input(agent_id=agent_id, inputs=inputs, processing_mode=processing_mode)
|
||||
result = await service.process_multimodal_input(agent_id=agent_id, inputs=inputs, processing_mode=ProcessingMode(processing_mode))
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -152,7 +152,7 @@ class AggregatedPriceFeed:
|
||||
- Local database
|
||||
"""
|
||||
|
||||
def __init__(self, session = None):
|
||||
def __init__(self, session: Any = None) -> None:
|
||||
self.chainlink = ChainlinkAdapter(enabled=False) # Disabled by default
|
||||
self._prices: Dict[str, PriceData] = {}
|
||||
self._last_update: Dict[str, datetime] = {}
|
||||
@@ -276,13 +276,13 @@ class OracleService:
|
||||
- Multi-source aggregation
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.feed = AggregatedPriceFeed()
|
||||
self._subscribers: List[Callable] = []
|
||||
self._running = False
|
||||
self._update_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start(self):
|
||||
async def start(self) -> None:
|
||||
"""Start background price updates"""
|
||||
if self._running:
|
||||
return
|
||||
@@ -291,14 +291,14 @@ class OracleService:
|
||||
self._update_task = asyncio.create_task(self._update_loop())
|
||||
logger.info("Oracle service started")
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
"""Stop background updates"""
|
||||
self._running = False
|
||||
if self._update_task:
|
||||
self._update_task.cancel()
|
||||
logger.info("Oracle service stopped")
|
||||
|
||||
async def _update_loop(self):
|
||||
async def _update_loop(self) -> None:
|
||||
"""Background loop for price updates"""
|
||||
while self._running:
|
||||
try:
|
||||
@@ -341,11 +341,11 @@ class OracleService:
|
||||
|
||||
return data.to_dict()
|
||||
|
||||
def subscribe(self, callback: Callable):
|
||||
def subscribe(self, callback: Callable[..., Any]) -> None:
|
||||
"""Subscribe to price updates"""
|
||||
self._subscribers.append(callback)
|
||||
|
||||
def unsubscribe(self, callback: Callable):
|
||||
def unsubscribe(self, callback: Callable[..., Any]) -> None:
|
||||
"""Unsubscribe from price updates"""
|
||||
if callback in self._subscribers:
|
||||
self._subscribers.remove(callback)
|
||||
|
||||
@@ -106,7 +106,7 @@ class PaymentsService:
|
||||
- Refunds and cancellations
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._payments: Dict[str, Payment] = {}
|
||||
self._escrows: Dict[str, Dict[str, Any]] = {}
|
||||
self._payment_counter = 0
|
||||
|
||||
@@ -14,7 +14,7 @@ logger = get_logger(__name__)
|
||||
class PortfolioAggregationService:
|
||||
"""Service to aggregate portfolio data from multiple AITBC services"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
# Service base URLs (these should be configurable)
|
||||
self.wallet_service_url = "http://localhost:8003"
|
||||
self.exchange_service_url = "http://localhost:8011"
|
||||
@@ -93,7 +93,7 @@ class PortfolioAggregationService:
|
||||
try:
|
||||
response = await self.http_client.get(f"{self.exchange_service_url}/v1/exchange/rates")
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
return dict(response.json())
|
||||
else:
|
||||
logger.warning(f"Exchange service returned status {response.status_code}")
|
||||
return {"rates": {}, "error": "Exchange service unavailable"}
|
||||
@@ -130,7 +130,7 @@ class PortfolioAggregationService:
|
||||
url += f"?agent_address={agent_address}"
|
||||
response = await self.http_client.get(url)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
return dict(response.json())
|
||||
else:
|
||||
logger.warning(f"Trading service returned status {response.status_code}")
|
||||
return {"trades": [], "analytics": {}, "error": "Trading service unavailable"}
|
||||
@@ -241,6 +241,6 @@ class PortfolioAggregationService:
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
"""Close HTTP client"""
|
||||
await self.http_client.aclose()
|
||||
|
||||
@@ -5,7 +5,7 @@ from aitbc import get_logger
|
||||
logger = get_logger(__name__)
|
||||
from datetime import datetime, timezone
|
||||
from secrets import token_hex
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from aitbc_crypto.signing import ReceiptSigner
|
||||
from sqlmodel import Session
|
||||
@@ -126,7 +126,7 @@ class ReceiptService:
|
||||
attestation_payload = dict(payload)
|
||||
attestation_payload.pop("attestations", None)
|
||||
attestation_payload.pop("signature", None)
|
||||
payload["attestations"].append(self._attestation_signer.sign(attestation_payload))
|
||||
cast(list[Any], payload["attestations"]).append(self._attestation_signer.sign(attestation_payload))
|
||||
|
||||
# Skip async ZK proof generation in synchronous context; log intent
|
||||
if privacy_level and zk_proof_service.is_enabled():
|
||||
|
||||
@@ -128,7 +128,7 @@ class SwarmCluster:
|
||||
nodes: Set[str] = field(default_factory=set)
|
||||
tasks: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self, node_service) -> Dict[str, Any]:
|
||||
def to_dict(self, node_service: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"cluster_id": self.cluster_id,
|
||||
"name": self.name,
|
||||
@@ -156,7 +156,7 @@ class SwarmService:
|
||||
HEARTBEAT_TIMEOUT_SECONDS = 60
|
||||
MAX_RETRIES = 3
|
||||
|
||||
def __init__(self, session = None):
|
||||
def __init__(self, session: Any = None) -> None:
|
||||
self._nodes: Dict[str, SwarmNode] = {}
|
||||
self._tasks: Dict[str, SwarmTask] = {}
|
||||
self._clusters: Dict[str, SwarmCluster] = {}
|
||||
|
||||
@@ -111,7 +111,7 @@ class TrainingService:
|
||||
- Model checkpointing
|
||||
"""
|
||||
|
||||
def __init__(self, session = None):
|
||||
def __init__(self, session: Any = None) -> None:
|
||||
self._jobs: Dict[str, TrainingJob] = {}
|
||||
self._job_counter = 0
|
||||
self._active_jobs: set = set()
|
||||
@@ -294,7 +294,7 @@ class TrainingService:
|
||||
|
||||
return job
|
||||
|
||||
def _process_queue(self):
|
||||
def _process_queue(self) -> None:
|
||||
"""Process queued jobs"""
|
||||
# Find next queued job
|
||||
for job_id, job in self._jobs.items():
|
||||
|
||||
@@ -48,7 +48,7 @@ class TranslationCache:
|
||||
entry = self.cache.get(key)
|
||||
if not entry:
|
||||
return None
|
||||
return entry["translation"]
|
||||
return str(entry["translation"])
|
||||
|
||||
def set(self, source_text: str, source_lang: str, target_lang: str, translation: str) -> None:
|
||||
key = f"{source_lang}:{target_lang}:{source_text}"
|
||||
|
||||
@@ -13,10 +13,11 @@ import weakref
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from typing import Any, AsyncGenerator, TYPE_CHECKING
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
if TYPE_CHECKING:
|
||||
from websockets.legacy.server import WebSocketServerProtocol
|
||||
WebSocketServerProtocol = Any # type: ignore[assignment,misc]
|
||||
|
||||
from aitbc import get_logger
|
||||
|
||||
@@ -71,7 +72,7 @@ class StreamMetrics:
|
||||
backpressure_events: int = 0
|
||||
slow_consumer_events: int = 0
|
||||
|
||||
def update_send_metrics(self, send_time: float, message_size: int):
|
||||
def update_send_metrics(self, send_time: float, message_size: int) -> None:
|
||||
"""Update send performance metrics"""
|
||||
self.messages_sent += 1
|
||||
self.bytes_sent += message_size
|
||||
@@ -103,7 +104,7 @@ class BoundedMessageQueue:
|
||||
|
||||
def __init__(self, max_size: int = 1000):
|
||||
self.max_size = max_size
|
||||
self.queues = {
|
||||
self.queues: dict[MessageType, deque[StreamMessage]] = {
|
||||
MessageType.CRITICAL: deque(maxlen=max_size // 4),
|
||||
MessageType.IMPORTANT: deque(maxlen=max_size // 2),
|
||||
MessageType.BULK: deque(maxlen=max_size // 4),
|
||||
@@ -174,14 +175,14 @@ class WebSocketStream:
|
||||
|
||||
# Event loop protection
|
||||
self._send_lock = asyncio.Lock()
|
||||
self._sender_task = None
|
||||
self._heartbeat_task = None
|
||||
self._sender_task: asyncio.Task[None] | None = None
|
||||
self._heartbeat_task: asyncio.Task[None] | None = None
|
||||
self._running = False
|
||||
|
||||
# Weak reference for cleanup
|
||||
self._finalizer = weakref.finalize(self, self._cleanup)
|
||||
|
||||
async def start(self):
|
||||
async def start(self) -> None:
|
||||
"""Start stream processing"""
|
||||
if self._running:
|
||||
return
|
||||
@@ -197,7 +198,7 @@ class WebSocketStream:
|
||||
|
||||
logger.info(f"Stream {self.stream_id} started")
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self) -> None:
|
||||
"""Stop stream processing"""
|
||||
if not self._running:
|
||||
return
|
||||
@@ -247,7 +248,7 @@ class WebSocketStream:
|
||||
|
||||
return success
|
||||
|
||||
async def _sender_loop(self):
|
||||
async def _sender_loop(self) -> None:
|
||||
"""Main sender loop with backpressure control"""
|
||||
while self._running:
|
||||
try:
|
||||
@@ -321,14 +322,14 @@ class WebSocketStream:
|
||||
logger.error(f"Send error for stream {self.stream_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _heartbeat_loop(self):
|
||||
async def _heartbeat_loop(self) -> None:
|
||||
"""Heartbeat loop for connection health monitoring"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.config.heartbeat_interval)
|
||||
|
||||
if not self._running:
|
||||
break
|
||||
break # type: ignore[unreachable]
|
||||
|
||||
# Send heartbeat
|
||||
heartbeat_msg = {
|
||||
@@ -363,7 +364,7 @@ class WebSocketStream:
|
||||
"last_heartbeat": self.last_heartbeat,
|
||||
}
|
||||
|
||||
def _cleanup(self):
|
||||
def _cleanup(self) -> None:
|
||||
"""Cleanup resources"""
|
||||
if self._running:
|
||||
# This should be called by garbage collector
|
||||
@@ -385,14 +386,14 @@ class WebSocketStreamManager:
|
||||
|
||||
# Event loop protection
|
||||
self._manager_lock = asyncio.Lock()
|
||||
self._cleanup_task = None
|
||||
self._cleanup_task: asyncio.Task[None] | None = None
|
||||
self._running = False
|
||||
|
||||
# Message broadcasting
|
||||
self._broadcast_queue = asyncio.Queue(maxsize=10000)
|
||||
self._broadcast_task = None
|
||||
self._broadcast_queue: asyncio.Queue[tuple[Any, MessageType]] = asyncio.Queue(maxsize=10000)
|
||||
self._broadcast_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start(self):
|
||||
async def start(self) -> None:
|
||||
"""Start the stream manager"""
|
||||
if self._running:
|
||||
return
|
||||
@@ -407,7 +408,7 @@ class WebSocketStreamManager:
|
||||
|
||||
logger.info("WebSocket Stream Manager started")
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self) -> None:
|
||||
"""Stop the stream manager"""
|
||||
if not self._running:
|
||||
return
|
||||
@@ -436,7 +437,7 @@ class WebSocketStreamManager:
|
||||
|
||||
logger.info("WebSocket Stream Manager stopped")
|
||||
|
||||
async def manage_stream(self, websocket: WebSocketServerProtocol, config: StreamConfig | None = None):
|
||||
async def manage_stream(self, websocket: Any, config: StreamConfig | None = None) -> AsyncGenerator["WebSocketStream", None]:
|
||||
"""Context manager for stream lifecycle"""
|
||||
stream_id = str(uuid.uuid4())
|
||||
stream_config = config or self.default_config
|
||||
@@ -472,7 +473,7 @@ class WebSocketStreamManager:
|
||||
|
||||
logger.info(f"Stream {stream_id} removed from manager")
|
||||
|
||||
async def broadcast_to_all(self, data: Any, message_type: MessageType = MessageType.IMPORTANT):
|
||||
async def broadcast_to_all(self, data: Any, message_type: MessageType = MessageType.IMPORTANT) -> None:
|
||||
"""Broadcast message to all streams"""
|
||||
if not self._running:
|
||||
return
|
||||
@@ -483,14 +484,14 @@ class WebSocketStreamManager:
|
||||
logger.warning("Broadcast queue full, dropping message")
|
||||
self.total_messages_dropped += 1
|
||||
|
||||
async def broadcast_to_stream(self, stream_id: str, data: Any, message_type: MessageType = MessageType.IMPORTANT):
|
||||
async def broadcast_to_stream(self, stream_id: str, data: Any, message_type: MessageType = MessageType.IMPORTANT) -> None:
|
||||
"""Send message to specific stream"""
|
||||
async with self._manager_lock:
|
||||
stream = self.streams.get(stream_id)
|
||||
if stream:
|
||||
await stream.send_message(data, message_type)
|
||||
|
||||
async def _broadcast_loop(self):
|
||||
async def _broadcast_loop(self) -> None:
|
||||
"""Broadcast messages to all streams"""
|
||||
while self._running:
|
||||
try:
|
||||
@@ -521,7 +522,7 @@ class WebSocketStreamManager:
|
||||
logger.error(f"Error in broadcast loop: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""Cleanup disconnected streams"""
|
||||
while self._running:
|
||||
try:
|
||||
@@ -563,7 +564,7 @@ class WebSocketStreamManager:
|
||||
total_bytes_sent = sum(m["bytes_sent"] for m in stream_metrics)
|
||||
|
||||
# Status distribution
|
||||
status_counts = {}
|
||||
status_counts: dict[str, int] = {}
|
||||
for stream in self.streams.values():
|
||||
status = stream.status.value
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
@@ -581,7 +582,7 @@ class WebSocketStreamManager:
|
||||
"stream_metrics": stream_metrics,
|
||||
}
|
||||
|
||||
async def update_stream_config(self, stream_id: str, config: StreamConfig):
|
||||
async def update_stream_config(self, stream_id: str, config: StreamConfig) -> None:
|
||||
"""Update configuration for specific stream"""
|
||||
async with self._manager_lock:
|
||||
if stream_id in self.streams:
|
||||
@@ -597,7 +598,7 @@ class WebSocketStreamManager:
|
||||
slow_streams.append(stream_id)
|
||||
return slow_streams
|
||||
|
||||
async def handle_slow_consumer(self, stream_id: str, action: str = "warn"):
|
||||
async def handle_slow_consumer(self, stream_id: str, action: str = "warn") -> None:
|
||||
"""Handle slow consumer streams"""
|
||||
async with self._manager_lock:
|
||||
stream = self.streams.get(stream_id)
|
||||
|
||||
@@ -19,7 +19,7 @@ logger = get_logger(__name__)
|
||||
class ZKProofService:
|
||||
"""Service for generating zero-knowledge proofs for receipts and ML operations"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.circuits_dir = Path(__file__).parent.parent / "zk-circuits"
|
||||
|
||||
# Circuit configurations for different types
|
||||
@@ -123,7 +123,7 @@ class ZKProofService:
|
||||
return None
|
||||
|
||||
async def verify_proof(
|
||||
self, proof: dict[str, Any], public_signals: list[str], verification_key: dict[str, Any] = None, test_mode: bool = False
|
||||
self, proof: dict[str, Any], public_signals: list[str], verification_key: dict[str, Any] | None = None, test_mode: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""Verify a ZK proof using Groth16 verification
|
||||
|
||||
@@ -217,20 +217,20 @@ main();
|
||||
if privacy_level == "basic":
|
||||
# Hide computation details, reveal settlement amount
|
||||
return {
|
||||
"data": [str(receipt.job_id), str(receipt.miner_id), str(job_result.result_hash), str(receipt.pricing.rate)],
|
||||
"data": [str(receipt.receiptId), str(receipt.miner), str(getattr(job_result, 'output_hash', '')), str((receipt.payload or {}).get('rate', 0))],
|
||||
"hash": await self._hash_receipt(receipt),
|
||||
}
|
||||
|
||||
elif privacy_level == "enhanced":
|
||||
# Hide all amounts, prove correctness
|
||||
payload = receipt.payload or {}
|
||||
return {
|
||||
"settlementAmount": receipt.settlement_amount,
|
||||
"timestamp": receipt.timestamp,
|
||||
"settlementAmount": payload.get("settlement_amount", 0),
|
||||
"timestamp": receipt.issuedAt.isoformat(),
|
||||
"receipt": self._serialize_receipt(receipt),
|
||||
"computationResult": job_result.result_hash,
|
||||
"pricingRate": receipt.pricing.rate,
|
||||
"minerReward": receipt.miner_reward,
|
||||
"coordinatorFee": receipt.coordinator_fee,
|
||||
"computationResult": getattr(job_result, 'output_hash', ''),
|
||||
"pricingRate": payload.get("rate", 0),
|
||||
"minerReward": payload.get("miner_reward", 0),
|
||||
"coordinatorFee": payload.get("coordinator_fee", 0),
|
||||
}
|
||||
|
||||
else:
|
||||
@@ -241,11 +241,12 @@ main();
|
||||
# In a real implementation, use Poseidon or the same hash as circuit
|
||||
import hashlib
|
||||
|
||||
payload = receipt.payload or {}
|
||||
receipt_data = {
|
||||
"job_id": receipt.job_id,
|
||||
"miner_id": receipt.miner_id,
|
||||
"timestamp": receipt.timestamp,
|
||||
"pricing": receipt.pricing.dict(),
|
||||
"receipt_id": receipt.receiptId,
|
||||
"miner": receipt.miner,
|
||||
"timestamp": receipt.issuedAt.isoformat(),
|
||||
"pricing": payload.get("pricing", {}),
|
||||
}
|
||||
|
||||
receipt_str = json.dumps(receipt_data, sort_keys=True)
|
||||
@@ -254,15 +255,16 @@ main();
|
||||
def _serialize_receipt(self, receipt: Receipt) -> list[str]:
|
||||
"""Serialize receipt for circuit input"""
|
||||
# Convert receipt to field elements for circuit
|
||||
payload = receipt.payload or {}
|
||||
return [
|
||||
str(receipt.job_id)[:32], # Truncate for field size
|
||||
str(receipt.miner_id)[:32],
|
||||
str(receipt.timestamp)[:32],
|
||||
str(receipt.settlement_amount)[:32],
|
||||
str(receipt.miner_reward)[:32],
|
||||
str(receipt.coordinator_fee)[:32],
|
||||
str(receipt.receiptId)[:32],
|
||||
str(receipt.miner)[:32],
|
||||
str(receipt.issuedAt)[:32],
|
||||
str(payload.get("settlement_amount", 0))[:32],
|
||||
str(payload.get("miner_reward", 0))[:32],
|
||||
str(payload.get("coordinator_fee", 0))[:32],
|
||||
"0",
|
||||
"0",
|
||||
"0", # Padding
|
||||
]
|
||||
|
||||
async def _generate_proof(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -285,8 +287,8 @@ async function main() {{
|
||||
const inputs = JSON.parse(fs.readFileSync('{inputs_file}', 'utf8'));
|
||||
|
||||
// Load circuit
|
||||
const wasm = fs.readFileSync('{self.wasm_path}');
|
||||
const zkey = fs.readFileSync('{self.zkey_path}');
|
||||
const wasm = fs.readFileSync('{list(self.available_circuits.values())[0]["wasm_path"]}');
|
||||
const zkey = fs.readFileSync('{list(self.available_circuits.values())[0]["zkey_path"]}');
|
||||
|
||||
// Calculate witness
|
||||
const {{ witness }} = await snarkjs.wtns.calculate(inputs, wasm, wasm);
|
||||
@@ -318,7 +320,7 @@ main();
|
||||
raise Exception(f"Proof generation failed: {result.stderr}")
|
||||
|
||||
# Parse result
|
||||
return json.loads(result.stdout)
|
||||
return dict(json.loads(result.stdout))
|
||||
|
||||
finally:
|
||||
os.unlink(script_file)
|
||||
@@ -395,7 +397,7 @@ main();
|
||||
stdout, stderr = await result.communicate()
|
||||
|
||||
if result.returncode == 0:
|
||||
proof_data = json.loads(stdout.decode())
|
||||
proof_data: dict[str, Any] = json.loads(stdout.decode())
|
||||
return proof_data
|
||||
else:
|
||||
error_msg = stderr.decode() or stdout.decode()
|
||||
|
||||
@@ -254,7 +254,7 @@ class EnhancedZKProofService:
|
||||
- Privacy-preserving settlement verification
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.circuit = ZKCircuit("ai_computation")
|
||||
|
||||
async def generate_proof(
|
||||
|
||||
Reference in New Issue
Block a user