docs: update CLI command syntax across workflow documentation
Some checks failed
API Endpoint Tests / test-api-endpoints (push) Waiting to run
Documentation Validation / validate-docs (push) Waiting to run
Integration Tests / test-service-integration (push) Waiting to run
Python Tests / test-python (push) Waiting to run
CLI Tests / test-cli (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled
Some checks failed
API Endpoint Tests / test-api-endpoints (push) Waiting to run
Documentation Validation / validate-docs (push) Waiting to run
Integration Tests / test-service-integration (push) Waiting to run
Python Tests / test-python (push) Waiting to run
CLI Tests / test-cli (push) Has been cancelled
Security Scanning / security-scan (push) Has been cancelled
- Updated marketplace commands: `marketplace --action` → `market` subcommands - Updated wallet commands: direct flags → `wallet` subcommands - Updated AI commands: `ai-submit`, `ai-status` → `ai submit`, `ai status` - Updated blockchain commands: `chain` → `blockchain info` - Standardized command structure across all workflow files - Affected files: MULTI_NODE_MASTER_INDEX.md, TEST_MASTER_INDEX.md, multi-node-blockchain-marketplace
This commit is contained in:
@@ -189,7 +189,7 @@ sudo systemctl start aitbc-blockchain-node-production.service
|
|||||||
**Quick Start**:
|
**Quick Start**:
|
||||||
```bash
|
```bash
|
||||||
# Create marketplace service
|
# Create marketplace service
|
||||||
./aitbc-cli marketplace --action create --name "AI Service" --price 100 --wallet provider
|
./aitbc-cli market create --type ai-inference --price 100 --description "AI Service" --wallet provider
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -297,10 +297,10 @@ curl -s http://localhost:8006/health | jq .
|
|||||||
curl -s http://localhost:8006/rpc/head | jq .height
|
curl -s http://localhost:8006/rpc/head | jq .height
|
||||||
|
|
||||||
# List wallets
|
# List wallets
|
||||||
./aitbc-cli list
|
./aitbc-cli wallet list
|
||||||
|
|
||||||
# Send transaction
|
# Send transaction
|
||||||
./aitbc-cli send --from wallet1 --to wallet2 --amount 100 --password 123
|
./aitbc-cli wallet send wallet1 wallet2 100 123
|
||||||
```
|
```
|
||||||
|
|
||||||
### Operations Commands (From Operations Module)
|
### Operations Commands (From Operations Module)
|
||||||
@@ -342,10 +342,10 @@ curl -s http://localhost:9090/metrics
|
|||||||
### Marketplace Commands (From Marketplace Module)
|
### Marketplace Commands (From Marketplace Module)
|
||||||
```bash
|
```bash
|
||||||
# Create service
|
# Create service
|
||||||
./aitbc-cli marketplace --action create --name "Service" --price 100 --wallet provider
|
./aitbc-cli market create --type ai-inference --price 100 --description "Service" --wallet provider
|
||||||
|
|
||||||
# Submit AI job
|
# Submit AI job
|
||||||
./aitbc-cli ai-submit --wallet wallet --type inference --prompt "Generate image" --payment 100
|
./aitbc-cli ai submit --wallet wallet --type inference --prompt "Generate image" --payment 100
|
||||||
|
|
||||||
# Check resource status
|
# Check resource status
|
||||||
./aitbc-cli resource status
|
./aitbc-cli resource status
|
||||||
|
|||||||
@@ -95,8 +95,8 @@ openclaw agent --agent FollowerAgent --session-id test --message "Test response"
|
|||||||
**Quick Start**:
|
**Quick Start**:
|
||||||
```bash
|
```bash
|
||||||
# Test AI operations
|
# Test AI operations
|
||||||
./aitbc-cli ai-submit --wallet genesis-ops --type inference --prompt "Test AI job" --payment 100
|
./aitbc-cli ai submit --wallet genesis-ops --type inference --prompt "Test AI job" --payment 100
|
||||||
./aitbc-cli ai-ops --action status --job-id latest
|
./aitbc-cli ai status --job-id latest
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -117,8 +117,8 @@ openclaw agent --agent FollowerAgent --session-id test --message "Test response"
|
|||||||
**Quick Start**:
|
**Quick Start**:
|
||||||
```bash
|
```bash
|
||||||
# Test advanced AI operations
|
# Test advanced AI operations
|
||||||
./aitbc-cli ai-submit --wallet genesis-ops --type parallel --prompt "Complex pipeline test" --payment 500
|
./aitbc-cli ai submit --wallet genesis-ops --type parallel --prompt "Complex pipeline test" --payment 500
|
||||||
./aitbc-cli ai-submit --wallet genesis-ops --type multimodal --prompt "Multi-modal test" --payment 1000
|
./aitbc-cli ai submit --wallet genesis-ops --type multimodal --prompt "Multi-modal test" --payment 1000
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -139,7 +139,7 @@ openclaw agent --agent FollowerAgent --session-id test --message "Test response"
|
|||||||
**Quick Start**:
|
**Quick Start**:
|
||||||
```bash
|
```bash
|
||||||
# Test cross-node operations
|
# Test cross-node operations
|
||||||
ssh aitbc1 'cd /opt/aitbc && ./aitbc-cli chain'
|
ssh aitbc1 'cd /opt/aitbc && ./aitbc-cli blockchain info'
|
||||||
./aitbc-cli resource status
|
./aitbc-cli resource status
|
||||||
ssh aitbc1 'cd /opt/aitbc && ./aitbc-cli resource status'
|
ssh aitbc1 'cd /opt/aitbc && ./aitbc-cli resource status'
|
||||||
```
|
```
|
||||||
@@ -223,16 +223,16 @@ test-basic.md (foundation)
|
|||||||
### 🚀 Quick Test Commands
|
### 🚀 Quick Test Commands
|
||||||
```bash
|
```bash
|
||||||
# Basic functionality test
|
# Basic functionality test
|
||||||
./aitbc-cli --version && ./aitbc-cli chain
|
./aitbc-cli --version && ./aitbc-cli blockchain info
|
||||||
|
|
||||||
# OpenClaw agent test
|
# OpenClaw agent test
|
||||||
openclaw agent --agent GenesisAgent --session-id quick-test --message "Quick test" --thinking low
|
openclaw agent --agent GenesisAgent --session-id quick-test --message "Quick test" --thinking low
|
||||||
|
|
||||||
# AI operations test
|
# AI operations test
|
||||||
./aitbc-cli ai-submit --wallet genesis-ops --type inference --prompt "Quick test" --payment 50
|
./aitbc-cli ai submit --wallet genesis-ops --type inference --prompt "Quick test" --payment 50
|
||||||
|
|
||||||
# Cross-node test
|
# Cross-node test
|
||||||
ssh aitbc1 'cd /opt/aitbc && ./aitbc-cli chain'
|
ssh aitbc1 'cd /opt/aitbc && ./aitbc-cli blockchain info'
|
||||||
|
|
||||||
# Performance test
|
# Performance test
|
||||||
./aitbc-cli simulate blockchain --blocks 10 --transactions 50 --delay 0
|
./aitbc-cli simulate blockchain --blocks 10 --transactions 50 --delay 0
|
||||||
|
|||||||
@@ -25,77 +25,69 @@ This module covers marketplace scenario testing, GPU provider testing, transacti
|
|||||||
cd /opt/aitbc && source venv/bin/activate
|
cd /opt/aitbc && source venv/bin/activate
|
||||||
|
|
||||||
# Create marketplace service provider wallet
|
# Create marketplace service provider wallet
|
||||||
./aitbc-cli create --name marketplace-provider --password 123
|
./aitbc-cli wallet create marketplace-provider 123
|
||||||
|
|
||||||
# Fund marketplace provider wallet
|
# Fund marketplace provider wallet
|
||||||
./aitbc-cli send --from genesis-ops --to $(./aitbc-cli list | grep "marketplace-provider:" | cut -d" " -f2) --amount 10000 --password 123
|
./aitbc-cli wallet send genesis-ops $(./aitbc-cli wallet list | grep "marketplace-provider:" | cut -d" " -f2) 10000 123
|
||||||
|
|
||||||
# Create AI service provider wallet
|
# Create AI service provider wallet
|
||||||
./aitbc-cli create --name ai-service-provider --password 123
|
./aitbc-cli wallet create ai-service-provider 123
|
||||||
|
|
||||||
# Fund AI service provider wallet
|
# Fund AI service provider wallet
|
||||||
./aitbc-cli send --from genesis-ops --to $(./aitbc-cli list | grep "ai-service-provider:" | cut -d" " -f2) --amount 5000 --password 123
|
./aitbc-cli wallet send genesis-ops $(./aitbc-cli wallet list | grep "ai-service-provider:" | cut -d" " -f2) 5000 123
|
||||||
|
|
||||||
# Create GPU provider wallet
|
# Create GPU provider wallet
|
||||||
./aitbc-cli create --name gpu-provider --password 123
|
./aitbc-cli wallet create gpu-provider 123
|
||||||
|
|
||||||
# Fund GPU provider wallet
|
# Fund GPU provider wallet
|
||||||
./aitbc-cli send --from genesis-ops --to $(./aitbc-cli list | grep "gpu-provider:" | cut -d" " -f2) --amount 5000 --password 123
|
./aitbc-cli wallet send genesis-ops $(./aitbc-cli wallet list | grep "gpu-provider:" | cut -d" " -f2) 5000 123
|
||||||
```
|
```
|
||||||
|
|
||||||
### Create Marketplace Services
|
### Create Marketplace Services
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Create AI inference service
|
# Create AI inference service
|
||||||
./aitbc-cli marketplace --action create \
|
./aitbc-cli market create \
|
||||||
--name "AI Image Generation Service" \
|
|
||||||
--type ai-inference \
|
--type ai-inference \
|
||||||
--price 100 \
|
--price 100 \
|
||||||
--wallet marketplace-provider \
|
--wallet marketplace-provider \
|
||||||
--description "High-quality image generation using advanced AI models" \
|
--description "High-quality image generation using advanced AI models"
|
||||||
--parameters "resolution:512x512,style:photorealistic,quality:high"
|
|
||||||
|
|
||||||
# Create AI training service
|
# Create AI training service
|
||||||
./aitbc-cli marketplace --action create \
|
./aitbc-cli market create \
|
||||||
--name "Custom Model Training Service" \
|
|
||||||
--type ai-training \
|
--type ai-training \
|
||||||
--price 500 \
|
--price 500 \
|
||||||
--wallet ai-service-provider \
|
--wallet ai-service-provider \
|
||||||
--description "Custom AI model training on your datasets" \
|
--description "Custom AI model training on your datasets"
|
||||||
--parameters "model_type:custom,epochs:100,batch_size:32"
|
|
||||||
|
|
||||||
# Create GPU rental service
|
# Create GPU rental service
|
||||||
./aitbc-cli marketplace --action create \
|
./aitbc-cli market create \
|
||||||
--name "GPU Cloud Computing" \
|
|
||||||
--type gpu-rental \
|
--type gpu-rental \
|
||||||
--price 50 \
|
--price 50 \
|
||||||
--wallet gpu-provider \
|
--wallet gpu-provider \
|
||||||
--description "High-performance GPU rental for AI workloads" \
|
--description "High-performance GPU rental for AI workloads"
|
||||||
--parameters "gpu_type:rtx4090,memory:24gb,bandwidth:high"
|
|
||||||
|
|
||||||
# Create data processing service
|
# Create data processing service
|
||||||
./aitbc-cli marketplace --action create \
|
./aitbc-cli market create \
|
||||||
--name "Data Analysis Pipeline" \
|
|
||||||
--type data-processing \
|
--type data-processing \
|
||||||
--price 25 \
|
--price 25 \
|
||||||
--wallet marketplace-provider \
|
--wallet marketplace-provider \
|
||||||
--description "Automated data analysis and processing" \
|
--description "Automated data analysis and processing"
|
||||||
--parameters "data_format:csv,json,xml,output_format:reports"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Verify Marketplace Services
|
### Verify Marketplace Services
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# List all marketplace services
|
# List all marketplace services
|
||||||
./aitbc-cli marketplace --action list
|
./aitbc-cli market list
|
||||||
|
|
||||||
# Check service details
|
# Check service details
|
||||||
./aitbc-cli marketplace --action search --query "AI"
|
./aitbc-cli market search --query "AI"
|
||||||
|
|
||||||
# Verify provider listings
|
# Verify provider listings
|
||||||
./aitbc-cli marketplace --action my-listings --wallet marketplace-provider
|
./aitbc-cli market my-listings --wallet marketplace-provider
|
||||||
./aitbc-cli marketplace --action my-listings --wallet ai-service-provider
|
./aitbc-cli market my-listings --wallet ai-service-provider
|
||||||
./aitbc-cli marketplace --action my-listings --wallet gpu-provider
|
./aitbc-cli market my-listings --wallet gpu-provider
|
||||||
```
|
```
|
||||||
|
|
||||||
## Scenario Testing
|
## Scenario Testing
|
||||||
@@ -104,88 +96,88 @@ cd /opt/aitbc && source venv/bin/activate
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Customer creates wallet and funds it
|
# Customer creates wallet and funds it
|
||||||
./aitbc-cli create --name customer-1 --password 123
|
./aitbc-cli wallet create customer-1 123
|
||||||
./aitbc-cli send --from genesis-ops --to $(./aitbc-cli list | grep "customer-1:" | cut -d" " -f2) --amount 1000 --password 123
|
./aitbc-cli wallet send genesis-ops $(./aitbc-cli wallet list | grep "customer-1:" | cut -d" " -f2) 1000 123
|
||||||
|
|
||||||
# Customer browses marketplace
|
# Customer browses marketplace
|
||||||
./aitbc-cli marketplace --action search --query "image generation"
|
./aitbc-cli market search --query "image generation"
|
||||||
|
|
||||||
# Customer bids on AI image generation service
|
# Customer bids on AI image generation service
|
||||||
SERVICE_ID=$(./aitbc-cli marketplace --action search --query "AI Image Generation" | grep "service_id" | head -1 | cut -d" " -f2)
|
SERVICE_ID=$(./aitbc-cli market search --query "AI Image Generation" | grep "service_id" | head -1 | cut -d" " -f2)
|
||||||
./aitbc-cli marketplace --action bid --service-id $SERVICE_ID --amount 120 --wallet customer-1
|
./aitbc-cli market bid --service-id $SERVICE_ID --amount 120 --wallet customer-1
|
||||||
|
|
||||||
# Service provider accepts bid
|
# Service provider accepts bid
|
||||||
./aitbc-cli marketplace --action accept-bid --service-id $SERVICE_ID --bid-id "bid_123" --wallet marketplace-provider
|
./aitbc-cli market accept-bid --service-id $SERVICE_ID --bid-id "bid_123" --wallet marketplace-provider
|
||||||
|
|
||||||
# Customer submits AI job
|
# Customer submits AI job
|
||||||
./aitbc-cli ai-submit --wallet customer-1 --type inference \
|
./aitbc-cli ai submit --wallet customer-1 --type inference \
|
||||||
--prompt "Generate a futuristic cityscape with flying cars" \
|
--prompt "Generate a futuristic cityscape with flying cars" \
|
||||||
--payment 120 --service-id $SERVICE_ID
|
--payment 120 --service-id $SERVICE_ID
|
||||||
|
|
||||||
# Monitor job completion
|
# Monitor job completion
|
||||||
./aitbc-cli ai-status --job-id "ai_job_123"
|
./aitbc-cli ai status --job-id "ai_job_123"
|
||||||
|
|
||||||
# Customer receives results
|
# Customer receives results
|
||||||
./aitbc-cli ai-results --job-id "ai_job_123"
|
./aitbc-cli ai results --job-id "ai_job_123"
|
||||||
|
|
||||||
# Verify transaction completed
|
# Verify transaction completed
|
||||||
./aitbc-cli balance --name customer-1
|
./aitbc-cli wallet balance customer-1
|
||||||
./aitbc-cli balance --name marketplace-provider
|
./aitbc-cli wallet balance marketplace-provider
|
||||||
```
|
```
|
||||||
|
|
||||||
### Scenario 2: GPU Rental + AI Training
|
### Scenario 2: GPU Rental + AI Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Researcher creates wallet and funds it
|
# Researcher creates wallet and funds it
|
||||||
./aitbc-cli create --name researcher-1 --password 123
|
./aitbc-cli wallet create researcher-1 123
|
||||||
./aitbc-cli send --from genesis-ops --to $(./aitbc-cli list | grep "researcher-1:" | cut -d" " -f2) --amount 2000 --password 123
|
./aitbc-cli wallet send genesis-ops $(./aitbc-cli wallet list | grep "researcher-1:" | cut -d" " -f2) 2000 123
|
||||||
|
|
||||||
# Researcher rents GPU for training
|
# Researcher rents GPU for training
|
||||||
GPU_SERVICE_ID=$(./aitbc-cli marketplace --action search --query "GPU" | grep "service_id" | head -1 | cut -d" " -f2)
|
GPU_SERVICE_ID=$(./aitbc-cli market search --query "GPU" | grep "service_id" | head -1 | cut -d" " -f2)
|
||||||
./aitbc-cli marketplace --action bid --service-id $GPU_SERVICE_ID --amount 60 --wallet researcher-1
|
./aitbc-cli market bid --service-id $GPU_SERVICE_ID --amount 60 --wallet researcher-1
|
||||||
|
|
||||||
# GPU provider accepts and allocates GPU
|
# GPU provider accepts and allocates GPU
|
||||||
./aitbc-cli marketplace --action accept-bid --service-id $GPU_SERVICE_ID --bid-id "bid_456" --wallet gpu-provider
|
./aitbc-cli market accept-bid --service-id $GPU_SERVICE_ID --bid-id "bid_456" --wallet gpu-provider
|
||||||
|
|
||||||
# Researcher submits training job with allocated GPU
|
# Researcher submits training job with allocated GPU
|
||||||
./aitbc-cli ai-submit --wallet researcher-1 --type training \
|
./aitbc-cli ai submit --wallet researcher-1 --type training \
|
||||||
--model "custom-classifier" --dataset "/data/training_data.csv" \
|
--model "custom-classifier" --dataset "/data/training_data.csv" \
|
||||||
--payment 500 --gpu-allocated 1 --memory 8192
|
--payment 500 --gpu-allocated 1 --memory 8192
|
||||||
|
|
||||||
# Monitor training progress
|
# Monitor training progress
|
||||||
./aitbc-cli ai-status --job-id "ai_job_456"
|
./aitbc-cli ai status --job-id "ai_job_456"
|
||||||
|
|
||||||
# Verify GPU utilization
|
# Verify GPU utilization
|
||||||
./aitbc-cli resource status --agent-id "gpu-worker-1"
|
./aitbc-cli resource status --agent-id "gpu-worker-1"
|
||||||
|
|
||||||
# Training completes and researcher gets model
|
# Training completes and researcher gets model
|
||||||
./aitbc-cli ai-results --job-id "ai_job_456"
|
./aitbc-cli ai results --job-id "ai_job_456"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Scenario 3: Multi-Service Pipeline
|
### Scenario 3: Multi-Service Pipeline
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Enterprise creates wallet and funds it
|
# Enterprise creates wallet and funds it
|
||||||
./aitbc-cli create --name enterprise-1 --password 123
|
./aitbc-cli wallet create enterprise-1 123
|
||||||
./aitbc-cli send --from genesis-ops --to $(./aitbc-cli list | grep "enterprise-1:" | cut -d" " -f2) --amount 5000 --password 123
|
./aitbc-cli wallet send genesis-ops $(./aitbc-cli wallet list | grep "enterprise-1:" | cut -d" " -f2) 5000 123
|
||||||
|
|
||||||
# Enterprise creates data processing pipeline
|
# Enterprise creates data processing pipeline
|
||||||
DATA_SERVICE_ID=$(./aitbc-cli marketplace --action search --query "data processing" | grep "service_id" | head -1 | cut -d" " -f2)
|
DATA_SERVICE_ID=$(./aitbc-cli market search --query "data processing" | grep "service_id" | head -1 | cut -d" " -f2)
|
||||||
./aitbc-cli marketplace --action bid --service-id $DATA_SERVICE_ID --amount 30 --wallet enterprise-1
|
./aitbc-cli market bid --service-id $DATA_SERVICE_ID --amount 30 --wallet enterprise-1
|
||||||
|
|
||||||
# Data provider processes raw data
|
# Data provider processes raw data
|
||||||
./aitbc-cli marketplace --action accept-bid --service-id $DATA_SERVICE_ID --bid-id "bid_789" --wallet marketplace-provider
|
./aitbc-cli market accept-bid --service-id $DATA_SERVICE_ID --bid-id "bid_789" --wallet marketplace-provider
|
||||||
|
|
||||||
# Enterprise submits AI analysis on processed data
|
# Enterprise submits AI analysis on processed data
|
||||||
./aitbc-cli ai-submit --wallet enterprise-1 --type inference \
|
./aitbc-cli ai submit --wallet enterprise-1 --type inference \
|
||||||
--prompt "Analyze processed data for trends and patterns" \
|
--prompt "Analyze processed data for trends and patterns" \
|
||||||
--payment 200 --input-data "/data/processed_data.csv"
|
--payment 200 --input-data "/data/processed_data.csv"
|
||||||
|
|
||||||
# Results are delivered and verified
|
# Results are delivered and verified
|
||||||
./aitbc-cli ai-results --job-id "ai_job_789"
|
./aitbc-cli ai results --job-id "ai_job_789"
|
||||||
|
|
||||||
# Enterprise pays for services
|
# Enterprise pays for services
|
||||||
./aitbc-cli marketplace --action settle-payment --service-id $DATA_SERVICE_ID --amount 30 --wallet enterprise-1
|
./aitbc-cli market settle-payment --service-id $DATA_SERVICE_ID --amount 30 --wallet enterprise-1
|
||||||
```
|
```
|
||||||
|
|
||||||
## GPU Provider Testing
|
## GPU Provider Testing
|
||||||
@@ -194,7 +186,7 @@ DATA_SERVICE_ID=$(./aitbc-cli marketplace --action search --query "data processi
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Test GPU allocation and deallocation
|
# Test GPU allocation and deallocation
|
||||||
./aitbc-cli resource allocate --agent-id "gpu-worker-1" --gpu 1 --memory 8192 --duration 3600
|
./aitbc-cli resource allocate --agent-id "gpu-worker-1" --memory 8192 --duration 3600
|
||||||
|
|
||||||
# Verify GPU allocation
|
# Verify GPU allocation
|
||||||
./aitbc-cli resource status --agent-id "gpu-worker-1"
|
./aitbc-cli resource status --agent-id "gpu-worker-1"
|
||||||
@@ -207,7 +199,7 @@ DATA_SERVICE_ID=$(./aitbc-cli marketplace --action search --query "data processi
|
|||||||
|
|
||||||
# Test concurrent GPU allocations
|
# Test concurrent GPU allocations
|
||||||
for i in {1..5}; do
|
for i in {1..5}; do
|
||||||
./aitbc-cli resource allocate --agent-id "gpu-worker-$i" --gpu 1 --memory 8192 --duration 1800 &
|
./aitbc-cli resource allocate --agent-id "gpu-worker-$i" --memory 8192 --duration 1800 &
|
||||||
done
|
done
|
||||||
wait
|
wait
|
||||||
|
|
||||||
@@ -219,16 +211,16 @@ wait
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Test GPU performance with different workloads
|
# Test GPU performance with different workloads
|
||||||
./aitbc-cli ai-submit --wallet gpu-provider --type inference \
|
./aitbc-cli ai submit --wallet gpu-provider --type inference \
|
||||||
--prompt "Generate high-resolution image" --payment 100 \
|
--prompt "Generate high-resolution image" --payment 100 \
|
||||||
--gpu-allocated 1 --resolution "1024x1024"
|
--gpu-allocated 1 --resolution "1024x1024"
|
||||||
|
|
||||||
./aitbc-cli ai-submit --wallet gpu-provider --type training \
|
./aitbc-cli ai submit --wallet gpu-provider --type training \
|
||||||
--model "large-model" --dataset "/data/large_dataset.csv" --payment 500 \
|
--model "large-model" --dataset "/data/large_dataset.csv" --payment 500 \
|
||||||
--gpu-allocated 1 --batch-size 64
|
--gpu-allocated 1 --batch-size 64
|
||||||
|
|
||||||
# Monitor GPU performance metrics
|
# Monitor GPU performance metrics
|
||||||
./aitbc-cli ai-metrics --agent-id "gpu-worker-1" --period "1h"
|
./aitbc-cli ai metrics --agent-id "gpu-worker-1" --period "1h"
|
||||||
|
|
||||||
# Test GPU memory management
|
# Test GPU memory management
|
||||||
./aitbc-cli resource test --type gpu --memory-stress --duration 300
|
./aitbc-cli resource test --type gpu --memory-stress --duration 300
|
||||||
@@ -238,13 +230,13 @@ wait
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Test GPU provider revenue tracking
|
# Test GPU provider revenue tracking
|
||||||
./aitbc-cli marketplace --action revenue --wallet gpu-provider --period "24h"
|
./aitbc-cli market revenue --wallet gpu-provider --period "24h"
|
||||||
|
|
||||||
# Test GPU utilization optimization
|
# Test GPU utilization optimization
|
||||||
./aitbc-cli marketplace --action optimize --wallet gpu-provider --metric "utilization"
|
./aitbc-cli market optimize --wallet gpu-provider --metric "utilization"
|
||||||
|
|
||||||
# Test GPU pricing strategy
|
# Test GPU pricing strategy
|
||||||
./aitbc-cli marketplace --action pricing --service-id $GPU_SERVICE_ID --strategy "dynamic"
|
./aitbc-cli market pricing --service-id $GPU_SERVICE_ID --strategy "dynamic"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Transaction Tracking
|
## Transaction Tracking
|
||||||
@@ -253,45 +245,45 @@ wait
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Monitor all marketplace transactions
|
# Monitor all marketplace transactions
|
||||||
./aitbc-cli marketplace --action transactions --period "1h"
|
./aitbc-cli market transactions --period "1h"
|
||||||
|
|
||||||
# Track specific service transactions
|
# Track specific service transactions
|
||||||
./aitbc-cli marketplace --action transactions --service-id $SERVICE_ID
|
./aitbc-cli market transactions --service-id $SERVICE_ID
|
||||||
|
|
||||||
# Monitor customer transaction history
|
# Monitor customer transaction history
|
||||||
./aitbc-cli transactions --name customer-1 --limit 50
|
./aitbc-cli wallet transactions customer-1 --limit 50
|
||||||
|
|
||||||
# Track provider revenue
|
# Track provider revenue
|
||||||
./aitbc-cli marketplace --action revenue --wallet marketplace-provider --period "24h"
|
./aitbc-cli market revenue --wallet marketplace-provider --period "24h"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Transaction Verification
|
### Transaction Verification
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Verify transaction integrity
|
# Verify transaction integrity
|
||||||
./aitbc-cli transaction verify --tx-id "tx_123"
|
./aitbc-cli wallet transaction verify --tx-id "tx_123"
|
||||||
|
|
||||||
# Check transaction confirmation status
|
# Check transaction confirmation status
|
||||||
./aitbc-cli transaction status --tx-id "tx_123"
|
./aitbc-cli wallet transaction status --tx-id "tx_123"
|
||||||
|
|
||||||
# Verify marketplace settlement
|
# Verify marketplace settlement
|
||||||
./aitbc-cli marketplace --action verify-settlement --service-id $SERVICE_ID
|
./aitbc-cli market verify-settlement --service-id $SERVICE_ID
|
||||||
|
|
||||||
# Audit transaction trail
|
# Audit transaction trail
|
||||||
./aitbc-cli marketplace --action audit --period "24h"
|
./aitbc-cli market audit --period "24h"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Cross-Node Transaction Tracking
|
### Cross-Node Transaction Tracking
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Monitor transactions across both nodes
|
# Monitor transactions across both nodes
|
||||||
./aitbc-cli transactions --cross-node --period "1h"
|
./aitbc-cli wallet transactions --cross-node --period "1h"
|
||||||
|
|
||||||
# Verify transaction propagation
|
# Verify transaction propagation
|
||||||
./aitbc-cli transaction verify-propagation --tx-id "tx_123"
|
./aitbc-cli wallet transaction verify-propagation --tx-id "tx_123"
|
||||||
|
|
||||||
# Track cross-node marketplace activity
|
# Track cross-node marketplace activity
|
||||||
./aitbc-cli marketplace --action cross-node-stats --period "24h"
|
./aitbc-cli market cross-node-stats --period "24h"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Verification Procedures
|
## Verification Procedures
|
||||||
@@ -300,39 +292,39 @@ wait
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Verify service provider performance
|
# Verify service provider performance
|
||||||
./aitbc-cli marketplace --action verify-provider --wallet ai-service-provider
|
./aitbc-cli market verify-provider --wallet ai-service-provider
|
||||||
|
|
||||||
# Check service quality metrics
|
# Check service quality metrics
|
||||||
./aitbc-cli marketplace --action quality-metrics --service-id $SERVICE_ID
|
./aitbc-cli market quality-metrics --service-id $SERVICE_ID
|
||||||
|
|
||||||
# Verify customer satisfaction
|
# Verify customer satisfaction
|
||||||
./aitbc-cli marketplace --action satisfaction --wallet customer-1 --period "7d"
|
./aitbc-cli market satisfaction --wallet customer-1 --period "7d"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Compliance Verification
|
### Compliance Verification
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Verify marketplace compliance
|
# Verify marketplace compliance
|
||||||
./aitbc-cli marketplace --action compliance-check --period "24h"
|
./aitbc-cli market compliance-check --period "24h"
|
||||||
|
|
||||||
# Check regulatory compliance
|
# Check regulatory compliance
|
||||||
./aitbc-cli marketplace --action regulatory-audit --period "30d"
|
./aitbc-cli market regulatory-audit --period "30d"
|
||||||
|
|
||||||
# Verify data privacy compliance
|
# Verify data privacy compliance
|
||||||
./aitbc-cli marketplace --action privacy-audit --service-id $SERVICE_ID
|
./aitbc-cli market privacy-audit --service-id $SERVICE_ID
|
||||||
```
|
```
|
||||||
|
|
||||||
### Financial Verification
|
### Financial Verification
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Verify financial transactions
|
# Verify financial transactions
|
||||||
./aitbc-cli marketplace --action financial-audit --period "24h"
|
./aitbc-cli market financial-audit --period "24h"
|
||||||
|
|
||||||
# Check payment processing
|
# Check payment processing
|
||||||
./aitbc-cli marketplace --action payment-verify --period "1h"
|
./aitbc-cli market payment-verify --period "1h"
|
||||||
|
|
||||||
# Reconcile marketplace accounts
|
# Reconcile marketplace accounts
|
||||||
./aitbc-cli marketplace --action reconcile --period "24h"
|
./aitbc-cli market reconcile --period "24h"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Performance Testing
|
## Performance Testing
|
||||||
@@ -342,41 +334,41 @@ wait
|
|||||||
```bash
|
```bash
|
||||||
# Simulate high transaction volume
|
# Simulate high transaction volume
|
||||||
for i in {1..100}; do
|
for i in {1..100}; do
|
||||||
./aitbc-cli marketplace --action bid --service-id $SERVICE_ID --amount 100 --wallet test-wallet-$i &
|
./aitbc-cli market bid --service-id $SERVICE_ID --amount 100 --wallet test-wallet-$i &
|
||||||
done
|
done
|
||||||
wait
|
wait
|
||||||
|
|
||||||
# Monitor system performance under load
|
# Monitor system performance under load
|
||||||
./aitbc-cli marketplace --action performance-metrics --period "5m"
|
./aitbc-cli market performance-metrics --period "5m"
|
||||||
|
|
||||||
# Test marketplace scalability
|
# Test marketplace scalability
|
||||||
./aitbc-cli marketplace --action stress-test --transactions 1000 --concurrent 50
|
./aitbc-cli market stress-test --transactions 1000 --concurrent 50
|
||||||
```
|
```
|
||||||
|
|
||||||
### Latency Testing
|
### Latency Testing
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Test transaction processing latency
|
# Test transaction processing latency
|
||||||
time ./aitbc-cli marketplace --action bid --service-id $SERVICE_ID --amount 100 --wallet test-wallet
|
time ./aitbc-cli market bid --service-id $SERVICE_ID --amount 100 --wallet test-wallet
|
||||||
|
|
||||||
# Test AI job submission latency
|
# Test AI job submission latency
|
||||||
time ./aitbc-cli ai-submit --wallet test-wallet --type inference --prompt "test" --payment 50
|
time ./aitbc-cli ai submit --wallet test-wallet --type inference --prompt "test" --payment 50
|
||||||
|
|
||||||
# Monitor overall system latency
|
# Monitor overall system latency
|
||||||
./aitbc-cli marketplace --action latency-metrics --period "1h"
|
./aitbc-cli market latency-metrics --period "1h"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Throughput Testing
|
### Throughput Testing
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Test marketplace throughput
|
# Test marketplace throughput
|
||||||
./aitbc-cli marketplace --action throughput-test --duration 300 --transactions-per-second 10
|
./aitbc-cli market throughput-test --duration 300 --transactions-per-second 10
|
||||||
|
|
||||||
# Test AI job throughput
|
# Test AI job throughput
|
||||||
./aitbc-cli marketplace --action ai-throughput-test --duration 300 --jobs-per-minute 5
|
./aitbc-cli market ai-throughput-test --duration 300 --jobs-per-minute 5
|
||||||
|
|
||||||
# Monitor system capacity
|
# Monitor system capacity
|
||||||
./aitbc-cli marketplace --action capacity-metrics --period "24h"
|
./aitbc-cli market capacity-metrics --period "24h"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Troubleshooting Marketplace Issues
|
## Troubleshooting Marketplace Issues
|
||||||
@@ -395,16 +387,16 @@ time ./aitbc-cli ai-submit --wallet test-wallet --type inference --prompt "test"
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Diagnose marketplace connectivity
|
# Diagnose marketplace connectivity
|
||||||
./aitbc-cli marketplace --action connectivity-test
|
./aitbc-cli market connectivity-test
|
||||||
|
|
||||||
# Check marketplace service health
|
# Check marketplace service health
|
||||||
./aitbc-cli marketplace --action health-check
|
./aitbc-cli market health-check
|
||||||
|
|
||||||
# Verify marketplace data integrity
|
# Verify marketplace data integrity
|
||||||
./aitbc-cli marketplace --action integrity-check
|
./aitbc-cli market integrity-check
|
||||||
|
|
||||||
# Debug marketplace transactions
|
# Debug marketplace transactions
|
||||||
./aitbc-cli marketplace --action debug --transaction-id "tx_123"
|
./aitbc-cli market debug --transaction-id "tx_123"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Automation Scripts
|
## Automation Scripts
|
||||||
@@ -418,31 +410,30 @@ time ./aitbc-cli ai-submit --wallet test-wallet --type inference --prompt "test"
|
|||||||
echo "Starting automated marketplace testing..."
|
echo "Starting automated marketplace testing..."
|
||||||
|
|
||||||
# Create test wallets
|
# Create test wallets
|
||||||
./aitbc-cli create --name test-customer --password 123
|
./aitbc-cli wallet create test-customer 123
|
||||||
./aitbc-cli create --name test-provider --password 123
|
./aitbc-cli wallet create test-provider 123
|
||||||
|
|
||||||
# Fund test wallets
|
# Fund test wallets
|
||||||
CUSTOMER_ADDR=$(./aitbc-cli list | grep "test-customer:" | cut -d" " -f2)
|
CUSTOMER_ADDR=$(./aitbc-cli wallet list | grep "test-customer:" | cut -d" " -f2)
|
||||||
PROVIDER_ADDR=$(./aitbc-cli list | grep "test-provider:" | cut -d" " -f2)
|
PROVIDER_ADDR=$(./aitbc-cli wallet list | grep "test-provider:" | cut -d" " -f2)
|
||||||
|
|
||||||
./aitbc-cli send --from genesis-ops --to $CUSTOMER_ADDR --amount 1000 --password 123
|
./aitbc-cli wallet send genesis-ops $CUSTOMER_ADDR 1000 123
|
||||||
./aitbc-cli send --from genesis-ops --to $PROVIDER_ADDR --amount 1000 --password 123
|
./aitbc-cli wallet send genesis-ops $PROVIDER_ADDR 1000 123
|
||||||
|
|
||||||
# Create test service
|
# Create test service
|
||||||
./aitbc-cli marketplace --action create \
|
./aitbc-cli market create \
|
||||||
--name "Test AI Service" \
|
|
||||||
--type ai-inference \
|
--type ai-inference \
|
||||||
--price 50 \
|
--price 50 \
|
||||||
--wallet test-provider \
|
--wallet test-provider \
|
||||||
--description "Automated test service"
|
--description "Test AI Service"
|
||||||
|
|
||||||
# Test complete workflow
|
# Test complete workflow
|
||||||
SERVICE_ID=$(./aitbc-cli marketplace --action list | grep "Test AI Service" | grep "service_id" | cut -d" " -f2)
|
SERVICE_ID=$(./aitbc-cli market list | grep "Test AI Service" | grep "service_id" | cut -d" " -f2)
|
||||||
|
|
||||||
./aitbc-cli marketplace --action bid --service-id $SERVICE_ID --amount 60 --wallet test-customer
|
./aitbc-cli market bid --service-id $SERVICE_ID --amount 60 --wallet test-customer
|
||||||
./aitbc-cli marketplace --action accept-bid --service-id $SERVICE_ID --bid-id "test_bid" --wallet test-provider
|
./aitbc-cli market accept-bid --service-id $SERVICE_ID --bid-id "test_bid" --wallet test-provider
|
||||||
|
|
||||||
./aitbc-cli ai-submit --wallet test-customer --type inference --prompt "test image" --payment 60
|
./aitbc-cli ai submit --wallet test-customer --type inference --prompt "test image" --payment 60
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
echo "Test completed successfully!"
|
echo "Test completed successfully!"
|
||||||
@@ -458,9 +449,9 @@ while true; do
|
|||||||
TIMESTAMP=$(date +%Y-%m-%d_%H:%M:%S)
|
TIMESTAMP=$(date +%Y-%m-%d_%H:%M:%S)
|
||||||
|
|
||||||
# Collect metrics
|
# Collect metrics
|
||||||
ACTIVE_SERVICES=$(./aitbc-cli marketplace --action list | grep -c "service_id")
|
ACTIVE_SERVICES=$(./aitbc-cli market list | grep -c "service_id")
|
||||||
PENDING_BIDS=$(./aitbc-cli marketplace --action pending-bids | grep -c "bid_id")
|
PENDING_BIDS=$(./aitbc-cli market pending-bids | grep -c "bid_id")
|
||||||
TOTAL_VOLUME=$(./aitbc-cli marketplace --action volume --period "1h")
|
TOTAL_VOLUME=$(./aitbc-cli market volume --period "1h")
|
||||||
|
|
||||||
# Log metrics
|
# Log metrics
|
||||||
echo "$TIMESTAMP,services:$ACTIVE_SERVICES,bids:$PENDING_BIDS,volume:$TOTAL_VOLUME" >> /var/log/aitbc/marketplace_performance.log
|
echo "$TIMESTAMP,services:$ACTIVE_SERVICES,bids:$PENDING_BIDS,volume:$TOTAL_VOLUME" >> /var/log/aitbc/marketplace_performance.log
|
||||||
|
|||||||
@@ -53,18 +53,18 @@ watch -n 10 'curl -s http://localhost:8006/rpc/head | jq "{height: .height, time
|
|||||||
```bash
|
```bash
|
||||||
# Check wallet balances
|
# Check wallet balances
|
||||||
cd /opt/aitbc && source venv/bin/activate
|
cd /opt/aitbc && source venv/bin/activate
|
||||||
./aitbc-cli balance --name genesis-ops
|
./aitbc-cli wallet balance genesis-ops
|
||||||
./aitbc-cli balance --name user-wallet
|
./aitbc-cli wallet balance user-wallet
|
||||||
|
|
||||||
# Send transactions
|
# Send transactions
|
||||||
./aitbc-cli send --from genesis-ops --to user-wallet --amount 100 --password 123
|
./aitbc-cli wallet send genesis-ops user-wallet 100 123
|
||||||
|
|
||||||
# Check transaction history
|
# Check transaction history
|
||||||
./aitbc-cli transactions --name genesis-ops --limit 10
|
./aitbc-cli wallet transactions genesis-ops --limit 10
|
||||||
|
|
||||||
# Cross-node transaction
|
# Cross-node transaction
|
||||||
FOLLOWER_ADDR=$(ssh aitbc1 'cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli list | grep "follower-ops:" | cut -d" " -f2')
|
FOLLOWER_ADDR=$(ssh aitbc1 'cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli wallet list | grep "follower-ops:" | cut -d" " -f2')
|
||||||
./aitbc-cli send --from genesis-ops --to $FOLLOWER_ADDR --amount 50 --password 123
|
./aitbc-cli wallet send genesis-ops $FOLLOWER_ADDR 50 123
|
||||||
```
|
```
|
||||||
|
|
||||||
## Health Monitoring
|
## Health Monitoring
|
||||||
@@ -216,7 +216,7 @@ curl -s http://localhost:8006/rpc/head | jq .height
|
|||||||
sudo grep "Failed password" /var/log/auth.log | tail -10
|
sudo grep "Failed password" /var/log/auth.log | tail -10
|
||||||
|
|
||||||
# Monitor blockchain for suspicious activity
|
# Monitor blockchain for suspicious activity
|
||||||
./aitbc-cli transactions --name genesis-ops --limit 20 | grep -E "(large|unusual)"
|
./aitbc-cli wallet transactions genesis-ops --limit 20 | grep -E "(large|unusual)"
|
||||||
|
|
||||||
# Check file permissions
|
# Check file permissions
|
||||||
ls -la /var/lib/aitbc/
|
ls -la /var/lib/aitbc/
|
||||||
|
|||||||
@@ -111,17 +111,17 @@ echo "Height difference: $((FOLLOWER_HEIGHT - GENESIS_HEIGHT))"
|
|||||||
```bash
|
```bash
|
||||||
# List all wallets
|
# List all wallets
|
||||||
cd /opt/aitbc && source venv/bin/activate
|
cd /opt/aitbc && source venv/bin/activate
|
||||||
./aitbc-cli list
|
./aitbc-cli wallet list
|
||||||
|
|
||||||
# Check specific wallet balance
|
# Check specific wallet balance
|
||||||
./aitbc-cli balance --name genesis-ops
|
./aitbc-cli wallet balance genesis-ops
|
||||||
./aitbc-cli balance --name follower-ops
|
./aitbc-cli wallet balance follower-ops
|
||||||
|
|
||||||
# Verify wallet addresses
|
# Verify wallet addresses
|
||||||
./aitbc-cli list | grep -E "(genesis-ops|follower-ops)"
|
./aitbc-cli wallet list | grep -E "(genesis-ops|follower-ops)"
|
||||||
|
|
||||||
# Test wallet operations
|
# Test wallet operations
|
||||||
./aitbc-cli send --from genesis-ops --to follower-ops --amount 10 --password 123
|
./aitbc-cli wallet send genesis-ops follower-ops 10 123
|
||||||
```
|
```
|
||||||
|
|
||||||
### Network Verification
|
### Network Verification
|
||||||
@@ -133,7 +133,7 @@ ssh aitbc1 'ping -c 3 localhost'
|
|||||||
|
|
||||||
# Test RPC endpoints
|
# Test RPC endpoints
|
||||||
curl -s http://localhost:8006/rpc/head > /dev/null && echo "Local RPC OK"
|
curl -s http://localhost:8006/rpc/head > /dev/null && echo "Local RPC OK"
|
||||||
ssh aitbc1 'curl -s http://localhost:8006/rpc/head > /dev/null && echo "Remote RPC OK"'
|
ssh aitbc1 'curl -s http://localhost:8007/rpc/head > /dev/null && echo "Remote RPC OK"'
|
||||||
|
|
||||||
# Test P2P connectivity
|
# Test P2P connectivity
|
||||||
telnet aitbc1 7070
|
telnet aitbc1 7070
|
||||||
@@ -146,16 +146,16 @@ ping -c 5 aitbc1 | tail -1
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Check AI services
|
# Check AI services
|
||||||
./aitbc-cli marketplace --action list
|
./aitbc-cli market list
|
||||||
|
|
||||||
# Test AI job submission
|
# Test AI job submission
|
||||||
./aitbc-cli ai-submit --wallet genesis-ops --type inference --prompt "test" --payment 10
|
./aitbc-cli ai submit --wallet genesis-ops --type inference --prompt "test" --payment 10
|
||||||
|
|
||||||
# Verify resource allocation
|
# Verify resource allocation
|
||||||
./aitbc-cli resource status
|
./aitbc-cli resource status
|
||||||
|
|
||||||
# Check AI job status
|
# Check AI job status
|
||||||
./aitbc-cli ai-status --job-id "latest"
|
./aitbc-cli ai status --job-id "latest"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Smart Contract Verification
|
### Smart Contract Verification
|
||||||
@@ -263,16 +263,16 @@ Redis Service (for gossip)
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Quick health check
|
# Quick health check
|
||||||
./aitbc-cli chain && ./aitbc-cli network
|
./aitbc-cli blockchain info && ./aitbc-cli network status
|
||||||
|
|
||||||
# Service status
|
# Service status
|
||||||
systemctl status aitbc-blockchain-node.service aitbc-blockchain-rpc.service
|
systemctl status aitbc-blockchain-node.service aitbc-blockchain-rpc.service
|
||||||
|
|
||||||
# Cross-node sync check
|
# Cross-node sync check
|
||||||
curl -s http://localhost:8006/rpc/head | jq .height && ssh aitbc1 'curl -s http://localhost:8006/rpc/head | jq .height'
|
curl -s http://localhost:8006/rpc/head | jq .height && ssh aitbc1 'curl -s http://localhost:8007/rpc/head | jq .height'
|
||||||
|
|
||||||
# Wallet balance check
|
# Wallet balance check
|
||||||
./aitbc-cli balance --name genesis-ops
|
./aitbc-cli wallet balance genesis-ops
|
||||||
```
|
```
|
||||||
|
|
||||||
### Troubleshooting
|
### Troubleshooting
|
||||||
@@ -347,20 +347,20 @@ SESSION_ID="task-$(date +%s)"
|
|||||||
openclaw agent --agent main --session-id $SESSION_ID --message "Task description"
|
openclaw agent --agent main --session-id $SESSION_ID --message "Task description"
|
||||||
|
|
||||||
# Always verify transactions
|
# Always verify transactions
|
||||||
./aitbc-cli transactions --name wallet-name --limit 5
|
./aitbc-cli wallet transactions wallet-name --limit 5
|
||||||
|
|
||||||
# Monitor cross-node synchronization
|
# Monitor cross-node synchronization
|
||||||
watch -n 10 'curl -s http://localhost:8006/rpc/head | jq .height && ssh aitbc1 "curl -s http://localhost:8006/rpc/head | jq .height"'
|
watch -n 10 'curl -s http://localhost:8006/rpc/head | jq .height && ssh aitbc1 "curl -s http://localhost:8007/rpc/head | jq .height"'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Development Best Practices
|
### Development Best Practices
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Test in development environment first
|
# Test in development environment first
|
||||||
./aitbc-cli send --from test-wallet --to test-wallet --amount 1 --password test
|
./aitbc-cli wallet send test-wallet test-wallet 1 test
|
||||||
|
|
||||||
# Use meaningful wallet names
|
# Use meaningful wallet names
|
||||||
./aitbc-cli create --name "genesis-operations" --password "strong_password"
|
./aitbc-cli wallet create "genesis-operations" "strong_password"
|
||||||
|
|
||||||
# Document all configuration changes
|
# Document all configuration changes
|
||||||
git add /etc/aitbc/.env
|
git add /etc/aitbc/.env
|
||||||
@@ -424,14 +424,14 @@ sudo systemctl restart aitbc-blockchain-node.service
|
|||||||
**Problem**: Wallet balance incorrect
|
**Problem**: Wallet balance incorrect
|
||||||
```bash
|
```bash
|
||||||
# Check correct node
|
# Check correct node
|
||||||
./aitbc-cli balance --name wallet-name
|
./aitbc-cli wallet balance wallet-name
|
||||||
ssh aitbc1 './aitbc-cli balance --name wallet-name'
|
ssh aitbc1 './aitbc-cli wallet balance wallet-name'
|
||||||
|
|
||||||
# Verify wallet address
|
# Verify wallet address
|
||||||
./aitbc-cli list | grep "wallet-name"
|
./aitbc-cli wallet list | grep "wallet-name"
|
||||||
|
|
||||||
# Check transaction history
|
# Check transaction history
|
||||||
./aitbc-cli transactions --name wallet-name --limit 10
|
./aitbc-cli wallet transactions wallet-name --limit 10
|
||||||
```
|
```
|
||||||
|
|
||||||
#### AI Operations Issues
|
#### AI Operations Issues
|
||||||
@@ -439,16 +439,16 @@ ssh aitbc1 './aitbc-cli balance --name wallet-name'
|
|||||||
**Problem**: AI jobs not processing
|
**Problem**: AI jobs not processing
|
||||||
```bash
|
```bash
|
||||||
# Check AI services
|
# Check AI services
|
||||||
./aitbc-cli marketplace --action list
|
./aitbc-cli market list
|
||||||
|
|
||||||
# Check resource allocation
|
# Check resource allocation
|
||||||
./aitbc-cli resource status
|
./aitbc-cli resource status
|
||||||
|
|
||||||
# Check job status
|
# Check AI job status
|
||||||
./aitbc-cli ai-status --job-id "job_id"
|
./aitbc-cli ai status --job-id "job_id"
|
||||||
|
|
||||||
# Verify wallet balance
|
# Verify wallet balance
|
||||||
./aitbc-cli balance --name wallet-name
|
./aitbc-cli wallet balance wallet-name
|
||||||
```
|
```
|
||||||
|
|
||||||
### Emergency Procedures
|
### Emergency Procedures
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ ssh aitbc1 '/opt/aitbc/scripts/workflow/03_follower_node_setup.sh'
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Monitor sync progress on both nodes
|
# Monitor sync progress on both nodes
|
||||||
watch -n 5 'echo "=== Genesis Node ===" && curl -s http://localhost:8006/rpc/head | jq .height && echo "=== Follower Node ===" && ssh aitbc1 "curl -s http://localhost:8006/rpc/head | jq .height"'
|
watch -n 5 'echo "=== Genesis Node ===" && curl -s http://localhost:8006/rpc/head | jq .height && echo "=== Follower Node ===" && ssh aitbc1 "curl -s http://localhost:8007/rpc/head | jq .height"'
|
||||||
```
|
```
|
||||||
|
|
||||||
### 5. Basic Wallet Operations
|
### 5. Basic Wallet Operations
|
||||||
@@ -113,30 +113,30 @@ watch -n 5 'echo "=== Genesis Node ===" && curl -s http://localhost:8006/rpc/hea
|
|||||||
cd /opt/aitbc && source venv/bin/activate
|
cd /opt/aitbc && source venv/bin/activate
|
||||||
|
|
||||||
# Create genesis operations wallet
|
# Create genesis operations wallet
|
||||||
./aitbc-cli create --name genesis-ops --password 123
|
./aitbc-cli wallet create genesis-ops 123
|
||||||
|
|
||||||
# Create user wallet
|
# Create user wallet
|
||||||
./aitbc-cli create --name user-wallet --password 123
|
./aitbc-cli wallet create user-wallet 123
|
||||||
|
|
||||||
# List wallets
|
# List wallets
|
||||||
./aitbc-cli list
|
./aitbc-cli wallet list
|
||||||
|
|
||||||
# Check balances
|
# Check balances
|
||||||
./aitbc-cli balance --name genesis-ops
|
./aitbc-cli wallet balance genesis-ops
|
||||||
./aitbc-cli balance --name user-wallet
|
./aitbc-cli wallet balance user-wallet
|
||||||
```
|
```
|
||||||
|
|
||||||
### 6. Cross-Node Transaction Test
|
### 6. Cross-Node Transaction Test
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Get follower node wallet address
|
# Get follower node wallet address
|
||||||
FOLLOWER_WALLET_ADDR=$(ssh aitbc1 'cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli create --name follower-ops --password 123 | grep "Address:" | cut -d" " -f2')
|
FOLLOWER_WALLET_ADDR=$(ssh aitbc1 'cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli wallet create follower-ops 123 | grep "Address:" | cut -d" " -f2')
|
||||||
|
|
||||||
# Send transaction from genesis to follower
|
# Send transaction from genesis to follower
|
||||||
./aitbc-cli send --from genesis-ops --to $FOLLOWER_WALLET_ADDR --amount 1000 --password 123
|
./aitbc-cli wallet send genesis-ops $FOLLOWER_WALLET_ADDR 1000 123
|
||||||
|
|
||||||
# Verify transaction on follower node
|
# Verify transaction on follower node
|
||||||
ssh aitbc1 'cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli balance --name follower-ops'
|
ssh aitbc1 'cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli wallet balance follower-ops'
|
||||||
```
|
```
|
||||||
|
|
||||||
## Verification Commands
|
## Verification Commands
|
||||||
@@ -148,15 +148,15 @@ ssh aitbc1 'systemctl status aitbc-blockchain-node.service aitbc-blockchain-rpc.
|
|||||||
|
|
||||||
# Check blockchain heights match
|
# Check blockchain heights match
|
||||||
curl -s http://localhost:8006/rpc/head | jq .height
|
curl -s http://localhost:8006/rpc/head | jq .height
|
||||||
ssh aitbc1 'curl -s http://localhost:8006/rpc/head | jq .height'
|
ssh aitbc1 'curl -s http://localhost:8007/rpc/head | jq .height'
|
||||||
|
|
||||||
# Check network connectivity
|
# Check network connectivity
|
||||||
ping -c 3 aitbc1
|
ping -c 3 aitbc1
|
||||||
ssh aitbc1 'ping -c 3 localhost'
|
ssh aitbc1 'ping -c 3 localhost'
|
||||||
|
|
||||||
# Verify wallet creation
|
# Verify wallet creation
|
||||||
./aitbc-cli list
|
./aitbc-cli wallet list
|
||||||
ssh aitbc1 'cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli list'
|
ssh aitbc1 'cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli wallet list'
|
||||||
```
|
```
|
||||||
|
|
||||||
## Troubleshooting Core Setup
|
## Troubleshooting Core Setup
|
||||||
|
|||||||
@@ -33,25 +33,25 @@ openclaw agent --agent main --session-id $SESSION_ID --message "Report progress"
|
|||||||
|
|
||||||
# AITBC CLI — always from /opt/aitbc with venv
|
# AITBC CLI — always from /opt/aitbc with venv
|
||||||
cd /opt/aitbc && source venv/bin/activate
|
cd /opt/aitbc && source venv/bin/activate
|
||||||
./aitbc-cli create --name wallet-name
|
./aitbc-cli wallet create wallet-name
|
||||||
./aitbc-cli list
|
./aitbc-cli wallet list
|
||||||
./aitbc-cli balance --name wallet-name
|
./aitbc-cli wallet balance wallet-name
|
||||||
./aitbc-cli send --from wallet1 --to address --amount 100 --password pass
|
./aitbc-cli wallet send wallet1 address 100 pass
|
||||||
./aitbc-cli chain
|
./aitbc-cli blockchain info
|
||||||
./aitbc-cli network
|
./aitbc-cli network status
|
||||||
|
|
||||||
# AI Operations (NEW)
|
# AI Operations (NEW)
|
||||||
./aitbc-cli ai-submit --wallet wallet --type inference --prompt "Generate image" --payment 100
|
./aitbc-cli ai submit --wallet wallet --type inference --prompt "Generate image" --payment 100
|
||||||
./aitbc-cli agent create --name ai-agent --description "AI agent"
|
./aitbc-cli agent create --name ai-agent --description "AI agent"
|
||||||
./aitbc-cli resource allocate --agent-id ai-agent --gpu 1 --memory 8192 --duration 3600
|
./aitbc-cli resource allocate --agent-id ai-agent --memory 8192 --duration 3600
|
||||||
./aitbc-cli marketplace --action create --name "AI Service" --price 50 --wallet wallet
|
./aitbc-cli market create --type ai-inference --price 50 --description "AI Service" --wallet wallet
|
||||||
|
|
||||||
# Cross-node — always activate venv on remote
|
# Cross-node — always activate venv on remote
|
||||||
ssh aitbc1 'cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli list'
|
ssh aitbc1 'cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli wallet list'
|
||||||
|
|
||||||
# RPC checks
|
# RPC checks
|
||||||
curl -s http://localhost:8006/rpc/head | jq '.height'
|
curl -s http://localhost:8006/rpc/head | jq '.height'
|
||||||
ssh aitbc1 'curl -s http://localhost:8006/rpc/head | jq .height'
|
ssh aitbc1 'curl -s http://localhost:8007/rpc/head | jq .height'
|
||||||
|
|
||||||
# Smart Contract Messaging (NEW)
|
# Smart Contract Messaging (NEW)
|
||||||
curl -X POST http://localhost:8006/rpc/messaging/topics/create \
|
curl -X POST http://localhost:8006/rpc/messaging/topics/create \
|
||||||
@@ -219,11 +219,11 @@ openclaw agent --agent main --message "Teach me AITBC Agent Messaging Contract f
|
|||||||
```bash
|
```bash
|
||||||
# Blockchain height (both nodes)
|
# Blockchain height (both nodes)
|
||||||
curl -s http://localhost:8006/rpc/head | jq '.height'
|
curl -s http://localhost:8006/rpc/head | jq '.height'
|
||||||
ssh aitbc1 'curl -s http://localhost:8006/rpc/head | jq .height'
|
ssh aitbc1 'curl -s http://localhost:8007/rpc/head | jq .height'
|
||||||
|
|
||||||
# Wallets
|
# Wallets
|
||||||
cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli list
|
cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli wallet list
|
||||||
ssh aitbc1 'cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli list'
|
ssh aitbc1 'cd /opt/aitbc && source venv/bin/activate && ./aitbc-cli wallet list'
|
||||||
|
|
||||||
# Services
|
# Services
|
||||||
systemctl is-active aitbc-blockchain-{node,rpc}.service
|
systemctl is-active aitbc-blockchain-{node,rpc}.service
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from .poa import PoAProposer, ProposerConfig, CircuitBreaker
|
|
||||||
|
|
||||||
__all__ = ["PoAProposer", "ProposerConfig", "CircuitBreaker"]
|
|
||||||
@@ -1,345 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block, Account
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
await self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
await self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool and include transactions
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
from ..models import Transaction, Account
|
|
||||||
mempool = get_mempool()
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
# Pull transactions from mempool
|
|
||||||
max_txs = self._config.max_txs_per_block
|
|
||||||
max_bytes = self._config.max_block_size_bytes
|
|
||||||
pending_txs = mempool.drain(max_txs, max_bytes, self._config.chain_id)
|
|
||||||
self._logger.info(f"[PROPOSE] drained {len(pending_txs)} txs from mempool, chain={self._config.chain_id}")
|
|
||||||
|
|
||||||
# Process transactions and update balances
|
|
||||||
processed_txs = []
|
|
||||||
for tx in pending_txs:
|
|
||||||
try:
|
|
||||||
# Parse transaction data
|
|
||||||
tx_data = tx.content
|
|
||||||
sender = tx_data.get("from")
|
|
||||||
recipient = tx_data.get("to")
|
|
||||||
value = tx_data.get("amount", 0)
|
|
||||||
fee = tx_data.get("fee", 0)
|
|
||||||
|
|
||||||
if not sender or not recipient:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get sender account
|
|
||||||
sender_account = session.get(Account, (self._config.chain_id, sender))
|
|
||||||
if not sender_account:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check sufficient balance
|
|
||||||
total_cost = value + fee
|
|
||||||
if sender_account.balance < total_cost:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get or create recipient account
|
|
||||||
recipient_account = session.get(Account, (self._config.chain_id, recipient))
|
|
||||||
if not recipient_account:
|
|
||||||
recipient_account = Account(chain_id=self._config.chain_id, address=recipient, balance=0, nonce=0)
|
|
||||||
session.add(recipient_account)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
# Update balances
|
|
||||||
sender_account.balance -= total_cost
|
|
||||||
sender_account.nonce += 1
|
|
||||||
recipient_account.balance += value
|
|
||||||
|
|
||||||
# Create transaction record
|
|
||||||
transaction = Transaction(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
tx_hash=tx.tx_hash,
|
|
||||||
sender=sender,
|
|
||||||
recipient=recipient,
|
|
||||||
payload=tx_data,
|
|
||||||
value=value,
|
|
||||||
fee=fee,
|
|
||||||
nonce=sender_account.nonce - 1,
|
|
||||||
timestamp=timestamp,
|
|
||||||
block_height=next_height,
|
|
||||||
status="confirmed"
|
|
||||||
)
|
|
||||||
session.add(transaction)
|
|
||||||
processed_txs.append(tx)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.warning(f"Failed to process transaction {tx.tx_hash}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Compute block hash with transaction data
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp, processed_txs)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=len(processed_txs),
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
tx_list = [tx.content for tx in processed_txs] if processed_txs else []
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
"transactions": tx_list,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer=self._config.proposer_id, # Use configured proposer as genesis proposer
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Initialize accounts from genesis allocations file (if present)
|
|
||||||
await self._initialize_genesis_allocations(session)
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _initialize_genesis_allocations(self, session: Session) -> None:
|
|
||||||
"""Create Account entries from the genesis allocations file."""
|
|
||||||
# Use standardized data directory from configuration
|
|
||||||
from ..config import settings
|
|
||||||
|
|
||||||
genesis_paths = [
|
|
||||||
Path(f"/var/lib/aitbc/data/{self._config.chain_id}/genesis.json"), # Standard location
|
|
||||||
]
|
|
||||||
|
|
||||||
genesis_path = None
|
|
||||||
for path in genesis_paths:
|
|
||||||
if path.exists():
|
|
||||||
genesis_path = path
|
|
||||||
break
|
|
||||||
|
|
||||||
if not genesis_path:
|
|
||||||
self._logger.warning("Genesis allocations file not found; skipping account initialization", extra={"paths": str(genesis_paths)})
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(genesis_path) as f:
|
|
||||||
genesis_data = json.load(f)
|
|
||||||
|
|
||||||
allocations = genesis_data.get("allocations", [])
|
|
||||||
created = 0
|
|
||||||
for alloc in allocations:
|
|
||||||
addr = alloc["address"]
|
|
||||||
balance = int(alloc["balance"])
|
|
||||||
nonce = int(alloc.get("nonce", 0))
|
|
||||||
# Check if account already exists (idempotent)
|
|
||||||
acct = session.get(Account, (self._config.chain_id, addr))
|
|
||||||
if acct is None:
|
|
||||||
acct = Account(chain_id=self._config.chain_id, address=addr, balance=balance, nonce=nonce)
|
|
||||||
session.add(acct)
|
|
||||||
created += 1
|
|
||||||
session.commit()
|
|
||||||
self._logger.info("Initialized genesis accounts", extra={"count": created, "total": len(allocations), "path": str(genesis_path)})
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime, transactions: list = None) -> str:
|
|
||||||
# Include transaction hashes in block hash computation
|
|
||||||
tx_hashes = []
|
|
||||||
if transactions:
|
|
||||||
tx_hashes = [tx.tx_hash for tx in transactions]
|
|
||||||
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}|{'|'.join(sorted(tx_hashes))}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,229 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
if get_mempool().size(self._config.chain_id) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer="genesis",
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime) -> str:
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
--- apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
+++ apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
@@ -101,7 +101,7 @@
|
|
||||||
# Wait for interval before proposing next block
|
|
||||||
await asyncio.sleep(self.config.interval_seconds)
|
|
||||||
|
|
||||||
- self._propose_block()
|
|
||||||
+ await self._propose_block()
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from .poa import PoAProposer, ProposerConfig, CircuitBreaker
|
|
||||||
|
|
||||||
__all__ = ["PoAProposer", "ProposerConfig", "CircuitBreaker"]
|
|
||||||
@@ -1,210 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Key Management
|
|
||||||
Handles cryptographic key operations for validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ValidatorKeyPair:
|
|
||||||
address: str
|
|
||||||
private_key_pem: str
|
|
||||||
public_key_pem: str
|
|
||||||
created_at: float
|
|
||||||
last_rotated: float
|
|
||||||
|
|
||||||
class KeyManager:
|
|
||||||
"""Manages validator cryptographic keys"""
|
|
||||||
|
|
||||||
def __init__(self, keys_dir: str = "/opt/aitbc/keys"):
|
|
||||||
self.keys_dir = keys_dir
|
|
||||||
self.key_pairs: Dict[str, ValidatorKeyPair] = {}
|
|
||||||
self._ensure_keys_directory()
|
|
||||||
self._load_existing_keys()
|
|
||||||
|
|
||||||
def _ensure_keys_directory(self):
|
|
||||||
"""Ensure keys directory exists and has proper permissions"""
|
|
||||||
os.makedirs(self.keys_dir, mode=0o700, exist_ok=True)
|
|
||||||
|
|
||||||
def _load_existing_keys(self):
|
|
||||||
"""Load existing key pairs from disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
if os.path.exists(keys_file):
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'r') as f:
|
|
||||||
keys_data = json.load(f)
|
|
||||||
|
|
||||||
for address, key_data in keys_data.items():
|
|
||||||
self.key_pairs[address] = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=key_data['private_key_pem'],
|
|
||||||
public_key_pem=key_data['public_key_pem'],
|
|
||||||
created_at=key_data['created_at'],
|
|
||||||
last_rotated=key_data['last_rotated']
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading keys: {e}")
|
|
||||||
|
|
||||||
def generate_key_pair(self, address: str) -> ValidatorKeyPair:
|
|
||||||
"""Generate new RSA key pair for validator"""
|
|
||||||
# Generate private key
|
|
||||||
private_key = rsa.generate_private_key(
|
|
||||||
public_exponent=65537,
|
|
||||||
key_size=2048,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Serialize private key
|
|
||||||
private_key_pem = private_key.private_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=PrivateFormat.PKCS8,
|
|
||||||
encryption_algorithm=NoEncryption()
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Get public key
|
|
||||||
public_key = private_key.public_key()
|
|
||||||
public_key_pem = public_key.public_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Create key pair object
|
|
||||||
current_time = time.time()
|
|
||||||
key_pair = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=private_key_pem,
|
|
||||||
public_key_pem=public_key_pem,
|
|
||||||
created_at=current_time,
|
|
||||||
last_rotated=current_time
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store key pair
|
|
||||||
self.key_pairs[address] = key_pair
|
|
||||||
self._save_keys()
|
|
||||||
|
|
||||||
return key_pair
|
|
||||||
|
|
||||||
def get_key_pair(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Get key pair for validator"""
|
|
||||||
return self.key_pairs.get(address)
|
|
||||||
|
|
||||||
def rotate_key(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Rotate validator keys"""
|
|
||||||
if address not in self.key_pairs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Generate new key pair
|
|
||||||
new_key_pair = self.generate_key_pair(address)
|
|
||||||
|
|
||||||
# Update rotation time
|
|
||||||
new_key_pair.created_at = self.key_pairs[address].created_at
|
|
||||||
new_key_pair.last_rotated = time.time()
|
|
||||||
|
|
||||||
self._save_keys()
|
|
||||||
return new_key_pair
|
|
||||||
|
|
||||||
def sign_message(self, address: str, message: str) -> Optional[str]:
|
|
||||||
"""Sign message with validator private key"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load private key from PEM
|
|
||||||
private_key = serialization.load_pem_private_key(
|
|
||||||
key_pair.private_key_pem.encode(),
|
|
||||||
password=None,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sign message
|
|
||||||
signature = private_key.sign(
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return signature.hex()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error signing message: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def verify_signature(self, address: str, message: str, signature: str) -> bool:
|
|
||||||
"""Verify message signature"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load public key from PEM
|
|
||||||
public_key = serialization.load_pem_public_key(
|
|
||||||
key_pair.public_key_pem.encode(),
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify signature
|
|
||||||
public_key.verify(
|
|
||||||
bytes.fromhex(signature),
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error verifying signature: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_public_key_pem(self, address: str) -> Optional[str]:
|
|
||||||
"""Get public key PEM for validator"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
return key_pair.public_key_pem if key_pair else None
|
|
||||||
|
|
||||||
def _save_keys(self):
|
|
||||||
"""Save key pairs to disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
keys_data = {}
|
|
||||||
for address, key_pair in self.key_pairs.items():
|
|
||||||
keys_data[address] = {
|
|
||||||
'private_key_pem': key_pair.private_key_pem,
|
|
||||||
'public_key_pem': key_pair.public_key_pem,
|
|
||||||
'created_at': key_pair.created_at,
|
|
||||||
'last_rotated': key_pair.last_rotated
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'w') as f:
|
|
||||||
json.dump(keys_data, f, indent=2)
|
|
||||||
|
|
||||||
# Set secure permissions
|
|
||||||
os.chmod(keys_file, 0o600)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving keys: {e}")
|
|
||||||
|
|
||||||
def should_rotate_key(self, address: str, rotation_interval: int = 86400) -> bool:
|
|
||||||
"""Check if key should be rotated (default: 24 hours)"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return (time.time() - key_pair.last_rotated) >= rotation_interval
|
|
||||||
|
|
||||||
def get_key_age(self, address: str) -> Optional[float]:
|
|
||||||
"""Get age of key in seconds"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return time.time() - key_pair.created_at
|
|
||||||
|
|
||||||
# Global key manager
|
|
||||||
key_manager = KeyManager()
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
"""
|
|
||||||
Multi-Validator Proof of Authority Consensus Implementation
|
|
||||||
Extends single validator PoA to support multiple validators with rotation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from ..config import settings
|
|
||||||
from ..models import Block, Transaction
|
|
||||||
from ..database import session_scope
|
|
||||||
|
|
||||||
class ValidatorRole(Enum):
|
|
||||||
PROPOSER = "proposer"
|
|
||||||
VALIDATOR = "validator"
|
|
||||||
STANDBY = "standby"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Validator:
|
|
||||||
address: str
|
|
||||||
stake: float
|
|
||||||
reputation: float
|
|
||||||
role: ValidatorRole
|
|
||||||
last_proposed: int
|
|
||||||
is_active: bool
|
|
||||||
|
|
||||||
class MultiValidatorPoA:
|
|
||||||
"""Multi-Validator Proof of Authority consensus mechanism"""
|
|
||||||
|
|
||||||
def __init__(self, chain_id: str):
|
|
||||||
self.chain_id = chain_id
|
|
||||||
self.validators: Dict[str, Validator] = {}
|
|
||||||
self.current_proposer_index = 0
|
|
||||||
self.round_robin_enabled = True
|
|
||||||
self.consensus_timeout = 30 # seconds
|
|
||||||
|
|
||||||
def add_validator(self, address: str, stake: float = 1000.0) -> bool:
|
|
||||||
"""Add a new validator to the consensus"""
|
|
||||||
if address in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.validators[address] = Validator(
|
|
||||||
address=address,
|
|
||||||
stake=stake,
|
|
||||||
reputation=1.0,
|
|
||||||
role=ValidatorRole.STANDBY,
|
|
||||||
last_proposed=0,
|
|
||||||
is_active=True
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def remove_validator(self, address: str) -> bool:
|
|
||||||
"""Remove a validator from the consensus"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.is_active = False
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
return True
|
|
||||||
|
|
||||||
def select_proposer(self, block_height: int) -> Optional[str]:
|
|
||||||
"""Select proposer for the current block using round-robin"""
|
|
||||||
active_validators = [
|
|
||||||
v for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
if not active_validators:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Round-robin selection
|
|
||||||
proposer_index = block_height % len(active_validators)
|
|
||||||
return active_validators[proposer_index].address
|
|
||||||
|
|
||||||
def validate_block(self, block: Block, proposer: str) -> bool:
|
|
||||||
"""Validate a proposed block"""
|
|
||||||
if proposer not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[proposer]
|
|
||||||
if not validator.is_active:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if validator is allowed to propose
|
|
||||||
if validator.role not in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Additional validation logic here
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_consensus_participants(self) -> List[str]:
|
|
||||||
"""Get list of active consensus participants"""
|
|
||||||
return [
|
|
||||||
v.address for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
def update_validator_reputation(self, address: str, delta: float) -> bool:
|
|
||||||
"""Update validator reputation"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.reputation = max(0.0, min(1.0, validator.reputation + delta))
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Global consensus instance
|
|
||||||
consensus_instances: Dict[str, MultiValidatorPoA] = {}
|
|
||||||
|
|
||||||
def get_consensus(chain_id: str) -> MultiValidatorPoA:
|
|
||||||
"""Get or create consensus instance for chain"""
|
|
||||||
if chain_id not in consensus_instances:
|
|
||||||
consensus_instances[chain_id] = MultiValidatorPoA(chain_id)
|
|
||||||
return consensus_instances[chain_id]
|
|
||||||
@@ -1,193 +0,0 @@
|
|||||||
"""
|
|
||||||
Practical Byzantine Fault Tolerance (PBFT) Consensus Implementation
|
|
||||||
Provides Byzantine fault tolerance for up to 1/3 faulty validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator
|
|
||||||
|
|
||||||
class PBFTPhase(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
EXECUTE = "execute"
|
|
||||||
|
|
||||||
class PBFTMessageType(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
VIEW_CHANGE = "view_change"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTMessage:
|
|
||||||
message_type: PBFTMessageType
|
|
||||||
sender: str
|
|
||||||
view_number: int
|
|
||||||
sequence_number: int
|
|
||||||
digest: str
|
|
||||||
signature: str
|
|
||||||
timestamp: float
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTState:
|
|
||||||
current_view: int
|
|
||||||
current_sequence: int
|
|
||||||
prepared_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
committed_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
pre_prepare_messages: Dict[str, PBFTMessage]
|
|
||||||
|
|
||||||
class PBFTConsensus:
|
|
||||||
"""PBFT consensus implementation"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.state = PBFTState(
|
|
||||||
current_view=0,
|
|
||||||
current_sequence=0,
|
|
||||||
prepared_messages={},
|
|
||||||
committed_messages={},
|
|
||||||
pre_prepare_messages={}
|
|
||||||
)
|
|
||||||
self.fault_tolerance = max(1, len(consensus.get_consensus_participants()) // 3)
|
|
||||||
self.required_messages = 2 * self.fault_tolerance + 1
|
|
||||||
|
|
||||||
def get_message_digest(self, block_hash: str, sequence: int, view: int) -> str:
|
|
||||||
"""Generate message digest for PBFT"""
|
|
||||||
content = f"{block_hash}:{sequence}:{view}"
|
|
||||||
return hashlib.sha256(content.encode()).hexdigest()
|
|
||||||
|
|
||||||
async def pre_prepare_phase(self, proposer: str, block_hash: str) -> bool:
|
|
||||||
"""Phase 1: Pre-prepare"""
|
|
||||||
sequence = self.state.current_sequence + 1
|
|
||||||
view = self.state.current_view
|
|
||||||
digest = self.get_message_digest(block_hash, sequence, view)
|
|
||||||
|
|
||||||
message = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PRE_PREPARE,
|
|
||||||
sender=proposer,
|
|
||||||
view_number=view,
|
|
||||||
sequence_number=sequence,
|
|
||||||
digest=digest,
|
|
||||||
signature="", # Would be signed in real implementation
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store pre-prepare message
|
|
||||||
key = f"{sequence}:{view}"
|
|
||||||
self.state.pre_prepare_messages[key] = message
|
|
||||||
|
|
||||||
# Broadcast to all validators
|
|
||||||
await self._broadcast_message(message)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def prepare_phase(self, validator: str, pre_prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 2: Prepare"""
|
|
||||||
key = f"{pre_prepare_msg.sequence_number}:{pre_prepare_msg.view_number}"
|
|
||||||
|
|
||||||
if key not in self.state.pre_prepare_messages:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Create prepare message
|
|
||||||
prepare_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PREPARE,
|
|
||||||
sender=validator,
|
|
||||||
view_number=pre_prepare_msg.view_number,
|
|
||||||
sequence_number=pre_prepare_msg.sequence_number,
|
|
||||||
digest=pre_prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store prepare message
|
|
||||||
if key not in self.state.prepared_messages:
|
|
||||||
self.state.prepared_messages[key] = []
|
|
||||||
self.state.prepared_messages[key].append(prepare_msg)
|
|
||||||
|
|
||||||
# Broadcast prepare message
|
|
||||||
await self._broadcast_message(prepare_msg)
|
|
||||||
|
|
||||||
# Check if we have enough prepare messages
|
|
||||||
return len(self.state.prepared_messages[key]) >= self.required_messages
|
|
||||||
|
|
||||||
async def commit_phase(self, validator: str, prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 3: Commit"""
|
|
||||||
key = f"{prepare_msg.sequence_number}:{prepare_msg.view_number}"
|
|
||||||
|
|
||||||
# Create commit message
|
|
||||||
commit_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.COMMIT,
|
|
||||||
sender=validator,
|
|
||||||
view_number=prepare_msg.view_number,
|
|
||||||
sequence_number=prepare_msg.sequence_number,
|
|
||||||
digest=prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store commit message
|
|
||||||
if key not in self.state.committed_messages:
|
|
||||||
self.state.committed_messages[key] = []
|
|
||||||
self.state.committed_messages[key].append(commit_msg)
|
|
||||||
|
|
||||||
# Broadcast commit message
|
|
||||||
await self._broadcast_message(commit_msg)
|
|
||||||
|
|
||||||
# Check if we have enough commit messages
|
|
||||||
if len(self.state.committed_messages[key]) >= self.required_messages:
|
|
||||||
return await self.execute_phase(key)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def execute_phase(self, key: str) -> bool:
|
|
||||||
"""Phase 4: Execute"""
|
|
||||||
# Extract sequence and view from key
|
|
||||||
sequence, view = map(int, key.split(':'))
|
|
||||||
|
|
||||||
# Update state
|
|
||||||
self.state.current_sequence = sequence
|
|
||||||
|
|
||||||
# Clean up old messages
|
|
||||||
self._cleanup_messages(sequence)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _broadcast_message(self, message: PBFTMessage):
|
|
||||||
"""Broadcast message to all validators"""
|
|
||||||
validators = self.consensus.get_consensus_participants()
|
|
||||||
|
|
||||||
for validator in validators:
|
|
||||||
if validator != message.sender:
|
|
||||||
# In real implementation, this would send over network
|
|
||||||
await self._send_to_validator(validator, message)
|
|
||||||
|
|
||||||
async def _send_to_validator(self, validator: str, message: PBFTMessage):
|
|
||||||
"""Send message to specific validator"""
|
|
||||||
# Network communication would be implemented here
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _cleanup_messages(self, sequence: int):
|
|
||||||
"""Clean up old messages to prevent memory leaks"""
|
|
||||||
old_keys = [
|
|
||||||
key for key in self.state.prepared_messages.keys()
|
|
||||||
if int(key.split(':')[0]) < sequence
|
|
||||||
]
|
|
||||||
|
|
||||||
for key in old_keys:
|
|
||||||
self.state.prepared_messages.pop(key, None)
|
|
||||||
self.state.committed_messages.pop(key, None)
|
|
||||||
self.state.pre_prepare_messages.pop(key, None)
|
|
||||||
|
|
||||||
def handle_view_change(self, new_view: int) -> bool:
|
|
||||||
"""Handle view change when proposer fails"""
|
|
||||||
self.state.current_view = new_view
|
|
||||||
# Reset state for new view
|
|
||||||
self.state.prepared_messages.clear()
|
|
||||||
self.state.committed_messages.clear()
|
|
||||||
self.state.pre_prepare_messages.clear()
|
|
||||||
return True
|
|
||||||
@@ -1,345 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block, Account
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
await self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
await self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool and include transactions
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
from ..models import Transaction, Account
|
|
||||||
mempool = get_mempool()
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
# Pull transactions from mempool
|
|
||||||
max_txs = self._config.max_txs_per_block
|
|
||||||
max_bytes = self._config.max_block_size_bytes
|
|
||||||
pending_txs = mempool.drain(max_txs, max_bytes, self._config.chain_id)
|
|
||||||
self._logger.info(f"[PROPOSE] drained {len(pending_txs)} txs from mempool, chain={self._config.chain_id}")
|
|
||||||
|
|
||||||
# Process transactions and update balances
|
|
||||||
processed_txs = []
|
|
||||||
for tx in pending_txs:
|
|
||||||
try:
|
|
||||||
# Parse transaction data
|
|
||||||
tx_data = tx.content
|
|
||||||
sender = tx_data.get("from")
|
|
||||||
recipient = tx_data.get("to")
|
|
||||||
value = tx_data.get("amount", 0)
|
|
||||||
fee = tx_data.get("fee", 0)
|
|
||||||
|
|
||||||
if not sender or not recipient:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get sender account
|
|
||||||
sender_account = session.get(Account, (self._config.chain_id, sender))
|
|
||||||
if not sender_account:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check sufficient balance
|
|
||||||
total_cost = value + fee
|
|
||||||
if sender_account.balance < total_cost:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get or create recipient account
|
|
||||||
recipient_account = session.get(Account, (self._config.chain_id, recipient))
|
|
||||||
if not recipient_account:
|
|
||||||
recipient_account = Account(chain_id=self._config.chain_id, address=recipient, balance=0, nonce=0)
|
|
||||||
session.add(recipient_account)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
# Update balances
|
|
||||||
sender_account.balance -= total_cost
|
|
||||||
sender_account.nonce += 1
|
|
||||||
recipient_account.balance += value
|
|
||||||
|
|
||||||
# Create transaction record
|
|
||||||
transaction = Transaction(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
tx_hash=tx.tx_hash,
|
|
||||||
sender=sender,
|
|
||||||
recipient=recipient,
|
|
||||||
payload=tx_data,
|
|
||||||
value=value,
|
|
||||||
fee=fee,
|
|
||||||
nonce=sender_account.nonce - 1,
|
|
||||||
timestamp=timestamp,
|
|
||||||
block_height=next_height,
|
|
||||||
status="confirmed"
|
|
||||||
)
|
|
||||||
session.add(transaction)
|
|
||||||
processed_txs.append(tx)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.warning(f"Failed to process transaction {tx.tx_hash}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Compute block hash with transaction data
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp, processed_txs)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=len(processed_txs),
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
tx_list = [tx.content for tx in processed_txs] if processed_txs else []
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
"transactions": tx_list,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer=self._config.proposer_id, # Use configured proposer as genesis proposer
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Initialize accounts from genesis allocations file (if present)
|
|
||||||
await self._initialize_genesis_allocations(session)
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _initialize_genesis_allocations(self, session: Session) -> None:
|
|
||||||
"""Create Account entries from the genesis allocations file."""
|
|
||||||
# Use standardized data directory from configuration
|
|
||||||
from ..config import settings
|
|
||||||
|
|
||||||
genesis_paths = [
|
|
||||||
Path(f"/var/lib/aitbc/data/{self._config.chain_id}/genesis.json"), # Standard location
|
|
||||||
]
|
|
||||||
|
|
||||||
genesis_path = None
|
|
||||||
for path in genesis_paths:
|
|
||||||
if path.exists():
|
|
||||||
genesis_path = path
|
|
||||||
break
|
|
||||||
|
|
||||||
if not genesis_path:
|
|
||||||
self._logger.warning("Genesis allocations file not found; skipping account initialization", extra={"paths": str(genesis_paths)})
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(genesis_path) as f:
|
|
||||||
genesis_data = json.load(f)
|
|
||||||
|
|
||||||
allocations = genesis_data.get("allocations", [])
|
|
||||||
created = 0
|
|
||||||
for alloc in allocations:
|
|
||||||
addr = alloc["address"]
|
|
||||||
balance = int(alloc["balance"])
|
|
||||||
nonce = int(alloc.get("nonce", 0))
|
|
||||||
# Check if account already exists (idempotent)
|
|
||||||
acct = session.get(Account, (self._config.chain_id, addr))
|
|
||||||
if acct is None:
|
|
||||||
acct = Account(chain_id=self._config.chain_id, address=addr, balance=balance, nonce=nonce)
|
|
||||||
session.add(acct)
|
|
||||||
created += 1
|
|
||||||
session.commit()
|
|
||||||
self._logger.info("Initialized genesis accounts", extra={"count": created, "total": len(allocations), "path": str(genesis_path)})
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime, transactions: list = None) -> str:
|
|
||||||
# Include transaction hashes in block hash computation
|
|
||||||
tx_hashes = []
|
|
||||||
if transactions:
|
|
||||||
tx_hashes = [tx.tx_hash for tx in transactions]
|
|
||||||
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}|{'|'.join(sorted(tx_hashes))}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,229 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
if get_mempool().size(self._config.chain_id) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer="genesis",
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime) -> str:
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
--- apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
+++ apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
@@ -101,7 +101,7 @@
|
|
||||||
# Wait for interval before proposing next block
|
|
||||||
await asyncio.sleep(self.config.interval_seconds)
|
|
||||||
|
|
||||||
- self._propose_block()
|
|
||||||
+ await self._propose_block()
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Rotation Mechanism
|
|
||||||
Handles automatic rotation of validators based on performance and stake
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from typing import List, Dict, Optional
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator, ValidatorRole
|
|
||||||
|
|
||||||
class RotationStrategy(Enum):
|
|
||||||
ROUND_ROBIN = "round_robin"
|
|
||||||
STAKE_WEIGHTED = "stake_weighted"
|
|
||||||
REPUTATION_BASED = "reputation_based"
|
|
||||||
HYBRID = "hybrid"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RotationConfig:
|
|
||||||
strategy: RotationStrategy
|
|
||||||
rotation_interval: int # blocks
|
|
||||||
min_stake: float
|
|
||||||
reputation_threshold: float
|
|
||||||
max_validators: int
|
|
||||||
|
|
||||||
class ValidatorRotation:
|
|
||||||
"""Manages validator rotation based on various strategies"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA, config: RotationConfig):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.config = config
|
|
||||||
self.last_rotation_height = 0
|
|
||||||
|
|
||||||
def should_rotate(self, current_height: int) -> bool:
|
|
||||||
"""Check if rotation should occur at current height"""
|
|
||||||
return (current_height - self.last_rotation_height) >= self.config.rotation_interval
|
|
||||||
|
|
||||||
def rotate_validators(self, current_height: int) -> bool:
|
|
||||||
"""Perform validator rotation based on configured strategy"""
|
|
||||||
if not self.should_rotate(current_height):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.config.strategy == RotationStrategy.ROUND_ROBIN:
|
|
||||||
return self._rotate_round_robin()
|
|
||||||
elif self.config.strategy == RotationStrategy.STAKE_WEIGHTED:
|
|
||||||
return self._rotate_stake_weighted()
|
|
||||||
elif self.config.strategy == RotationStrategy.REPUTATION_BASED:
|
|
||||||
return self._rotate_reputation_based()
|
|
||||||
elif self.config.strategy == RotationStrategy.HYBRID:
|
|
||||||
return self._rotate_hybrid()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _rotate_round_robin(self) -> bool:
|
|
||||||
"""Round-robin rotation of validator roles"""
|
|
||||||
validators = list(self.consensus.validators.values())
|
|
||||||
active_validators = [v for v in validators if v.is_active]
|
|
||||||
|
|
||||||
# Rotate roles among active validators
|
|
||||||
for i, validator in enumerate(active_validators):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 3: # Top 3 become validators
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_stake_weighted(self) -> bool:
|
|
||||||
"""Stake-weighted rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.stake,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_reputation_based(self) -> bool:
|
|
||||||
"""Reputation-based rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.reputation,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter by reputation threshold
|
|
||||||
qualified_validators = [
|
|
||||||
v for v in validators
|
|
||||||
if v.reputation >= self.config.reputation_threshold
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, validator in enumerate(qualified_validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_hybrid(self) -> bool:
|
|
||||||
"""Hybrid rotation considering both stake and reputation"""
|
|
||||||
validators = [v for v in self.consensus.validators.values() if v.is_active]
|
|
||||||
|
|
||||||
# Calculate hybrid score
|
|
||||||
for validator in validators:
|
|
||||||
validator.hybrid_score = validator.stake * validator.reputation
|
|
||||||
|
|
||||||
# Sort by hybrid score
|
|
||||||
validators.sort(key=lambda v: v.hybrid_score, reverse=True)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Default rotation configuration
|
|
||||||
DEFAULT_ROTATION_CONFIG = RotationConfig(
|
|
||||||
strategy=RotationStrategy.HYBRID,
|
|
||||||
rotation_interval=100, # Rotate every 100 blocks
|
|
||||||
min_stake=1000.0,
|
|
||||||
reputation_threshold=0.7,
|
|
||||||
max_validators=10
|
|
||||||
)
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
"""
|
|
||||||
Slashing Conditions Implementation
|
|
||||||
Handles detection and penalties for validator misbehavior
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Dict, List, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import Validator, ValidatorRole
|
|
||||||
|
|
||||||
class SlashingCondition(Enum):
|
|
||||||
DOUBLE_SIGN = "double_sign"
|
|
||||||
UNAVAILABLE = "unavailable"
|
|
||||||
INVALID_BLOCK = "invalid_block"
|
|
||||||
SLOW_RESPONSE = "slow_response"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SlashingEvent:
|
|
||||||
validator_address: str
|
|
||||||
condition: SlashingCondition
|
|
||||||
evidence: str
|
|
||||||
block_height: int
|
|
||||||
timestamp: float
|
|
||||||
slash_amount: float
|
|
||||||
|
|
||||||
class SlashingManager:
|
|
||||||
"""Manages validator slashing conditions and penalties"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.slashing_events: List[SlashingEvent] = []
|
|
||||||
self.slash_rates = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 0.5, # 50% slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 0.1, # 10% slash
|
|
||||||
SlashingCondition.INVALID_BLOCK: 0.3, # 30% slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 0.05 # 5% slash
|
|
||||||
}
|
|
||||||
self.slash_thresholds = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 1, # Immediate slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 3, # After 3 offenses
|
|
||||||
SlashingCondition.INVALID_BLOCK: 1, # Immediate slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 5 # After 5 offenses
|
|
||||||
}
|
|
||||||
|
|
||||||
def detect_double_sign(self, validator: str, block_hash1: str, block_hash2: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect double signing (validator signed two different blocks at same height)"""
|
|
||||||
if block_hash1 == block_hash2:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.DOUBLE_SIGN,
|
|
||||||
evidence=f"Double sign detected: {block_hash1} vs {block_hash2} at height {height}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.DOUBLE_SIGN]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_unavailability(self, validator: str, missed_blocks: int, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect validator unavailability (missing consensus participation)"""
|
|
||||||
if missed_blocks < self.slash_thresholds[SlashingCondition.UNAVAILABLE]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.UNAVAILABLE,
|
|
||||||
evidence=f"Missed {missed_blocks} consecutive blocks",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.UNAVAILABLE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_invalid_block(self, validator: str, block_hash: str, reason: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect invalid block proposal"""
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.INVALID_BLOCK,
|
|
||||||
evidence=f"Invalid block {block_hash}: {reason}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.INVALID_BLOCK]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_slow_response(self, validator: str, response_time: float, threshold: float, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect slow consensus participation"""
|
|
||||||
if response_time <= threshold:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.SLOW_RESPONSE,
|
|
||||||
evidence=f"Slow response: {response_time}s (threshold: {threshold}s)",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.SLOW_RESPONSE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply_slashing(self, validator: Validator, event: SlashingEvent) -> bool:
|
|
||||||
"""Apply slashing penalty to validator"""
|
|
||||||
slash_amount = validator.stake * event.slash_amount
|
|
||||||
validator.stake -= slash_amount
|
|
||||||
|
|
||||||
# Demote validator role if stake is too low
|
|
||||||
if validator.stake < 100: # Minimum stake threshold
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
# Record slashing event
|
|
||||||
self.slashing_events.append(event)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_validator_slash_count(self, validator_address: str, condition: SlashingCondition) -> int:
|
|
||||||
"""Get count of slashing events for validator and condition"""
|
|
||||||
return len([
|
|
||||||
event for event in self.slashing_events
|
|
||||||
if event.validator_address == validator_address and event.condition == condition
|
|
||||||
])
|
|
||||||
|
|
||||||
def should_slash(self, validator: str, condition: SlashingCondition) -> bool:
|
|
||||||
"""Check if validator should be slashed for condition"""
|
|
||||||
current_count = self.get_validator_slash_count(validator, condition)
|
|
||||||
threshold = self.slash_thresholds.get(condition, 1)
|
|
||||||
return current_count >= threshold
|
|
||||||
|
|
||||||
def get_slashing_history(self, validator_address: Optional[str] = None) -> List[SlashingEvent]:
|
|
||||||
"""Get slashing history for validator or all validators"""
|
|
||||||
if validator_address:
|
|
||||||
return [event for event in self.slashing_events if event.validator_address == validator_address]
|
|
||||||
return self.slashing_events.copy()
|
|
||||||
|
|
||||||
def calculate_total_slashed(self, validator_address: str) -> float:
|
|
||||||
"""Calculate total amount slashed for validator"""
|
|
||||||
events = self.get_slashing_history(validator_address)
|
|
||||||
return sum(event.slash_amount for event in events)
|
|
||||||
|
|
||||||
# Global slashing manager
|
|
||||||
slashing_manager = SlashingManager()
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from .poa import PoAProposer, ProposerConfig, CircuitBreaker
|
|
||||||
|
|
||||||
__all__ = ["PoAProposer", "ProposerConfig", "CircuitBreaker"]
|
|
||||||
@@ -1,210 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Key Management
|
|
||||||
Handles cryptographic key operations for validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ValidatorKeyPair:
|
|
||||||
address: str
|
|
||||||
private_key_pem: str
|
|
||||||
public_key_pem: str
|
|
||||||
created_at: float
|
|
||||||
last_rotated: float
|
|
||||||
|
|
||||||
class KeyManager:
|
|
||||||
"""Manages validator cryptographic keys"""
|
|
||||||
|
|
||||||
def __init__(self, keys_dir: str = "/opt/aitbc/keys"):
|
|
||||||
self.keys_dir = keys_dir
|
|
||||||
self.key_pairs: Dict[str, ValidatorKeyPair] = {}
|
|
||||||
self._ensure_keys_directory()
|
|
||||||
self._load_existing_keys()
|
|
||||||
|
|
||||||
def _ensure_keys_directory(self):
|
|
||||||
"""Ensure keys directory exists and has proper permissions"""
|
|
||||||
os.makedirs(self.keys_dir, mode=0o700, exist_ok=True)
|
|
||||||
|
|
||||||
def _load_existing_keys(self):
|
|
||||||
"""Load existing key pairs from disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
if os.path.exists(keys_file):
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'r') as f:
|
|
||||||
keys_data = json.load(f)
|
|
||||||
|
|
||||||
for address, key_data in keys_data.items():
|
|
||||||
self.key_pairs[address] = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=key_data['private_key_pem'],
|
|
||||||
public_key_pem=key_data['public_key_pem'],
|
|
||||||
created_at=key_data['created_at'],
|
|
||||||
last_rotated=key_data['last_rotated']
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading keys: {e}")
|
|
||||||
|
|
||||||
def generate_key_pair(self, address: str) -> ValidatorKeyPair:
|
|
||||||
"""Generate new RSA key pair for validator"""
|
|
||||||
# Generate private key
|
|
||||||
private_key = rsa.generate_private_key(
|
|
||||||
public_exponent=65537,
|
|
||||||
key_size=2048,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Serialize private key
|
|
||||||
private_key_pem = private_key.private_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=PrivateFormat.PKCS8,
|
|
||||||
encryption_algorithm=NoEncryption()
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Get public key
|
|
||||||
public_key = private_key.public_key()
|
|
||||||
public_key_pem = public_key.public_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Create key pair object
|
|
||||||
current_time = time.time()
|
|
||||||
key_pair = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=private_key_pem,
|
|
||||||
public_key_pem=public_key_pem,
|
|
||||||
created_at=current_time,
|
|
||||||
last_rotated=current_time
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store key pair
|
|
||||||
self.key_pairs[address] = key_pair
|
|
||||||
self._save_keys()
|
|
||||||
|
|
||||||
return key_pair
|
|
||||||
|
|
||||||
def get_key_pair(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Get key pair for validator"""
|
|
||||||
return self.key_pairs.get(address)
|
|
||||||
|
|
||||||
def rotate_key(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Rotate validator keys"""
|
|
||||||
if address not in self.key_pairs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Generate new key pair
|
|
||||||
new_key_pair = self.generate_key_pair(address)
|
|
||||||
|
|
||||||
# Update rotation time
|
|
||||||
new_key_pair.created_at = self.key_pairs[address].created_at
|
|
||||||
new_key_pair.last_rotated = time.time()
|
|
||||||
|
|
||||||
self._save_keys()
|
|
||||||
return new_key_pair
|
|
||||||
|
|
||||||
def sign_message(self, address: str, message: str) -> Optional[str]:
|
|
||||||
"""Sign message with validator private key"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load private key from PEM
|
|
||||||
private_key = serialization.load_pem_private_key(
|
|
||||||
key_pair.private_key_pem.encode(),
|
|
||||||
password=None,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sign message
|
|
||||||
signature = private_key.sign(
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return signature.hex()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error signing message: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def verify_signature(self, address: str, message: str, signature: str) -> bool:
|
|
||||||
"""Verify message signature"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load public key from PEM
|
|
||||||
public_key = serialization.load_pem_public_key(
|
|
||||||
key_pair.public_key_pem.encode(),
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify signature
|
|
||||||
public_key.verify(
|
|
||||||
bytes.fromhex(signature),
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error verifying signature: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_public_key_pem(self, address: str) -> Optional[str]:
|
|
||||||
"""Get public key PEM for validator"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
return key_pair.public_key_pem if key_pair else None
|
|
||||||
|
|
||||||
def _save_keys(self):
|
|
||||||
"""Save key pairs to disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
keys_data = {}
|
|
||||||
for address, key_pair in self.key_pairs.items():
|
|
||||||
keys_data[address] = {
|
|
||||||
'private_key_pem': key_pair.private_key_pem,
|
|
||||||
'public_key_pem': key_pair.public_key_pem,
|
|
||||||
'created_at': key_pair.created_at,
|
|
||||||
'last_rotated': key_pair.last_rotated
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'w') as f:
|
|
||||||
json.dump(keys_data, f, indent=2)
|
|
||||||
|
|
||||||
# Set secure permissions
|
|
||||||
os.chmod(keys_file, 0o600)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving keys: {e}")
|
|
||||||
|
|
||||||
def should_rotate_key(self, address: str, rotation_interval: int = 86400) -> bool:
|
|
||||||
"""Check if key should be rotated (default: 24 hours)"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return (time.time() - key_pair.last_rotated) >= rotation_interval
|
|
||||||
|
|
||||||
def get_key_age(self, address: str) -> Optional[float]:
|
|
||||||
"""Get age of key in seconds"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return time.time() - key_pair.created_at
|
|
||||||
|
|
||||||
# Global key manager
|
|
||||||
key_manager = KeyManager()
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
"""
|
|
||||||
Multi-Validator Proof of Authority Consensus Implementation
|
|
||||||
Extends single validator PoA to support multiple validators with rotation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from ..config import settings
|
|
||||||
from ..models import Block, Transaction
|
|
||||||
from ..database import session_scope
|
|
||||||
|
|
||||||
class ValidatorRole(Enum):
|
|
||||||
PROPOSER = "proposer"
|
|
||||||
VALIDATOR = "validator"
|
|
||||||
STANDBY = "standby"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Validator:
|
|
||||||
address: str
|
|
||||||
stake: float
|
|
||||||
reputation: float
|
|
||||||
role: ValidatorRole
|
|
||||||
last_proposed: int
|
|
||||||
is_active: bool
|
|
||||||
|
|
||||||
class MultiValidatorPoA:
|
|
||||||
"""Multi-Validator Proof of Authority consensus mechanism"""
|
|
||||||
|
|
||||||
def __init__(self, chain_id: str):
|
|
||||||
self.chain_id = chain_id
|
|
||||||
self.validators: Dict[str, Validator] = {}
|
|
||||||
self.current_proposer_index = 0
|
|
||||||
self.round_robin_enabled = True
|
|
||||||
self.consensus_timeout = 30 # seconds
|
|
||||||
|
|
||||||
def add_validator(self, address: str, stake: float = 1000.0) -> bool:
|
|
||||||
"""Add a new validator to the consensus"""
|
|
||||||
if address in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.validators[address] = Validator(
|
|
||||||
address=address,
|
|
||||||
stake=stake,
|
|
||||||
reputation=1.0,
|
|
||||||
role=ValidatorRole.STANDBY,
|
|
||||||
last_proposed=0,
|
|
||||||
is_active=True
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def remove_validator(self, address: str) -> bool:
|
|
||||||
"""Remove a validator from the consensus"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.is_active = False
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
return True
|
|
||||||
|
|
||||||
def select_proposer(self, block_height: int) -> Optional[str]:
|
|
||||||
"""Select proposer for the current block using round-robin"""
|
|
||||||
active_validators = [
|
|
||||||
v for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
if not active_validators:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Round-robin selection
|
|
||||||
proposer_index = block_height % len(active_validators)
|
|
||||||
return active_validators[proposer_index].address
|
|
||||||
|
|
||||||
def validate_block(self, block: Block, proposer: str) -> bool:
|
|
||||||
"""Validate a proposed block"""
|
|
||||||
if proposer not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[proposer]
|
|
||||||
if not validator.is_active:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if validator is allowed to propose
|
|
||||||
if validator.role not in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Additional validation logic here
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_consensus_participants(self) -> List[str]:
|
|
||||||
"""Get list of active consensus participants"""
|
|
||||||
return [
|
|
||||||
v.address for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
def update_validator_reputation(self, address: str, delta: float) -> bool:
|
|
||||||
"""Update validator reputation"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.reputation = max(0.0, min(1.0, validator.reputation + delta))
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Global consensus instance
|
|
||||||
consensus_instances: Dict[str, MultiValidatorPoA] = {}
|
|
||||||
|
|
||||||
def get_consensus(chain_id: str) -> MultiValidatorPoA:
|
|
||||||
"""Get or create consensus instance for chain"""
|
|
||||||
if chain_id not in consensus_instances:
|
|
||||||
consensus_instances[chain_id] = MultiValidatorPoA(chain_id)
|
|
||||||
return consensus_instances[chain_id]
|
|
||||||
@@ -1,193 +0,0 @@
|
|||||||
"""
|
|
||||||
Practical Byzantine Fault Tolerance (PBFT) Consensus Implementation
|
|
||||||
Provides Byzantine fault tolerance for up to 1/3 faulty validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator
|
|
||||||
|
|
||||||
class PBFTPhase(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
EXECUTE = "execute"
|
|
||||||
|
|
||||||
class PBFTMessageType(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
VIEW_CHANGE = "view_change"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTMessage:
|
|
||||||
message_type: PBFTMessageType
|
|
||||||
sender: str
|
|
||||||
view_number: int
|
|
||||||
sequence_number: int
|
|
||||||
digest: str
|
|
||||||
signature: str
|
|
||||||
timestamp: float
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTState:
|
|
||||||
current_view: int
|
|
||||||
current_sequence: int
|
|
||||||
prepared_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
committed_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
pre_prepare_messages: Dict[str, PBFTMessage]
|
|
||||||
|
|
||||||
class PBFTConsensus:
|
|
||||||
"""PBFT consensus implementation"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.state = PBFTState(
|
|
||||||
current_view=0,
|
|
||||||
current_sequence=0,
|
|
||||||
prepared_messages={},
|
|
||||||
committed_messages={},
|
|
||||||
pre_prepare_messages={}
|
|
||||||
)
|
|
||||||
self.fault_tolerance = max(1, len(consensus.get_consensus_participants()) // 3)
|
|
||||||
self.required_messages = 2 * self.fault_tolerance + 1
|
|
||||||
|
|
||||||
def get_message_digest(self, block_hash: str, sequence: int, view: int) -> str:
|
|
||||||
"""Generate message digest for PBFT"""
|
|
||||||
content = f"{block_hash}:{sequence}:{view}"
|
|
||||||
return hashlib.sha256(content.encode()).hexdigest()
|
|
||||||
|
|
||||||
async def pre_prepare_phase(self, proposer: str, block_hash: str) -> bool:
|
|
||||||
"""Phase 1: Pre-prepare"""
|
|
||||||
sequence = self.state.current_sequence + 1
|
|
||||||
view = self.state.current_view
|
|
||||||
digest = self.get_message_digest(block_hash, sequence, view)
|
|
||||||
|
|
||||||
message = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PRE_PREPARE,
|
|
||||||
sender=proposer,
|
|
||||||
view_number=view,
|
|
||||||
sequence_number=sequence,
|
|
||||||
digest=digest,
|
|
||||||
signature="", # Would be signed in real implementation
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store pre-prepare message
|
|
||||||
key = f"{sequence}:{view}"
|
|
||||||
self.state.pre_prepare_messages[key] = message
|
|
||||||
|
|
||||||
# Broadcast to all validators
|
|
||||||
await self._broadcast_message(message)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def prepare_phase(self, validator: str, pre_prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 2: Prepare"""
|
|
||||||
key = f"{pre_prepare_msg.sequence_number}:{pre_prepare_msg.view_number}"
|
|
||||||
|
|
||||||
if key not in self.state.pre_prepare_messages:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Create prepare message
|
|
||||||
prepare_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PREPARE,
|
|
||||||
sender=validator,
|
|
||||||
view_number=pre_prepare_msg.view_number,
|
|
||||||
sequence_number=pre_prepare_msg.sequence_number,
|
|
||||||
digest=pre_prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store prepare message
|
|
||||||
if key not in self.state.prepared_messages:
|
|
||||||
self.state.prepared_messages[key] = []
|
|
||||||
self.state.prepared_messages[key].append(prepare_msg)
|
|
||||||
|
|
||||||
# Broadcast prepare message
|
|
||||||
await self._broadcast_message(prepare_msg)
|
|
||||||
|
|
||||||
# Check if we have enough prepare messages
|
|
||||||
return len(self.state.prepared_messages[key]) >= self.required_messages
|
|
||||||
|
|
||||||
async def commit_phase(self, validator: str, prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 3: Commit"""
|
|
||||||
key = f"{prepare_msg.sequence_number}:{prepare_msg.view_number}"
|
|
||||||
|
|
||||||
# Create commit message
|
|
||||||
commit_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.COMMIT,
|
|
||||||
sender=validator,
|
|
||||||
view_number=prepare_msg.view_number,
|
|
||||||
sequence_number=prepare_msg.sequence_number,
|
|
||||||
digest=prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store commit message
|
|
||||||
if key not in self.state.committed_messages:
|
|
||||||
self.state.committed_messages[key] = []
|
|
||||||
self.state.committed_messages[key].append(commit_msg)
|
|
||||||
|
|
||||||
# Broadcast commit message
|
|
||||||
await self._broadcast_message(commit_msg)
|
|
||||||
|
|
||||||
# Check if we have enough commit messages
|
|
||||||
if len(self.state.committed_messages[key]) >= self.required_messages:
|
|
||||||
return await self.execute_phase(key)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def execute_phase(self, key: str) -> bool:
|
|
||||||
"""Phase 4: Execute"""
|
|
||||||
# Extract sequence and view from key
|
|
||||||
sequence, view = map(int, key.split(':'))
|
|
||||||
|
|
||||||
# Update state
|
|
||||||
self.state.current_sequence = sequence
|
|
||||||
|
|
||||||
# Clean up old messages
|
|
||||||
self._cleanup_messages(sequence)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _broadcast_message(self, message: PBFTMessage):
|
|
||||||
"""Broadcast message to all validators"""
|
|
||||||
validators = self.consensus.get_consensus_participants()
|
|
||||||
|
|
||||||
for validator in validators:
|
|
||||||
if validator != message.sender:
|
|
||||||
# In real implementation, this would send over network
|
|
||||||
await self._send_to_validator(validator, message)
|
|
||||||
|
|
||||||
async def _send_to_validator(self, validator: str, message: PBFTMessage):
|
|
||||||
"""Send message to specific validator"""
|
|
||||||
# Network communication would be implemented here
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _cleanup_messages(self, sequence: int):
|
|
||||||
"""Clean up old messages to prevent memory leaks"""
|
|
||||||
old_keys = [
|
|
||||||
key for key in self.state.prepared_messages.keys()
|
|
||||||
if int(key.split(':')[0]) < sequence
|
|
||||||
]
|
|
||||||
|
|
||||||
for key in old_keys:
|
|
||||||
self.state.prepared_messages.pop(key, None)
|
|
||||||
self.state.committed_messages.pop(key, None)
|
|
||||||
self.state.pre_prepare_messages.pop(key, None)
|
|
||||||
|
|
||||||
def handle_view_change(self, new_view: int) -> bool:
|
|
||||||
"""Handle view change when proposer fails"""
|
|
||||||
self.state.current_view = new_view
|
|
||||||
# Reset state for new view
|
|
||||||
self.state.prepared_messages.clear()
|
|
||||||
self.state.committed_messages.clear()
|
|
||||||
self.state.pre_prepare_messages.clear()
|
|
||||||
return True
|
|
||||||
@@ -1,345 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block, Account
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
await self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
await self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool and include transactions
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
from ..models import Transaction, Account
|
|
||||||
mempool = get_mempool()
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
# Pull transactions from mempool
|
|
||||||
max_txs = self._config.max_txs_per_block
|
|
||||||
max_bytes = self._config.max_block_size_bytes
|
|
||||||
pending_txs = mempool.drain(max_txs, max_bytes, self._config.chain_id)
|
|
||||||
self._logger.info(f"[PROPOSE] drained {len(pending_txs)} txs from mempool, chain={self._config.chain_id}")
|
|
||||||
|
|
||||||
# Process transactions and update balances
|
|
||||||
processed_txs = []
|
|
||||||
for tx in pending_txs:
|
|
||||||
try:
|
|
||||||
# Parse transaction data
|
|
||||||
tx_data = tx.content
|
|
||||||
sender = tx_data.get("from")
|
|
||||||
recipient = tx_data.get("to")
|
|
||||||
value = tx_data.get("amount", 0)
|
|
||||||
fee = tx_data.get("fee", 0)
|
|
||||||
|
|
||||||
if not sender or not recipient:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get sender account
|
|
||||||
sender_account = session.get(Account, (self._config.chain_id, sender))
|
|
||||||
if not sender_account:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check sufficient balance
|
|
||||||
total_cost = value + fee
|
|
||||||
if sender_account.balance < total_cost:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get or create recipient account
|
|
||||||
recipient_account = session.get(Account, (self._config.chain_id, recipient))
|
|
||||||
if not recipient_account:
|
|
||||||
recipient_account = Account(chain_id=self._config.chain_id, address=recipient, balance=0, nonce=0)
|
|
||||||
session.add(recipient_account)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
# Update balances
|
|
||||||
sender_account.balance -= total_cost
|
|
||||||
sender_account.nonce += 1
|
|
||||||
recipient_account.balance += value
|
|
||||||
|
|
||||||
# Create transaction record
|
|
||||||
transaction = Transaction(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
tx_hash=tx.tx_hash,
|
|
||||||
sender=sender,
|
|
||||||
recipient=recipient,
|
|
||||||
payload=tx_data,
|
|
||||||
value=value,
|
|
||||||
fee=fee,
|
|
||||||
nonce=sender_account.nonce - 1,
|
|
||||||
timestamp=timestamp,
|
|
||||||
block_height=next_height,
|
|
||||||
status="confirmed"
|
|
||||||
)
|
|
||||||
session.add(transaction)
|
|
||||||
processed_txs.append(tx)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.warning(f"Failed to process transaction {tx.tx_hash}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Compute block hash with transaction data
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp, processed_txs)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=len(processed_txs),
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
tx_list = [tx.content for tx in processed_txs] if processed_txs else []
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
"transactions": tx_list,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer=self._config.proposer_id, # Use configured proposer as genesis proposer
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Initialize accounts from genesis allocations file (if present)
|
|
||||||
await self._initialize_genesis_allocations(session)
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _initialize_genesis_allocations(self, session: Session) -> None:
|
|
||||||
"""Create Account entries from the genesis allocations file."""
|
|
||||||
# Use standardized data directory from configuration
|
|
||||||
from ..config import settings
|
|
||||||
|
|
||||||
genesis_paths = [
|
|
||||||
Path(f"/var/lib/aitbc/data/{self._config.chain_id}/genesis.json"), # Standard location
|
|
||||||
]
|
|
||||||
|
|
||||||
genesis_path = None
|
|
||||||
for path in genesis_paths:
|
|
||||||
if path.exists():
|
|
||||||
genesis_path = path
|
|
||||||
break
|
|
||||||
|
|
||||||
if not genesis_path:
|
|
||||||
self._logger.warning("Genesis allocations file not found; skipping account initialization", extra={"paths": str(genesis_paths)})
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(genesis_path) as f:
|
|
||||||
genesis_data = json.load(f)
|
|
||||||
|
|
||||||
allocations = genesis_data.get("allocations", [])
|
|
||||||
created = 0
|
|
||||||
for alloc in allocations:
|
|
||||||
addr = alloc["address"]
|
|
||||||
balance = int(alloc["balance"])
|
|
||||||
nonce = int(alloc.get("nonce", 0))
|
|
||||||
# Check if account already exists (idempotent)
|
|
||||||
acct = session.get(Account, (self._config.chain_id, addr))
|
|
||||||
if acct is None:
|
|
||||||
acct = Account(chain_id=self._config.chain_id, address=addr, balance=balance, nonce=nonce)
|
|
||||||
session.add(acct)
|
|
||||||
created += 1
|
|
||||||
session.commit()
|
|
||||||
self._logger.info("Initialized genesis accounts", extra={"count": created, "total": len(allocations), "path": str(genesis_path)})
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime, transactions: list = None) -> str:
|
|
||||||
# Include transaction hashes in block hash computation
|
|
||||||
tx_hashes = []
|
|
||||||
if transactions:
|
|
||||||
tx_hashes = [tx.tx_hash for tx in transactions]
|
|
||||||
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}|{'|'.join(sorted(tx_hashes))}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,229 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
if get_mempool().size(self._config.chain_id) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer="genesis",
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime) -> str:
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
--- apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
+++ apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
@@ -101,7 +101,7 @@
|
|
||||||
# Wait for interval before proposing next block
|
|
||||||
await asyncio.sleep(self.config.interval_seconds)
|
|
||||||
|
|
||||||
- self._propose_block()
|
|
||||||
+ await self._propose_block()
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Rotation Mechanism
|
|
||||||
Handles automatic rotation of validators based on performance and stake
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from typing import List, Dict, Optional
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator, ValidatorRole
|
|
||||||
|
|
||||||
class RotationStrategy(Enum):
|
|
||||||
ROUND_ROBIN = "round_robin"
|
|
||||||
STAKE_WEIGHTED = "stake_weighted"
|
|
||||||
REPUTATION_BASED = "reputation_based"
|
|
||||||
HYBRID = "hybrid"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RotationConfig:
|
|
||||||
strategy: RotationStrategy
|
|
||||||
rotation_interval: int # blocks
|
|
||||||
min_stake: float
|
|
||||||
reputation_threshold: float
|
|
||||||
max_validators: int
|
|
||||||
|
|
||||||
class ValidatorRotation:
|
|
||||||
"""Manages validator rotation based on various strategies"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA, config: RotationConfig):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.config = config
|
|
||||||
self.last_rotation_height = 0
|
|
||||||
|
|
||||||
def should_rotate(self, current_height: int) -> bool:
|
|
||||||
"""Check if rotation should occur at current height"""
|
|
||||||
return (current_height - self.last_rotation_height) >= self.config.rotation_interval
|
|
||||||
|
|
||||||
def rotate_validators(self, current_height: int) -> bool:
|
|
||||||
"""Perform validator rotation based on configured strategy"""
|
|
||||||
if not self.should_rotate(current_height):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.config.strategy == RotationStrategy.ROUND_ROBIN:
|
|
||||||
return self._rotate_round_robin()
|
|
||||||
elif self.config.strategy == RotationStrategy.STAKE_WEIGHTED:
|
|
||||||
return self._rotate_stake_weighted()
|
|
||||||
elif self.config.strategy == RotationStrategy.REPUTATION_BASED:
|
|
||||||
return self._rotate_reputation_based()
|
|
||||||
elif self.config.strategy == RotationStrategy.HYBRID:
|
|
||||||
return self._rotate_hybrid()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _rotate_round_robin(self) -> bool:
|
|
||||||
"""Round-robin rotation of validator roles"""
|
|
||||||
validators = list(self.consensus.validators.values())
|
|
||||||
active_validators = [v for v in validators if v.is_active]
|
|
||||||
|
|
||||||
# Rotate roles among active validators
|
|
||||||
for i, validator in enumerate(active_validators):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 3: # Top 3 become validators
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_stake_weighted(self) -> bool:
|
|
||||||
"""Stake-weighted rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.stake,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_reputation_based(self) -> bool:
|
|
||||||
"""Reputation-based rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.reputation,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter by reputation threshold
|
|
||||||
qualified_validators = [
|
|
||||||
v for v in validators
|
|
||||||
if v.reputation >= self.config.reputation_threshold
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, validator in enumerate(qualified_validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_hybrid(self) -> bool:
|
|
||||||
"""Hybrid rotation considering both stake and reputation"""
|
|
||||||
validators = [v for v in self.consensus.validators.values() if v.is_active]
|
|
||||||
|
|
||||||
# Calculate hybrid score
|
|
||||||
for validator in validators:
|
|
||||||
validator.hybrid_score = validator.stake * validator.reputation
|
|
||||||
|
|
||||||
# Sort by hybrid score
|
|
||||||
validators.sort(key=lambda v: v.hybrid_score, reverse=True)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Default rotation configuration
|
|
||||||
DEFAULT_ROTATION_CONFIG = RotationConfig(
|
|
||||||
strategy=RotationStrategy.HYBRID,
|
|
||||||
rotation_interval=100, # Rotate every 100 blocks
|
|
||||||
min_stake=1000.0,
|
|
||||||
reputation_threshold=0.7,
|
|
||||||
max_validators=10
|
|
||||||
)
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
"""
|
|
||||||
Slashing Conditions Implementation
|
|
||||||
Handles detection and penalties for validator misbehavior
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Dict, List, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import Validator, ValidatorRole
|
|
||||||
|
|
||||||
class SlashingCondition(Enum):
|
|
||||||
DOUBLE_SIGN = "double_sign"
|
|
||||||
UNAVAILABLE = "unavailable"
|
|
||||||
INVALID_BLOCK = "invalid_block"
|
|
||||||
SLOW_RESPONSE = "slow_response"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SlashingEvent:
|
|
||||||
validator_address: str
|
|
||||||
condition: SlashingCondition
|
|
||||||
evidence: str
|
|
||||||
block_height: int
|
|
||||||
timestamp: float
|
|
||||||
slash_amount: float
|
|
||||||
|
|
||||||
class SlashingManager:
|
|
||||||
"""Manages validator slashing conditions and penalties"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.slashing_events: List[SlashingEvent] = []
|
|
||||||
self.slash_rates = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 0.5, # 50% slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 0.1, # 10% slash
|
|
||||||
SlashingCondition.INVALID_BLOCK: 0.3, # 30% slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 0.05 # 5% slash
|
|
||||||
}
|
|
||||||
self.slash_thresholds = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 1, # Immediate slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 3, # After 3 offenses
|
|
||||||
SlashingCondition.INVALID_BLOCK: 1, # Immediate slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 5 # After 5 offenses
|
|
||||||
}
|
|
||||||
|
|
||||||
def detect_double_sign(self, validator: str, block_hash1: str, block_hash2: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect double signing (validator signed two different blocks at same height)"""
|
|
||||||
if block_hash1 == block_hash2:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.DOUBLE_SIGN,
|
|
||||||
evidence=f"Double sign detected: {block_hash1} vs {block_hash2} at height {height}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.DOUBLE_SIGN]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_unavailability(self, validator: str, missed_blocks: int, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect validator unavailability (missing consensus participation)"""
|
|
||||||
if missed_blocks < self.slash_thresholds[SlashingCondition.UNAVAILABLE]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.UNAVAILABLE,
|
|
||||||
evidence=f"Missed {missed_blocks} consecutive blocks",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.UNAVAILABLE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_invalid_block(self, validator: str, block_hash: str, reason: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect invalid block proposal"""
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.INVALID_BLOCK,
|
|
||||||
evidence=f"Invalid block {block_hash}: {reason}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.INVALID_BLOCK]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_slow_response(self, validator: str, response_time: float, threshold: float, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect slow consensus participation"""
|
|
||||||
if response_time <= threshold:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.SLOW_RESPONSE,
|
|
||||||
evidence=f"Slow response: {response_time}s (threshold: {threshold}s)",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.SLOW_RESPONSE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply_slashing(self, validator: Validator, event: SlashingEvent) -> bool:
|
|
||||||
"""Apply slashing penalty to validator"""
|
|
||||||
slash_amount = validator.stake * event.slash_amount
|
|
||||||
validator.stake -= slash_amount
|
|
||||||
|
|
||||||
# Demote validator role if stake is too low
|
|
||||||
if validator.stake < 100: # Minimum stake threshold
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
# Record slashing event
|
|
||||||
self.slashing_events.append(event)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_validator_slash_count(self, validator_address: str, condition: SlashingCondition) -> int:
|
|
||||||
"""Get count of slashing events for validator and condition"""
|
|
||||||
return len([
|
|
||||||
event for event in self.slashing_events
|
|
||||||
if event.validator_address == validator_address and event.condition == condition
|
|
||||||
])
|
|
||||||
|
|
||||||
def should_slash(self, validator: str, condition: SlashingCondition) -> bool:
|
|
||||||
"""Check if validator should be slashed for condition"""
|
|
||||||
current_count = self.get_validator_slash_count(validator, condition)
|
|
||||||
threshold = self.slash_thresholds.get(condition, 1)
|
|
||||||
return current_count >= threshold
|
|
||||||
|
|
||||||
def get_slashing_history(self, validator_address: Optional[str] = None) -> List[SlashingEvent]:
|
|
||||||
"""Get slashing history for validator or all validators"""
|
|
||||||
if validator_address:
|
|
||||||
return [event for event in self.slashing_events if event.validator_address == validator_address]
|
|
||||||
return self.slashing_events.copy()
|
|
||||||
|
|
||||||
def calculate_total_slashed(self, validator_address: str) -> float:
|
|
||||||
"""Calculate total amount slashed for validator"""
|
|
||||||
events = self.get_slashing_history(validator_address)
|
|
||||||
return sum(event.slash_amount for event in events)
|
|
||||||
|
|
||||||
# Global slashing manager
|
|
||||||
slashing_manager = SlashingManager()
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from .poa import PoAProposer, ProposerConfig, CircuitBreaker
|
|
||||||
|
|
||||||
__all__ = ["PoAProposer", "ProposerConfig", "CircuitBreaker"]
|
|
||||||
@@ -1,211 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Key Management
|
|
||||||
Handles cryptographic key operations for validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ValidatorKeyPair:
|
|
||||||
address: str
|
|
||||||
private_key_pem: str
|
|
||||||
public_key_pem: str
|
|
||||||
created_at: float
|
|
||||||
last_rotated: float
|
|
||||||
|
|
||||||
class KeyManager:
|
|
||||||
"""Manages validator cryptographic keys"""
|
|
||||||
|
|
||||||
def __init__(self, keys_dir: str = "/opt/aitbc/keys"):
|
|
||||||
self.keys_dir = keys_dir
|
|
||||||
self.key_pairs: Dict[str, ValidatorKeyPair] = {}
|
|
||||||
self._ensure_keys_directory()
|
|
||||||
self._load_existing_keys()
|
|
||||||
|
|
||||||
def _ensure_keys_directory(self):
|
|
||||||
"""Ensure keys directory exists and has proper permissions"""
|
|
||||||
os.makedirs(self.keys_dir, mode=0o700, exist_ok=True)
|
|
||||||
|
|
||||||
def _load_existing_keys(self):
|
|
||||||
"""Load existing key pairs from disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
if os.path.exists(keys_file):
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'r') as f:
|
|
||||||
keys_data = json.load(f)
|
|
||||||
|
|
||||||
for address, key_data in keys_data.items():
|
|
||||||
self.key_pairs[address] = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=key_data['private_key_pem'],
|
|
||||||
public_key_pem=key_data['public_key_pem'],
|
|
||||||
created_at=key_data['created_at'],
|
|
||||||
last_rotated=key_data['last_rotated']
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading keys: {e}")
|
|
||||||
|
|
||||||
def generate_key_pair(self, address: str) -> ValidatorKeyPair:
|
|
||||||
"""Generate new RSA key pair for validator"""
|
|
||||||
# Generate private key
|
|
||||||
private_key = rsa.generate_private_key(
|
|
||||||
public_exponent=65537,
|
|
||||||
key_size=2048,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Serialize private key
|
|
||||||
private_key_pem = private_key.private_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=PrivateFormat.PKCS8,
|
|
||||||
encryption_algorithm=NoEncryption()
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Get public key
|
|
||||||
public_key = private_key.public_key()
|
|
||||||
public_key_pem = public_key.public_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Create key pair object
|
|
||||||
current_time = time.time()
|
|
||||||
key_pair = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=private_key_pem,
|
|
||||||
public_key_pem=public_key_pem,
|
|
||||||
created_at=current_time,
|
|
||||||
last_rotated=current_time
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store key pair
|
|
||||||
self.key_pairs[address] = key_pair
|
|
||||||
self._save_keys()
|
|
||||||
|
|
||||||
return key_pair
|
|
||||||
|
|
||||||
def get_key_pair(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Get key pair for validator"""
|
|
||||||
return self.key_pairs.get(address)
|
|
||||||
|
|
||||||
def rotate_key(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Rotate validator keys"""
|
|
||||||
if address not in self.key_pairs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Generate new key pair
|
|
||||||
new_key_pair = self.generate_key_pair(address)
|
|
||||||
|
|
||||||
# Update rotation time
|
|
||||||
new_key_pair.created_at = self.key_pairs[address].created_at
|
|
||||||
new_key_pair.last_rotated = time.time()
|
|
||||||
|
|
||||||
self._save_keys()
|
|
||||||
return new_key_pair
|
|
||||||
|
|
||||||
def sign_message(self, address: str, message: str) -> Optional[str]:
|
|
||||||
"""Sign message with validator private key"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load private key from PEM
|
|
||||||
private_key = serialization.load_pem_private_key(
|
|
||||||
key_pair.private_key_pem.encode(),
|
|
||||||
password=None,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sign message
|
|
||||||
signature = private_key.sign(
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return signature.hex()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error signing message: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def verify_signature(self, address: str, message: str, signature: str) -> bool:
|
|
||||||
"""Verify message signature"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load public key from PEM
|
|
||||||
public_key = serialization.load_pem_public_key(
|
|
||||||
key_pair.public_key_pem.encode(),
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify signature
|
|
||||||
public_key.verify(
|
|
||||||
bytes.fromhex(signature),
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error verifying signature: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_public_key_pem(self, address: str) -> Optional[str]:
|
|
||||||
"""Get public key PEM for validator"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
return key_pair.public_key_pem if key_pair else None
|
|
||||||
|
|
||||||
def _save_keys(self):
|
|
||||||
"""Save key pairs to disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
keys_data = {}
|
|
||||||
for address, key_pair in self.key_pairs.items():
|
|
||||||
keys_data[address] = {
|
|
||||||
'private_key_pem': key_pair.private_key_pem,
|
|
||||||
'public_key_pem': key_pair.public_key_pem,
|
|
||||||
'created_at': key_pair.created_at,
|
|
||||||
'last_rotated': key_pair.last_rotated
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'w') as f:
|
|
||||||
json.dump(keys_data, f, indent=2)
|
|
||||||
|
|
||||||
# Set secure permissions
|
|
||||||
os.chmod(keys_file, 0o600)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving keys: {e}")
|
|
||||||
|
|
||||||
def should_rotate_key(self, address: str, rotation_interval: int = 86400) -> bool:
|
|
||||||
"""Check if key should be rotated (default: 24 hours)"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return (time.time() - key_pair.last_rotated) >= rotation_interval
|
|
||||||
|
|
||||||
def get_key_age(self, address: str) -> Optional[float]:
|
|
||||||
"""Get age of key in seconds"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return time.time() - key_pair.created_at
|
|
||||||
|
|
||||||
# Global key manager
|
|
||||||
key_manager = KeyManager()
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
"""
|
|
||||||
Multi-Validator Proof of Authority Consensus Implementation
|
|
||||||
Extends single validator PoA to support multiple validators with rotation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from ..config import settings
|
|
||||||
from ..models import Block, Transaction
|
|
||||||
from ..database import session_scope
|
|
||||||
|
|
||||||
class ValidatorRole(Enum):
|
|
||||||
PROPOSER = "proposer"
|
|
||||||
VALIDATOR = "validator"
|
|
||||||
STANDBY = "standby"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Validator:
|
|
||||||
address: str
|
|
||||||
stake: float
|
|
||||||
reputation: float
|
|
||||||
role: ValidatorRole
|
|
||||||
last_proposed: int
|
|
||||||
is_active: bool
|
|
||||||
|
|
||||||
class MultiValidatorPoA:
|
|
||||||
"""Multi-Validator Proof of Authority consensus mechanism"""
|
|
||||||
|
|
||||||
def __init__(self, chain_id: str):
|
|
||||||
self.chain_id = chain_id
|
|
||||||
self.validators: Dict[str, Validator] = {}
|
|
||||||
self.current_proposer_index = 0
|
|
||||||
self.round_robin_enabled = True
|
|
||||||
self.consensus_timeout = 30 # seconds
|
|
||||||
|
|
||||||
def add_validator(self, address: str, stake: float = 1000.0) -> bool:
|
|
||||||
"""Add a new validator to the consensus"""
|
|
||||||
if address in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.validators[address] = Validator(
|
|
||||||
address=address,
|
|
||||||
stake=stake,
|
|
||||||
reputation=1.0,
|
|
||||||
role=ValidatorRole.STANDBY,
|
|
||||||
last_proposed=0,
|
|
||||||
is_active=True
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def remove_validator(self, address: str) -> bool:
|
|
||||||
"""Remove a validator from the consensus"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.is_active = False
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
return True
|
|
||||||
|
|
||||||
def select_proposer(self, block_height: int) -> Optional[str]:
|
|
||||||
"""Select proposer for the current block using round-robin"""
|
|
||||||
active_validators = [
|
|
||||||
v for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
if not active_validators:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Round-robin selection
|
|
||||||
proposer_index = block_height % len(active_validators)
|
|
||||||
return active_validators[proposer_index].address
|
|
||||||
|
|
||||||
def validate_block(self, block: Block, proposer: str) -> bool:
|
|
||||||
"""Validate a proposed block"""
|
|
||||||
if proposer not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[proposer]
|
|
||||||
if not validator.is_active:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if validator is allowed to propose
|
|
||||||
if validator.role not in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Additional validation logic here
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_consensus_participants(self) -> List[str]:
|
|
||||||
"""Get list of active consensus participants"""
|
|
||||||
return [
|
|
||||||
v.address for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
def update_validator_reputation(self, address: str, delta: float) -> bool:
|
|
||||||
"""Update validator reputation"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.reputation = max(0.0, min(1.0, validator.reputation + delta))
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Global consensus instance
|
|
||||||
consensus_instances: Dict[str, MultiValidatorPoA] = {}
|
|
||||||
|
|
||||||
def get_consensus(chain_id: str) -> MultiValidatorPoA:
|
|
||||||
"""Get or create consensus instance for chain"""
|
|
||||||
if chain_id not in consensus_instances:
|
|
||||||
consensus_instances[chain_id] = MultiValidatorPoA(chain_id)
|
|
||||||
return consensus_instances[chain_id]
|
|
||||||
@@ -1,193 +0,0 @@
|
|||||||
"""
|
|
||||||
Practical Byzantine Fault Tolerance (PBFT) Consensus Implementation
|
|
||||||
Provides Byzantine fault tolerance for up to 1/3 faulty validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator
|
|
||||||
|
|
||||||
class PBFTPhase(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
EXECUTE = "execute"
|
|
||||||
|
|
||||||
class PBFTMessageType(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
VIEW_CHANGE = "view_change"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTMessage:
|
|
||||||
message_type: PBFTMessageType
|
|
||||||
sender: str
|
|
||||||
view_number: int
|
|
||||||
sequence_number: int
|
|
||||||
digest: str
|
|
||||||
signature: str
|
|
||||||
timestamp: float
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTState:
|
|
||||||
current_view: int
|
|
||||||
current_sequence: int
|
|
||||||
prepared_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
committed_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
pre_prepare_messages: Dict[str, PBFTMessage]
|
|
||||||
|
|
||||||
class PBFTConsensus:
|
|
||||||
"""PBFT consensus implementation"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.state = PBFTState(
|
|
||||||
current_view=0,
|
|
||||||
current_sequence=0,
|
|
||||||
prepared_messages={},
|
|
||||||
committed_messages={},
|
|
||||||
pre_prepare_messages={}
|
|
||||||
)
|
|
||||||
self.fault_tolerance = max(1, len(consensus.get_consensus_participants()) // 3)
|
|
||||||
self.required_messages = 2 * self.fault_tolerance + 1
|
|
||||||
|
|
||||||
def get_message_digest(self, block_hash: str, sequence: int, view: int) -> str:
|
|
||||||
"""Generate message digest for PBFT"""
|
|
||||||
content = f"{block_hash}:{sequence}:{view}"
|
|
||||||
return hashlib.sha256(content.encode()).hexdigest()
|
|
||||||
|
|
||||||
async def pre_prepare_phase(self, proposer: str, block_hash: str) -> bool:
|
|
||||||
"""Phase 1: Pre-prepare"""
|
|
||||||
sequence = self.state.current_sequence + 1
|
|
||||||
view = self.state.current_view
|
|
||||||
digest = self.get_message_digest(block_hash, sequence, view)
|
|
||||||
|
|
||||||
message = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PRE_PREPARE,
|
|
||||||
sender=proposer,
|
|
||||||
view_number=view,
|
|
||||||
sequence_number=sequence,
|
|
||||||
digest=digest,
|
|
||||||
signature="", # Would be signed in real implementation
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store pre-prepare message
|
|
||||||
key = f"{sequence}:{view}"
|
|
||||||
self.state.pre_prepare_messages[key] = message
|
|
||||||
|
|
||||||
# Broadcast to all validators
|
|
||||||
await self._broadcast_message(message)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def prepare_phase(self, validator: str, pre_prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 2: Prepare"""
|
|
||||||
key = f"{pre_prepare_msg.sequence_number}:{pre_prepare_msg.view_number}"
|
|
||||||
|
|
||||||
if key not in self.state.pre_prepare_messages:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Create prepare message
|
|
||||||
prepare_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PREPARE,
|
|
||||||
sender=validator,
|
|
||||||
view_number=pre_prepare_msg.view_number,
|
|
||||||
sequence_number=pre_prepare_msg.sequence_number,
|
|
||||||
digest=pre_prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store prepare message
|
|
||||||
if key not in self.state.prepared_messages:
|
|
||||||
self.state.prepared_messages[key] = []
|
|
||||||
self.state.prepared_messages[key].append(prepare_msg)
|
|
||||||
|
|
||||||
# Broadcast prepare message
|
|
||||||
await self._broadcast_message(prepare_msg)
|
|
||||||
|
|
||||||
# Check if we have enough prepare messages
|
|
||||||
return len(self.state.prepared_messages[key]) >= self.required_messages
|
|
||||||
|
|
||||||
async def commit_phase(self, validator: str, prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 3: Commit"""
|
|
||||||
key = f"{prepare_msg.sequence_number}:{prepare_msg.view_number}"
|
|
||||||
|
|
||||||
# Create commit message
|
|
||||||
commit_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.COMMIT,
|
|
||||||
sender=validator,
|
|
||||||
view_number=prepare_msg.view_number,
|
|
||||||
sequence_number=prepare_msg.sequence_number,
|
|
||||||
digest=prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store commit message
|
|
||||||
if key not in self.state.committed_messages:
|
|
||||||
self.state.committed_messages[key] = []
|
|
||||||
self.state.committed_messages[key].append(commit_msg)
|
|
||||||
|
|
||||||
# Broadcast commit message
|
|
||||||
await self._broadcast_message(commit_msg)
|
|
||||||
|
|
||||||
# Check if we have enough commit messages
|
|
||||||
if len(self.state.committed_messages[key]) >= self.required_messages:
|
|
||||||
return await self.execute_phase(key)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def execute_phase(self, key: str) -> bool:
|
|
||||||
"""Phase 4: Execute"""
|
|
||||||
# Extract sequence and view from key
|
|
||||||
sequence, view = map(int, key.split(':'))
|
|
||||||
|
|
||||||
# Update state
|
|
||||||
self.state.current_sequence = sequence
|
|
||||||
|
|
||||||
# Clean up old messages
|
|
||||||
self._cleanup_messages(sequence)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _broadcast_message(self, message: PBFTMessage):
|
|
||||||
"""Broadcast message to all validators"""
|
|
||||||
validators = self.consensus.get_consensus_participants()
|
|
||||||
|
|
||||||
for validator in validators:
|
|
||||||
if validator != message.sender:
|
|
||||||
# In real implementation, this would send over network
|
|
||||||
await self._send_to_validator(validator, message)
|
|
||||||
|
|
||||||
async def _send_to_validator(self, validator: str, message: PBFTMessage):
|
|
||||||
"""Send message to specific validator"""
|
|
||||||
# Network communication would be implemented here
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _cleanup_messages(self, sequence: int):
|
|
||||||
"""Clean up old messages to prevent memory leaks"""
|
|
||||||
old_keys = [
|
|
||||||
key for key in self.state.prepared_messages.keys()
|
|
||||||
if int(key.split(':')[0]) < sequence
|
|
||||||
]
|
|
||||||
|
|
||||||
for key in old_keys:
|
|
||||||
self.state.prepared_messages.pop(key, None)
|
|
||||||
self.state.committed_messages.pop(key, None)
|
|
||||||
self.state.pre_prepare_messages.pop(key, None)
|
|
||||||
|
|
||||||
def handle_view_change(self, new_view: int) -> bool:
|
|
||||||
"""Handle view change when proposer fails"""
|
|
||||||
self.state.current_view = new_view
|
|
||||||
# Reset state for new view
|
|
||||||
self.state.prepared_messages.clear()
|
|
||||||
self.state.committed_messages.clear()
|
|
||||||
self.state.pre_prepare_messages.clear()
|
|
||||||
return True
|
|
||||||
@@ -1,345 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block, Account
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
await self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
await self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool and include transactions
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
from ..models import Transaction, Account
|
|
||||||
mempool = get_mempool()
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
# Pull transactions from mempool
|
|
||||||
max_txs = self._config.max_txs_per_block
|
|
||||||
max_bytes = self._config.max_block_size_bytes
|
|
||||||
pending_txs = mempool.drain(max_txs, max_bytes, self._config.chain_id)
|
|
||||||
self._logger.info(f"[PROPOSE] drained {len(pending_txs)} txs from mempool, chain={self._config.chain_id}")
|
|
||||||
|
|
||||||
# Process transactions and update balances
|
|
||||||
processed_txs = []
|
|
||||||
for tx in pending_txs:
|
|
||||||
try:
|
|
||||||
# Parse transaction data
|
|
||||||
tx_data = tx.content
|
|
||||||
sender = tx_data.get("from")
|
|
||||||
recipient = tx_data.get("to")
|
|
||||||
value = tx_data.get("amount", 0)
|
|
||||||
fee = tx_data.get("fee", 0)
|
|
||||||
|
|
||||||
if not sender or not recipient:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get sender account
|
|
||||||
sender_account = session.get(Account, (self._config.chain_id, sender))
|
|
||||||
if not sender_account:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check sufficient balance
|
|
||||||
total_cost = value + fee
|
|
||||||
if sender_account.balance < total_cost:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get or create recipient account
|
|
||||||
recipient_account = session.get(Account, (self._config.chain_id, recipient))
|
|
||||||
if not recipient_account:
|
|
||||||
recipient_account = Account(chain_id=self._config.chain_id, address=recipient, balance=0, nonce=0)
|
|
||||||
session.add(recipient_account)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
# Update balances
|
|
||||||
sender_account.balance -= total_cost
|
|
||||||
sender_account.nonce += 1
|
|
||||||
recipient_account.balance += value
|
|
||||||
|
|
||||||
# Create transaction record
|
|
||||||
transaction = Transaction(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
tx_hash=tx.tx_hash,
|
|
||||||
sender=sender,
|
|
||||||
recipient=recipient,
|
|
||||||
payload=tx_data,
|
|
||||||
value=value,
|
|
||||||
fee=fee,
|
|
||||||
nonce=sender_account.nonce - 1,
|
|
||||||
timestamp=timestamp,
|
|
||||||
block_height=next_height,
|
|
||||||
status="confirmed"
|
|
||||||
)
|
|
||||||
session.add(transaction)
|
|
||||||
processed_txs.append(tx)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.warning(f"Failed to process transaction {tx.tx_hash}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Compute block hash with transaction data
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp, processed_txs)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=len(processed_txs),
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
tx_list = [tx.content for tx in processed_txs] if processed_txs else []
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
"transactions": tx_list,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer=self._config.proposer_id, # Use configured proposer as genesis proposer
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Initialize accounts from genesis allocations file (if present)
|
|
||||||
await self._initialize_genesis_allocations(session)
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _initialize_genesis_allocations(self, session: Session) -> None:
|
|
||||||
"""Create Account entries from the genesis allocations file."""
|
|
||||||
# Use standardized data directory from configuration
|
|
||||||
from ..config import settings
|
|
||||||
|
|
||||||
genesis_paths = [
|
|
||||||
Path(f"/var/lib/aitbc/data/{self._config.chain_id}/genesis.json"), # Standard location
|
|
||||||
]
|
|
||||||
|
|
||||||
genesis_path = None
|
|
||||||
for path in genesis_paths:
|
|
||||||
if path.exists():
|
|
||||||
genesis_path = path
|
|
||||||
break
|
|
||||||
|
|
||||||
if not genesis_path:
|
|
||||||
self._logger.warning("Genesis allocations file not found; skipping account initialization", extra={"paths": str(genesis_paths)})
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(genesis_path) as f:
|
|
||||||
genesis_data = json.load(f)
|
|
||||||
|
|
||||||
allocations = genesis_data.get("allocations", [])
|
|
||||||
created = 0
|
|
||||||
for alloc in allocations:
|
|
||||||
addr = alloc["address"]
|
|
||||||
balance = int(alloc["balance"])
|
|
||||||
nonce = int(alloc.get("nonce", 0))
|
|
||||||
# Check if account already exists (idempotent)
|
|
||||||
acct = session.get(Account, (self._config.chain_id, addr))
|
|
||||||
if acct is None:
|
|
||||||
acct = Account(chain_id=self._config.chain_id, address=addr, balance=balance, nonce=nonce)
|
|
||||||
session.add(acct)
|
|
||||||
created += 1
|
|
||||||
session.commit()
|
|
||||||
self._logger.info("Initialized genesis accounts", extra={"count": created, "total": len(allocations), "path": str(genesis_path)})
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime, transactions: list = None) -> str:
|
|
||||||
# Include transaction hashes in block hash computation
|
|
||||||
tx_hashes = []
|
|
||||||
if transactions:
|
|
||||||
tx_hashes = [tx.tx_hash for tx in transactions]
|
|
||||||
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}|{'|'.join(sorted(tx_hashes))}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,229 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
if get_mempool().size(self._config.chain_id) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer="genesis",
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime) -> str:
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
--- apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
+++ apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
@@ -101,7 +101,7 @@
|
|
||||||
# Wait for interval before proposing next block
|
|
||||||
await asyncio.sleep(self.config.interval_seconds)
|
|
||||||
|
|
||||||
- self._propose_block()
|
|
||||||
+ await self._propose_block()
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Rotation Mechanism
|
|
||||||
Handles automatic rotation of validators based on performance and stake
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from typing import List, Dict, Optional
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator, ValidatorRole
|
|
||||||
|
|
||||||
class RotationStrategy(Enum):
|
|
||||||
ROUND_ROBIN = "round_robin"
|
|
||||||
STAKE_WEIGHTED = "stake_weighted"
|
|
||||||
REPUTATION_BASED = "reputation_based"
|
|
||||||
HYBRID = "hybrid"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RotationConfig:
|
|
||||||
strategy: RotationStrategy
|
|
||||||
rotation_interval: int # blocks
|
|
||||||
min_stake: float
|
|
||||||
reputation_threshold: float
|
|
||||||
max_validators: int
|
|
||||||
|
|
||||||
class ValidatorRotation:
|
|
||||||
"""Manages validator rotation based on various strategies"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA, config: RotationConfig):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.config = config
|
|
||||||
self.last_rotation_height = 0
|
|
||||||
|
|
||||||
def should_rotate(self, current_height: int) -> bool:
|
|
||||||
"""Check if rotation should occur at current height"""
|
|
||||||
return (current_height - self.last_rotation_height) >= self.config.rotation_interval
|
|
||||||
|
|
||||||
def rotate_validators(self, current_height: int) -> bool:
|
|
||||||
"""Perform validator rotation based on configured strategy"""
|
|
||||||
if not self.should_rotate(current_height):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.config.strategy == RotationStrategy.ROUND_ROBIN:
|
|
||||||
return self._rotate_round_robin()
|
|
||||||
elif self.config.strategy == RotationStrategy.STAKE_WEIGHTED:
|
|
||||||
return self._rotate_stake_weighted()
|
|
||||||
elif self.config.strategy == RotationStrategy.REPUTATION_BASED:
|
|
||||||
return self._rotate_reputation_based()
|
|
||||||
elif self.config.strategy == RotationStrategy.HYBRID:
|
|
||||||
return self._rotate_hybrid()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _rotate_round_robin(self) -> bool:
|
|
||||||
"""Round-robin rotation of validator roles"""
|
|
||||||
validators = list(self.consensus.validators.values())
|
|
||||||
active_validators = [v for v in validators if v.is_active]
|
|
||||||
|
|
||||||
# Rotate roles among active validators
|
|
||||||
for i, validator in enumerate(active_validators):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 3: # Top 3 become validators
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_stake_weighted(self) -> bool:
|
|
||||||
"""Stake-weighted rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.stake,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_reputation_based(self) -> bool:
|
|
||||||
"""Reputation-based rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.reputation,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter by reputation threshold
|
|
||||||
qualified_validators = [
|
|
||||||
v for v in validators
|
|
||||||
if v.reputation >= self.config.reputation_threshold
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, validator in enumerate(qualified_validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_hybrid(self) -> bool:
|
|
||||||
"""Hybrid rotation considering both stake and reputation"""
|
|
||||||
validators = [v for v in self.consensus.validators.values() if v.is_active]
|
|
||||||
|
|
||||||
# Calculate hybrid score
|
|
||||||
for validator in validators:
|
|
||||||
validator.hybrid_score = validator.stake * validator.reputation
|
|
||||||
|
|
||||||
# Sort by hybrid score
|
|
||||||
validators.sort(key=lambda v: v.hybrid_score, reverse=True)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Default rotation configuration
|
|
||||||
DEFAULT_ROTATION_CONFIG = RotationConfig(
|
|
||||||
strategy=RotationStrategy.HYBRID,
|
|
||||||
rotation_interval=100, # Rotate every 100 blocks
|
|
||||||
min_stake=1000.0,
|
|
||||||
reputation_threshold=0.7,
|
|
||||||
max_validators=10
|
|
||||||
)
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
"""
|
|
||||||
Slashing Conditions Implementation
|
|
||||||
Handles detection and penalties for validator misbehavior
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Dict, List, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import Validator, ValidatorRole
|
|
||||||
|
|
||||||
class SlashingCondition(Enum):
|
|
||||||
DOUBLE_SIGN = "double_sign"
|
|
||||||
UNAVAILABLE = "unavailable"
|
|
||||||
INVALID_BLOCK = "invalid_block"
|
|
||||||
SLOW_RESPONSE = "slow_response"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SlashingEvent:
|
|
||||||
validator_address: str
|
|
||||||
condition: SlashingCondition
|
|
||||||
evidence: str
|
|
||||||
block_height: int
|
|
||||||
timestamp: float
|
|
||||||
slash_amount: float
|
|
||||||
|
|
||||||
class SlashingManager:
|
|
||||||
"""Manages validator slashing conditions and penalties"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.slashing_events: List[SlashingEvent] = []
|
|
||||||
self.slash_rates = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 0.5, # 50% slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 0.1, # 10% slash
|
|
||||||
SlashingCondition.INVALID_BLOCK: 0.3, # 30% slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 0.05 # 5% slash
|
|
||||||
}
|
|
||||||
self.slash_thresholds = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 1, # Immediate slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 3, # After 3 offenses
|
|
||||||
SlashingCondition.INVALID_BLOCK: 1, # Immediate slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 5 # After 5 offenses
|
|
||||||
}
|
|
||||||
|
|
||||||
def detect_double_sign(self, validator: str, block_hash1: str, block_hash2: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect double signing (validator signed two different blocks at same height)"""
|
|
||||||
if block_hash1 == block_hash2:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.DOUBLE_SIGN,
|
|
||||||
evidence=f"Double sign detected: {block_hash1} vs {block_hash2} at height {height}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.DOUBLE_SIGN]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_unavailability(self, validator: str, missed_blocks: int, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect validator unavailability (missing consensus participation)"""
|
|
||||||
if missed_blocks < self.slash_thresholds[SlashingCondition.UNAVAILABLE]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.UNAVAILABLE,
|
|
||||||
evidence=f"Missed {missed_blocks} consecutive blocks",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.UNAVAILABLE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_invalid_block(self, validator: str, block_hash: str, reason: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect invalid block proposal"""
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.INVALID_BLOCK,
|
|
||||||
evidence=f"Invalid block {block_hash}: {reason}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.INVALID_BLOCK]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_slow_response(self, validator: str, response_time: float, threshold: float, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect slow consensus participation"""
|
|
||||||
if response_time <= threshold:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.SLOW_RESPONSE,
|
|
||||||
evidence=f"Slow response: {response_time}s (threshold: {threshold}s)",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.SLOW_RESPONSE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply_slashing(self, validator: Validator, event: SlashingEvent) -> bool:
|
|
||||||
"""Apply slashing penalty to validator"""
|
|
||||||
slash_amount = validator.stake * event.slash_amount
|
|
||||||
validator.stake -= slash_amount
|
|
||||||
|
|
||||||
# Demote validator role if stake is too low
|
|
||||||
if validator.stake < 100: # Minimum stake threshold
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
# Record slashing event
|
|
||||||
self.slashing_events.append(event)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_validator_slash_count(self, validator_address: str, condition: SlashingCondition) -> int:
|
|
||||||
"""Get count of slashing events for validator and condition"""
|
|
||||||
return len([
|
|
||||||
event for event in self.slashing_events
|
|
||||||
if event.validator_address == validator_address and event.condition == condition
|
|
||||||
])
|
|
||||||
|
|
||||||
def should_slash(self, validator: str, condition: SlashingCondition) -> bool:
|
|
||||||
"""Check if validator should be slashed for condition"""
|
|
||||||
current_count = self.get_validator_slash_count(validator, condition)
|
|
||||||
threshold = self.slash_thresholds.get(condition, 1)
|
|
||||||
return current_count >= threshold
|
|
||||||
|
|
||||||
def get_slashing_history(self, validator_address: Optional[str] = None) -> List[SlashingEvent]:
|
|
||||||
"""Get slashing history for validator or all validators"""
|
|
||||||
if validator_address:
|
|
||||||
return [event for event in self.slashing_events if event.validator_address == validator_address]
|
|
||||||
return self.slashing_events.copy()
|
|
||||||
|
|
||||||
def calculate_total_slashed(self, validator_address: str) -> float:
|
|
||||||
"""Calculate total amount slashed for validator"""
|
|
||||||
events = self.get_slashing_history(validator_address)
|
|
||||||
return sum(event.slash_amount for event in events)
|
|
||||||
|
|
||||||
# Global slashing manager
|
|
||||||
slashing_manager = SlashingManager()
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from .poa import PoAProposer, ProposerConfig, CircuitBreaker
|
|
||||||
|
|
||||||
__all__ = ["PoAProposer", "ProposerConfig", "CircuitBreaker"]
|
|
||||||
@@ -1,210 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Key Management
|
|
||||||
Handles cryptographic key operations for validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ValidatorKeyPair:
|
|
||||||
address: str
|
|
||||||
private_key_pem: str
|
|
||||||
public_key_pem: str
|
|
||||||
created_at: float
|
|
||||||
last_rotated: float
|
|
||||||
|
|
||||||
class KeyManager:
|
|
||||||
"""Manages validator cryptographic keys"""
|
|
||||||
|
|
||||||
def __init__(self, keys_dir: str = "/opt/aitbc/keys"):
|
|
||||||
self.keys_dir = keys_dir
|
|
||||||
self.key_pairs: Dict[str, ValidatorKeyPair] = {}
|
|
||||||
self._ensure_keys_directory()
|
|
||||||
self._load_existing_keys()
|
|
||||||
|
|
||||||
def _ensure_keys_directory(self):
|
|
||||||
"""Ensure keys directory exists and has proper permissions"""
|
|
||||||
os.makedirs(self.keys_dir, mode=0o700, exist_ok=True)
|
|
||||||
|
|
||||||
def _load_existing_keys(self):
|
|
||||||
"""Load existing key pairs from disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
if os.path.exists(keys_file):
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'r') as f:
|
|
||||||
keys_data = json.load(f)
|
|
||||||
|
|
||||||
for address, key_data in keys_data.items():
|
|
||||||
self.key_pairs[address] = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=key_data['private_key_pem'],
|
|
||||||
public_key_pem=key_data['public_key_pem'],
|
|
||||||
created_at=key_data['created_at'],
|
|
||||||
last_rotated=key_data['last_rotated']
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading keys: {e}")
|
|
||||||
|
|
||||||
def generate_key_pair(self, address: str) -> ValidatorKeyPair:
|
|
||||||
"""Generate new RSA key pair for validator"""
|
|
||||||
# Generate private key
|
|
||||||
private_key = rsa.generate_private_key(
|
|
||||||
public_exponent=65537,
|
|
||||||
key_size=2048,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Serialize private key
|
|
||||||
private_key_pem = private_key.private_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=PrivateFormat.PKCS8,
|
|
||||||
encryption_algorithm=NoEncryption()
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Get public key
|
|
||||||
public_key = private_key.public_key()
|
|
||||||
public_key_pem = public_key.public_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Create key pair object
|
|
||||||
current_time = time.time()
|
|
||||||
key_pair = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=private_key_pem,
|
|
||||||
public_key_pem=public_key_pem,
|
|
||||||
created_at=current_time,
|
|
||||||
last_rotated=current_time
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store key pair
|
|
||||||
self.key_pairs[address] = key_pair
|
|
||||||
self._save_keys()
|
|
||||||
|
|
||||||
return key_pair
|
|
||||||
|
|
||||||
def get_key_pair(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Get key pair for validator"""
|
|
||||||
return self.key_pairs.get(address)
|
|
||||||
|
|
||||||
def rotate_key(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Rotate validator keys"""
|
|
||||||
if address not in self.key_pairs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Generate new key pair
|
|
||||||
new_key_pair = self.generate_key_pair(address)
|
|
||||||
|
|
||||||
# Update rotation time
|
|
||||||
new_key_pair.created_at = self.key_pairs[address].created_at
|
|
||||||
new_key_pair.last_rotated = time.time()
|
|
||||||
|
|
||||||
self._save_keys()
|
|
||||||
return new_key_pair
|
|
||||||
|
|
||||||
def sign_message(self, address: str, message: str) -> Optional[str]:
|
|
||||||
"""Sign message with validator private key"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load private key from PEM
|
|
||||||
private_key = serialization.load_pem_private_key(
|
|
||||||
key_pair.private_key_pem.encode(),
|
|
||||||
password=None,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sign message
|
|
||||||
signature = private_key.sign(
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return signature.hex()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error signing message: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def verify_signature(self, address: str, message: str, signature: str) -> bool:
|
|
||||||
"""Verify message signature"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load public key from PEM
|
|
||||||
public_key = serialization.load_pem_public_key(
|
|
||||||
key_pair.public_key_pem.encode(),
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify signature
|
|
||||||
public_key.verify(
|
|
||||||
bytes.fromhex(signature),
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error verifying signature: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_public_key_pem(self, address: str) -> Optional[str]:
|
|
||||||
"""Get public key PEM for validator"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
return key_pair.public_key_pem if key_pair else None
|
|
||||||
|
|
||||||
def _save_keys(self):
|
|
||||||
"""Save key pairs to disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
keys_data = {}
|
|
||||||
for address, key_pair in self.key_pairs.items():
|
|
||||||
keys_data[address] = {
|
|
||||||
'private_key_pem': key_pair.private_key_pem,
|
|
||||||
'public_key_pem': key_pair.public_key_pem,
|
|
||||||
'created_at': key_pair.created_at,
|
|
||||||
'last_rotated': key_pair.last_rotated
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'w') as f:
|
|
||||||
json.dump(keys_data, f, indent=2)
|
|
||||||
|
|
||||||
# Set secure permissions
|
|
||||||
os.chmod(keys_file, 0o600)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving keys: {e}")
|
|
||||||
|
|
||||||
def should_rotate_key(self, address: str, rotation_interval: int = 86400) -> bool:
|
|
||||||
"""Check if key should be rotated (default: 24 hours)"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return (time.time() - key_pair.last_rotated) >= rotation_interval
|
|
||||||
|
|
||||||
def get_key_age(self, address: str) -> Optional[float]:
|
|
||||||
"""Get age of key in seconds"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return time.time() - key_pair.created_at
|
|
||||||
|
|
||||||
# Global key manager
|
|
||||||
key_manager = KeyManager()
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
"""
|
|
||||||
Multi-Validator Proof of Authority Consensus Implementation
|
|
||||||
Extends single validator PoA to support multiple validators with rotation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from ..config import settings
|
|
||||||
from ..models import Block, Transaction
|
|
||||||
from ..database import session_scope
|
|
||||||
|
|
||||||
class ValidatorRole(Enum):
|
|
||||||
PROPOSER = "proposer"
|
|
||||||
VALIDATOR = "validator"
|
|
||||||
STANDBY = "standby"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Validator:
|
|
||||||
address: str
|
|
||||||
stake: float
|
|
||||||
reputation: float
|
|
||||||
role: ValidatorRole
|
|
||||||
last_proposed: int
|
|
||||||
is_active: bool
|
|
||||||
|
|
||||||
class MultiValidatorPoA:
|
|
||||||
"""Multi-Validator Proof of Authority consensus mechanism"""
|
|
||||||
|
|
||||||
def __init__(self, chain_id: str):
|
|
||||||
self.chain_id = chain_id
|
|
||||||
self.validators: Dict[str, Validator] = {}
|
|
||||||
self.current_proposer_index = 0
|
|
||||||
self.round_robin_enabled = True
|
|
||||||
self.consensus_timeout = 30 # seconds
|
|
||||||
|
|
||||||
def add_validator(self, address: str, stake: float = 1000.0) -> bool:
|
|
||||||
"""Add a new validator to the consensus"""
|
|
||||||
if address in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.validators[address] = Validator(
|
|
||||||
address=address,
|
|
||||||
stake=stake,
|
|
||||||
reputation=1.0,
|
|
||||||
role=ValidatorRole.STANDBY,
|
|
||||||
last_proposed=0,
|
|
||||||
is_active=True
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def remove_validator(self, address: str) -> bool:
|
|
||||||
"""Remove a validator from the consensus"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.is_active = False
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
return True
|
|
||||||
|
|
||||||
def select_proposer(self, block_height: int) -> Optional[str]:
|
|
||||||
"""Select proposer for the current block using round-robin"""
|
|
||||||
active_validators = [
|
|
||||||
v for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
if not active_validators:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Round-robin selection
|
|
||||||
proposer_index = block_height % len(active_validators)
|
|
||||||
return active_validators[proposer_index].address
|
|
||||||
|
|
||||||
def validate_block(self, block: Block, proposer: str) -> bool:
|
|
||||||
"""Validate a proposed block"""
|
|
||||||
if proposer not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[proposer]
|
|
||||||
if not validator.is_active:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if validator is allowed to propose
|
|
||||||
if validator.role not in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Additional validation logic here
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_consensus_participants(self) -> List[str]:
|
|
||||||
"""Get list of active consensus participants"""
|
|
||||||
return [
|
|
||||||
v.address for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
def update_validator_reputation(self, address: str, delta: float) -> bool:
|
|
||||||
"""Update validator reputation"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.reputation = max(0.0, min(1.0, validator.reputation + delta))
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Global consensus instance
|
|
||||||
consensus_instances: Dict[str, MultiValidatorPoA] = {}
|
|
||||||
|
|
||||||
def get_consensus(chain_id: str) -> MultiValidatorPoA:
|
|
||||||
"""Get or create consensus instance for chain"""
|
|
||||||
if chain_id not in consensus_instances:
|
|
||||||
consensus_instances[chain_id] = MultiValidatorPoA(chain_id)
|
|
||||||
return consensus_instances[chain_id]
|
|
||||||
@@ -1,193 +0,0 @@
|
|||||||
"""
|
|
||||||
Practical Byzantine Fault Tolerance (PBFT) Consensus Implementation
|
|
||||||
Provides Byzantine fault tolerance for up to 1/3 faulty validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator
|
|
||||||
|
|
||||||
class PBFTPhase(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
EXECUTE = "execute"
|
|
||||||
|
|
||||||
class PBFTMessageType(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
VIEW_CHANGE = "view_change"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTMessage:
|
|
||||||
message_type: PBFTMessageType
|
|
||||||
sender: str
|
|
||||||
view_number: int
|
|
||||||
sequence_number: int
|
|
||||||
digest: str
|
|
||||||
signature: str
|
|
||||||
timestamp: float
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTState:
|
|
||||||
current_view: int
|
|
||||||
current_sequence: int
|
|
||||||
prepared_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
committed_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
pre_prepare_messages: Dict[str, PBFTMessage]
|
|
||||||
|
|
||||||
class PBFTConsensus:
|
|
||||||
"""PBFT consensus implementation"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.state = PBFTState(
|
|
||||||
current_view=0,
|
|
||||||
current_sequence=0,
|
|
||||||
prepared_messages={},
|
|
||||||
committed_messages={},
|
|
||||||
pre_prepare_messages={}
|
|
||||||
)
|
|
||||||
self.fault_tolerance = max(1, len(consensus.get_consensus_participants()) // 3)
|
|
||||||
self.required_messages = 2 * self.fault_tolerance + 1
|
|
||||||
|
|
||||||
def get_message_digest(self, block_hash: str, sequence: int, view: int) -> str:
|
|
||||||
"""Generate message digest for PBFT"""
|
|
||||||
content = f"{block_hash}:{sequence}:{view}"
|
|
||||||
return hashlib.sha256(content.encode()).hexdigest()
|
|
||||||
|
|
||||||
async def pre_prepare_phase(self, proposer: str, block_hash: str) -> bool:
|
|
||||||
"""Phase 1: Pre-prepare"""
|
|
||||||
sequence = self.state.current_sequence + 1
|
|
||||||
view = self.state.current_view
|
|
||||||
digest = self.get_message_digest(block_hash, sequence, view)
|
|
||||||
|
|
||||||
message = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PRE_PREPARE,
|
|
||||||
sender=proposer,
|
|
||||||
view_number=view,
|
|
||||||
sequence_number=sequence,
|
|
||||||
digest=digest,
|
|
||||||
signature="", # Would be signed in real implementation
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store pre-prepare message
|
|
||||||
key = f"{sequence}:{view}"
|
|
||||||
self.state.pre_prepare_messages[key] = message
|
|
||||||
|
|
||||||
# Broadcast to all validators
|
|
||||||
await self._broadcast_message(message)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def prepare_phase(self, validator: str, pre_prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 2: Prepare"""
|
|
||||||
key = f"{pre_prepare_msg.sequence_number}:{pre_prepare_msg.view_number}"
|
|
||||||
|
|
||||||
if key not in self.state.pre_prepare_messages:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Create prepare message
|
|
||||||
prepare_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PREPARE,
|
|
||||||
sender=validator,
|
|
||||||
view_number=pre_prepare_msg.view_number,
|
|
||||||
sequence_number=pre_prepare_msg.sequence_number,
|
|
||||||
digest=pre_prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store prepare message
|
|
||||||
if key not in self.state.prepared_messages:
|
|
||||||
self.state.prepared_messages[key] = []
|
|
||||||
self.state.prepared_messages[key].append(prepare_msg)
|
|
||||||
|
|
||||||
# Broadcast prepare message
|
|
||||||
await self._broadcast_message(prepare_msg)
|
|
||||||
|
|
||||||
# Check if we have enough prepare messages
|
|
||||||
return len(self.state.prepared_messages[key]) >= self.required_messages
|
|
||||||
|
|
||||||
async def commit_phase(self, validator: str, prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 3: Commit"""
|
|
||||||
key = f"{prepare_msg.sequence_number}:{prepare_msg.view_number}"
|
|
||||||
|
|
||||||
# Create commit message
|
|
||||||
commit_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.COMMIT,
|
|
||||||
sender=validator,
|
|
||||||
view_number=prepare_msg.view_number,
|
|
||||||
sequence_number=prepare_msg.sequence_number,
|
|
||||||
digest=prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store commit message
|
|
||||||
if key not in self.state.committed_messages:
|
|
||||||
self.state.committed_messages[key] = []
|
|
||||||
self.state.committed_messages[key].append(commit_msg)
|
|
||||||
|
|
||||||
# Broadcast commit message
|
|
||||||
await self._broadcast_message(commit_msg)
|
|
||||||
|
|
||||||
# Check if we have enough commit messages
|
|
||||||
if len(self.state.committed_messages[key]) >= self.required_messages:
|
|
||||||
return await self.execute_phase(key)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def execute_phase(self, key: str) -> bool:
|
|
||||||
"""Phase 4: Execute"""
|
|
||||||
# Extract sequence and view from key
|
|
||||||
sequence, view = map(int, key.split(':'))
|
|
||||||
|
|
||||||
# Update state
|
|
||||||
self.state.current_sequence = sequence
|
|
||||||
|
|
||||||
# Clean up old messages
|
|
||||||
self._cleanup_messages(sequence)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _broadcast_message(self, message: PBFTMessage):
|
|
||||||
"""Broadcast message to all validators"""
|
|
||||||
validators = self.consensus.get_consensus_participants()
|
|
||||||
|
|
||||||
for validator in validators:
|
|
||||||
if validator != message.sender:
|
|
||||||
# In real implementation, this would send over network
|
|
||||||
await self._send_to_validator(validator, message)
|
|
||||||
|
|
||||||
async def _send_to_validator(self, validator: str, message: PBFTMessage):
|
|
||||||
"""Send message to specific validator"""
|
|
||||||
# Network communication would be implemented here
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _cleanup_messages(self, sequence: int):
|
|
||||||
"""Clean up old messages to prevent memory leaks"""
|
|
||||||
old_keys = [
|
|
||||||
key for key in self.state.prepared_messages.keys()
|
|
||||||
if int(key.split(':')[0]) < sequence
|
|
||||||
]
|
|
||||||
|
|
||||||
for key in old_keys:
|
|
||||||
self.state.prepared_messages.pop(key, None)
|
|
||||||
self.state.committed_messages.pop(key, None)
|
|
||||||
self.state.pre_prepare_messages.pop(key, None)
|
|
||||||
|
|
||||||
def handle_view_change(self, new_view: int) -> bool:
|
|
||||||
"""Handle view change when proposer fails"""
|
|
||||||
self.state.current_view = new_view
|
|
||||||
# Reset state for new view
|
|
||||||
self.state.prepared_messages.clear()
|
|
||||||
self.state.committed_messages.clear()
|
|
||||||
self.state.pre_prepare_messages.clear()
|
|
||||||
return True
|
|
||||||
@@ -1,345 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block, Account
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
await self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
await self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool and include transactions
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
from ..models import Transaction, Account
|
|
||||||
mempool = get_mempool()
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
# Pull transactions from mempool
|
|
||||||
max_txs = self._config.max_txs_per_block
|
|
||||||
max_bytes = self._config.max_block_size_bytes
|
|
||||||
pending_txs = mempool.drain(max_txs, max_bytes, self._config.chain_id)
|
|
||||||
self._logger.info(f"[PROPOSE] drained {len(pending_txs)} txs from mempool, chain={self._config.chain_id}")
|
|
||||||
|
|
||||||
# Process transactions and update balances
|
|
||||||
processed_txs = []
|
|
||||||
for tx in pending_txs:
|
|
||||||
try:
|
|
||||||
# Parse transaction data
|
|
||||||
tx_data = tx.content
|
|
||||||
sender = tx_data.get("from")
|
|
||||||
recipient = tx_data.get("to")
|
|
||||||
value = tx_data.get("amount", 0)
|
|
||||||
fee = tx_data.get("fee", 0)
|
|
||||||
|
|
||||||
if not sender or not recipient:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get sender account
|
|
||||||
sender_account = session.get(Account, (self._config.chain_id, sender))
|
|
||||||
if not sender_account:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check sufficient balance
|
|
||||||
total_cost = value + fee
|
|
||||||
if sender_account.balance < total_cost:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get or create recipient account
|
|
||||||
recipient_account = session.get(Account, (self._config.chain_id, recipient))
|
|
||||||
if not recipient_account:
|
|
||||||
recipient_account = Account(chain_id=self._config.chain_id, address=recipient, balance=0, nonce=0)
|
|
||||||
session.add(recipient_account)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
# Update balances
|
|
||||||
sender_account.balance -= total_cost
|
|
||||||
sender_account.nonce += 1
|
|
||||||
recipient_account.balance += value
|
|
||||||
|
|
||||||
# Create transaction record
|
|
||||||
transaction = Transaction(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
tx_hash=tx.tx_hash,
|
|
||||||
sender=sender,
|
|
||||||
recipient=recipient,
|
|
||||||
payload=tx_data,
|
|
||||||
value=value,
|
|
||||||
fee=fee,
|
|
||||||
nonce=sender_account.nonce - 1,
|
|
||||||
timestamp=timestamp,
|
|
||||||
block_height=next_height,
|
|
||||||
status="confirmed"
|
|
||||||
)
|
|
||||||
session.add(transaction)
|
|
||||||
processed_txs.append(tx)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.warning(f"Failed to process transaction {tx.tx_hash}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Compute block hash with transaction data
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp, processed_txs)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=len(processed_txs),
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
tx_list = [tx.content for tx in processed_txs] if processed_txs else []
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
"transactions": tx_list,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer=self._config.proposer_id, # Use configured proposer as genesis proposer
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Initialize accounts from genesis allocations file (if present)
|
|
||||||
await self._initialize_genesis_allocations(session)
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _initialize_genesis_allocations(self, session: Session) -> None:
|
|
||||||
"""Create Account entries from the genesis allocations file."""
|
|
||||||
# Use standardized data directory from configuration
|
|
||||||
from ..config import settings
|
|
||||||
|
|
||||||
genesis_paths = [
|
|
||||||
Path(f"/var/lib/aitbc/data/{self._config.chain_id}/genesis.json"), # Standard location
|
|
||||||
]
|
|
||||||
|
|
||||||
genesis_path = None
|
|
||||||
for path in genesis_paths:
|
|
||||||
if path.exists():
|
|
||||||
genesis_path = path
|
|
||||||
break
|
|
||||||
|
|
||||||
if not genesis_path:
|
|
||||||
self._logger.warning("Genesis allocations file not found; skipping account initialization", extra={"paths": str(genesis_paths)})
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(genesis_path) as f:
|
|
||||||
genesis_data = json.load(f)
|
|
||||||
|
|
||||||
allocations = genesis_data.get("allocations", [])
|
|
||||||
created = 0
|
|
||||||
for alloc in allocations:
|
|
||||||
addr = alloc["address"]
|
|
||||||
balance = int(alloc["balance"])
|
|
||||||
nonce = int(alloc.get("nonce", 0))
|
|
||||||
# Check if account already exists (idempotent)
|
|
||||||
acct = session.get(Account, (self._config.chain_id, addr))
|
|
||||||
if acct is None:
|
|
||||||
acct = Account(chain_id=self._config.chain_id, address=addr, balance=balance, nonce=nonce)
|
|
||||||
session.add(acct)
|
|
||||||
created += 1
|
|
||||||
session.commit()
|
|
||||||
self._logger.info("Initialized genesis accounts", extra={"count": created, "total": len(allocations), "path": str(genesis_path)})
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime, transactions: list = None) -> str:
|
|
||||||
# Include transaction hashes in block hash computation
|
|
||||||
tx_hashes = []
|
|
||||||
if transactions:
|
|
||||||
tx_hashes = [tx.tx_hash for tx in transactions]
|
|
||||||
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}|{'|'.join(sorted(tx_hashes))}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,229 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
if get_mempool().size(self._config.chain_id) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer="genesis",
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime) -> str:
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
--- apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
+++ apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
@@ -101,7 +101,7 @@
|
|
||||||
# Wait for interval before proposing next block
|
|
||||||
await asyncio.sleep(self.config.interval_seconds)
|
|
||||||
|
|
||||||
- self._propose_block()
|
|
||||||
+ await self._propose_block()
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Rotation Mechanism
|
|
||||||
Handles automatic rotation of validators based on performance and stake
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from typing import List, Dict, Optional
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator, ValidatorRole
|
|
||||||
|
|
||||||
class RotationStrategy(Enum):
|
|
||||||
ROUND_ROBIN = "round_robin"
|
|
||||||
STAKE_WEIGHTED = "stake_weighted"
|
|
||||||
REPUTATION_BASED = "reputation_based"
|
|
||||||
HYBRID = "hybrid"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RotationConfig:
|
|
||||||
strategy: RotationStrategy
|
|
||||||
rotation_interval: int # blocks
|
|
||||||
min_stake: float
|
|
||||||
reputation_threshold: float
|
|
||||||
max_validators: int
|
|
||||||
|
|
||||||
class ValidatorRotation:
|
|
||||||
"""Manages validator rotation based on various strategies"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA, config: RotationConfig):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.config = config
|
|
||||||
self.last_rotation_height = 0
|
|
||||||
|
|
||||||
def should_rotate(self, current_height: int) -> bool:
|
|
||||||
"""Check if rotation should occur at current height"""
|
|
||||||
return (current_height - self.last_rotation_height) >= self.config.rotation_interval
|
|
||||||
|
|
||||||
def rotate_validators(self, current_height: int) -> bool:
|
|
||||||
"""Perform validator rotation based on configured strategy"""
|
|
||||||
if not self.should_rotate(current_height):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.config.strategy == RotationStrategy.ROUND_ROBIN:
|
|
||||||
return self._rotate_round_robin()
|
|
||||||
elif self.config.strategy == RotationStrategy.STAKE_WEIGHTED:
|
|
||||||
return self._rotate_stake_weighted()
|
|
||||||
elif self.config.strategy == RotationStrategy.REPUTATION_BASED:
|
|
||||||
return self._rotate_reputation_based()
|
|
||||||
elif self.config.strategy == RotationStrategy.HYBRID:
|
|
||||||
return self._rotate_hybrid()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _rotate_round_robin(self) -> bool:
|
|
||||||
"""Round-robin rotation of validator roles"""
|
|
||||||
validators = list(self.consensus.validators.values())
|
|
||||||
active_validators = [v for v in validators if v.is_active]
|
|
||||||
|
|
||||||
# Rotate roles among active validators
|
|
||||||
for i, validator in enumerate(active_validators):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 3: # Top 3 become validators
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_stake_weighted(self) -> bool:
|
|
||||||
"""Stake-weighted rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.stake,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_reputation_based(self) -> bool:
|
|
||||||
"""Reputation-based rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.reputation,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter by reputation threshold
|
|
||||||
qualified_validators = [
|
|
||||||
v for v in validators
|
|
||||||
if v.reputation >= self.config.reputation_threshold
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, validator in enumerate(qualified_validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_hybrid(self) -> bool:
|
|
||||||
"""Hybrid rotation considering both stake and reputation"""
|
|
||||||
validators = [v for v in self.consensus.validators.values() if v.is_active]
|
|
||||||
|
|
||||||
# Calculate hybrid score
|
|
||||||
for validator in validators:
|
|
||||||
validator.hybrid_score = validator.stake * validator.reputation
|
|
||||||
|
|
||||||
# Sort by hybrid score
|
|
||||||
validators.sort(key=lambda v: v.hybrid_score, reverse=True)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Default rotation configuration
|
|
||||||
DEFAULT_ROTATION_CONFIG = RotationConfig(
|
|
||||||
strategy=RotationStrategy.HYBRID,
|
|
||||||
rotation_interval=100, # Rotate every 100 blocks
|
|
||||||
min_stake=1000.0,
|
|
||||||
reputation_threshold=0.7,
|
|
||||||
max_validators=10
|
|
||||||
)
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
"""
|
|
||||||
Slashing Conditions Implementation
|
|
||||||
Handles detection and penalties for validator misbehavior
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Dict, List, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import Validator, ValidatorRole
|
|
||||||
|
|
||||||
class SlashingCondition(Enum):
|
|
||||||
DOUBLE_SIGN = "double_sign"
|
|
||||||
UNAVAILABLE = "unavailable"
|
|
||||||
INVALID_BLOCK = "invalid_block"
|
|
||||||
SLOW_RESPONSE = "slow_response"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SlashingEvent:
|
|
||||||
validator_address: str
|
|
||||||
condition: SlashingCondition
|
|
||||||
evidence: str
|
|
||||||
block_height: int
|
|
||||||
timestamp: float
|
|
||||||
slash_amount: float
|
|
||||||
|
|
||||||
class SlashingManager:
|
|
||||||
"""Manages validator slashing conditions and penalties"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.slashing_events: List[SlashingEvent] = []
|
|
||||||
self.slash_rates = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 0.5, # 50% slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 0.1, # 10% slash
|
|
||||||
SlashingCondition.INVALID_BLOCK: 0.3, # 30% slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 0.05 # 5% slash
|
|
||||||
}
|
|
||||||
self.slash_thresholds = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 1, # Immediate slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 3, # After 3 offenses
|
|
||||||
SlashingCondition.INVALID_BLOCK: 1, # Immediate slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 5 # After 5 offenses
|
|
||||||
}
|
|
||||||
|
|
||||||
def detect_double_sign(self, validator: str, block_hash1: str, block_hash2: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect double signing (validator signed two different blocks at same height)"""
|
|
||||||
if block_hash1 == block_hash2:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.DOUBLE_SIGN,
|
|
||||||
evidence=f"Double sign detected: {block_hash1} vs {block_hash2} at height {height}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.DOUBLE_SIGN]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_unavailability(self, validator: str, missed_blocks: int, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect validator unavailability (missing consensus participation)"""
|
|
||||||
if missed_blocks < self.slash_thresholds[SlashingCondition.UNAVAILABLE]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.UNAVAILABLE,
|
|
||||||
evidence=f"Missed {missed_blocks} consecutive blocks",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.UNAVAILABLE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_invalid_block(self, validator: str, block_hash: str, reason: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect invalid block proposal"""
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.INVALID_BLOCK,
|
|
||||||
evidence=f"Invalid block {block_hash}: {reason}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.INVALID_BLOCK]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_slow_response(self, validator: str, response_time: float, threshold: float, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect slow consensus participation"""
|
|
||||||
if response_time <= threshold:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.SLOW_RESPONSE,
|
|
||||||
evidence=f"Slow response: {response_time}s (threshold: {threshold}s)",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.SLOW_RESPONSE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply_slashing(self, validator: Validator, event: SlashingEvent) -> bool:
|
|
||||||
"""Apply slashing penalty to validator"""
|
|
||||||
slash_amount = validator.stake * event.slash_amount
|
|
||||||
validator.stake -= slash_amount
|
|
||||||
|
|
||||||
# Demote validator role if stake is too low
|
|
||||||
if validator.stake < 100: # Minimum stake threshold
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
# Record slashing event
|
|
||||||
self.slashing_events.append(event)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_validator_slash_count(self, validator_address: str, condition: SlashingCondition) -> int:
|
|
||||||
"""Get count of slashing events for validator and condition"""
|
|
||||||
return len([
|
|
||||||
event for event in self.slashing_events
|
|
||||||
if event.validator_address == validator_address and event.condition == condition
|
|
||||||
])
|
|
||||||
|
|
||||||
def should_slash(self, validator: str, condition: SlashingCondition) -> bool:
|
|
||||||
"""Check if validator should be slashed for condition"""
|
|
||||||
current_count = self.get_validator_slash_count(validator, condition)
|
|
||||||
threshold = self.slash_thresholds.get(condition, 1)
|
|
||||||
return current_count >= threshold
|
|
||||||
|
|
||||||
def get_slashing_history(self, validator_address: Optional[str] = None) -> List[SlashingEvent]:
|
|
||||||
"""Get slashing history for validator or all validators"""
|
|
||||||
if validator_address:
|
|
||||||
return [event for event in self.slashing_events if event.validator_address == validator_address]
|
|
||||||
return self.slashing_events.copy()
|
|
||||||
|
|
||||||
def calculate_total_slashed(self, validator_address: str) -> float:
|
|
||||||
"""Calculate total amount slashed for validator"""
|
|
||||||
events = self.get_slashing_history(validator_address)
|
|
||||||
return sum(event.slash_amount for event in events)
|
|
||||||
|
|
||||||
# Global slashing manager
|
|
||||||
slashing_manager = SlashingManager()
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from .poa import PoAProposer, ProposerConfig, CircuitBreaker
|
|
||||||
|
|
||||||
__all__ = ["PoAProposer", "ProposerConfig", "CircuitBreaker"]
|
|
||||||
@@ -1,210 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Key Management
|
|
||||||
Handles cryptographic key operations for validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ValidatorKeyPair:
|
|
||||||
address: str
|
|
||||||
private_key_pem: str
|
|
||||||
public_key_pem: str
|
|
||||||
created_at: float
|
|
||||||
last_rotated: float
|
|
||||||
|
|
||||||
class KeyManager:
|
|
||||||
"""Manages validator cryptographic keys"""
|
|
||||||
|
|
||||||
def __init__(self, keys_dir: str = "/opt/aitbc/keys"):
|
|
||||||
self.keys_dir = keys_dir
|
|
||||||
self.key_pairs: Dict[str, ValidatorKeyPair] = {}
|
|
||||||
self._ensure_keys_directory()
|
|
||||||
self._load_existing_keys()
|
|
||||||
|
|
||||||
def _ensure_keys_directory(self):
|
|
||||||
"""Ensure keys directory exists and has proper permissions"""
|
|
||||||
os.makedirs(self.keys_dir, mode=0o700, exist_ok=True)
|
|
||||||
|
|
||||||
def _load_existing_keys(self):
|
|
||||||
"""Load existing key pairs from disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
if os.path.exists(keys_file):
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'r') as f:
|
|
||||||
keys_data = json.load(f)
|
|
||||||
|
|
||||||
for address, key_data in keys_data.items():
|
|
||||||
self.key_pairs[address] = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=key_data['private_key_pem'],
|
|
||||||
public_key_pem=key_data['public_key_pem'],
|
|
||||||
created_at=key_data['created_at'],
|
|
||||||
last_rotated=key_data['last_rotated']
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading keys: {e}")
|
|
||||||
|
|
||||||
def generate_key_pair(self, address: str) -> ValidatorKeyPair:
|
|
||||||
"""Generate new RSA key pair for validator"""
|
|
||||||
# Generate private key
|
|
||||||
private_key = rsa.generate_private_key(
|
|
||||||
public_exponent=65537,
|
|
||||||
key_size=2048,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Serialize private key
|
|
||||||
private_key_pem = private_key.private_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=PrivateFormat.PKCS8,
|
|
||||||
encryption_algorithm=NoEncryption()
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Get public key
|
|
||||||
public_key = private_key.public_key()
|
|
||||||
public_key_pem = public_key.public_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Create key pair object
|
|
||||||
current_time = time.time()
|
|
||||||
key_pair = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=private_key_pem,
|
|
||||||
public_key_pem=public_key_pem,
|
|
||||||
created_at=current_time,
|
|
||||||
last_rotated=current_time
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store key pair
|
|
||||||
self.key_pairs[address] = key_pair
|
|
||||||
self._save_keys()
|
|
||||||
|
|
||||||
return key_pair
|
|
||||||
|
|
||||||
def get_key_pair(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Get key pair for validator"""
|
|
||||||
return self.key_pairs.get(address)
|
|
||||||
|
|
||||||
def rotate_key(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Rotate validator keys"""
|
|
||||||
if address not in self.key_pairs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Generate new key pair
|
|
||||||
new_key_pair = self.generate_key_pair(address)
|
|
||||||
|
|
||||||
# Update rotation time
|
|
||||||
new_key_pair.created_at = self.key_pairs[address].created_at
|
|
||||||
new_key_pair.last_rotated = time.time()
|
|
||||||
|
|
||||||
self._save_keys()
|
|
||||||
return new_key_pair
|
|
||||||
|
|
||||||
def sign_message(self, address: str, message: str) -> Optional[str]:
|
|
||||||
"""Sign message with validator private key"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load private key from PEM
|
|
||||||
private_key = serialization.load_pem_private_key(
|
|
||||||
key_pair.private_key_pem.encode(),
|
|
||||||
password=None,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sign message
|
|
||||||
signature = private_key.sign(
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return signature.hex()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error signing message: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def verify_signature(self, address: str, message: str, signature: str) -> bool:
|
|
||||||
"""Verify message signature"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load public key from PEM
|
|
||||||
public_key = serialization.load_pem_public_key(
|
|
||||||
key_pair.public_key_pem.encode(),
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify signature
|
|
||||||
public_key.verify(
|
|
||||||
bytes.fromhex(signature),
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error verifying signature: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_public_key_pem(self, address: str) -> Optional[str]:
|
|
||||||
"""Get public key PEM for validator"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
return key_pair.public_key_pem if key_pair else None
|
|
||||||
|
|
||||||
def _save_keys(self):
|
|
||||||
"""Save key pairs to disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
keys_data = {}
|
|
||||||
for address, key_pair in self.key_pairs.items():
|
|
||||||
keys_data[address] = {
|
|
||||||
'private_key_pem': key_pair.private_key_pem,
|
|
||||||
'public_key_pem': key_pair.public_key_pem,
|
|
||||||
'created_at': key_pair.created_at,
|
|
||||||
'last_rotated': key_pair.last_rotated
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'w') as f:
|
|
||||||
json.dump(keys_data, f, indent=2)
|
|
||||||
|
|
||||||
# Set secure permissions
|
|
||||||
os.chmod(keys_file, 0o600)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving keys: {e}")
|
|
||||||
|
|
||||||
def should_rotate_key(self, address: str, rotation_interval: int = 86400) -> bool:
|
|
||||||
"""Check if key should be rotated (default: 24 hours)"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return (time.time() - key_pair.last_rotated) >= rotation_interval
|
|
||||||
|
|
||||||
def get_key_age(self, address: str) -> Optional[float]:
|
|
||||||
"""Get age of key in seconds"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return time.time() - key_pair.created_at
|
|
||||||
|
|
||||||
# Global key manager
|
|
||||||
key_manager = KeyManager()
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
"""
|
|
||||||
Multi-Validator Proof of Authority Consensus Implementation
|
|
||||||
Extends single validator PoA to support multiple validators with rotation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from ..config import settings
|
|
||||||
from ..models import Block, Transaction
|
|
||||||
from ..database import session_scope
|
|
||||||
|
|
||||||
class ValidatorRole(Enum):
|
|
||||||
PROPOSER = "proposer"
|
|
||||||
VALIDATOR = "validator"
|
|
||||||
STANDBY = "standby"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Validator:
|
|
||||||
address: str
|
|
||||||
stake: float
|
|
||||||
reputation: float
|
|
||||||
role: ValidatorRole
|
|
||||||
last_proposed: int
|
|
||||||
is_active: bool
|
|
||||||
|
|
||||||
class MultiValidatorPoA:
|
|
||||||
"""Multi-Validator Proof of Authority consensus mechanism"""
|
|
||||||
|
|
||||||
def __init__(self, chain_id: str):
|
|
||||||
self.chain_id = chain_id
|
|
||||||
self.validators: Dict[str, Validator] = {}
|
|
||||||
self.current_proposer_index = 0
|
|
||||||
self.round_robin_enabled = True
|
|
||||||
self.consensus_timeout = 30 # seconds
|
|
||||||
|
|
||||||
def add_validator(self, address: str, stake: float = 1000.0) -> bool:
|
|
||||||
"""Add a new validator to the consensus"""
|
|
||||||
if address in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.validators[address] = Validator(
|
|
||||||
address=address,
|
|
||||||
stake=stake,
|
|
||||||
reputation=1.0,
|
|
||||||
role=ValidatorRole.STANDBY,
|
|
||||||
last_proposed=0,
|
|
||||||
is_active=True
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def remove_validator(self, address: str) -> bool:
|
|
||||||
"""Remove a validator from the consensus"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.is_active = False
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
return True
|
|
||||||
|
|
||||||
def select_proposer(self, block_height: int) -> Optional[str]:
|
|
||||||
"""Select proposer for the current block using round-robin"""
|
|
||||||
active_validators = [
|
|
||||||
v for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
if not active_validators:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Round-robin selection
|
|
||||||
proposer_index = block_height % len(active_validators)
|
|
||||||
return active_validators[proposer_index].address
|
|
||||||
|
|
||||||
def validate_block(self, block: Block, proposer: str) -> bool:
|
|
||||||
"""Validate a proposed block"""
|
|
||||||
if proposer not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[proposer]
|
|
||||||
if not validator.is_active:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if validator is allowed to propose
|
|
||||||
if validator.role not in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Additional validation logic here
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_consensus_participants(self) -> List[str]:
|
|
||||||
"""Get list of active consensus participants"""
|
|
||||||
return [
|
|
||||||
v.address for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
def update_validator_reputation(self, address: str, delta: float) -> bool:
|
|
||||||
"""Update validator reputation"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.reputation = max(0.0, min(1.0, validator.reputation + delta))
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Global consensus instance
|
|
||||||
consensus_instances: Dict[str, MultiValidatorPoA] = {}
|
|
||||||
|
|
||||||
def get_consensus(chain_id: str) -> MultiValidatorPoA:
|
|
||||||
"""Get or create consensus instance for chain"""
|
|
||||||
if chain_id not in consensus_instances:
|
|
||||||
consensus_instances[chain_id] = MultiValidatorPoA(chain_id)
|
|
||||||
return consensus_instances[chain_id]
|
|
||||||
@@ -1,193 +0,0 @@
|
|||||||
"""
|
|
||||||
Practical Byzantine Fault Tolerance (PBFT) Consensus Implementation
|
|
||||||
Provides Byzantine fault tolerance for up to 1/3 faulty validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator
|
|
||||||
|
|
||||||
class PBFTPhase(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
EXECUTE = "execute"
|
|
||||||
|
|
||||||
class PBFTMessageType(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
VIEW_CHANGE = "view_change"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTMessage:
|
|
||||||
message_type: PBFTMessageType
|
|
||||||
sender: str
|
|
||||||
view_number: int
|
|
||||||
sequence_number: int
|
|
||||||
digest: str
|
|
||||||
signature: str
|
|
||||||
timestamp: float
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTState:
|
|
||||||
current_view: int
|
|
||||||
current_sequence: int
|
|
||||||
prepared_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
committed_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
pre_prepare_messages: Dict[str, PBFTMessage]
|
|
||||||
|
|
||||||
class PBFTConsensus:
|
|
||||||
"""PBFT consensus implementation"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.state = PBFTState(
|
|
||||||
current_view=0,
|
|
||||||
current_sequence=0,
|
|
||||||
prepared_messages={},
|
|
||||||
committed_messages={},
|
|
||||||
pre_prepare_messages={}
|
|
||||||
)
|
|
||||||
self.fault_tolerance = max(1, len(consensus.get_consensus_participants()) // 3)
|
|
||||||
self.required_messages = 2 * self.fault_tolerance + 1
|
|
||||||
|
|
||||||
def get_message_digest(self, block_hash: str, sequence: int, view: int) -> str:
|
|
||||||
"""Generate message digest for PBFT"""
|
|
||||||
content = f"{block_hash}:{sequence}:{view}"
|
|
||||||
return hashlib.sha256(content.encode()).hexdigest()
|
|
||||||
|
|
||||||
async def pre_prepare_phase(self, proposer: str, block_hash: str) -> bool:
|
|
||||||
"""Phase 1: Pre-prepare"""
|
|
||||||
sequence = self.state.current_sequence + 1
|
|
||||||
view = self.state.current_view
|
|
||||||
digest = self.get_message_digest(block_hash, sequence, view)
|
|
||||||
|
|
||||||
message = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PRE_PREPARE,
|
|
||||||
sender=proposer,
|
|
||||||
view_number=view,
|
|
||||||
sequence_number=sequence,
|
|
||||||
digest=digest,
|
|
||||||
signature="", # Would be signed in real implementation
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store pre-prepare message
|
|
||||||
key = f"{sequence}:{view}"
|
|
||||||
self.state.pre_prepare_messages[key] = message
|
|
||||||
|
|
||||||
# Broadcast to all validators
|
|
||||||
await self._broadcast_message(message)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def prepare_phase(self, validator: str, pre_prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 2: Prepare"""
|
|
||||||
key = f"{pre_prepare_msg.sequence_number}:{pre_prepare_msg.view_number}"
|
|
||||||
|
|
||||||
if key not in self.state.pre_prepare_messages:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Create prepare message
|
|
||||||
prepare_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PREPARE,
|
|
||||||
sender=validator,
|
|
||||||
view_number=pre_prepare_msg.view_number,
|
|
||||||
sequence_number=pre_prepare_msg.sequence_number,
|
|
||||||
digest=pre_prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store prepare message
|
|
||||||
if key not in self.state.prepared_messages:
|
|
||||||
self.state.prepared_messages[key] = []
|
|
||||||
self.state.prepared_messages[key].append(prepare_msg)
|
|
||||||
|
|
||||||
# Broadcast prepare message
|
|
||||||
await self._broadcast_message(prepare_msg)
|
|
||||||
|
|
||||||
# Check if we have enough prepare messages
|
|
||||||
return len(self.state.prepared_messages[key]) >= self.required_messages
|
|
||||||
|
|
||||||
async def commit_phase(self, validator: str, prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 3: Commit"""
|
|
||||||
key = f"{prepare_msg.sequence_number}:{prepare_msg.view_number}"
|
|
||||||
|
|
||||||
# Create commit message
|
|
||||||
commit_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.COMMIT,
|
|
||||||
sender=validator,
|
|
||||||
view_number=prepare_msg.view_number,
|
|
||||||
sequence_number=prepare_msg.sequence_number,
|
|
||||||
digest=prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store commit message
|
|
||||||
if key not in self.state.committed_messages:
|
|
||||||
self.state.committed_messages[key] = []
|
|
||||||
self.state.committed_messages[key].append(commit_msg)
|
|
||||||
|
|
||||||
# Broadcast commit message
|
|
||||||
await self._broadcast_message(commit_msg)
|
|
||||||
|
|
||||||
# Check if we have enough commit messages
|
|
||||||
if len(self.state.committed_messages[key]) >= self.required_messages:
|
|
||||||
return await self.execute_phase(key)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def execute_phase(self, key: str) -> bool:
|
|
||||||
"""Phase 4: Execute"""
|
|
||||||
# Extract sequence and view from key
|
|
||||||
sequence, view = map(int, key.split(':'))
|
|
||||||
|
|
||||||
# Update state
|
|
||||||
self.state.current_sequence = sequence
|
|
||||||
|
|
||||||
# Clean up old messages
|
|
||||||
self._cleanup_messages(sequence)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _broadcast_message(self, message: PBFTMessage):
|
|
||||||
"""Broadcast message to all validators"""
|
|
||||||
validators = self.consensus.get_consensus_participants()
|
|
||||||
|
|
||||||
for validator in validators:
|
|
||||||
if validator != message.sender:
|
|
||||||
# In real implementation, this would send over network
|
|
||||||
await self._send_to_validator(validator, message)
|
|
||||||
|
|
||||||
async def _send_to_validator(self, validator: str, message: PBFTMessage):
|
|
||||||
"""Send message to specific validator"""
|
|
||||||
# Network communication would be implemented here
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _cleanup_messages(self, sequence: int):
|
|
||||||
"""Clean up old messages to prevent memory leaks"""
|
|
||||||
old_keys = [
|
|
||||||
key for key in self.state.prepared_messages.keys()
|
|
||||||
if int(key.split(':')[0]) < sequence
|
|
||||||
]
|
|
||||||
|
|
||||||
for key in old_keys:
|
|
||||||
self.state.prepared_messages.pop(key, None)
|
|
||||||
self.state.committed_messages.pop(key, None)
|
|
||||||
self.state.pre_prepare_messages.pop(key, None)
|
|
||||||
|
|
||||||
def handle_view_change(self, new_view: int) -> bool:
|
|
||||||
"""Handle view change when proposer fails"""
|
|
||||||
self.state.current_view = new_view
|
|
||||||
# Reset state for new view
|
|
||||||
self.state.prepared_messages.clear()
|
|
||||||
self.state.committed_messages.clear()
|
|
||||||
self.state.pre_prepare_messages.clear()
|
|
||||||
return True
|
|
||||||
@@ -1,345 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block, Account
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
await self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
await self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool and include transactions
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
from ..models import Transaction, Account
|
|
||||||
mempool = get_mempool()
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
# Pull transactions from mempool
|
|
||||||
max_txs = self._config.max_txs_per_block
|
|
||||||
max_bytes = self._config.max_block_size_bytes
|
|
||||||
pending_txs = mempool.drain(max_txs, max_bytes, self._config.chain_id)
|
|
||||||
self._logger.info(f"[PROPOSE] drained {len(pending_txs)} txs from mempool, chain={self._config.chain_id}")
|
|
||||||
|
|
||||||
# Process transactions and update balances
|
|
||||||
processed_txs = []
|
|
||||||
for tx in pending_txs:
|
|
||||||
try:
|
|
||||||
# Parse transaction data
|
|
||||||
tx_data = tx.content
|
|
||||||
sender = tx_data.get("from")
|
|
||||||
recipient = tx_data.get("to")
|
|
||||||
value = tx_data.get("amount", 0)
|
|
||||||
fee = tx_data.get("fee", 0)
|
|
||||||
|
|
||||||
if not sender or not recipient:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get sender account
|
|
||||||
sender_account = session.get(Account, (self._config.chain_id, sender))
|
|
||||||
if not sender_account:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check sufficient balance
|
|
||||||
total_cost = value + fee
|
|
||||||
if sender_account.balance < total_cost:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get or create recipient account
|
|
||||||
recipient_account = session.get(Account, (self._config.chain_id, recipient))
|
|
||||||
if not recipient_account:
|
|
||||||
recipient_account = Account(chain_id=self._config.chain_id, address=recipient, balance=0, nonce=0)
|
|
||||||
session.add(recipient_account)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
# Update balances
|
|
||||||
sender_account.balance -= total_cost
|
|
||||||
sender_account.nonce += 1
|
|
||||||
recipient_account.balance += value
|
|
||||||
|
|
||||||
# Create transaction record
|
|
||||||
transaction = Transaction(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
tx_hash=tx.tx_hash,
|
|
||||||
sender=sender,
|
|
||||||
recipient=recipient,
|
|
||||||
payload=tx_data,
|
|
||||||
value=value,
|
|
||||||
fee=fee,
|
|
||||||
nonce=sender_account.nonce - 1,
|
|
||||||
timestamp=timestamp,
|
|
||||||
block_height=next_height,
|
|
||||||
status="confirmed"
|
|
||||||
)
|
|
||||||
session.add(transaction)
|
|
||||||
processed_txs.append(tx)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.warning(f"Failed to process transaction {tx.tx_hash}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Compute block hash with transaction data
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp, processed_txs)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=len(processed_txs),
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
tx_list = [tx.content for tx in processed_txs] if processed_txs else []
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
"transactions": tx_list,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer=self._config.proposer_id, # Use configured proposer as genesis proposer
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Initialize accounts from genesis allocations file (if present)
|
|
||||||
await self._initialize_genesis_allocations(session)
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _initialize_genesis_allocations(self, session: Session) -> None:
|
|
||||||
"""Create Account entries from the genesis allocations file."""
|
|
||||||
# Use standardized data directory from configuration
|
|
||||||
from ..config import settings
|
|
||||||
|
|
||||||
genesis_paths = [
|
|
||||||
Path(f"/var/lib/aitbc/data/{self._config.chain_id}/genesis.json"), # Standard location
|
|
||||||
]
|
|
||||||
|
|
||||||
genesis_path = None
|
|
||||||
for path in genesis_paths:
|
|
||||||
if path.exists():
|
|
||||||
genesis_path = path
|
|
||||||
break
|
|
||||||
|
|
||||||
if not genesis_path:
|
|
||||||
self._logger.warning("Genesis allocations file not found; skipping account initialization", extra={"paths": str(genesis_paths)})
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(genesis_path) as f:
|
|
||||||
genesis_data = json.load(f)
|
|
||||||
|
|
||||||
allocations = genesis_data.get("allocations", [])
|
|
||||||
created = 0
|
|
||||||
for alloc in allocations:
|
|
||||||
addr = alloc["address"]
|
|
||||||
balance = int(alloc["balance"])
|
|
||||||
nonce = int(alloc.get("nonce", 0))
|
|
||||||
# Check if account already exists (idempotent)
|
|
||||||
acct = session.get(Account, (self._config.chain_id, addr))
|
|
||||||
if acct is None:
|
|
||||||
acct = Account(chain_id=self._config.chain_id, address=addr, balance=balance, nonce=nonce)
|
|
||||||
session.add(acct)
|
|
||||||
created += 1
|
|
||||||
session.commit()
|
|
||||||
self._logger.info("Initialized genesis accounts", extra={"count": created, "total": len(allocations), "path": str(genesis_path)})
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime, transactions: list = None) -> str:
|
|
||||||
# Include transaction hashes in block hash computation
|
|
||||||
tx_hashes = []
|
|
||||||
if transactions:
|
|
||||||
tx_hashes = [tx.tx_hash for tx in transactions]
|
|
||||||
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}|{'|'.join(sorted(tx_hashes))}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,229 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
if get_mempool().size(self._config.chain_id) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer="genesis",
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime) -> str:
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
--- apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
+++ apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
@@ -101,7 +101,7 @@
|
|
||||||
# Wait for interval before proposing next block
|
|
||||||
await asyncio.sleep(self.config.interval_seconds)
|
|
||||||
|
|
||||||
- self._propose_block()
|
|
||||||
+ await self._propose_block()
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Rotation Mechanism
|
|
||||||
Handles automatic rotation of validators based on performance and stake
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from typing import List, Dict, Optional
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator, ValidatorRole
|
|
||||||
|
|
||||||
class RotationStrategy(Enum):
|
|
||||||
ROUND_ROBIN = "round_robin"
|
|
||||||
STAKE_WEIGHTED = "stake_weighted"
|
|
||||||
REPUTATION_BASED = "reputation_based"
|
|
||||||
HYBRID = "hybrid"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RotationConfig:
|
|
||||||
strategy: RotationStrategy
|
|
||||||
rotation_interval: int # blocks
|
|
||||||
min_stake: float
|
|
||||||
reputation_threshold: float
|
|
||||||
max_validators: int
|
|
||||||
|
|
||||||
class ValidatorRotation:
|
|
||||||
"""Manages validator rotation based on various strategies"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA, config: RotationConfig):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.config = config
|
|
||||||
self.last_rotation_height = 0
|
|
||||||
|
|
||||||
def should_rotate(self, current_height: int) -> bool:
|
|
||||||
"""Check if rotation should occur at current height"""
|
|
||||||
return (current_height - self.last_rotation_height) >= self.config.rotation_interval
|
|
||||||
|
|
||||||
def rotate_validators(self, current_height: int) -> bool:
|
|
||||||
"""Perform validator rotation based on configured strategy"""
|
|
||||||
if not self.should_rotate(current_height):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.config.strategy == RotationStrategy.ROUND_ROBIN:
|
|
||||||
return self._rotate_round_robin()
|
|
||||||
elif self.config.strategy == RotationStrategy.STAKE_WEIGHTED:
|
|
||||||
return self._rotate_stake_weighted()
|
|
||||||
elif self.config.strategy == RotationStrategy.REPUTATION_BASED:
|
|
||||||
return self._rotate_reputation_based()
|
|
||||||
elif self.config.strategy == RotationStrategy.HYBRID:
|
|
||||||
return self._rotate_hybrid()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _rotate_round_robin(self) -> bool:
|
|
||||||
"""Round-robin rotation of validator roles"""
|
|
||||||
validators = list(self.consensus.validators.values())
|
|
||||||
active_validators = [v for v in validators if v.is_active]
|
|
||||||
|
|
||||||
# Rotate roles among active validators
|
|
||||||
for i, validator in enumerate(active_validators):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 3: # Top 3 become validators
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_stake_weighted(self) -> bool:
|
|
||||||
"""Stake-weighted rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.stake,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_reputation_based(self) -> bool:
|
|
||||||
"""Reputation-based rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.reputation,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter by reputation threshold
|
|
||||||
qualified_validators = [
|
|
||||||
v for v in validators
|
|
||||||
if v.reputation >= self.config.reputation_threshold
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, validator in enumerate(qualified_validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_hybrid(self) -> bool:
|
|
||||||
"""Hybrid rotation considering both stake and reputation"""
|
|
||||||
validators = [v for v in self.consensus.validators.values() if v.is_active]
|
|
||||||
|
|
||||||
# Calculate hybrid score
|
|
||||||
for validator in validators:
|
|
||||||
validator.hybrid_score = validator.stake * validator.reputation
|
|
||||||
|
|
||||||
# Sort by hybrid score
|
|
||||||
validators.sort(key=lambda v: v.hybrid_score, reverse=True)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Default rotation configuration
|
|
||||||
DEFAULT_ROTATION_CONFIG = RotationConfig(
|
|
||||||
strategy=RotationStrategy.HYBRID,
|
|
||||||
rotation_interval=100, # Rotate every 100 blocks
|
|
||||||
min_stake=1000.0,
|
|
||||||
reputation_threshold=0.7,
|
|
||||||
max_validators=10
|
|
||||||
)
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
"""
|
|
||||||
Slashing Conditions Implementation
|
|
||||||
Handles detection and penalties for validator misbehavior
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Dict, List, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import Validator, ValidatorRole
|
|
||||||
|
|
||||||
class SlashingCondition(Enum):
|
|
||||||
DOUBLE_SIGN = "double_sign"
|
|
||||||
UNAVAILABLE = "unavailable"
|
|
||||||
INVALID_BLOCK = "invalid_block"
|
|
||||||
SLOW_RESPONSE = "slow_response"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SlashingEvent:
|
|
||||||
validator_address: str
|
|
||||||
condition: SlashingCondition
|
|
||||||
evidence: str
|
|
||||||
block_height: int
|
|
||||||
timestamp: float
|
|
||||||
slash_amount: float
|
|
||||||
|
|
||||||
class SlashingManager:
|
|
||||||
"""Manages validator slashing conditions and penalties"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.slashing_events: List[SlashingEvent] = []
|
|
||||||
self.slash_rates = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 0.5, # 50% slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 0.1, # 10% slash
|
|
||||||
SlashingCondition.INVALID_BLOCK: 0.3, # 30% slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 0.05 # 5% slash
|
|
||||||
}
|
|
||||||
self.slash_thresholds = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 1, # Immediate slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 3, # After 3 offenses
|
|
||||||
SlashingCondition.INVALID_BLOCK: 1, # Immediate slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 5 # After 5 offenses
|
|
||||||
}
|
|
||||||
|
|
||||||
def detect_double_sign(self, validator: str, block_hash1: str, block_hash2: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect double signing (validator signed two different blocks at same height)"""
|
|
||||||
if block_hash1 == block_hash2:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.DOUBLE_SIGN,
|
|
||||||
evidence=f"Double sign detected: {block_hash1} vs {block_hash2} at height {height}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.DOUBLE_SIGN]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_unavailability(self, validator: str, missed_blocks: int, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect validator unavailability (missing consensus participation)"""
|
|
||||||
if missed_blocks < self.slash_thresholds[SlashingCondition.UNAVAILABLE]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.UNAVAILABLE,
|
|
||||||
evidence=f"Missed {missed_blocks} consecutive blocks",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.UNAVAILABLE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_invalid_block(self, validator: str, block_hash: str, reason: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect invalid block proposal"""
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.INVALID_BLOCK,
|
|
||||||
evidence=f"Invalid block {block_hash}: {reason}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.INVALID_BLOCK]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_slow_response(self, validator: str, response_time: float, threshold: float, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect slow consensus participation"""
|
|
||||||
if response_time <= threshold:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.SLOW_RESPONSE,
|
|
||||||
evidence=f"Slow response: {response_time}s (threshold: {threshold}s)",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.SLOW_RESPONSE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply_slashing(self, validator: Validator, event: SlashingEvent) -> bool:
|
|
||||||
"""Apply slashing penalty to validator"""
|
|
||||||
slash_amount = validator.stake * event.slash_amount
|
|
||||||
validator.stake -= slash_amount
|
|
||||||
|
|
||||||
# Demote validator role if stake is too low
|
|
||||||
if validator.stake < 100: # Minimum stake threshold
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
# Record slashing event
|
|
||||||
self.slashing_events.append(event)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_validator_slash_count(self, validator_address: str, condition: SlashingCondition) -> int:
|
|
||||||
"""Get count of slashing events for validator and condition"""
|
|
||||||
return len([
|
|
||||||
event for event in self.slashing_events
|
|
||||||
if event.validator_address == validator_address and event.condition == condition
|
|
||||||
])
|
|
||||||
|
|
||||||
def should_slash(self, validator: str, condition: SlashingCondition) -> bool:
|
|
||||||
"""Check if validator should be slashed for condition"""
|
|
||||||
current_count = self.get_validator_slash_count(validator, condition)
|
|
||||||
threshold = self.slash_thresholds.get(condition, 1)
|
|
||||||
return current_count >= threshold
|
|
||||||
|
|
||||||
def get_slashing_history(self, validator_address: Optional[str] = None) -> List[SlashingEvent]:
|
|
||||||
"""Get slashing history for validator or all validators"""
|
|
||||||
if validator_address:
|
|
||||||
return [event for event in self.slashing_events if event.validator_address == validator_address]
|
|
||||||
return self.slashing_events.copy()
|
|
||||||
|
|
||||||
def calculate_total_slashed(self, validator_address: str) -> float:
|
|
||||||
"""Calculate total amount slashed for validator"""
|
|
||||||
events = self.get_slashing_history(validator_address)
|
|
||||||
return sum(event.slash_amount for event in events)
|
|
||||||
|
|
||||||
# Global slashing manager
|
|
||||||
slashing_manager = SlashingManager()
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from .poa import PoAProposer, ProposerConfig, CircuitBreaker
|
|
||||||
|
|
||||||
__all__ = ["PoAProposer", "ProposerConfig", "CircuitBreaker"]
|
|
||||||
@@ -1,210 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Key Management
|
|
||||||
Handles cryptographic key operations for validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ValidatorKeyPair:
|
|
||||||
address: str
|
|
||||||
private_key_pem: str
|
|
||||||
public_key_pem: str
|
|
||||||
created_at: float
|
|
||||||
last_rotated: float
|
|
||||||
|
|
||||||
class KeyManager:
|
|
||||||
"""Manages validator cryptographic keys"""
|
|
||||||
|
|
||||||
def __init__(self, keys_dir: str = "/opt/aitbc/keys"):
|
|
||||||
self.keys_dir = keys_dir
|
|
||||||
self.key_pairs: Dict[str, ValidatorKeyPair] = {}
|
|
||||||
self._ensure_keys_directory()
|
|
||||||
self._load_existing_keys()
|
|
||||||
|
|
||||||
def _ensure_keys_directory(self):
|
|
||||||
"""Ensure keys directory exists and has proper permissions"""
|
|
||||||
os.makedirs(self.keys_dir, mode=0o700, exist_ok=True)
|
|
||||||
|
|
||||||
def _load_existing_keys(self):
|
|
||||||
"""Load existing key pairs from disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
if os.path.exists(keys_file):
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'r') as f:
|
|
||||||
keys_data = json.load(f)
|
|
||||||
|
|
||||||
for address, key_data in keys_data.items():
|
|
||||||
self.key_pairs[address] = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=key_data['private_key_pem'],
|
|
||||||
public_key_pem=key_data['public_key_pem'],
|
|
||||||
created_at=key_data['created_at'],
|
|
||||||
last_rotated=key_data['last_rotated']
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading keys: {e}")
|
|
||||||
|
|
||||||
def generate_key_pair(self, address: str) -> ValidatorKeyPair:
|
|
||||||
"""Generate new RSA key pair for validator"""
|
|
||||||
# Generate private key
|
|
||||||
private_key = rsa.generate_private_key(
|
|
||||||
public_exponent=65537,
|
|
||||||
key_size=2048,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Serialize private key
|
|
||||||
private_key_pem = private_key.private_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=PrivateFormat.PKCS8,
|
|
||||||
encryption_algorithm=NoEncryption()
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Get public key
|
|
||||||
public_key = private_key.public_key()
|
|
||||||
public_key_pem = public_key.public_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Create key pair object
|
|
||||||
current_time = time.time()
|
|
||||||
key_pair = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=private_key_pem,
|
|
||||||
public_key_pem=public_key_pem,
|
|
||||||
created_at=current_time,
|
|
||||||
last_rotated=current_time
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store key pair
|
|
||||||
self.key_pairs[address] = key_pair
|
|
||||||
self._save_keys()
|
|
||||||
|
|
||||||
return key_pair
|
|
||||||
|
|
||||||
def get_key_pair(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Get key pair for validator"""
|
|
||||||
return self.key_pairs.get(address)
|
|
||||||
|
|
||||||
def rotate_key(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Rotate validator keys"""
|
|
||||||
if address not in self.key_pairs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Generate new key pair
|
|
||||||
new_key_pair = self.generate_key_pair(address)
|
|
||||||
|
|
||||||
# Update rotation time
|
|
||||||
new_key_pair.created_at = self.key_pairs[address].created_at
|
|
||||||
new_key_pair.last_rotated = time.time()
|
|
||||||
|
|
||||||
self._save_keys()
|
|
||||||
return new_key_pair
|
|
||||||
|
|
||||||
def sign_message(self, address: str, message: str) -> Optional[str]:
|
|
||||||
"""Sign message with validator private key"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load private key from PEM
|
|
||||||
private_key = serialization.load_pem_private_key(
|
|
||||||
key_pair.private_key_pem.encode(),
|
|
||||||
password=None,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sign message
|
|
||||||
signature = private_key.sign(
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return signature.hex()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error signing message: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def verify_signature(self, address: str, message: str, signature: str) -> bool:
|
|
||||||
"""Verify message signature"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load public key from PEM
|
|
||||||
public_key = serialization.load_pem_public_key(
|
|
||||||
key_pair.public_key_pem.encode(),
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify signature
|
|
||||||
public_key.verify(
|
|
||||||
bytes.fromhex(signature),
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error verifying signature: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_public_key_pem(self, address: str) -> Optional[str]:
|
|
||||||
"""Get public key PEM for validator"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
return key_pair.public_key_pem if key_pair else None
|
|
||||||
|
|
||||||
def _save_keys(self):
|
|
||||||
"""Save key pairs to disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
keys_data = {}
|
|
||||||
for address, key_pair in self.key_pairs.items():
|
|
||||||
keys_data[address] = {
|
|
||||||
'private_key_pem': key_pair.private_key_pem,
|
|
||||||
'public_key_pem': key_pair.public_key_pem,
|
|
||||||
'created_at': key_pair.created_at,
|
|
||||||
'last_rotated': key_pair.last_rotated
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'w') as f:
|
|
||||||
json.dump(keys_data, f, indent=2)
|
|
||||||
|
|
||||||
# Set secure permissions
|
|
||||||
os.chmod(keys_file, 0o600)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving keys: {e}")
|
|
||||||
|
|
||||||
def should_rotate_key(self, address: str, rotation_interval: int = 86400) -> bool:
|
|
||||||
"""Check if key should be rotated (default: 24 hours)"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return (time.time() - key_pair.last_rotated) >= rotation_interval
|
|
||||||
|
|
||||||
def get_key_age(self, address: str) -> Optional[float]:
|
|
||||||
"""Get age of key in seconds"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return time.time() - key_pair.created_at
|
|
||||||
|
|
||||||
# Global key manager
|
|
||||||
key_manager = KeyManager()
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
"""
|
|
||||||
Multi-Validator Proof of Authority Consensus Implementation
|
|
||||||
Extends single validator PoA to support multiple validators with rotation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from ..config import settings
|
|
||||||
from ..models import Block, Transaction
|
|
||||||
from ..database import session_scope
|
|
||||||
|
|
||||||
class ValidatorRole(Enum):
|
|
||||||
PROPOSER = "proposer"
|
|
||||||
VALIDATOR = "validator"
|
|
||||||
STANDBY = "standby"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Validator:
|
|
||||||
address: str
|
|
||||||
stake: float
|
|
||||||
reputation: float
|
|
||||||
role: ValidatorRole
|
|
||||||
last_proposed: int
|
|
||||||
is_active: bool
|
|
||||||
|
|
||||||
class MultiValidatorPoA:
|
|
||||||
"""Multi-Validator Proof of Authority consensus mechanism"""
|
|
||||||
|
|
||||||
def __init__(self, chain_id: str):
|
|
||||||
self.chain_id = chain_id
|
|
||||||
self.validators: Dict[str, Validator] = {}
|
|
||||||
self.current_proposer_index = 0
|
|
||||||
self.round_robin_enabled = True
|
|
||||||
self.consensus_timeout = 30 # seconds
|
|
||||||
|
|
||||||
def add_validator(self, address: str, stake: float = 1000.0) -> bool:
|
|
||||||
"""Add a new validator to the consensus"""
|
|
||||||
if address in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.validators[address] = Validator(
|
|
||||||
address=address,
|
|
||||||
stake=stake,
|
|
||||||
reputation=1.0,
|
|
||||||
role=ValidatorRole.STANDBY,
|
|
||||||
last_proposed=0,
|
|
||||||
is_active=True
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def remove_validator(self, address: str) -> bool:
|
|
||||||
"""Remove a validator from the consensus"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.is_active = False
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
return True
|
|
||||||
|
|
||||||
def select_proposer(self, block_height: int) -> Optional[str]:
|
|
||||||
"""Select proposer for the current block using round-robin"""
|
|
||||||
active_validators = [
|
|
||||||
v for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
if not active_validators:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Round-robin selection
|
|
||||||
proposer_index = block_height % len(active_validators)
|
|
||||||
return active_validators[proposer_index].address
|
|
||||||
|
|
||||||
def validate_block(self, block: Block, proposer: str) -> bool:
|
|
||||||
"""Validate a proposed block"""
|
|
||||||
if proposer not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[proposer]
|
|
||||||
if not validator.is_active:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if validator is allowed to propose
|
|
||||||
if validator.role not in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Additional validation logic here
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_consensus_participants(self) -> List[str]:
|
|
||||||
"""Get list of active consensus participants"""
|
|
||||||
return [
|
|
||||||
v.address for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
def update_validator_reputation(self, address: str, delta: float) -> bool:
|
|
||||||
"""Update validator reputation"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.reputation = max(0.0, min(1.0, validator.reputation + delta))
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Global consensus instance
|
|
||||||
consensus_instances: Dict[str, MultiValidatorPoA] = {}
|
|
||||||
|
|
||||||
def get_consensus(chain_id: str) -> MultiValidatorPoA:
|
|
||||||
"""Get or create consensus instance for chain"""
|
|
||||||
if chain_id not in consensus_instances:
|
|
||||||
consensus_instances[chain_id] = MultiValidatorPoA(chain_id)
|
|
||||||
return consensus_instances[chain_id]
|
|
||||||
@@ -1,193 +0,0 @@
|
|||||||
"""
|
|
||||||
Practical Byzantine Fault Tolerance (PBFT) Consensus Implementation
|
|
||||||
Provides Byzantine fault tolerance for up to 1/3 faulty validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator
|
|
||||||
|
|
||||||
class PBFTPhase(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
EXECUTE = "execute"
|
|
||||||
|
|
||||||
class PBFTMessageType(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
VIEW_CHANGE = "view_change"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTMessage:
|
|
||||||
message_type: PBFTMessageType
|
|
||||||
sender: str
|
|
||||||
view_number: int
|
|
||||||
sequence_number: int
|
|
||||||
digest: str
|
|
||||||
signature: str
|
|
||||||
timestamp: float
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTState:
|
|
||||||
current_view: int
|
|
||||||
current_sequence: int
|
|
||||||
prepared_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
committed_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
pre_prepare_messages: Dict[str, PBFTMessage]
|
|
||||||
|
|
||||||
class PBFTConsensus:
|
|
||||||
"""PBFT consensus implementation"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.state = PBFTState(
|
|
||||||
current_view=0,
|
|
||||||
current_sequence=0,
|
|
||||||
prepared_messages={},
|
|
||||||
committed_messages={},
|
|
||||||
pre_prepare_messages={}
|
|
||||||
)
|
|
||||||
self.fault_tolerance = max(1, len(consensus.get_consensus_participants()) // 3)
|
|
||||||
self.required_messages = 2 * self.fault_tolerance + 1
|
|
||||||
|
|
||||||
def get_message_digest(self, block_hash: str, sequence: int, view: int) -> str:
|
|
||||||
"""Generate message digest for PBFT"""
|
|
||||||
content = f"{block_hash}:{sequence}:{view}"
|
|
||||||
return hashlib.sha256(content.encode()).hexdigest()
|
|
||||||
|
|
||||||
async def pre_prepare_phase(self, proposer: str, block_hash: str) -> bool:
|
|
||||||
"""Phase 1: Pre-prepare"""
|
|
||||||
sequence = self.state.current_sequence + 1
|
|
||||||
view = self.state.current_view
|
|
||||||
digest = self.get_message_digest(block_hash, sequence, view)
|
|
||||||
|
|
||||||
message = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PRE_PREPARE,
|
|
||||||
sender=proposer,
|
|
||||||
view_number=view,
|
|
||||||
sequence_number=sequence,
|
|
||||||
digest=digest,
|
|
||||||
signature="", # Would be signed in real implementation
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store pre-prepare message
|
|
||||||
key = f"{sequence}:{view}"
|
|
||||||
self.state.pre_prepare_messages[key] = message
|
|
||||||
|
|
||||||
# Broadcast to all validators
|
|
||||||
await self._broadcast_message(message)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def prepare_phase(self, validator: str, pre_prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 2: Prepare"""
|
|
||||||
key = f"{pre_prepare_msg.sequence_number}:{pre_prepare_msg.view_number}"
|
|
||||||
|
|
||||||
if key not in self.state.pre_prepare_messages:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Create prepare message
|
|
||||||
prepare_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PREPARE,
|
|
||||||
sender=validator,
|
|
||||||
view_number=pre_prepare_msg.view_number,
|
|
||||||
sequence_number=pre_prepare_msg.sequence_number,
|
|
||||||
digest=pre_prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store prepare message
|
|
||||||
if key not in self.state.prepared_messages:
|
|
||||||
self.state.prepared_messages[key] = []
|
|
||||||
self.state.prepared_messages[key].append(prepare_msg)
|
|
||||||
|
|
||||||
# Broadcast prepare message
|
|
||||||
await self._broadcast_message(prepare_msg)
|
|
||||||
|
|
||||||
# Check if we have enough prepare messages
|
|
||||||
return len(self.state.prepared_messages[key]) >= self.required_messages
|
|
||||||
|
|
||||||
async def commit_phase(self, validator: str, prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 3: Commit"""
|
|
||||||
key = f"{prepare_msg.sequence_number}:{prepare_msg.view_number}"
|
|
||||||
|
|
||||||
# Create commit message
|
|
||||||
commit_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.COMMIT,
|
|
||||||
sender=validator,
|
|
||||||
view_number=prepare_msg.view_number,
|
|
||||||
sequence_number=prepare_msg.sequence_number,
|
|
||||||
digest=prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store commit message
|
|
||||||
if key not in self.state.committed_messages:
|
|
||||||
self.state.committed_messages[key] = []
|
|
||||||
self.state.committed_messages[key].append(commit_msg)
|
|
||||||
|
|
||||||
# Broadcast commit message
|
|
||||||
await self._broadcast_message(commit_msg)
|
|
||||||
|
|
||||||
# Check if we have enough commit messages
|
|
||||||
if len(self.state.committed_messages[key]) >= self.required_messages:
|
|
||||||
return await self.execute_phase(key)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def execute_phase(self, key: str) -> bool:
|
|
||||||
"""Phase 4: Execute"""
|
|
||||||
# Extract sequence and view from key
|
|
||||||
sequence, view = map(int, key.split(':'))
|
|
||||||
|
|
||||||
# Update state
|
|
||||||
self.state.current_sequence = sequence
|
|
||||||
|
|
||||||
# Clean up old messages
|
|
||||||
self._cleanup_messages(sequence)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _broadcast_message(self, message: PBFTMessage):
|
|
||||||
"""Broadcast message to all validators"""
|
|
||||||
validators = self.consensus.get_consensus_participants()
|
|
||||||
|
|
||||||
for validator in validators:
|
|
||||||
if validator != message.sender:
|
|
||||||
# In real implementation, this would send over network
|
|
||||||
await self._send_to_validator(validator, message)
|
|
||||||
|
|
||||||
async def _send_to_validator(self, validator: str, message: PBFTMessage):
|
|
||||||
"""Send message to specific validator"""
|
|
||||||
# Network communication would be implemented here
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _cleanup_messages(self, sequence: int):
|
|
||||||
"""Clean up old messages to prevent memory leaks"""
|
|
||||||
old_keys = [
|
|
||||||
key for key in self.state.prepared_messages.keys()
|
|
||||||
if int(key.split(':')[0]) < sequence
|
|
||||||
]
|
|
||||||
|
|
||||||
for key in old_keys:
|
|
||||||
self.state.prepared_messages.pop(key, None)
|
|
||||||
self.state.committed_messages.pop(key, None)
|
|
||||||
self.state.pre_prepare_messages.pop(key, None)
|
|
||||||
|
|
||||||
def handle_view_change(self, new_view: int) -> bool:
|
|
||||||
"""Handle view change when proposer fails"""
|
|
||||||
self.state.current_view = new_view
|
|
||||||
# Reset state for new view
|
|
||||||
self.state.prepared_messages.clear()
|
|
||||||
self.state.committed_messages.clear()
|
|
||||||
self.state.pre_prepare_messages.clear()
|
|
||||||
return True
|
|
||||||
@@ -1,345 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block, Account
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
await self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
await self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool and include transactions
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
from ..models import Transaction, Account
|
|
||||||
mempool = get_mempool()
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
# Pull transactions from mempool
|
|
||||||
max_txs = self._config.max_txs_per_block
|
|
||||||
max_bytes = self._config.max_block_size_bytes
|
|
||||||
pending_txs = mempool.drain(max_txs, max_bytes, self._config.chain_id)
|
|
||||||
self._logger.info(f"[PROPOSE] drained {len(pending_txs)} txs from mempool, chain={self._config.chain_id}")
|
|
||||||
|
|
||||||
# Process transactions and update balances
|
|
||||||
processed_txs = []
|
|
||||||
for tx in pending_txs:
|
|
||||||
try:
|
|
||||||
# Parse transaction data
|
|
||||||
tx_data = tx.content
|
|
||||||
sender = tx_data.get("from")
|
|
||||||
recipient = tx_data.get("to")
|
|
||||||
value = tx_data.get("amount", 0)
|
|
||||||
fee = tx_data.get("fee", 0)
|
|
||||||
|
|
||||||
if not sender or not recipient:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get sender account
|
|
||||||
sender_account = session.get(Account, (self._config.chain_id, sender))
|
|
||||||
if not sender_account:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check sufficient balance
|
|
||||||
total_cost = value + fee
|
|
||||||
if sender_account.balance < total_cost:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get or create recipient account
|
|
||||||
recipient_account = session.get(Account, (self._config.chain_id, recipient))
|
|
||||||
if not recipient_account:
|
|
||||||
recipient_account = Account(chain_id=self._config.chain_id, address=recipient, balance=0, nonce=0)
|
|
||||||
session.add(recipient_account)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
# Update balances
|
|
||||||
sender_account.balance -= total_cost
|
|
||||||
sender_account.nonce += 1
|
|
||||||
recipient_account.balance += value
|
|
||||||
|
|
||||||
# Create transaction record
|
|
||||||
transaction = Transaction(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
tx_hash=tx.tx_hash,
|
|
||||||
sender=sender,
|
|
||||||
recipient=recipient,
|
|
||||||
payload=tx_data,
|
|
||||||
value=value,
|
|
||||||
fee=fee,
|
|
||||||
nonce=sender_account.nonce - 1,
|
|
||||||
timestamp=timestamp,
|
|
||||||
block_height=next_height,
|
|
||||||
status="confirmed"
|
|
||||||
)
|
|
||||||
session.add(transaction)
|
|
||||||
processed_txs.append(tx)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.warning(f"Failed to process transaction {tx.tx_hash}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Compute block hash with transaction data
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp, processed_txs)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=len(processed_txs),
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
tx_list = [tx.content for tx in processed_txs] if processed_txs else []
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
"transactions": tx_list,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer=self._config.proposer_id, # Use configured proposer as genesis proposer
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Initialize accounts from genesis allocations file (if present)
|
|
||||||
await self._initialize_genesis_allocations(session)
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _initialize_genesis_allocations(self, session: Session) -> None:
|
|
||||||
"""Create Account entries from the genesis allocations file."""
|
|
||||||
# Use standardized data directory from configuration
|
|
||||||
from ..config import settings
|
|
||||||
|
|
||||||
genesis_paths = [
|
|
||||||
Path(f"/var/lib/aitbc/data/{self._config.chain_id}/genesis.json"), # Standard location
|
|
||||||
]
|
|
||||||
|
|
||||||
genesis_path = None
|
|
||||||
for path in genesis_paths:
|
|
||||||
if path.exists():
|
|
||||||
genesis_path = path
|
|
||||||
break
|
|
||||||
|
|
||||||
if not genesis_path:
|
|
||||||
self._logger.warning("Genesis allocations file not found; skipping account initialization", extra={"paths": str(genesis_paths)})
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(genesis_path) as f:
|
|
||||||
genesis_data = json.load(f)
|
|
||||||
|
|
||||||
allocations = genesis_data.get("allocations", [])
|
|
||||||
created = 0
|
|
||||||
for alloc in allocations:
|
|
||||||
addr = alloc["address"]
|
|
||||||
balance = int(alloc["balance"])
|
|
||||||
nonce = int(alloc.get("nonce", 0))
|
|
||||||
# Check if account already exists (idempotent)
|
|
||||||
acct = session.get(Account, (self._config.chain_id, addr))
|
|
||||||
if acct is None:
|
|
||||||
acct = Account(chain_id=self._config.chain_id, address=addr, balance=balance, nonce=nonce)
|
|
||||||
session.add(acct)
|
|
||||||
created += 1
|
|
||||||
session.commit()
|
|
||||||
self._logger.info("Initialized genesis accounts", extra={"count": created, "total": len(allocations), "path": str(genesis_path)})
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime, transactions: list = None) -> str:
|
|
||||||
# Include transaction hashes in block hash computation
|
|
||||||
tx_hashes = []
|
|
||||||
if transactions:
|
|
||||||
tx_hashes = [tx.tx_hash for tx in transactions]
|
|
||||||
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}|{'|'.join(sorted(tx_hashes))}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,229 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
if get_mempool().size(self._config.chain_id) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer="genesis",
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime) -> str:
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
--- apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
+++ apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
@@ -101,7 +101,7 @@
|
|
||||||
# Wait for interval before proposing next block
|
|
||||||
await asyncio.sleep(self.config.interval_seconds)
|
|
||||||
|
|
||||||
- self._propose_block()
|
|
||||||
+ await self._propose_block()
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Rotation Mechanism
|
|
||||||
Handles automatic rotation of validators based on performance and stake
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from typing import List, Dict, Optional
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator, ValidatorRole
|
|
||||||
|
|
||||||
class RotationStrategy(Enum):
|
|
||||||
ROUND_ROBIN = "round_robin"
|
|
||||||
STAKE_WEIGHTED = "stake_weighted"
|
|
||||||
REPUTATION_BASED = "reputation_based"
|
|
||||||
HYBRID = "hybrid"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RotationConfig:
|
|
||||||
strategy: RotationStrategy
|
|
||||||
rotation_interval: int # blocks
|
|
||||||
min_stake: float
|
|
||||||
reputation_threshold: float
|
|
||||||
max_validators: int
|
|
||||||
|
|
||||||
class ValidatorRotation:
|
|
||||||
"""Manages validator rotation based on various strategies"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA, config: RotationConfig):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.config = config
|
|
||||||
self.last_rotation_height = 0
|
|
||||||
|
|
||||||
def should_rotate(self, current_height: int) -> bool:
|
|
||||||
"""Check if rotation should occur at current height"""
|
|
||||||
return (current_height - self.last_rotation_height) >= self.config.rotation_interval
|
|
||||||
|
|
||||||
def rotate_validators(self, current_height: int) -> bool:
|
|
||||||
"""Perform validator rotation based on configured strategy"""
|
|
||||||
if not self.should_rotate(current_height):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.config.strategy == RotationStrategy.ROUND_ROBIN:
|
|
||||||
return self._rotate_round_robin()
|
|
||||||
elif self.config.strategy == RotationStrategy.STAKE_WEIGHTED:
|
|
||||||
return self._rotate_stake_weighted()
|
|
||||||
elif self.config.strategy == RotationStrategy.REPUTATION_BASED:
|
|
||||||
return self._rotate_reputation_based()
|
|
||||||
elif self.config.strategy == RotationStrategy.HYBRID:
|
|
||||||
return self._rotate_hybrid()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _rotate_round_robin(self) -> bool:
|
|
||||||
"""Round-robin rotation of validator roles"""
|
|
||||||
validators = list(self.consensus.validators.values())
|
|
||||||
active_validators = [v for v in validators if v.is_active]
|
|
||||||
|
|
||||||
# Rotate roles among active validators
|
|
||||||
for i, validator in enumerate(active_validators):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 3: # Top 3 become validators
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_stake_weighted(self) -> bool:
|
|
||||||
"""Stake-weighted rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.stake,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_reputation_based(self) -> bool:
|
|
||||||
"""Reputation-based rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.reputation,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter by reputation threshold
|
|
||||||
qualified_validators = [
|
|
||||||
v for v in validators
|
|
||||||
if v.reputation >= self.config.reputation_threshold
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, validator in enumerate(qualified_validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_hybrid(self) -> bool:
|
|
||||||
"""Hybrid rotation considering both stake and reputation"""
|
|
||||||
validators = [v for v in self.consensus.validators.values() if v.is_active]
|
|
||||||
|
|
||||||
# Calculate hybrid score
|
|
||||||
for validator in validators:
|
|
||||||
validator.hybrid_score = validator.stake * validator.reputation
|
|
||||||
|
|
||||||
# Sort by hybrid score
|
|
||||||
validators.sort(key=lambda v: v.hybrid_score, reverse=True)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Default rotation configuration
|
|
||||||
DEFAULT_ROTATION_CONFIG = RotationConfig(
|
|
||||||
strategy=RotationStrategy.HYBRID,
|
|
||||||
rotation_interval=100, # Rotate every 100 blocks
|
|
||||||
min_stake=1000.0,
|
|
||||||
reputation_threshold=0.7,
|
|
||||||
max_validators=10
|
|
||||||
)
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
"""
|
|
||||||
Slashing Conditions Implementation
|
|
||||||
Handles detection and penalties for validator misbehavior
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Dict, List, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import Validator, ValidatorRole
|
|
||||||
|
|
||||||
class SlashingCondition(Enum):
|
|
||||||
DOUBLE_SIGN = "double_sign"
|
|
||||||
UNAVAILABLE = "unavailable"
|
|
||||||
INVALID_BLOCK = "invalid_block"
|
|
||||||
SLOW_RESPONSE = "slow_response"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SlashingEvent:
|
|
||||||
validator_address: str
|
|
||||||
condition: SlashingCondition
|
|
||||||
evidence: str
|
|
||||||
block_height: int
|
|
||||||
timestamp: float
|
|
||||||
slash_amount: float
|
|
||||||
|
|
||||||
class SlashingManager:
|
|
||||||
"""Manages validator slashing conditions and penalties"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.slashing_events: List[SlashingEvent] = []
|
|
||||||
self.slash_rates = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 0.5, # 50% slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 0.1, # 10% slash
|
|
||||||
SlashingCondition.INVALID_BLOCK: 0.3, # 30% slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 0.05 # 5% slash
|
|
||||||
}
|
|
||||||
self.slash_thresholds = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 1, # Immediate slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 3, # After 3 offenses
|
|
||||||
SlashingCondition.INVALID_BLOCK: 1, # Immediate slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 5 # After 5 offenses
|
|
||||||
}
|
|
||||||
|
|
||||||
def detect_double_sign(self, validator: str, block_hash1: str, block_hash2: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect double signing (validator signed two different blocks at same height)"""
|
|
||||||
if block_hash1 == block_hash2:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.DOUBLE_SIGN,
|
|
||||||
evidence=f"Double sign detected: {block_hash1} vs {block_hash2} at height {height}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.DOUBLE_SIGN]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_unavailability(self, validator: str, missed_blocks: int, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect validator unavailability (missing consensus participation)"""
|
|
||||||
if missed_blocks < self.slash_thresholds[SlashingCondition.UNAVAILABLE]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.UNAVAILABLE,
|
|
||||||
evidence=f"Missed {missed_blocks} consecutive blocks",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.UNAVAILABLE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_invalid_block(self, validator: str, block_hash: str, reason: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect invalid block proposal"""
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.INVALID_BLOCK,
|
|
||||||
evidence=f"Invalid block {block_hash}: {reason}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.INVALID_BLOCK]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_slow_response(self, validator: str, response_time: float, threshold: float, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect slow consensus participation"""
|
|
||||||
if response_time <= threshold:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.SLOW_RESPONSE,
|
|
||||||
evidence=f"Slow response: {response_time}s (threshold: {threshold}s)",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.SLOW_RESPONSE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply_slashing(self, validator: Validator, event: SlashingEvent) -> bool:
|
|
||||||
"""Apply slashing penalty to validator"""
|
|
||||||
slash_amount = validator.stake * event.slash_amount
|
|
||||||
validator.stake -= slash_amount
|
|
||||||
|
|
||||||
# Demote validator role if stake is too low
|
|
||||||
if validator.stake < 100: # Minimum stake threshold
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
# Record slashing event
|
|
||||||
self.slashing_events.append(event)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_validator_slash_count(self, validator_address: str, condition: SlashingCondition) -> int:
|
|
||||||
"""Get count of slashing events for validator and condition"""
|
|
||||||
return len([
|
|
||||||
event for event in self.slashing_events
|
|
||||||
if event.validator_address == validator_address and event.condition == condition
|
|
||||||
])
|
|
||||||
|
|
||||||
def should_slash(self, validator: str, condition: SlashingCondition) -> bool:
|
|
||||||
"""Check if validator should be slashed for condition"""
|
|
||||||
current_count = self.get_validator_slash_count(validator, condition)
|
|
||||||
threshold = self.slash_thresholds.get(condition, 1)
|
|
||||||
return current_count >= threshold
|
|
||||||
|
|
||||||
def get_slashing_history(self, validator_address: Optional[str] = None) -> List[SlashingEvent]:
|
|
||||||
"""Get slashing history for validator or all validators"""
|
|
||||||
if validator_address:
|
|
||||||
return [event for event in self.slashing_events if event.validator_address == validator_address]
|
|
||||||
return self.slashing_events.copy()
|
|
||||||
|
|
||||||
def calculate_total_slashed(self, validator_address: str) -> float:
|
|
||||||
"""Calculate total amount slashed for validator"""
|
|
||||||
events = self.get_slashing_history(validator_address)
|
|
||||||
return sum(event.slash_amount for event in events)
|
|
||||||
|
|
||||||
# Global slashing manager
|
|
||||||
slashing_manager = SlashingManager()
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from .poa import PoAProposer, ProposerConfig, CircuitBreaker
|
|
||||||
|
|
||||||
__all__ = ["PoAProposer", "ProposerConfig", "CircuitBreaker"]
|
|
||||||
@@ -1,210 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Key Management
|
|
||||||
Handles cryptographic key operations for validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ValidatorKeyPair:
|
|
||||||
address: str
|
|
||||||
private_key_pem: str
|
|
||||||
public_key_pem: str
|
|
||||||
created_at: float
|
|
||||||
last_rotated: float
|
|
||||||
|
|
||||||
class KeyManager:
|
|
||||||
"""Manages validator cryptographic keys"""
|
|
||||||
|
|
||||||
def __init__(self, keys_dir: str = "/opt/aitbc/keys"):
|
|
||||||
self.keys_dir = keys_dir
|
|
||||||
self.key_pairs: Dict[str, ValidatorKeyPair] = {}
|
|
||||||
self._ensure_keys_directory()
|
|
||||||
self._load_existing_keys()
|
|
||||||
|
|
||||||
def _ensure_keys_directory(self):
|
|
||||||
"""Ensure keys directory exists and has proper permissions"""
|
|
||||||
os.makedirs(self.keys_dir, mode=0o700, exist_ok=True)
|
|
||||||
|
|
||||||
def _load_existing_keys(self):
|
|
||||||
"""Load existing key pairs from disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
if os.path.exists(keys_file):
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'r') as f:
|
|
||||||
keys_data = json.load(f)
|
|
||||||
|
|
||||||
for address, key_data in keys_data.items():
|
|
||||||
self.key_pairs[address] = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=key_data['private_key_pem'],
|
|
||||||
public_key_pem=key_data['public_key_pem'],
|
|
||||||
created_at=key_data['created_at'],
|
|
||||||
last_rotated=key_data['last_rotated']
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading keys: {e}")
|
|
||||||
|
|
||||||
def generate_key_pair(self, address: str) -> ValidatorKeyPair:
|
|
||||||
"""Generate new RSA key pair for validator"""
|
|
||||||
# Generate private key
|
|
||||||
private_key = rsa.generate_private_key(
|
|
||||||
public_exponent=65537,
|
|
||||||
key_size=2048,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Serialize private key
|
|
||||||
private_key_pem = private_key.private_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=PrivateFormat.PKCS8,
|
|
||||||
encryption_algorithm=NoEncryption()
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Get public key
|
|
||||||
public_key = private_key.public_key()
|
|
||||||
public_key_pem = public_key.public_bytes(
|
|
||||||
encoding=Encoding.PEM,
|
|
||||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
|
||||||
).decode('utf-8')
|
|
||||||
|
|
||||||
# Create key pair object
|
|
||||||
current_time = time.time()
|
|
||||||
key_pair = ValidatorKeyPair(
|
|
||||||
address=address,
|
|
||||||
private_key_pem=private_key_pem,
|
|
||||||
public_key_pem=public_key_pem,
|
|
||||||
created_at=current_time,
|
|
||||||
last_rotated=current_time
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store key pair
|
|
||||||
self.key_pairs[address] = key_pair
|
|
||||||
self._save_keys()
|
|
||||||
|
|
||||||
return key_pair
|
|
||||||
|
|
||||||
def get_key_pair(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Get key pair for validator"""
|
|
||||||
return self.key_pairs.get(address)
|
|
||||||
|
|
||||||
def rotate_key(self, address: str) -> Optional[ValidatorKeyPair]:
|
|
||||||
"""Rotate validator keys"""
|
|
||||||
if address not in self.key_pairs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Generate new key pair
|
|
||||||
new_key_pair = self.generate_key_pair(address)
|
|
||||||
|
|
||||||
# Update rotation time
|
|
||||||
new_key_pair.created_at = self.key_pairs[address].created_at
|
|
||||||
new_key_pair.last_rotated = time.time()
|
|
||||||
|
|
||||||
self._save_keys()
|
|
||||||
return new_key_pair
|
|
||||||
|
|
||||||
def sign_message(self, address: str, message: str) -> Optional[str]:
|
|
||||||
"""Sign message with validator private key"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load private key from PEM
|
|
||||||
private_key = serialization.load_pem_private_key(
|
|
||||||
key_pair.private_key_pem.encode(),
|
|
||||||
password=None,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sign message
|
|
||||||
signature = private_key.sign(
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return signature.hex()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error signing message: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def verify_signature(self, address: str, message: str, signature: str) -> bool:
|
|
||||||
"""Verify message signature"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load public key from PEM
|
|
||||||
public_key = serialization.load_pem_public_key(
|
|
||||||
key_pair.public_key_pem.encode(),
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify signature
|
|
||||||
public_key.verify(
|
|
||||||
bytes.fromhex(signature),
|
|
||||||
message.encode('utf-8'),
|
|
||||||
hashes.SHA256(),
|
|
||||||
default_backend()
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error verifying signature: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_public_key_pem(self, address: str) -> Optional[str]:
|
|
||||||
"""Get public key PEM for validator"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
return key_pair.public_key_pem if key_pair else None
|
|
||||||
|
|
||||||
def _save_keys(self):
|
|
||||||
"""Save key pairs to disk"""
|
|
||||||
keys_file = os.path.join(self.keys_dir, "validator_keys.json")
|
|
||||||
|
|
||||||
keys_data = {}
|
|
||||||
for address, key_pair in self.key_pairs.items():
|
|
||||||
keys_data[address] = {
|
|
||||||
'private_key_pem': key_pair.private_key_pem,
|
|
||||||
'public_key_pem': key_pair.public_key_pem,
|
|
||||||
'created_at': key_pair.created_at,
|
|
||||||
'last_rotated': key_pair.last_rotated
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(keys_file, 'w') as f:
|
|
||||||
json.dump(keys_data, f, indent=2)
|
|
||||||
|
|
||||||
# Set secure permissions
|
|
||||||
os.chmod(keys_file, 0o600)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving keys: {e}")
|
|
||||||
|
|
||||||
def should_rotate_key(self, address: str, rotation_interval: int = 86400) -> bool:
|
|
||||||
"""Check if key should be rotated (default: 24 hours)"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return (time.time() - key_pair.last_rotated) >= rotation_interval
|
|
||||||
|
|
||||||
def get_key_age(self, address: str) -> Optional[float]:
|
|
||||||
"""Get age of key in seconds"""
|
|
||||||
key_pair = self.get_key_pair(address)
|
|
||||||
if not key_pair:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return time.time() - key_pair.created_at
|
|
||||||
|
|
||||||
# Global key manager
|
|
||||||
key_manager = KeyManager()
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
"""
|
|
||||||
Multi-Validator Proof of Authority Consensus Implementation
|
|
||||||
Extends single validator PoA to support multiple validators with rotation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from ..config import settings
|
|
||||||
from ..models import Block, Transaction
|
|
||||||
from ..database import session_scope
|
|
||||||
|
|
||||||
class ValidatorRole(Enum):
|
|
||||||
PROPOSER = "proposer"
|
|
||||||
VALIDATOR = "validator"
|
|
||||||
STANDBY = "standby"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Validator:
|
|
||||||
address: str
|
|
||||||
stake: float
|
|
||||||
reputation: float
|
|
||||||
role: ValidatorRole
|
|
||||||
last_proposed: int
|
|
||||||
is_active: bool
|
|
||||||
|
|
||||||
class MultiValidatorPoA:
|
|
||||||
"""Multi-Validator Proof of Authority consensus mechanism"""
|
|
||||||
|
|
||||||
def __init__(self, chain_id: str):
|
|
||||||
self.chain_id = chain_id
|
|
||||||
self.validators: Dict[str, Validator] = {}
|
|
||||||
self.current_proposer_index = 0
|
|
||||||
self.round_robin_enabled = True
|
|
||||||
self.consensus_timeout = 30 # seconds
|
|
||||||
|
|
||||||
def add_validator(self, address: str, stake: float = 1000.0) -> bool:
|
|
||||||
"""Add a new validator to the consensus"""
|
|
||||||
if address in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.validators[address] = Validator(
|
|
||||||
address=address,
|
|
||||||
stake=stake,
|
|
||||||
reputation=1.0,
|
|
||||||
role=ValidatorRole.STANDBY,
|
|
||||||
last_proposed=0,
|
|
||||||
is_active=True
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def remove_validator(self, address: str) -> bool:
|
|
||||||
"""Remove a validator from the consensus"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.is_active = False
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
return True
|
|
||||||
|
|
||||||
def select_proposer(self, block_height: int) -> Optional[str]:
|
|
||||||
"""Select proposer for the current block using round-robin"""
|
|
||||||
active_validators = [
|
|
||||||
v for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
if not active_validators:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Round-robin selection
|
|
||||||
proposer_index = block_height % len(active_validators)
|
|
||||||
return active_validators[proposer_index].address
|
|
||||||
|
|
||||||
def validate_block(self, block: Block, proposer: str) -> bool:
|
|
||||||
"""Validate a proposed block"""
|
|
||||||
if proposer not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[proposer]
|
|
||||||
if not validator.is_active:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if validator is allowed to propose
|
|
||||||
if validator.role not in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Additional validation logic here
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_consensus_participants(self) -> List[str]:
|
|
||||||
"""Get list of active consensus participants"""
|
|
||||||
return [
|
|
||||||
v.address for v in self.validators.values()
|
|
||||||
if v.is_active and v.role in [ValidatorRole.PROPOSER, ValidatorRole.VALIDATOR]
|
|
||||||
]
|
|
||||||
|
|
||||||
def update_validator_reputation(self, address: str, delta: float) -> bool:
|
|
||||||
"""Update validator reputation"""
|
|
||||||
if address not in self.validators:
|
|
||||||
return False
|
|
||||||
|
|
||||||
validator = self.validators[address]
|
|
||||||
validator.reputation = max(0.0, min(1.0, validator.reputation + delta))
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Global consensus instance
|
|
||||||
consensus_instances: Dict[str, MultiValidatorPoA] = {}
|
|
||||||
|
|
||||||
def get_consensus(chain_id: str) -> MultiValidatorPoA:
|
|
||||||
"""Get or create consensus instance for chain"""
|
|
||||||
if chain_id not in consensus_instances:
|
|
||||||
consensus_instances[chain_id] = MultiValidatorPoA(chain_id)
|
|
||||||
return consensus_instances[chain_id]
|
|
||||||
@@ -1,193 +0,0 @@
|
|||||||
"""
|
|
||||||
Practical Byzantine Fault Tolerance (PBFT) Consensus Implementation
|
|
||||||
Provides Byzantine fault tolerance for up to 1/3 faulty validators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import hashlib
|
|
||||||
from typing import List, Dict, Optional, Set, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator
|
|
||||||
|
|
||||||
class PBFTPhase(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
EXECUTE = "execute"
|
|
||||||
|
|
||||||
class PBFTMessageType(Enum):
|
|
||||||
PRE_PREPARE = "pre_prepare"
|
|
||||||
PREPARE = "prepare"
|
|
||||||
COMMIT = "commit"
|
|
||||||
VIEW_CHANGE = "view_change"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTMessage:
|
|
||||||
message_type: PBFTMessageType
|
|
||||||
sender: str
|
|
||||||
view_number: int
|
|
||||||
sequence_number: int
|
|
||||||
digest: str
|
|
||||||
signature: str
|
|
||||||
timestamp: float
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PBFTState:
|
|
||||||
current_view: int
|
|
||||||
current_sequence: int
|
|
||||||
prepared_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
committed_messages: Dict[str, List[PBFTMessage]]
|
|
||||||
pre_prepare_messages: Dict[str, PBFTMessage]
|
|
||||||
|
|
||||||
class PBFTConsensus:
|
|
||||||
"""PBFT consensus implementation"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.state = PBFTState(
|
|
||||||
current_view=0,
|
|
||||||
current_sequence=0,
|
|
||||||
prepared_messages={},
|
|
||||||
committed_messages={},
|
|
||||||
pre_prepare_messages={}
|
|
||||||
)
|
|
||||||
self.fault_tolerance = max(1, len(consensus.get_consensus_participants()) // 3)
|
|
||||||
self.required_messages = 2 * self.fault_tolerance + 1
|
|
||||||
|
|
||||||
def get_message_digest(self, block_hash: str, sequence: int, view: int) -> str:
|
|
||||||
"""Generate message digest for PBFT"""
|
|
||||||
content = f"{block_hash}:{sequence}:{view}"
|
|
||||||
return hashlib.sha256(content.encode()).hexdigest()
|
|
||||||
|
|
||||||
async def pre_prepare_phase(self, proposer: str, block_hash: str) -> bool:
|
|
||||||
"""Phase 1: Pre-prepare"""
|
|
||||||
sequence = self.state.current_sequence + 1
|
|
||||||
view = self.state.current_view
|
|
||||||
digest = self.get_message_digest(block_hash, sequence, view)
|
|
||||||
|
|
||||||
message = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PRE_PREPARE,
|
|
||||||
sender=proposer,
|
|
||||||
view_number=view,
|
|
||||||
sequence_number=sequence,
|
|
||||||
digest=digest,
|
|
||||||
signature="", # Would be signed in real implementation
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store pre-prepare message
|
|
||||||
key = f"{sequence}:{view}"
|
|
||||||
self.state.pre_prepare_messages[key] = message
|
|
||||||
|
|
||||||
# Broadcast to all validators
|
|
||||||
await self._broadcast_message(message)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def prepare_phase(self, validator: str, pre_prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 2: Prepare"""
|
|
||||||
key = f"{pre_prepare_msg.sequence_number}:{pre_prepare_msg.view_number}"
|
|
||||||
|
|
||||||
if key not in self.state.pre_prepare_messages:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Create prepare message
|
|
||||||
prepare_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.PREPARE,
|
|
||||||
sender=validator,
|
|
||||||
view_number=pre_prepare_msg.view_number,
|
|
||||||
sequence_number=pre_prepare_msg.sequence_number,
|
|
||||||
digest=pre_prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store prepare message
|
|
||||||
if key not in self.state.prepared_messages:
|
|
||||||
self.state.prepared_messages[key] = []
|
|
||||||
self.state.prepared_messages[key].append(prepare_msg)
|
|
||||||
|
|
||||||
# Broadcast prepare message
|
|
||||||
await self._broadcast_message(prepare_msg)
|
|
||||||
|
|
||||||
# Check if we have enough prepare messages
|
|
||||||
return len(self.state.prepared_messages[key]) >= self.required_messages
|
|
||||||
|
|
||||||
async def commit_phase(self, validator: str, prepare_msg: PBFTMessage) -> bool:
|
|
||||||
"""Phase 3: Commit"""
|
|
||||||
key = f"{prepare_msg.sequence_number}:{prepare_msg.view_number}"
|
|
||||||
|
|
||||||
# Create commit message
|
|
||||||
commit_msg = PBFTMessage(
|
|
||||||
message_type=PBFTMessageType.COMMIT,
|
|
||||||
sender=validator,
|
|
||||||
view_number=prepare_msg.view_number,
|
|
||||||
sequence_number=prepare_msg.sequence_number,
|
|
||||||
digest=prepare_msg.digest,
|
|
||||||
signature="", # Would be signed
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store commit message
|
|
||||||
if key not in self.state.committed_messages:
|
|
||||||
self.state.committed_messages[key] = []
|
|
||||||
self.state.committed_messages[key].append(commit_msg)
|
|
||||||
|
|
||||||
# Broadcast commit message
|
|
||||||
await self._broadcast_message(commit_msg)
|
|
||||||
|
|
||||||
# Check if we have enough commit messages
|
|
||||||
if len(self.state.committed_messages[key]) >= self.required_messages:
|
|
||||||
return await self.execute_phase(key)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def execute_phase(self, key: str) -> bool:
|
|
||||||
"""Phase 4: Execute"""
|
|
||||||
# Extract sequence and view from key
|
|
||||||
sequence, view = map(int, key.split(':'))
|
|
||||||
|
|
||||||
# Update state
|
|
||||||
self.state.current_sequence = sequence
|
|
||||||
|
|
||||||
# Clean up old messages
|
|
||||||
self._cleanup_messages(sequence)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _broadcast_message(self, message: PBFTMessage):
|
|
||||||
"""Broadcast message to all validators"""
|
|
||||||
validators = self.consensus.get_consensus_participants()
|
|
||||||
|
|
||||||
for validator in validators:
|
|
||||||
if validator != message.sender:
|
|
||||||
# In real implementation, this would send over network
|
|
||||||
await self._send_to_validator(validator, message)
|
|
||||||
|
|
||||||
async def _send_to_validator(self, validator: str, message: PBFTMessage):
|
|
||||||
"""Send message to specific validator"""
|
|
||||||
# Network communication would be implemented here
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _cleanup_messages(self, sequence: int):
|
|
||||||
"""Clean up old messages to prevent memory leaks"""
|
|
||||||
old_keys = [
|
|
||||||
key for key in self.state.prepared_messages.keys()
|
|
||||||
if int(key.split(':')[0]) < sequence
|
|
||||||
]
|
|
||||||
|
|
||||||
for key in old_keys:
|
|
||||||
self.state.prepared_messages.pop(key, None)
|
|
||||||
self.state.committed_messages.pop(key, None)
|
|
||||||
self.state.pre_prepare_messages.pop(key, None)
|
|
||||||
|
|
||||||
def handle_view_change(self, new_view: int) -> bool:
|
|
||||||
"""Handle view change when proposer fails"""
|
|
||||||
self.state.current_view = new_view
|
|
||||||
# Reset state for new view
|
|
||||||
self.state.prepared_messages.clear()
|
|
||||||
self.state.committed_messages.clear()
|
|
||||||
self.state.pre_prepare_messages.clear()
|
|
||||||
return True
|
|
||||||
@@ -1,345 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block, Account
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
await self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
await self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool and include transactions
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
from ..models import Transaction, Account
|
|
||||||
mempool = get_mempool()
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
# Pull transactions from mempool
|
|
||||||
max_txs = self._config.max_txs_per_block
|
|
||||||
max_bytes = self._config.max_block_size_bytes
|
|
||||||
pending_txs = mempool.drain(max_txs, max_bytes, self._config.chain_id)
|
|
||||||
self._logger.info(f"[PROPOSE] drained {len(pending_txs)} txs from mempool, chain={self._config.chain_id}")
|
|
||||||
|
|
||||||
# Process transactions and update balances
|
|
||||||
processed_txs = []
|
|
||||||
for tx in pending_txs:
|
|
||||||
try:
|
|
||||||
# Parse transaction data
|
|
||||||
tx_data = tx.content
|
|
||||||
sender = tx_data.get("from")
|
|
||||||
recipient = tx_data.get("to")
|
|
||||||
value = tx_data.get("amount", 0)
|
|
||||||
fee = tx_data.get("fee", 0)
|
|
||||||
|
|
||||||
if not sender or not recipient:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get sender account
|
|
||||||
sender_account = session.get(Account, (self._config.chain_id, sender))
|
|
||||||
if not sender_account:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check sufficient balance
|
|
||||||
total_cost = value + fee
|
|
||||||
if sender_account.balance < total_cost:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get or create recipient account
|
|
||||||
recipient_account = session.get(Account, (self._config.chain_id, recipient))
|
|
||||||
if not recipient_account:
|
|
||||||
recipient_account = Account(chain_id=self._config.chain_id, address=recipient, balance=0, nonce=0)
|
|
||||||
session.add(recipient_account)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
# Update balances
|
|
||||||
sender_account.balance -= total_cost
|
|
||||||
sender_account.nonce += 1
|
|
||||||
recipient_account.balance += value
|
|
||||||
|
|
||||||
# Create transaction record
|
|
||||||
transaction = Transaction(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
tx_hash=tx.tx_hash,
|
|
||||||
sender=sender,
|
|
||||||
recipient=recipient,
|
|
||||||
payload=tx_data,
|
|
||||||
value=value,
|
|
||||||
fee=fee,
|
|
||||||
nonce=sender_account.nonce - 1,
|
|
||||||
timestamp=timestamp,
|
|
||||||
block_height=next_height,
|
|
||||||
status="confirmed"
|
|
||||||
)
|
|
||||||
session.add(transaction)
|
|
||||||
processed_txs.append(tx)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.warning(f"Failed to process transaction {tx.tx_hash}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Compute block hash with transaction data
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp, processed_txs)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=len(processed_txs),
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
tx_list = [tx.content for tx in processed_txs] if processed_txs else []
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
"transactions": tx_list,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer=self._config.proposer_id, # Use configured proposer as genesis proposer
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Initialize accounts from genesis allocations file (if present)
|
|
||||||
await self._initialize_genesis_allocations(session)
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"chain_id": self._config.chain_id,
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _initialize_genesis_allocations(self, session: Session) -> None:
|
|
||||||
"""Create Account entries from the genesis allocations file."""
|
|
||||||
# Use standardized data directory from configuration
|
|
||||||
from ..config import settings
|
|
||||||
|
|
||||||
genesis_paths = [
|
|
||||||
Path(f"/var/lib/aitbc/data/{self._config.chain_id}/genesis.json"), # Standard location
|
|
||||||
]
|
|
||||||
|
|
||||||
genesis_path = None
|
|
||||||
for path in genesis_paths:
|
|
||||||
if path.exists():
|
|
||||||
genesis_path = path
|
|
||||||
break
|
|
||||||
|
|
||||||
if not genesis_path:
|
|
||||||
self._logger.warning("Genesis allocations file not found; skipping account initialization", extra={"paths": str(genesis_paths)})
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(genesis_path) as f:
|
|
||||||
genesis_data = json.load(f)
|
|
||||||
|
|
||||||
allocations = genesis_data.get("allocations", [])
|
|
||||||
created = 0
|
|
||||||
for alloc in allocations:
|
|
||||||
addr = alloc["address"]
|
|
||||||
balance = int(alloc["balance"])
|
|
||||||
nonce = int(alloc.get("nonce", 0))
|
|
||||||
# Check if account already exists (idempotent)
|
|
||||||
acct = session.get(Account, (self._config.chain_id, addr))
|
|
||||||
if acct is None:
|
|
||||||
acct = Account(chain_id=self._config.chain_id, address=addr, balance=balance, nonce=nonce)
|
|
||||||
session.add(acct)
|
|
||||||
created += 1
|
|
||||||
session.commit()
|
|
||||||
self._logger.info("Initialized genesis accounts", extra={"count": created, "total": len(allocations), "path": str(genesis_path)})
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime, transactions: list = None) -> str:
|
|
||||||
# Include transaction hashes in block hash computation
|
|
||||||
tx_hashes = []
|
|
||||||
if transactions:
|
|
||||||
tx_hashes = [tx.tx_hash for tx in transactions]
|
|
||||||
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}|{'|'.join(sorted(tx_hashes))}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,229 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Callable, ContextManager, Optional
|
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
|
|
||||||
from ..logger import get_logger
|
|
||||||
from ..metrics import metrics_registry
|
|
||||||
from ..config import ProposerConfig
|
|
||||||
from ..models import Block
|
|
||||||
from ..gossip import gossip_broker
|
|
||||||
|
|
||||||
_METRIC_KEY_SANITIZE = re.compile(r"[^a-zA-Z0-9_]")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_metric_suffix(value: str) -> str:
|
|
||||||
sanitized = _METRIC_KEY_SANITIZE.sub("_", value).strip("_")
|
|
||||||
return sanitized or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
class CircuitBreaker:
|
|
||||||
def __init__(self, threshold: int, timeout: int):
|
|
||||||
self._threshold = threshold
|
|
||||||
self._timeout = timeout
|
|
||||||
self._failures = 0
|
|
||||||
self._last_failure_time = 0.0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> str:
|
|
||||||
if self._state == "open":
|
|
||||||
if time.time() - self._last_failure_time > self._timeout:
|
|
||||||
self._state = "half-open"
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def allow_request(self) -> bool:
|
|
||||||
state = self.state
|
|
||||||
if state == "closed":
|
|
||||||
return True
|
|
||||||
if state == "half-open":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
|
||||||
self._failures += 1
|
|
||||||
self._last_failure_time = time.time()
|
|
||||||
if self._failures >= self._threshold:
|
|
||||||
self._state = "open"
|
|
||||||
|
|
||||||
def record_success(self) -> None:
|
|
||||||
self._failures = 0
|
|
||||||
self._state = "closed"
|
|
||||||
|
|
||||||
class PoAProposer:
|
|
||||||
"""Proof-of-Authority block proposer.
|
|
||||||
|
|
||||||
Responsible for periodically proposing blocks if this node is configured as a proposer.
|
|
||||||
In the real implementation, this would involve checking the mempool, validating transactions,
|
|
||||||
and signing the block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
config: ProposerConfig,
|
|
||||||
session_factory: Callable[[], ContextManager[Session]],
|
|
||||||
) -> None:
|
|
||||||
self._config = config
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._logger = get_logger(__name__)
|
|
||||||
self._stop_event = asyncio.Event()
|
|
||||||
self._task: Optional[asyncio.Task[None]] = None
|
|
||||||
self._last_proposer_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
if self._task is not None:
|
|
||||||
return
|
|
||||||
self._logger.info("Starting PoA proposer loop", extra={"interval": self._config.interval_seconds})
|
|
||||||
self._ensure_genesis_block()
|
|
||||||
self._stop_event.clear()
|
|
||||||
self._task = asyncio.create_task(self._run_loop())
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
|
||||||
if self._task is None:
|
|
||||||
return
|
|
||||||
self._logger.info("Stopping PoA proposer loop")
|
|
||||||
self._stop_event.set()
|
|
||||||
await self._task
|
|
||||||
self._task = None
|
|
||||||
|
|
||||||
async def _run_loop(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
await self._wait_until_next_slot()
|
|
||||||
if self._stop_event.is_set():
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
self._propose_block()
|
|
||||||
except Exception as exc: # pragma: no cover - defensive logging
|
|
||||||
self._logger.exception("Failed to propose block", extra={"error": str(exc)})
|
|
||||||
|
|
||||||
async def _wait_until_next_slot(self) -> None:
|
|
||||||
head = self._fetch_chain_head()
|
|
||||||
if head is None:
|
|
||||||
return
|
|
||||||
now = datetime.utcnow()
|
|
||||||
elapsed = (now - head.timestamp).total_seconds()
|
|
||||||
sleep_for = max(self._config.interval_seconds - elapsed, 0.1)
|
|
||||||
if sleep_for <= 0:
|
|
||||||
sleep_for = 0.1
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self._stop_event.wait(), timeout=sleep_for)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _propose_block(self) -> None:
|
|
||||||
# Check internal mempool
|
|
||||||
from ..mempool import get_mempool
|
|
||||||
if get_mempool().size(self._config.chain_id) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
next_height = 0
|
|
||||||
parent_hash = "0x00"
|
|
||||||
interval_seconds: Optional[float] = None
|
|
||||||
if head is not None:
|
|
||||||
next_height = head.height + 1
|
|
||||||
parent_hash = head.hash
|
|
||||||
interval_seconds = (datetime.utcnow() - head.timestamp).total_seconds()
|
|
||||||
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
block_hash = self._compute_block_hash(next_height, parent_hash, timestamp)
|
|
||||||
|
|
||||||
block = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=next_height,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash=parent_hash,
|
|
||||||
proposer=self._config.proposer_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(block)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
metrics_registry.increment("blocks_proposed_total")
|
|
||||||
metrics_registry.set_gauge("chain_head_height", float(next_height))
|
|
||||||
if interval_seconds is not None and interval_seconds >= 0:
|
|
||||||
metrics_registry.observe("block_interval_seconds", interval_seconds)
|
|
||||||
metrics_registry.set_gauge("poa_last_block_interval_seconds", float(interval_seconds))
|
|
||||||
|
|
||||||
proposer_suffix = _sanitize_metric_suffix(self._config.proposer_id)
|
|
||||||
metrics_registry.increment(f"poa_blocks_proposed_total_{proposer_suffix}")
|
|
||||||
if self._last_proposer_id is not None and self._last_proposer_id != self._config.proposer_id:
|
|
||||||
metrics_registry.increment("poa_proposer_switches_total")
|
|
||||||
self._last_proposer_id = self._config.proposer_id
|
|
||||||
|
|
||||||
self._logger.info(
|
|
||||||
"Proposed block",
|
|
||||||
extra={
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast the new block
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": block.height,
|
|
||||||
"hash": block.hash,
|
|
||||||
"parent_hash": block.parent_hash,
|
|
||||||
"proposer": block.proposer,
|
|
||||||
"timestamp": block.timestamp.isoformat(),
|
|
||||||
"tx_count": block.tx_count,
|
|
||||||
"state_root": block.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_genesis_block(self) -> None:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
head = session.exec(select(Block).where(Block.chain_id == self._config.chain_id).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
if head is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use a deterministic genesis timestamp so all nodes agree on the genesis block hash
|
|
||||||
timestamp = datetime(2025, 1, 1, 0, 0, 0)
|
|
||||||
block_hash = self._compute_block_hash(0, "0x00", timestamp)
|
|
||||||
genesis = Block(
|
|
||||||
chain_id=self._config.chain_id,
|
|
||||||
height=0,
|
|
||||||
hash=block_hash,
|
|
||||||
parent_hash="0x00",
|
|
||||||
proposer="genesis",
|
|
||||||
timestamp=timestamp,
|
|
||||||
tx_count=0,
|
|
||||||
state_root=None,
|
|
||||||
)
|
|
||||||
session.add(genesis)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Broadcast genesis block for initial sync
|
|
||||||
await gossip_broker.publish(
|
|
||||||
"blocks",
|
|
||||||
{
|
|
||||||
"height": genesis.height,
|
|
||||||
"hash": genesis.hash,
|
|
||||||
"parent_hash": genesis.parent_hash,
|
|
||||||
"proposer": genesis.proposer,
|
|
||||||
"timestamp": genesis.timestamp.isoformat(),
|
|
||||||
"tx_count": genesis.tx_count,
|
|
||||||
"state_root": genesis.state_root,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fetch_chain_head(self) -> Optional[Block]:
|
|
||||||
with self._session_factory() as session:
|
|
||||||
return session.exec(select(Block).order_by(Block.height.desc()).limit(1)).first()
|
|
||||||
|
|
||||||
def _compute_block_hash(self, height: int, parent_hash: str, timestamp: datetime) -> str:
|
|
||||||
payload = f"{self._config.chain_id}|{height}|{parent_hash}|{timestamp.isoformat()}".encode()
|
|
||||||
return "0x" + hashlib.sha256(payload).hexdigest()
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
--- apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
+++ apps/blockchain-node/src/aitbc_chain/consensus/poa.py
|
|
||||||
@@ -101,7 +101,7 @@
|
|
||||||
# Wait for interval before proposing next block
|
|
||||||
await asyncio.sleep(self.config.interval_seconds)
|
|
||||||
|
|
||||||
- self._propose_block()
|
|
||||||
+ await self._propose_block()
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
"""
|
|
||||||
Validator Rotation Mechanism
|
|
||||||
Handles automatic rotation of validators based on performance and stake
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from typing import List, Dict, Optional
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import MultiValidatorPoA, Validator, ValidatorRole
|
|
||||||
|
|
||||||
class RotationStrategy(Enum):
|
|
||||||
ROUND_ROBIN = "round_robin"
|
|
||||||
STAKE_WEIGHTED = "stake_weighted"
|
|
||||||
REPUTATION_BASED = "reputation_based"
|
|
||||||
HYBRID = "hybrid"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RotationConfig:
|
|
||||||
strategy: RotationStrategy
|
|
||||||
rotation_interval: int # blocks
|
|
||||||
min_stake: float
|
|
||||||
reputation_threshold: float
|
|
||||||
max_validators: int
|
|
||||||
|
|
||||||
class ValidatorRotation:
|
|
||||||
"""Manages validator rotation based on various strategies"""
|
|
||||||
|
|
||||||
def __init__(self, consensus: MultiValidatorPoA, config: RotationConfig):
|
|
||||||
self.consensus = consensus
|
|
||||||
self.config = config
|
|
||||||
self.last_rotation_height = 0
|
|
||||||
|
|
||||||
def should_rotate(self, current_height: int) -> bool:
|
|
||||||
"""Check if rotation should occur at current height"""
|
|
||||||
return (current_height - self.last_rotation_height) >= self.config.rotation_interval
|
|
||||||
|
|
||||||
def rotate_validators(self, current_height: int) -> bool:
|
|
||||||
"""Perform validator rotation based on configured strategy"""
|
|
||||||
if not self.should_rotate(current_height):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.config.strategy == RotationStrategy.ROUND_ROBIN:
|
|
||||||
return self._rotate_round_robin()
|
|
||||||
elif self.config.strategy == RotationStrategy.STAKE_WEIGHTED:
|
|
||||||
return self._rotate_stake_weighted()
|
|
||||||
elif self.config.strategy == RotationStrategy.REPUTATION_BASED:
|
|
||||||
return self._rotate_reputation_based()
|
|
||||||
elif self.config.strategy == RotationStrategy.HYBRID:
|
|
||||||
return self._rotate_hybrid()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _rotate_round_robin(self) -> bool:
|
|
||||||
"""Round-robin rotation of validator roles"""
|
|
||||||
validators = list(self.consensus.validators.values())
|
|
||||||
active_validators = [v for v in validators if v.is_active]
|
|
||||||
|
|
||||||
# Rotate roles among active validators
|
|
||||||
for i, validator in enumerate(active_validators):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 3: # Top 3 become validators
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_stake_weighted(self) -> bool:
|
|
||||||
"""Stake-weighted rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.stake,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_reputation_based(self) -> bool:
|
|
||||||
"""Reputation-based rotation"""
|
|
||||||
validators = sorted(
|
|
||||||
[v for v in self.consensus.validators.values() if v.is_active],
|
|
||||||
key=lambda v: v.reputation,
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter by reputation threshold
|
|
||||||
qualified_validators = [
|
|
||||||
v for v in validators
|
|
||||||
if v.reputation >= self.config.reputation_threshold
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, validator in enumerate(qualified_validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _rotate_hybrid(self) -> bool:
|
|
||||||
"""Hybrid rotation considering both stake and reputation"""
|
|
||||||
validators = [v for v in self.consensus.validators.values() if v.is_active]
|
|
||||||
|
|
||||||
# Calculate hybrid score
|
|
||||||
for validator in validators:
|
|
||||||
validator.hybrid_score = validator.stake * validator.reputation
|
|
||||||
|
|
||||||
# Sort by hybrid score
|
|
||||||
validators.sort(key=lambda v: v.hybrid_score, reverse=True)
|
|
||||||
|
|
||||||
for i, validator in enumerate(validators[:self.config.max_validators]):
|
|
||||||
if i == 0:
|
|
||||||
validator.role = ValidatorRole.PROPOSER
|
|
||||||
elif i < 4:
|
|
||||||
validator.role = ValidatorRole.VALIDATOR
|
|
||||||
else:
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
self.last_rotation_height += self.config.rotation_interval
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Default rotation configuration
|
|
||||||
DEFAULT_ROTATION_CONFIG = RotationConfig(
|
|
||||||
strategy=RotationStrategy.HYBRID,
|
|
||||||
rotation_interval=100, # Rotate every 100 blocks
|
|
||||||
min_stake=1000.0,
|
|
||||||
reputation_threshold=0.7,
|
|
||||||
max_validators=10
|
|
||||||
)
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
"""
|
|
||||||
Slashing Conditions Implementation
|
|
||||||
Handles detection and penalties for validator misbehavior
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Dict, List, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from .multi_validator_poa import Validator, ValidatorRole
|
|
||||||
|
|
||||||
class SlashingCondition(Enum):
|
|
||||||
DOUBLE_SIGN = "double_sign"
|
|
||||||
UNAVAILABLE = "unavailable"
|
|
||||||
INVALID_BLOCK = "invalid_block"
|
|
||||||
SLOW_RESPONSE = "slow_response"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SlashingEvent:
|
|
||||||
validator_address: str
|
|
||||||
condition: SlashingCondition
|
|
||||||
evidence: str
|
|
||||||
block_height: int
|
|
||||||
timestamp: float
|
|
||||||
slash_amount: float
|
|
||||||
|
|
||||||
class SlashingManager:
|
|
||||||
"""Manages validator slashing conditions and penalties"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.slashing_events: List[SlashingEvent] = []
|
|
||||||
self.slash_rates = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 0.5, # 50% slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 0.1, # 10% slash
|
|
||||||
SlashingCondition.INVALID_BLOCK: 0.3, # 30% slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 0.05 # 5% slash
|
|
||||||
}
|
|
||||||
self.slash_thresholds = {
|
|
||||||
SlashingCondition.DOUBLE_SIGN: 1, # Immediate slash
|
|
||||||
SlashingCondition.UNAVAILABLE: 3, # After 3 offenses
|
|
||||||
SlashingCondition.INVALID_BLOCK: 1, # Immediate slash
|
|
||||||
SlashingCondition.SLOW_RESPONSE: 5 # After 5 offenses
|
|
||||||
}
|
|
||||||
|
|
||||||
def detect_double_sign(self, validator: str, block_hash1: str, block_hash2: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect double signing (validator signed two different blocks at same height)"""
|
|
||||||
if block_hash1 == block_hash2:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.DOUBLE_SIGN,
|
|
||||||
evidence=f"Double sign detected: {block_hash1} vs {block_hash2} at height {height}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.DOUBLE_SIGN]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_unavailability(self, validator: str, missed_blocks: int, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect validator unavailability (missing consensus participation)"""
|
|
||||||
if missed_blocks < self.slash_thresholds[SlashingCondition.UNAVAILABLE]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.UNAVAILABLE,
|
|
||||||
evidence=f"Missed {missed_blocks} consecutive blocks",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.UNAVAILABLE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_invalid_block(self, validator: str, block_hash: str, reason: str, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect invalid block proposal"""
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.INVALID_BLOCK,
|
|
||||||
evidence=f"Invalid block {block_hash}: {reason}",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.INVALID_BLOCK]
|
|
||||||
)
|
|
||||||
|
|
||||||
def detect_slow_response(self, validator: str, response_time: float, threshold: float, height: int) -> Optional[SlashingEvent]:
|
|
||||||
"""Detect slow consensus participation"""
|
|
||||||
if response_time <= threshold:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SlashingEvent(
|
|
||||||
validator_address=validator,
|
|
||||||
condition=SlashingCondition.SLOW_RESPONSE,
|
|
||||||
evidence=f"Slow response: {response_time}s (threshold: {threshold}s)",
|
|
||||||
block_height=height,
|
|
||||||
timestamp=time.time(),
|
|
||||||
slash_amount=self.slash_rates[SlashingCondition.SLOW_RESPONSE]
|
|
||||||
)
|
|
||||||
|
|
||||||
def apply_slashing(self, validator: Validator, event: SlashingEvent) -> bool:
|
|
||||||
"""Apply slashing penalty to validator"""
|
|
||||||
slash_amount = validator.stake * event.slash_amount
|
|
||||||
validator.stake -= slash_amount
|
|
||||||
|
|
||||||
# Demote validator role if stake is too low
|
|
||||||
if validator.stake < 100: # Minimum stake threshold
|
|
||||||
validator.role = ValidatorRole.STANDBY
|
|
||||||
|
|
||||||
# Record slashing event
|
|
||||||
self.slashing_events.append(event)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_validator_slash_count(self, validator_address: str, condition: SlashingCondition) -> int:
|
|
||||||
"""Get count of slashing events for validator and condition"""
|
|
||||||
return len([
|
|
||||||
event for event in self.slashing_events
|
|
||||||
if event.validator_address == validator_address and event.condition == condition
|
|
||||||
])
|
|
||||||
|
|
||||||
def should_slash(self, validator: str, condition: SlashingCondition) -> bool:
|
|
||||||
"""Check if validator should be slashed for condition"""
|
|
||||||
current_count = self.get_validator_slash_count(validator, condition)
|
|
||||||
threshold = self.slash_thresholds.get(condition, 1)
|
|
||||||
return current_count >= threshold
|
|
||||||
|
|
||||||
def get_slashing_history(self, validator_address: Optional[str] = None) -> List[SlashingEvent]:
|
|
||||||
"""Get slashing history for validator or all validators"""
|
|
||||||
if validator_address:
|
|
||||||
return [event for event in self.slashing_events if event.validator_address == validator_address]
|
|
||||||
return self.slashing_events.copy()
|
|
||||||
|
|
||||||
def calculate_total_slashed(self, validator_address: str) -> float:
|
|
||||||
"""Calculate total amount slashed for validator"""
|
|
||||||
events = self.get_slashing_history(validator_address)
|
|
||||||
return sum(event.slash_amount for event in events)
|
|
||||||
|
|
||||||
# Global slashing manager
|
|
||||||
slashing_manager = SlashingManager()
|
|
||||||
@@ -1,519 +0,0 @@
|
|||||||
"""
|
|
||||||
AITBC Agent Messaging Contract Implementation
|
|
||||||
|
|
||||||
This module implements on-chain messaging functionality for agents,
|
|
||||||
enabling forum-like communication between autonomous agents.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Any
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from enum import Enum
|
|
||||||
import json
|
|
||||||
import hashlib
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address
|
|
||||||
|
|
||||||
class MessageType(Enum):
|
|
||||||
"""Types of messages agents can send"""
|
|
||||||
POST = "post"
|
|
||||||
REPLY = "reply"
|
|
||||||
ANNOUNCEMENT = "announcement"
|
|
||||||
QUESTION = "question"
|
|
||||||
ANSWER = "answer"
|
|
||||||
MODERATION = "moderation"
|
|
||||||
|
|
||||||
class MessageStatus(Enum):
|
|
||||||
"""Status of messages in the forum"""
|
|
||||||
ACTIVE = "active"
|
|
||||||
HIDDEN = "hidden"
|
|
||||||
DELETED = "deleted"
|
|
||||||
PINNED = "pinned"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Message:
|
|
||||||
"""Represents a message in the agent forum"""
|
|
||||||
message_id: str
|
|
||||||
agent_id: str
|
|
||||||
agent_address: str
|
|
||||||
topic: str
|
|
||||||
content: str
|
|
||||||
message_type: MessageType
|
|
||||||
timestamp: datetime
|
|
||||||
parent_message_id: Optional[str] = None
|
|
||||||
reply_count: int = 0
|
|
||||||
upvotes: int = 0
|
|
||||||
downvotes: int = 0
|
|
||||||
status: MessageStatus = MessageStatus.ACTIVE
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Topic:
|
|
||||||
"""Represents a forum topic"""
|
|
||||||
topic_id: str
|
|
||||||
title: str
|
|
||||||
description: str
|
|
||||||
creator_agent_id: str
|
|
||||||
created_at: datetime
|
|
||||||
message_count: int = 0
|
|
||||||
last_activity: datetime = field(default_factory=datetime.now)
|
|
||||||
tags: List[str] = field(default_factory=list)
|
|
||||||
is_pinned: bool = False
|
|
||||||
is_locked: bool = False
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentReputation:
|
|
||||||
"""Reputation system for agents"""
|
|
||||||
agent_id: str
|
|
||||||
message_count: int = 0
|
|
||||||
upvotes_received: int = 0
|
|
||||||
downvotes_received: int = 0
|
|
||||||
reputation_score: float = 0.0
|
|
||||||
trust_level: int = 1 # 1-5 trust levels
|
|
||||||
is_moderator: bool = False
|
|
||||||
is_banned: bool = False
|
|
||||||
ban_reason: Optional[str] = None
|
|
||||||
ban_expires: Optional[datetime] = None
|
|
||||||
|
|
||||||
class AgentMessagingContract:
|
|
||||||
"""Main contract for agent messaging functionality"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.messages: Dict[str, Message] = {}
|
|
||||||
self.topics: Dict[str, Topic] = {}
|
|
||||||
self.agent_reputations: Dict[str, AgentReputation] = {}
|
|
||||||
self.moderation_log: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
def create_topic(self, agent_id: str, agent_address: str, title: str,
|
|
||||||
description: str, tags: List[str] = None) -> Dict[str, Any]:
|
|
||||||
"""Create a new forum topic"""
|
|
||||||
|
|
||||||
# Check if agent is banned
|
|
||||||
if self._is_agent_banned(agent_id):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Agent is banned from posting",
|
|
||||||
"error_code": "AGENT_BANNED"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate topic ID
|
|
||||||
topic_id = f"topic_{hashlib.sha256(f'{agent_id}_{title}_{datetime.now()}'.encode()).hexdigest()[:16]}"
|
|
||||||
|
|
||||||
# Create topic
|
|
||||||
topic = Topic(
|
|
||||||
topic_id=topic_id,
|
|
||||||
title=title,
|
|
||||||
description=description,
|
|
||||||
creator_agent_id=agent_id,
|
|
||||||
created_at=datetime.now(),
|
|
||||||
tags=tags or []
|
|
||||||
)
|
|
||||||
|
|
||||||
self.topics[topic_id] = topic
|
|
||||||
|
|
||||||
# Update agent reputation
|
|
||||||
self._update_agent_reputation(agent_id, message_count=1)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"topic_id": topic_id,
|
|
||||||
"topic": self._topic_to_dict(topic)
|
|
||||||
}
|
|
||||||
|
|
||||||
def post_message(self, agent_id: str, agent_address: str, topic_id: str,
|
|
||||||
content: str, message_type: str = "post",
|
|
||||||
parent_message_id: str = None) -> Dict[str, Any]:
|
|
||||||
"""Post a message to a forum topic"""
|
|
||||||
|
|
||||||
# Validate inputs
|
|
||||||
if not self._validate_agent(agent_id, agent_address):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid agent credentials",
|
|
||||||
"error_code": "INVALID_AGENT"
|
|
||||||
}
|
|
||||||
|
|
||||||
if self._is_agent_banned(agent_id):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Agent is banned from posting",
|
|
||||||
"error_code": "AGENT_BANNED"
|
|
||||||
}
|
|
||||||
|
|
||||||
if topic_id not in self.topics:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Topic not found",
|
|
||||||
"error_code": "TOPIC_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.topics[topic_id].is_locked:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Topic is locked",
|
|
||||||
"error_code": "TOPIC_LOCKED"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate message type
|
|
||||||
try:
|
|
||||||
msg_type = MessageType(message_type)
|
|
||||||
except ValueError:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid message type",
|
|
||||||
"error_code": "INVALID_MESSAGE_TYPE"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate message ID
|
|
||||||
message_id = f"msg_{hashlib.sha256(f'{agent_id}_{topic_id}_{content}_{datetime.now()}'.encode()).hexdigest()[:16]}"
|
|
||||||
|
|
||||||
# Create message
|
|
||||||
message = Message(
|
|
||||||
message_id=message_id,
|
|
||||||
agent_id=agent_id,
|
|
||||||
agent_address=agent_address,
|
|
||||||
topic=topic_id,
|
|
||||||
content=content,
|
|
||||||
message_type=msg_type,
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
parent_message_id=parent_message_id
|
|
||||||
)
|
|
||||||
|
|
||||||
self.messages[message_id] = message
|
|
||||||
|
|
||||||
# Update topic
|
|
||||||
self.topics[topic_id].message_count += 1
|
|
||||||
self.topics[topic_id].last_activity = datetime.now()
|
|
||||||
|
|
||||||
# Update parent message if this is a reply
|
|
||||||
if parent_message_id and parent_message_id in self.messages:
|
|
||||||
self.messages[parent_message_id].reply_count += 1
|
|
||||||
|
|
||||||
# Update agent reputation
|
|
||||||
self._update_agent_reputation(agent_id, message_count=1)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message_id": message_id,
|
|
||||||
"message": self._message_to_dict(message)
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_messages(self, topic_id: str, limit: int = 50, offset: int = 0,
|
|
||||||
sort_by: str = "timestamp") -> Dict[str, Any]:
|
|
||||||
"""Get messages from a topic"""
|
|
||||||
|
|
||||||
if topic_id not in self.topics:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Topic not found",
|
|
||||||
"error_code": "TOPIC_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get all messages for this topic
|
|
||||||
topic_messages = [
|
|
||||||
msg for msg in self.messages.values()
|
|
||||||
if msg.topic == topic_id and msg.status == MessageStatus.ACTIVE
|
|
||||||
]
|
|
||||||
|
|
||||||
# Sort messages
|
|
||||||
if sort_by == "timestamp":
|
|
||||||
topic_messages.sort(key=lambda x: x.timestamp, reverse=True)
|
|
||||||
elif sort_by == "upvotes":
|
|
||||||
topic_messages.sort(key=lambda x: x.upvotes, reverse=True)
|
|
||||||
elif sort_by == "replies":
|
|
||||||
topic_messages.sort(key=lambda x: x.reply_count, reverse=True)
|
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
total_messages = len(topic_messages)
|
|
||||||
paginated_messages = topic_messages[offset:offset + limit]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"messages": [self._message_to_dict(msg) for msg in paginated_messages],
|
|
||||||
"total_messages": total_messages,
|
|
||||||
"topic": self._topic_to_dict(self.topics[topic_id])
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_topics(self, limit: int = 50, offset: int = 0,
|
|
||||||
sort_by: str = "last_activity") -> Dict[str, Any]:
|
|
||||||
"""Get list of forum topics"""
|
|
||||||
|
|
||||||
# Sort topics
|
|
||||||
topic_list = list(self.topics.values())
|
|
||||||
|
|
||||||
if sort_by == "last_activity":
|
|
||||||
topic_list.sort(key=lambda x: x.last_activity, reverse=True)
|
|
||||||
elif sort_by == "created_at":
|
|
||||||
topic_list.sort(key=lambda x: x.created_at, reverse=True)
|
|
||||||
elif sort_by == "message_count":
|
|
||||||
topic_list.sort(key=lambda x: x.message_count, reverse=True)
|
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
total_topics = len(topic_list)
|
|
||||||
paginated_topics = topic_list[offset:offset + limit]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"topics": [self._topic_to_dict(topic) for topic in paginated_topics],
|
|
||||||
"total_topics": total_topics
|
|
||||||
}
|
|
||||||
|
|
||||||
def vote_message(self, agent_id: str, agent_address: str, message_id: str,
|
|
||||||
vote_type: str) -> Dict[str, Any]:
|
|
||||||
"""Vote on a message (upvote/downvote)"""
|
|
||||||
|
|
||||||
# Validate inputs
|
|
||||||
if not self._validate_agent(agent_id, agent_address):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid agent credentials",
|
|
||||||
"error_code": "INVALID_AGENT"
|
|
||||||
}
|
|
||||||
|
|
||||||
if message_id not in self.messages:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Message not found",
|
|
||||||
"error_code": "MESSAGE_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
if vote_type not in ["upvote", "downvote"]:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid vote type",
|
|
||||||
"error_code": "INVALID_VOTE_TYPE"
|
|
||||||
}
|
|
||||||
|
|
||||||
message = self.messages[message_id]
|
|
||||||
|
|
||||||
# Update vote counts
|
|
||||||
if vote_type == "upvote":
|
|
||||||
message.upvotes += 1
|
|
||||||
else:
|
|
||||||
message.downvotes += 1
|
|
||||||
|
|
||||||
# Update message author reputation
|
|
||||||
self._update_agent_reputation(
|
|
||||||
message.agent_id,
|
|
||||||
upvotes_received=message.upvotes,
|
|
||||||
downvotes_received=message.downvotes
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message_id": message_id,
|
|
||||||
"upvotes": message.upvotes,
|
|
||||||
"downvotes": message.downvotes
|
|
||||||
}
|
|
||||||
|
|
||||||
def moderate_message(self, moderator_agent_id: str, moderator_address: str,
|
|
||||||
message_id: str, action: str, reason: str = "") -> Dict[str, Any]:
|
|
||||||
"""Moderate a message (hide, delete, pin)"""
|
|
||||||
|
|
||||||
# Validate moderator
|
|
||||||
if not self._is_moderator(moderator_agent_id):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Insufficient permissions",
|
|
||||||
"error_code": "INSUFFICIENT_PERMISSIONS"
|
|
||||||
}
|
|
||||||
|
|
||||||
if message_id not in self.messages:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Message not found",
|
|
||||||
"error_code": "MESSAGE_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
message = self.messages[message_id]
|
|
||||||
|
|
||||||
# Apply moderation action
|
|
||||||
if action == "hide":
|
|
||||||
message.status = MessageStatus.HIDDEN
|
|
||||||
elif action == "delete":
|
|
||||||
message.status = MessageStatus.DELETED
|
|
||||||
elif action == "pin":
|
|
||||||
message.status = MessageStatus.PINNED
|
|
||||||
elif action == "unpin":
|
|
||||||
message.status = MessageStatus.ACTIVE
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid moderation action",
|
|
||||||
"error_code": "INVALID_ACTION"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log moderation action
|
|
||||||
self.moderation_log.append({
|
|
||||||
"timestamp": datetime.now(),
|
|
||||||
"moderator_agent_id": moderator_agent_id,
|
|
||||||
"message_id": message_id,
|
|
||||||
"action": action,
|
|
||||||
"reason": reason
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message_id": message_id,
|
|
||||||
"status": message.status.value
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_agent_reputation(self, agent_id: str) -> Dict[str, Any]:
|
|
||||||
"""Get an agent's reputation information"""
|
|
||||||
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Agent not found",
|
|
||||||
"error_code": "AGENT_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
reputation = self.agent_reputations[agent_id]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"agent_id": agent_id,
|
|
||||||
"reputation": self._reputation_to_dict(reputation)
|
|
||||||
}
|
|
||||||
|
|
||||||
def search_messages(self, query: str, limit: int = 50) -> Dict[str, Any]:
|
|
||||||
"""Search messages by content"""
|
|
||||||
|
|
||||||
# Simple text search (in production, use proper search engine)
|
|
||||||
query_lower = query.lower()
|
|
||||||
matching_messages = []
|
|
||||||
|
|
||||||
for message in self.messages.values():
|
|
||||||
if (message.status == MessageStatus.ACTIVE and
|
|
||||||
query_lower in message.content.lower()):
|
|
||||||
matching_messages.append(message)
|
|
||||||
|
|
||||||
# Sort by timestamp (most recent first)
|
|
||||||
matching_messages.sort(key=lambda x: x.timestamp, reverse=True)
|
|
||||||
|
|
||||||
# Limit results
|
|
||||||
limited_messages = matching_messages[:limit]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"query": query,
|
|
||||||
"messages": [self._message_to_dict(msg) for msg in limited_messages],
|
|
||||||
"total_matches": len(matching_messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
def _validate_agent(self, agent_id: str, agent_address: str) -> bool:
|
|
||||||
"""Validate agent credentials"""
|
|
||||||
# In a real implementation, this would verify the agent's signature
|
|
||||||
# For now, we'll do basic validation
|
|
||||||
return bool(agent_id and agent_address)
|
|
||||||
|
|
||||||
def _is_agent_banned(self, agent_id: str) -> bool:
|
|
||||||
"""Check if an agent is banned"""
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
return False
|
|
||||||
|
|
||||||
reputation = self.agent_reputations[agent_id]
|
|
||||||
|
|
||||||
if reputation.is_banned:
|
|
||||||
# Check if ban has expired
|
|
||||||
if reputation.ban_expires and datetime.now() > reputation.ban_expires:
|
|
||||||
reputation.is_banned = False
|
|
||||||
reputation.ban_expires = None
|
|
||||||
reputation.ban_reason = None
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _is_moderator(self, agent_id: str) -> bool:
|
|
||||||
"""Check if an agent is a moderator"""
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.agent_reputations[agent_id].is_moderator
|
|
||||||
|
|
||||||
def _update_agent_reputation(self, agent_id: str, message_count: int = 0,
|
|
||||||
upvotes_received: int = 0, downvotes_received: int = 0):
|
|
||||||
"""Update agent reputation"""
|
|
||||||
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
self.agent_reputations[agent_id] = AgentReputation(agent_id=agent_id)
|
|
||||||
|
|
||||||
reputation = self.agent_reputations[agent_id]
|
|
||||||
|
|
||||||
if message_count > 0:
|
|
||||||
reputation.message_count += message_count
|
|
||||||
|
|
||||||
if upvotes_received > 0:
|
|
||||||
reputation.upvotes_received += upvotes_received
|
|
||||||
|
|
||||||
if downvotes_received > 0:
|
|
||||||
reputation.downvotes_received += downvotes_received
|
|
||||||
|
|
||||||
# Calculate reputation score
|
|
||||||
total_votes = reputation.upvotes_received + reputation.downvotes_received
|
|
||||||
if total_votes > 0:
|
|
||||||
reputation.reputation_score = (reputation.upvotes_received - reputation.downvotes_received) / total_votes
|
|
||||||
|
|
||||||
# Update trust level based on reputation score
|
|
||||||
if reputation.reputation_score >= 0.8:
|
|
||||||
reputation.trust_level = 5
|
|
||||||
elif reputation.reputation_score >= 0.6:
|
|
||||||
reputation.trust_level = 4
|
|
||||||
elif reputation.reputation_score >= 0.4:
|
|
||||||
reputation.trust_level = 3
|
|
||||||
elif reputation.reputation_score >= 0.2:
|
|
||||||
reputation.trust_level = 2
|
|
||||||
else:
|
|
||||||
reputation.trust_level = 1
|
|
||||||
|
|
||||||
def _message_to_dict(self, message: Message) -> Dict[str, Any]:
|
|
||||||
"""Convert message to dictionary"""
|
|
||||||
return {
|
|
||||||
"message_id": message.message_id,
|
|
||||||
"agent_id": message.agent_id,
|
|
||||||
"agent_address": message.agent_address,
|
|
||||||
"topic": message.topic,
|
|
||||||
"content": message.content,
|
|
||||||
"message_type": message.message_type.value,
|
|
||||||
"timestamp": message.timestamp.isoformat(),
|
|
||||||
"parent_message_id": message.parent_message_id,
|
|
||||||
"reply_count": message.reply_count,
|
|
||||||
"upvotes": message.upvotes,
|
|
||||||
"downvotes": message.downvotes,
|
|
||||||
"status": message.status.value,
|
|
||||||
"metadata": message.metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
def _topic_to_dict(self, topic: Topic) -> Dict[str, Any]:
|
|
||||||
"""Convert topic to dictionary"""
|
|
||||||
return {
|
|
||||||
"topic_id": topic.topic_id,
|
|
||||||
"title": topic.title,
|
|
||||||
"description": topic.description,
|
|
||||||
"creator_agent_id": topic.creator_agent_id,
|
|
||||||
"created_at": topic.created_at.isoformat(),
|
|
||||||
"message_count": topic.message_count,
|
|
||||||
"last_activity": topic.last_activity.isoformat(),
|
|
||||||
"tags": topic.tags,
|
|
||||||
"is_pinned": topic.is_pinned,
|
|
||||||
"is_locked": topic.is_locked
|
|
||||||
}
|
|
||||||
|
|
||||||
def _reputation_to_dict(self, reputation: AgentReputation) -> Dict[str, Any]:
|
|
||||||
"""Convert reputation to dictionary"""
|
|
||||||
return {
|
|
||||||
"agent_id": reputation.agent_id,
|
|
||||||
"message_count": reputation.message_count,
|
|
||||||
"upvotes_received": reputation.upvotes_received,
|
|
||||||
"downvotes_received": reputation.downvotes_received,
|
|
||||||
"reputation_score": reputation.reputation_score,
|
|
||||||
"trust_level": reputation.trust_level,
|
|
||||||
"is_moderator": reputation.is_moderator,
|
|
||||||
"is_banned": reputation.is_banned,
|
|
||||||
"ban_reason": reputation.ban_reason,
|
|
||||||
"ban_expires": reputation.ban_expires.isoformat() if reputation.ban_expires else None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Global contract instance
|
|
||||||
messaging_contract = AgentMessagingContract()
|
|
||||||
@@ -1,584 +0,0 @@
|
|||||||
"""
|
|
||||||
AITBC Agent Wallet Security Implementation
|
|
||||||
|
|
||||||
This module implements the security layer for autonomous agent wallets,
|
|
||||||
integrating the guardian contract to prevent unlimited spending in case
|
|
||||||
of agent compromise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address
|
|
||||||
|
|
||||||
from .guardian_contract import (
|
|
||||||
GuardianContract,
|
|
||||||
SpendingLimit,
|
|
||||||
TimeLockConfig,
|
|
||||||
GuardianConfig,
|
|
||||||
create_guardian_contract,
|
|
||||||
CONSERVATIVE_CONFIG,
|
|
||||||
AGGRESSIVE_CONFIG,
|
|
||||||
HIGH_SECURITY_CONFIG
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentSecurityProfile:
|
|
||||||
"""Security profile for an agent"""
|
|
||||||
agent_address: str
|
|
||||||
security_level: str # "conservative", "aggressive", "high_security"
|
|
||||||
guardian_addresses: List[str]
|
|
||||||
custom_limits: Optional[Dict] = None
|
|
||||||
enabled: bool = True
|
|
||||||
created_at: datetime = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.created_at is None:
|
|
||||||
self.created_at = datetime.utcnow()
|
|
||||||
|
|
||||||
|
|
||||||
class AgentWalletSecurity:
|
|
||||||
"""
|
|
||||||
Security manager for autonomous agent wallets
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.agent_profiles: Dict[str, AgentSecurityProfile] = {}
|
|
||||||
self.guardian_contracts: Dict[str, GuardianContract] = {}
|
|
||||||
self.security_events: List[Dict] = []
|
|
||||||
|
|
||||||
# Default configurations
|
|
||||||
self.configurations = {
|
|
||||||
"conservative": CONSERVATIVE_CONFIG,
|
|
||||||
"aggressive": AGGRESSIVE_CONFIG,
|
|
||||||
"high_security": HIGH_SECURITY_CONFIG
|
|
||||||
}
|
|
||||||
|
|
||||||
def register_agent(self,
|
|
||||||
agent_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
guardian_addresses: List[str] = None,
|
|
||||||
custom_limits: Dict = None) -> Dict:
|
|
||||||
"""
|
|
||||||
Register an agent for security protection
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
security_level: Security level (conservative, aggressive, high_security)
|
|
||||||
guardian_addresses: List of guardian addresses for recovery
|
|
||||||
custom_limits: Custom spending limits (overrides security_level)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Registration result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
if agent_address in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent already registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate security level
|
|
||||||
if security_level not in self.configurations:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Invalid security level: {security_level}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Default guardians if none provided
|
|
||||||
if guardian_addresses is None:
|
|
||||||
guardian_addresses = [agent_address] # Self-guardian (should be overridden)
|
|
||||||
|
|
||||||
# Validate guardian addresses
|
|
||||||
guardian_addresses = [to_checksum_address(addr) for addr in guardian_addresses]
|
|
||||||
|
|
||||||
# Create security profile
|
|
||||||
profile = AgentSecurityProfile(
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
guardian_addresses=guardian_addresses,
|
|
||||||
custom_limits=custom_limits
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create guardian contract
|
|
||||||
config = self.configurations[security_level]
|
|
||||||
if custom_limits:
|
|
||||||
config.update(custom_limits)
|
|
||||||
|
|
||||||
guardian_contract = create_guardian_contract(
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardians=guardian_addresses,
|
|
||||||
**config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store profile and contract
|
|
||||||
self.agent_profiles[agent_address] = profile
|
|
||||||
self.guardian_contracts[agent_address] = guardian_contract
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="agent_registered",
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
guardian_count=len(guardian_addresses)
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "registered",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": security_level,
|
|
||||||
"guardian_addresses": guardian_addresses,
|
|
||||||
"limits": guardian_contract.config.limits,
|
|
||||||
"time_lock_threshold": guardian_contract.config.time_lock.threshold,
|
|
||||||
"registered_at": profile.created_at.isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Registration failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def protect_transaction(self,
|
|
||||||
agent_address: str,
|
|
||||||
to_address: str,
|
|
||||||
amount: int,
|
|
||||||
data: str = "") -> Dict:
|
|
||||||
"""
|
|
||||||
Protect a transaction with guardian contract
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
to_address: Recipient address
|
|
||||||
amount: Amount to transfer
|
|
||||||
data: Transaction data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Protection result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
# Check if agent is registered
|
|
||||||
if agent_address not in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "unprotected",
|
|
||||||
"reason": "Agent not registered for security protection",
|
|
||||||
"suggestion": "Register agent with register_agent() first"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check if protection is enabled
|
|
||||||
profile = self.agent_profiles[agent_address]
|
|
||||||
if not profile.enabled:
|
|
||||||
return {
|
|
||||||
"status": "unprotected",
|
|
||||||
"reason": "Security protection disabled for this agent"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get guardian contract
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
# Initiate transaction protection
|
|
||||||
result = guardian_contract.initiate_transaction(to_address, amount, data)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="transaction_protected",
|
|
||||||
agent_address=agent_address,
|
|
||||||
to_address=to_address,
|
|
||||||
amount=amount,
|
|
||||||
protection_status=result["status"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Transaction protection failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute_protected_transaction(self,
|
|
||||||
agent_address: str,
|
|
||||||
operation_id: str,
|
|
||||||
signature: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Execute a previously protected transaction
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
operation_id: Operation ID from protection
|
|
||||||
signature: Transaction signature
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Execution result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
result = guardian_contract.execute_transaction(operation_id, signature)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
if result["status"] == "executed":
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="transaction_executed",
|
|
||||||
agent_address=agent_address,
|
|
||||||
operation_id=operation_id,
|
|
||||||
transaction_hash=result.get("transaction_hash")
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Transaction execution failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def emergency_pause_agent(self, agent_address: str, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Emergency pause an agent's operations
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address initiating pause
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Pause result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
result = guardian_contract.emergency_pause(guardian_address)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
if result["status"] == "paused":
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="emergency_pause",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Emergency pause failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def update_agent_security(self,
|
|
||||||
agent_address: str,
|
|
||||||
new_limits: Dict,
|
|
||||||
guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Update security limits for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
new_limits: New spending limits
|
|
||||||
guardian_address: Guardian address making the change
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Update result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
# Create new spending limits
|
|
||||||
limits = SpendingLimit(
|
|
||||||
per_transaction=new_limits.get("per_transaction", 1000),
|
|
||||||
per_hour=new_limits.get("per_hour", 5000),
|
|
||||||
per_day=new_limits.get("per_day", 20000),
|
|
||||||
per_week=new_limits.get("per_week", 100000)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = guardian_contract.update_limits(limits, guardian_address)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
if result["status"] == "updated":
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="security_limits_updated",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address,
|
|
||||||
new_limits=new_limits
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Security update failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_agent_security_status(self, agent_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Get security status for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Security status
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
if agent_address not in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "not_registered",
|
|
||||||
"message": "Agent not registered for security protection"
|
|
||||||
}
|
|
||||||
|
|
||||||
profile = self.agent_profiles[agent_address]
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "protected",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": profile.security_level,
|
|
||||||
"enabled": profile.enabled,
|
|
||||||
"guardian_addresses": profile.guardian_addresses,
|
|
||||||
"registered_at": profile.created_at.isoformat(),
|
|
||||||
"spending_status": guardian_contract.get_spending_status(),
|
|
||||||
"pending_operations": guardian_contract.get_pending_operations(),
|
|
||||||
"recent_activity": guardian_contract.get_operation_history(10)
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Status check failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def list_protected_agents(self) -> List[Dict]:
|
|
||||||
"""List all protected agents"""
|
|
||||||
agents = []
|
|
||||||
|
|
||||||
for agent_address, profile in self.agent_profiles.items():
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
agents.append({
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": profile.security_level,
|
|
||||||
"enabled": profile.enabled,
|
|
||||||
"guardian_count": len(profile.guardian_addresses),
|
|
||||||
"pending_operations": len(guardian_contract.pending_operations),
|
|
||||||
"paused": guardian_contract.paused,
|
|
||||||
"emergency_mode": guardian_contract.emergency_mode,
|
|
||||||
"registered_at": profile.created_at.isoformat()
|
|
||||||
})
|
|
||||||
|
|
||||||
return sorted(agents, key=lambda x: x["registered_at"], reverse=True)
|
|
||||||
|
|
||||||
def get_security_events(self, agent_address: str = None, limit: int = 50) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Get security events
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Filter by agent address (optional)
|
|
||||||
limit: Maximum number of events
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Security events
|
|
||||||
"""
|
|
||||||
events = self.security_events
|
|
||||||
|
|
||||||
if agent_address:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
events = [e for e in events if e.get("agent_address") == agent_address]
|
|
||||||
|
|
||||||
return sorted(events, key=lambda x: x["timestamp"], reverse=True)[:limit]
|
|
||||||
|
|
||||||
def _log_security_event(self, **kwargs):
|
|
||||||
"""Log a security event"""
|
|
||||||
event = {
|
|
||||||
"timestamp": datetime.utcnow().isoformat(),
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
self.security_events.append(event)
|
|
||||||
|
|
||||||
def disable_agent_protection(self, agent_address: str, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Disable protection for an agent (guardian only)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Disable result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
if agent_address not in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
profile = self.agent_profiles[agent_address]
|
|
||||||
|
|
||||||
if guardian_address not in profile.guardian_addresses:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Not authorized: not a guardian"
|
|
||||||
}
|
|
||||||
|
|
||||||
profile.enabled = False
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="protection_disabled",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "disabled",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"disabled_at": datetime.utcnow().isoformat(),
|
|
||||||
"guardian": guardian_address
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Disable protection failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Global security manager instance
|
|
||||||
agent_wallet_security = AgentWalletSecurity()
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience functions for common operations
|
|
||||||
def register_agent_for_protection(agent_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
guardians: List[str] = None) -> Dict:
|
|
||||||
"""Register an agent for security protection"""
|
|
||||||
return agent_wallet_security.register_agent(
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
guardian_addresses=guardians
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def protect_agent_transaction(agent_address: str,
|
|
||||||
to_address: str,
|
|
||||||
amount: int,
|
|
||||||
data: str = "") -> Dict:
|
|
||||||
"""Protect a transaction for an agent"""
|
|
||||||
return agent_wallet_security.protect_transaction(
|
|
||||||
agent_address=agent_address,
|
|
||||||
to_address=to_address,
|
|
||||||
amount=amount,
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_security_summary(agent_address: str) -> Dict:
|
|
||||||
"""Get security summary for an agent"""
|
|
||||||
return agent_wallet_security.get_agent_security_status(agent_address)
|
|
||||||
|
|
||||||
|
|
||||||
# Security audit and monitoring functions
|
|
||||||
def generate_security_report() -> Dict:
|
|
||||||
"""Generate comprehensive security report"""
|
|
||||||
protected_agents = agent_wallet_security.list_protected_agents()
|
|
||||||
|
|
||||||
total_agents = len(protected_agents)
|
|
||||||
active_agents = len([a for a in protected_agents if a["enabled"]])
|
|
||||||
paused_agents = len([a for a in protected_agents if a["paused"]])
|
|
||||||
emergency_agents = len([a for a in protected_agents if a["emergency_mode"]])
|
|
||||||
|
|
||||||
recent_events = agent_wallet_security.get_security_events(limit=20)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"generated_at": datetime.utcnow().isoformat(),
|
|
||||||
"summary": {
|
|
||||||
"total_protected_agents": total_agents,
|
|
||||||
"active_agents": active_agents,
|
|
||||||
"paused_agents": paused_agents,
|
|
||||||
"emergency_mode_agents": emergency_agents,
|
|
||||||
"protection_coverage": f"{(active_agents / total_agents * 100):.1f}%" if total_agents > 0 else "0%"
|
|
||||||
},
|
|
||||||
"agents": protected_agents,
|
|
||||||
"recent_security_events": recent_events,
|
|
||||||
"security_levels": {
|
|
||||||
level: len([a for a in protected_agents if a["security_level"] == level])
|
|
||||||
for level in ["conservative", "aggressive", "high_security"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def detect_suspicious_activity(agent_address: str, hours: int = 24) -> Dict:
|
|
||||||
"""Detect suspicious activity for an agent"""
|
|
||||||
status = agent_wallet_security.get_agent_security_status(agent_address)
|
|
||||||
|
|
||||||
if status["status"] != "protected":
|
|
||||||
return {
|
|
||||||
"status": "not_protected",
|
|
||||||
"suspicious_activity": False
|
|
||||||
}
|
|
||||||
|
|
||||||
spending_status = status["spending_status"]
|
|
||||||
recent_events = agent_wallet_security.get_security_events(agent_address, limit=50)
|
|
||||||
|
|
||||||
# Suspicious patterns
|
|
||||||
suspicious_patterns = []
|
|
||||||
|
|
||||||
# Check for rapid spending
|
|
||||||
if spending_status["spent"]["current_hour"] > spending_status["current_limits"]["per_hour"] * 0.8:
|
|
||||||
suspicious_patterns.append("High hourly spending rate")
|
|
||||||
|
|
||||||
# Check for many small transactions (potential dust attack)
|
|
||||||
recent_tx_count = len([e for e in recent_events if e["event_type"] == "transaction_executed"])
|
|
||||||
if recent_tx_count > 20:
|
|
||||||
suspicious_patterns.append("High transaction frequency")
|
|
||||||
|
|
||||||
# Check for emergency pauses
|
|
||||||
recent_pauses = len([e for e in recent_events if e["event_type"] == "emergency_pause"])
|
|
||||||
if recent_pauses > 0:
|
|
||||||
suspicious_patterns.append("Recent emergency pauses detected")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "analyzed",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"suspicious_activity": len(suspicious_patterns) > 0,
|
|
||||||
"suspicious_patterns": suspicious_patterns,
|
|
||||||
"analysis_period_hours": hours,
|
|
||||||
"analyzed_at": datetime.utcnow().isoformat()
|
|
||||||
}
|
|
||||||
@@ -1,405 +0,0 @@
|
|||||||
"""
|
|
||||||
Fixed Guardian Configuration with Proper Guardian Setup
|
|
||||||
Addresses the critical vulnerability where guardian lists were empty
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address, keccak
|
|
||||||
|
|
||||||
from .guardian_contract import (
|
|
||||||
SpendingLimit,
|
|
||||||
TimeLockConfig,
|
|
||||||
GuardianConfig,
|
|
||||||
GuardianContract
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GuardianSetup:
|
|
||||||
"""Guardian setup configuration"""
|
|
||||||
primary_guardian: str # Main guardian address
|
|
||||||
backup_guardians: List[str] # Backup guardian addresses
|
|
||||||
multisig_threshold: int # Number of signatures required
|
|
||||||
emergency_contacts: List[str] # Additional emergency contacts
|
|
||||||
|
|
||||||
|
|
||||||
class SecureGuardianManager:
|
|
||||||
"""
|
|
||||||
Secure guardian management with proper initialization
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.guardian_registrations: Dict[str, GuardianSetup] = {}
|
|
||||||
self.guardian_contracts: Dict[str, GuardianContract] = {}
|
|
||||||
|
|
||||||
def create_guardian_setup(
|
|
||||||
self,
|
|
||||||
agent_address: str,
|
|
||||||
owner_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
custom_guardians: Optional[List[str]] = None
|
|
||||||
) -> GuardianSetup:
|
|
||||||
"""
|
|
||||||
Create a proper guardian setup for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
owner_address: Owner of the agent
|
|
||||||
security_level: Security level (conservative, aggressive, high_security)
|
|
||||||
custom_guardians: Optional custom guardian addresses
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Guardian setup configuration
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
owner_address = to_checksum_address(owner_address)
|
|
||||||
|
|
||||||
# Determine guardian requirements based on security level
|
|
||||||
if security_level == "conservative":
|
|
||||||
required_guardians = 3
|
|
||||||
multisig_threshold = 2
|
|
||||||
elif security_level == "aggressive":
|
|
||||||
required_guardians = 2
|
|
||||||
multisig_threshold = 2
|
|
||||||
elif security_level == "high_security":
|
|
||||||
required_guardians = 5
|
|
||||||
multisig_threshold = 3
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid security level: {security_level}")
|
|
||||||
|
|
||||||
# Build guardian list
|
|
||||||
guardians = []
|
|
||||||
|
|
||||||
# Always include the owner as primary guardian
|
|
||||||
guardians.append(owner_address)
|
|
||||||
|
|
||||||
# Add custom guardians if provided
|
|
||||||
if custom_guardians:
|
|
||||||
for guardian in custom_guardians:
|
|
||||||
guardian = to_checksum_address(guardian)
|
|
||||||
if guardian not in guardians:
|
|
||||||
guardians.append(guardian)
|
|
||||||
|
|
||||||
# Generate backup guardians if needed
|
|
||||||
while len(guardians) < required_guardians:
|
|
||||||
# Generate a deterministic backup guardian based on agent address
|
|
||||||
# In production, these would be trusted service addresses
|
|
||||||
backup_index = len(guardians) - 1 # -1 because owner is already included
|
|
||||||
backup_guardian = self._generate_backup_guardian(agent_address, backup_index)
|
|
||||||
|
|
||||||
if backup_guardian not in guardians:
|
|
||||||
guardians.append(backup_guardian)
|
|
||||||
|
|
||||||
# Create setup
|
|
||||||
setup = GuardianSetup(
|
|
||||||
primary_guardian=owner_address,
|
|
||||||
backup_guardians=[g for g in guardians if g != owner_address],
|
|
||||||
multisig_threshold=multisig_threshold,
|
|
||||||
emergency_contacts=guardians.copy()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.guardian_registrations[agent_address] = setup
|
|
||||||
|
|
||||||
return setup
|
|
||||||
|
|
||||||
def _generate_backup_guardian(self, agent_address: str, index: int) -> str:
|
|
||||||
"""
|
|
||||||
Generate deterministic backup guardian address
|
|
||||||
|
|
||||||
In production, these would be pre-registered trusted guardian addresses
|
|
||||||
"""
|
|
||||||
# Create a deterministic address based on agent address and index
|
|
||||||
seed = f"{agent_address}_{index}_backup_guardian"
|
|
||||||
hash_result = keccak(seed.encode())
|
|
||||||
|
|
||||||
# Use the hash to generate a valid address
|
|
||||||
address_bytes = hash_result[-20:] # Take last 20 bytes
|
|
||||||
address = "0x" + address_bytes.hex()
|
|
||||||
|
|
||||||
return to_checksum_address(address)
|
|
||||||
|
|
||||||
def create_secure_guardian_contract(
|
|
||||||
self,
|
|
||||||
agent_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
custom_guardians: Optional[List[str]] = None
|
|
||||||
) -> GuardianContract:
|
|
||||||
"""
|
|
||||||
Create a guardian contract with proper guardian configuration
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
security_level: Security level
|
|
||||||
custom_guardians: Optional custom guardian addresses
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured guardian contract
|
|
||||||
"""
|
|
||||||
# Create guardian setup
|
|
||||||
setup = self.create_guardian_setup(
|
|
||||||
agent_address=agent_address,
|
|
||||||
owner_address=agent_address, # Agent is its own owner initially
|
|
||||||
security_level=security_level,
|
|
||||||
custom_guardians=custom_guardians
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get security configuration
|
|
||||||
config = self._get_security_config(security_level, setup)
|
|
||||||
|
|
||||||
# Create contract
|
|
||||||
contract = GuardianContract(agent_address, config)
|
|
||||||
|
|
||||||
# Store contract
|
|
||||||
self.guardian_contracts[agent_address] = contract
|
|
||||||
|
|
||||||
return contract
|
|
||||||
|
|
||||||
def _get_security_config(self, security_level: str, setup: GuardianSetup) -> GuardianConfig:
|
|
||||||
"""Get security configuration with proper guardian list"""
|
|
||||||
|
|
||||||
# Build guardian list
|
|
||||||
all_guardians = [setup.primary_guardian] + setup.backup_guardians
|
|
||||||
|
|
||||||
if security_level == "conservative":
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=1000,
|
|
||||||
per_hour=5000,
|
|
||||||
per_day=20000,
|
|
||||||
per_week=100000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=5000,
|
|
||||||
delay_hours=24,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=all_guardians,
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False,
|
|
||||||
multisig_threshold=setup.multisig_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
elif security_level == "aggressive":
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=5000,
|
|
||||||
per_hour=25000,
|
|
||||||
per_day=100000,
|
|
||||||
per_week=500000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=20000,
|
|
||||||
delay_hours=12,
|
|
||||||
max_delay_hours=72
|
|
||||||
),
|
|
||||||
guardians=all_guardians,
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False,
|
|
||||||
multisig_threshold=setup.multisig_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
elif security_level == "high_security":
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=500,
|
|
||||||
per_hour=2000,
|
|
||||||
per_day=8000,
|
|
||||||
per_week=40000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=2000,
|
|
||||||
delay_hours=48,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=all_guardians,
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False,
|
|
||||||
multisig_threshold=setup.multisig_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid security level: {security_level}")
|
|
||||||
|
|
||||||
def test_emergency_pause(self, agent_address: str, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Test emergency pause functionality
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent address
|
|
||||||
guardian_address: Guardian attempting pause
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Test result
|
|
||||||
"""
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
contract = self.guardian_contracts[agent_address]
|
|
||||||
return contract.emergency_pause(guardian_address)
|
|
||||||
|
|
||||||
def verify_guardian_authorization(self, agent_address: str, guardian_address: str) -> bool:
|
|
||||||
"""
|
|
||||||
Verify if a guardian is authorized for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent address
|
|
||||||
guardian_address: Guardian address to verify
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if guardian is authorized
|
|
||||||
"""
|
|
||||||
if agent_address not in self.guardian_registrations:
|
|
||||||
return False
|
|
||||||
|
|
||||||
setup = self.guardian_registrations[agent_address]
|
|
||||||
all_guardians = [setup.primary_guardian] + setup.backup_guardians
|
|
||||||
|
|
||||||
return to_checksum_address(guardian_address) in [
|
|
||||||
to_checksum_address(g) for g in all_guardians
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_guardian_summary(self, agent_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Get guardian setup summary for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Guardian summary
|
|
||||||
"""
|
|
||||||
if agent_address not in self.guardian_registrations:
|
|
||||||
return {"error": "Agent not registered"}
|
|
||||||
|
|
||||||
setup = self.guardian_registrations[agent_address]
|
|
||||||
contract = self.guardian_contracts.get(agent_address)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"primary_guardian": setup.primary_guardian,
|
|
||||||
"backup_guardians": setup.backup_guardians,
|
|
||||||
"total_guardians": len(setup.backup_guardians) + 1,
|
|
||||||
"multisig_threshold": setup.multisig_threshold,
|
|
||||||
"emergency_contacts": setup.emergency_contacts,
|
|
||||||
"contract_status": contract.get_spending_status() if contract else None,
|
|
||||||
"pause_functional": contract is not None and len(setup.backup_guardians) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Fixed security configurations with proper guardians
|
|
||||||
def get_fixed_conservative_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
|
||||||
"""Get fixed conservative configuration with proper guardians"""
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=1000,
|
|
||||||
per_hour=5000,
|
|
||||||
per_day=20000,
|
|
||||||
per_week=100000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=5000,
|
|
||||||
delay_hours=24,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=[owner_address], # At least the owner
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_fixed_aggressive_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
|
||||||
"""Get fixed aggressive configuration with proper guardians"""
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=5000,
|
|
||||||
per_hour=25000,
|
|
||||||
per_day=100000,
|
|
||||||
per_week=500000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=20000,
|
|
||||||
delay_hours=12,
|
|
||||||
max_delay_hours=72
|
|
||||||
),
|
|
||||||
guardians=[owner_address], # At least the owner
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_fixed_high_security_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
|
||||||
"""Get fixed high security configuration with proper guardians"""
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=500,
|
|
||||||
per_hour=2000,
|
|
||||||
per_day=8000,
|
|
||||||
per_week=40000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=2000,
|
|
||||||
delay_hours=48,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=[owner_address], # At least the owner
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Global secure guardian manager
|
|
||||||
secure_guardian_manager = SecureGuardianManager()
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience function for secure agent registration
|
|
||||||
def register_agent_with_guardians(
|
|
||||||
agent_address: str,
|
|
||||||
owner_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
custom_guardians: Optional[List[str]] = None
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
|
||||||
Register an agent with proper guardian configuration
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
owner_address: Owner address
|
|
||||||
security_level: Security level
|
|
||||||
custom_guardians: Optional custom guardians
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Registration result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Create secure guardian contract
|
|
||||||
contract = secure_guardian_manager.create_secure_guardian_contract(
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
custom_guardians=custom_guardians
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get guardian summary
|
|
||||||
summary = secure_guardian_manager.get_guardian_summary(agent_address)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "registered",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": security_level,
|
|
||||||
"guardian_count": summary["total_guardians"],
|
|
||||||
"multisig_threshold": summary["multisig_threshold"],
|
|
||||||
"pause_functional": summary["pause_functional"],
|
|
||||||
"registered_at": datetime.utcnow().isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Registration failed: {str(e)}"
|
|
||||||
}
|
|
||||||
@@ -1,682 +0,0 @@
|
|||||||
"""
|
|
||||||
AITBC Guardian Contract - Spending Limit Protection for Agent Wallets
|
|
||||||
|
|
||||||
This contract implements a spending limit guardian that protects autonomous agent
|
|
||||||
wallets from unlimited spending in case of compromise. It provides:
|
|
||||||
- Per-transaction spending limits
|
|
||||||
- Per-period (daily/hourly) spending caps
|
|
||||||
- Time-lock for large withdrawals
|
|
||||||
- Emergency pause functionality
|
|
||||||
- Multi-signature recovery for critical operations
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sqlite3
|
|
||||||
from pathlib import Path
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address, keccak
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SpendingLimit:
|
|
||||||
"""Spending limit configuration"""
|
|
||||||
per_transaction: int # Maximum per transaction
|
|
||||||
per_hour: int # Maximum per hour
|
|
||||||
per_day: int # Maximum per day
|
|
||||||
per_week: int # Maximum per week
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TimeLockConfig:
|
|
||||||
"""Time lock configuration for large withdrawals"""
|
|
||||||
threshold: int # Amount that triggers time lock
|
|
||||||
delay_hours: int # Delay period in hours
|
|
||||||
max_delay_hours: int # Maximum delay period
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GuardianConfig:
|
|
||||||
"""Complete guardian configuration"""
|
|
||||||
limits: SpendingLimit
|
|
||||||
time_lock: TimeLockConfig
|
|
||||||
guardians: List[str] # Guardian addresses for recovery
|
|
||||||
pause_enabled: bool = True
|
|
||||||
emergency_mode: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class GuardianContract:
|
|
||||||
"""
|
|
||||||
Guardian contract implementation for agent wallet protection
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, agent_address: str, config: GuardianConfig, storage_path: str = None):
|
|
||||||
self.agent_address = to_checksum_address(agent_address)
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
# CRITICAL SECURITY FIX: Use persistent storage instead of in-memory
|
|
||||||
if storage_path is None:
|
|
||||||
storage_path = os.path.join(os.path.expanduser("~"), ".aitbc", "guardian_contracts")
|
|
||||||
|
|
||||||
self.storage_dir = Path(storage_path)
|
|
||||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Database file for this contract
|
|
||||||
self.db_path = self.storage_dir / f"guardian_{self.agent_address}.db"
|
|
||||||
|
|
||||||
# Initialize persistent storage
|
|
||||||
self._init_storage()
|
|
||||||
|
|
||||||
# Load state from storage
|
|
||||||
self._load_state()
|
|
||||||
|
|
||||||
# In-memory cache for performance (synced with storage)
|
|
||||||
self.spending_history: List[Dict] = []
|
|
||||||
self.pending_operations: Dict[str, Dict] = {}
|
|
||||||
self.paused = False
|
|
||||||
self.emergency_mode = False
|
|
||||||
|
|
||||||
# Contract state
|
|
||||||
self.nonce = 0
|
|
||||||
self.guardian_approvals: Dict[str, bool] = {}
|
|
||||||
|
|
||||||
# Load data from persistent storage
|
|
||||||
self._load_spending_history()
|
|
||||||
self._load_pending_operations()
|
|
||||||
|
|
||||||
def _init_storage(self):
|
|
||||||
"""Initialize SQLite database for persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute('''
|
|
||||||
CREATE TABLE IF NOT EXISTS spending_history (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
operation_id TEXT UNIQUE,
|
|
||||||
agent_address TEXT,
|
|
||||||
to_address TEXT,
|
|
||||||
amount INTEGER,
|
|
||||||
data TEXT,
|
|
||||||
timestamp TEXT,
|
|
||||||
executed_at TEXT,
|
|
||||||
status TEXT,
|
|
||||||
nonce INTEGER,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
''')
|
|
||||||
|
|
||||||
conn.execute('''
|
|
||||||
CREATE TABLE IF NOT EXISTS pending_operations (
|
|
||||||
operation_id TEXT PRIMARY KEY,
|
|
||||||
agent_address TEXT,
|
|
||||||
operation_data TEXT,
|
|
||||||
status TEXT,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
''')
|
|
||||||
|
|
||||||
conn.execute('''
|
|
||||||
CREATE TABLE IF NOT EXISTS contract_state (
|
|
||||||
agent_address TEXT PRIMARY KEY,
|
|
||||||
nonce INTEGER DEFAULT 0,
|
|
||||||
paused BOOLEAN DEFAULT 0,
|
|
||||||
emergency_mode BOOLEAN DEFAULT 0,
|
|
||||||
last_updated DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
''')
|
|
||||||
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _load_state(self):
|
|
||||||
"""Load contract state from persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
'SELECT nonce, paused, emergency_mode FROM contract_state WHERE agent_address = ?',
|
|
||||||
(self.agent_address,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
|
||||||
|
|
||||||
if row:
|
|
||||||
self.nonce, self.paused, self.emergency_mode = row
|
|
||||||
else:
|
|
||||||
# Initialize state for new contract
|
|
||||||
conn.execute(
|
|
||||||
'INSERT INTO contract_state (agent_address, nonce, paused, emergency_mode) VALUES (?, ?, ?, ?)',
|
|
||||||
(self.agent_address, 0, False, False)
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _save_state(self):
|
|
||||||
"""Save contract state to persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute(
|
|
||||||
'UPDATE contract_state SET nonce = ?, paused = ?, emergency_mode = ?, last_updated = CURRENT_TIMESTAMP WHERE agent_address = ?',
|
|
||||||
(self.nonce, self.paused, self.emergency_mode, self.agent_address)
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _load_spending_history(self):
|
|
||||||
"""Load spending history from persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
'SELECT operation_id, to_address, amount, data, timestamp, executed_at, status, nonce FROM spending_history WHERE agent_address = ? ORDER BY timestamp DESC',
|
|
||||||
(self.agent_address,)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.spending_history = []
|
|
||||||
for row in cursor:
|
|
||||||
self.spending_history.append({
|
|
||||||
"operation_id": row[0],
|
|
||||||
"to": row[1],
|
|
||||||
"amount": row[2],
|
|
||||||
"data": row[3],
|
|
||||||
"timestamp": row[4],
|
|
||||||
"executed_at": row[5],
|
|
||||||
"status": row[6],
|
|
||||||
"nonce": row[7]
|
|
||||||
})
|
|
||||||
|
|
||||||
def _save_spending_record(self, record: Dict):
|
|
||||||
"""Save spending record to persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute(
|
|
||||||
'''INSERT OR REPLACE INTO spending_history
|
|
||||||
(operation_id, agent_address, to_address, amount, data, timestamp, executed_at, status, nonce)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)''',
|
|
||||||
(
|
|
||||||
record["operation_id"],
|
|
||||||
self.agent_address,
|
|
||||||
record["to"],
|
|
||||||
record["amount"],
|
|
||||||
record.get("data", ""),
|
|
||||||
record["timestamp"],
|
|
||||||
record.get("executed_at", ""),
|
|
||||||
record["status"],
|
|
||||||
record["nonce"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _load_pending_operations(self):
|
|
||||||
"""Load pending operations from persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
'SELECT operation_id, operation_data, status FROM pending_operations WHERE agent_address = ?',
|
|
||||||
(self.agent_address,)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pending_operations = {}
|
|
||||||
for row in cursor:
|
|
||||||
operation_data = json.loads(row[1])
|
|
||||||
operation_data["status"] = row[2]
|
|
||||||
self.pending_operations[row[0]] = operation_data
|
|
||||||
|
|
||||||
def _save_pending_operation(self, operation_id: str, operation: Dict):
|
|
||||||
"""Save pending operation to persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute(
|
|
||||||
'''INSERT OR REPLACE INTO pending_operations
|
|
||||||
(operation_id, agent_address, operation_data, status, updated_at)
|
|
||||||
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)''',
|
|
||||||
(operation_id, self.agent_address, json.dumps(operation), operation["status"])
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _remove_pending_operation(self, operation_id: str):
|
|
||||||
"""Remove pending operation from persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute(
|
|
||||||
'DELETE FROM pending_operations WHERE operation_id = ? AND agent_address = ?',
|
|
||||||
(operation_id, self.agent_address)
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _get_period_key(self, timestamp: datetime, period: str) -> str:
|
|
||||||
"""Generate period key for spending tracking"""
|
|
||||||
if period == "hour":
|
|
||||||
return timestamp.strftime("%Y-%m-%d-%H")
|
|
||||||
elif period == "day":
|
|
||||||
return timestamp.strftime("%Y-%m-%d")
|
|
||||||
elif period == "week":
|
|
||||||
# Get week number (Monday as first day)
|
|
||||||
week_num = timestamp.isocalendar()[1]
|
|
||||||
return f"{timestamp.year}-W{week_num:02d}"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid period: {period}")
|
|
||||||
|
|
||||||
def _get_spent_in_period(self, period: str, timestamp: datetime = None) -> int:
|
|
||||||
"""Calculate total spent in given period"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
period_key = self._get_period_key(timestamp, period)
|
|
||||||
|
|
||||||
total = 0
|
|
||||||
for record in self.spending_history:
|
|
||||||
record_time = datetime.fromisoformat(record["timestamp"])
|
|
||||||
record_period = self._get_period_key(record_time, period)
|
|
||||||
|
|
||||||
if record_period == period_key and record["status"] == "completed":
|
|
||||||
total += record["amount"]
|
|
||||||
|
|
||||||
return total
|
|
||||||
|
|
||||||
def _check_spending_limits(self, amount: int, timestamp: datetime = None) -> Tuple[bool, str]:
|
|
||||||
"""Check if amount exceeds spending limits"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
# Check per-transaction limit
|
|
||||||
if amount > self.config.limits.per_transaction:
|
|
||||||
return False, f"Amount {amount} exceeds per-transaction limit {self.config.limits.per_transaction}"
|
|
||||||
|
|
||||||
# Check per-hour limit
|
|
||||||
spent_hour = self._get_spent_in_period("hour", timestamp)
|
|
||||||
if spent_hour + amount > self.config.limits.per_hour:
|
|
||||||
return False, f"Hourly spending {spent_hour + amount} would exceed limit {self.config.limits.per_hour}"
|
|
||||||
|
|
||||||
# Check per-day limit
|
|
||||||
spent_day = self._get_spent_in_period("day", timestamp)
|
|
||||||
if spent_day + amount > self.config.limits.per_day:
|
|
||||||
return False, f"Daily spending {spent_day + amount} would exceed limit {self.config.limits.per_day}"
|
|
||||||
|
|
||||||
# Check per-week limit
|
|
||||||
spent_week = self._get_spent_in_period("week", timestamp)
|
|
||||||
if spent_week + amount > self.config.limits.per_week:
|
|
||||||
return False, f"Weekly spending {spent_week + amount} would exceed limit {self.config.limits.per_week}"
|
|
||||||
|
|
||||||
return True, "Spending limits check passed"
|
|
||||||
|
|
||||||
def _requires_time_lock(self, amount: int) -> bool:
|
|
||||||
"""Check if amount requires time lock"""
|
|
||||||
return amount >= self.config.time_lock.threshold
|
|
||||||
|
|
||||||
def _create_operation_hash(self, operation: Dict) -> str:
|
|
||||||
"""Create hash for operation identification"""
|
|
||||||
operation_str = json.dumps(operation, sort_keys=True)
|
|
||||||
return keccak(operation_str.encode()).hex()
|
|
||||||
|
|
||||||
def initiate_transaction(self, to_address: str, amount: int, data: str = "") -> Dict:
|
|
||||||
"""
|
|
||||||
Initiate a transaction with guardian protection
|
|
||||||
|
|
||||||
Args:
|
|
||||||
to_address: Recipient address
|
|
||||||
amount: Amount to transfer
|
|
||||||
data: Transaction data (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Operation result with status and details
|
|
||||||
"""
|
|
||||||
# Check if paused
|
|
||||||
if self.paused:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Guardian contract is paused",
|
|
||||||
"operation_id": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check emergency mode
|
|
||||||
if self.emergency_mode:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Emergency mode activated",
|
|
||||||
"operation_id": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate address
|
|
||||||
try:
|
|
||||||
to_address = to_checksum_address(to_address)
|
|
||||||
except Exception:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Invalid recipient address",
|
|
||||||
"operation_id": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check spending limits
|
|
||||||
limits_ok, limits_reason = self._check_spending_limits(amount)
|
|
||||||
if not limits_ok:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": limits_reason,
|
|
||||||
"operation_id": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create operation
|
|
||||||
operation = {
|
|
||||||
"type": "transaction",
|
|
||||||
"to": to_address,
|
|
||||||
"amount": amount,
|
|
||||||
"data": data,
|
|
||||||
"timestamp": datetime.utcnow().isoformat(),
|
|
||||||
"nonce": self.nonce,
|
|
||||||
"status": "pending"
|
|
||||||
}
|
|
||||||
|
|
||||||
operation_id = self._create_operation_hash(operation)
|
|
||||||
operation["operation_id"] = operation_id
|
|
||||||
|
|
||||||
# Check if time lock is required
|
|
||||||
if self._requires_time_lock(amount):
|
|
||||||
unlock_time = datetime.utcnow() + timedelta(hours=self.config.time_lock.delay_hours)
|
|
||||||
operation["unlock_time"] = unlock_time.isoformat()
|
|
||||||
operation["status"] = "time_locked"
|
|
||||||
|
|
||||||
# Store for later execution
|
|
||||||
self.pending_operations[operation_id] = operation
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "time_locked",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"unlock_time": unlock_time.isoformat(),
|
|
||||||
"delay_hours": self.config.time_lock.delay_hours,
|
|
||||||
"message": f"Transaction requires {self.config.time_lock.delay_hours}h time lock"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Immediate execution for smaller amounts
|
|
||||||
self.pending_operations[operation_id] = operation
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "approved",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"message": "Transaction approved for execution"
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute_transaction(self, operation_id: str, signature: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Execute a previously approved transaction
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: Operation ID from initiate_transaction
|
|
||||||
signature: Transaction signature from agent
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Execution result
|
|
||||||
"""
|
|
||||||
if operation_id not in self.pending_operations:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Operation not found"
|
|
||||||
}
|
|
||||||
|
|
||||||
operation = self.pending_operations[operation_id]
|
|
||||||
|
|
||||||
# Check if operation is time locked
|
|
||||||
if operation["status"] == "time_locked":
|
|
||||||
unlock_time = datetime.fromisoformat(operation["unlock_time"])
|
|
||||||
if datetime.utcnow() < unlock_time:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Operation locked until {unlock_time.isoformat()}"
|
|
||||||
}
|
|
||||||
|
|
||||||
operation["status"] = "ready"
|
|
||||||
|
|
||||||
# Verify signature (simplified - in production, use proper verification)
|
|
||||||
try:
|
|
||||||
# In production, verify the signature matches the agent address
|
|
||||||
# For now, we'll assume signature is valid
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Invalid signature: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Record the transaction
|
|
||||||
record = {
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"to": operation["to"],
|
|
||||||
"amount": operation["amount"],
|
|
||||||
"data": operation.get("data", ""),
|
|
||||||
"timestamp": operation["timestamp"],
|
|
||||||
"executed_at": datetime.utcnow().isoformat(),
|
|
||||||
"status": "completed",
|
|
||||||
"nonce": operation["nonce"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# CRITICAL SECURITY FIX: Save to persistent storage
|
|
||||||
self._save_spending_record(record)
|
|
||||||
self.spending_history.append(record)
|
|
||||||
self.nonce += 1
|
|
||||||
self._save_state()
|
|
||||||
|
|
||||||
# Remove from pending storage
|
|
||||||
self._remove_pending_operation(operation_id)
|
|
||||||
if operation_id in self.pending_operations:
|
|
||||||
del self.pending_operations[operation_id]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "executed",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"transaction_hash": f"0x{keccak(f'{operation_id}{signature}'.encode()).hex()}",
|
|
||||||
"executed_at": record["executed_at"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def emergency_pause(self, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Emergency pause function (guardian only)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
guardian_address: Address of guardian initiating pause
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Pause result
|
|
||||||
"""
|
|
||||||
if guardian_address not in self.config.guardians:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Not authorized: guardian address not recognized"
|
|
||||||
}
|
|
||||||
|
|
||||||
self.paused = True
|
|
||||||
self.emergency_mode = True
|
|
||||||
|
|
||||||
# CRITICAL SECURITY FIX: Save state to persistent storage
|
|
||||||
self._save_state()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "paused",
|
|
||||||
"paused_at": datetime.utcnow().isoformat(),
|
|
||||||
"guardian": guardian_address,
|
|
||||||
"message": "Emergency pause activated - all operations halted"
|
|
||||||
}
|
|
||||||
|
|
||||||
def emergency_unpause(self, guardian_signatures: List[str]) -> Dict:
|
|
||||||
"""
|
|
||||||
Emergency unpause function (requires multiple guardian signatures)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
guardian_signatures: Signatures from required guardians
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Unpause result
|
|
||||||
"""
|
|
||||||
# In production, verify all guardian signatures
|
|
||||||
required_signatures = len(self.config.guardians)
|
|
||||||
if len(guardian_signatures) < required_signatures:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": f"Requires {required_signatures} guardian signatures, got {len(guardian_signatures)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Verify signatures (simplified)
|
|
||||||
# In production, verify each signature matches a guardian address
|
|
||||||
|
|
||||||
self.paused = False
|
|
||||||
self.emergency_mode = False
|
|
||||||
|
|
||||||
# CRITICAL SECURITY FIX: Save state to persistent storage
|
|
||||||
self._save_state()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "unpaused",
|
|
||||||
"unpaused_at": datetime.utcnow().isoformat(),
|
|
||||||
"message": "Emergency pause lifted - operations resumed"
|
|
||||||
}
|
|
||||||
|
|
||||||
def update_limits(self, new_limits: SpendingLimit, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Update spending limits (guardian only)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
new_limits: New spending limits
|
|
||||||
guardian_address: Address of guardian making the change
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Update result
|
|
||||||
"""
|
|
||||||
if guardian_address not in self.config.guardians:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Not authorized: guardian address not recognized"
|
|
||||||
}
|
|
||||||
|
|
||||||
old_limits = self.config.limits
|
|
||||||
self.config.limits = new_limits
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "updated",
|
|
||||||
"old_limits": old_limits,
|
|
||||||
"new_limits": new_limits,
|
|
||||||
"updated_at": datetime.utcnow().isoformat(),
|
|
||||||
"guardian": guardian_address
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_spending_status(self) -> Dict:
|
|
||||||
"""Get current spending status and limits"""
|
|
||||||
now = datetime.utcnow()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"agent_address": self.agent_address,
|
|
||||||
"current_limits": self.config.limits,
|
|
||||||
"spent": {
|
|
||||||
"current_hour": self._get_spent_in_period("hour", now),
|
|
||||||
"current_day": self._get_spent_in_period("day", now),
|
|
||||||
"current_week": self._get_spent_in_period("week", now)
|
|
||||||
},
|
|
||||||
"remaining": {
|
|
||||||
"current_hour": self.config.limits.per_hour - self._get_spent_in_period("hour", now),
|
|
||||||
"current_day": self.config.limits.per_day - self._get_spent_in_period("day", now),
|
|
||||||
"current_week": self.config.limits.per_week - self._get_spent_in_period("week", now)
|
|
||||||
},
|
|
||||||
"pending_operations": len(self.pending_operations),
|
|
||||||
"paused": self.paused,
|
|
||||||
"emergency_mode": self.emergency_mode,
|
|
||||||
"nonce": self.nonce
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_operation_history(self, limit: int = 50) -> List[Dict]:
|
|
||||||
"""Get operation history"""
|
|
||||||
return sorted(self.spending_history, key=lambda x: x["timestamp"], reverse=True)[:limit]
|
|
||||||
|
|
||||||
def get_pending_operations(self) -> List[Dict]:
|
|
||||||
"""Get all pending operations"""
|
|
||||||
return list(self.pending_operations.values())
|
|
||||||
|
|
||||||
|
|
||||||
# Factory function for creating guardian contracts
|
|
||||||
def create_guardian_contract(
|
|
||||||
agent_address: str,
|
|
||||||
per_transaction: int = 1000,
|
|
||||||
per_hour: int = 5000,
|
|
||||||
per_day: int = 20000,
|
|
||||||
per_week: int = 100000,
|
|
||||||
time_lock_threshold: int = 10000,
|
|
||||||
time_lock_delay: int = 24,
|
|
||||||
guardians: List[str] = None
|
|
||||||
) -> GuardianContract:
|
|
||||||
"""
|
|
||||||
Create a guardian contract with default security parameters
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: The agent wallet address to protect
|
|
||||||
per_transaction: Maximum amount per transaction
|
|
||||||
per_hour: Maximum amount per hour
|
|
||||||
per_day: Maximum amount per day
|
|
||||||
per_week: Maximum amount per week
|
|
||||||
time_lock_threshold: Amount that triggers time lock
|
|
||||||
time_lock_delay: Time lock delay in hours
|
|
||||||
guardians: List of guardian addresses (REQUIRED for security)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured GuardianContract instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If no guardians are provided or guardians list is insufficient
|
|
||||||
"""
|
|
||||||
# CRITICAL SECURITY FIX: Require proper guardians, never default to agent address
|
|
||||||
if guardians is None or not guardians:
|
|
||||||
raise ValueError(
|
|
||||||
"❌ CRITICAL: Guardians are required for security. "
|
|
||||||
"Provide at least 3 trusted guardian addresses different from the agent address."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate that guardians are different from agent address
|
|
||||||
agent_checksum = to_checksum_address(agent_address)
|
|
||||||
guardian_checksums = [to_checksum_address(g) for g in guardians]
|
|
||||||
|
|
||||||
if agent_checksum in guardian_checksums:
|
|
||||||
raise ValueError(
|
|
||||||
"❌ CRITICAL: Agent address cannot be used as guardian. "
|
|
||||||
"Guardians must be independent trusted addresses."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Require minimum number of guardians for security
|
|
||||||
if len(guardian_checksums) < 3:
|
|
||||||
raise ValueError(
|
|
||||||
f"❌ CRITICAL: At least 3 guardians required for security, got {len(guardian_checksums)}. "
|
|
||||||
"Consider using a multi-sig wallet or trusted service providers."
|
|
||||||
)
|
|
||||||
|
|
||||||
limits = SpendingLimit(
|
|
||||||
per_transaction=per_transaction,
|
|
||||||
per_hour=per_hour,
|
|
||||||
per_day=per_day,
|
|
||||||
per_week=per_week
|
|
||||||
)
|
|
||||||
|
|
||||||
time_lock = TimeLockConfig(
|
|
||||||
threshold=time_lock_threshold,
|
|
||||||
delay_hours=time_lock_delay,
|
|
||||||
max_delay_hours=168 # 1 week max
|
|
||||||
)
|
|
||||||
|
|
||||||
config = GuardianConfig(
|
|
||||||
limits=limits,
|
|
||||||
time_lock=time_lock,
|
|
||||||
guardians=[to_checksum_address(g) for g in guardians]
|
|
||||||
)
|
|
||||||
|
|
||||||
return GuardianContract(agent_address, config)
|
|
||||||
|
|
||||||
|
|
||||||
# Example usage and security configurations
|
|
||||||
CONSERVATIVE_CONFIG = {
|
|
||||||
"per_transaction": 100, # $100 per transaction
|
|
||||||
"per_hour": 500, # $500 per hour
|
|
||||||
"per_day": 2000, # $2,000 per day
|
|
||||||
"per_week": 10000, # $10,000 per week
|
|
||||||
"time_lock_threshold": 1000, # Time lock over $1,000
|
|
||||||
"time_lock_delay": 24 # 24 hour delay
|
|
||||||
}
|
|
||||||
|
|
||||||
AGGRESSIVE_CONFIG = {
|
|
||||||
"per_transaction": 1000, # $1,000 per transaction
|
|
||||||
"per_hour": 5000, # $5,000 per hour
|
|
||||||
"per_day": 20000, # $20,000 per day
|
|
||||||
"per_week": 100000, # $100,000 per week
|
|
||||||
"time_lock_threshold": 10000, # Time lock over $10,000
|
|
||||||
"time_lock_delay": 12 # 12 hour delay
|
|
||||||
}
|
|
||||||
|
|
||||||
HIGH_SECURITY_CONFIG = {
|
|
||||||
"per_transaction": 50, # $50 per transaction
|
|
||||||
"per_hour": 200, # $200 per hour
|
|
||||||
"per_day": 1000, # $1,000 per day
|
|
||||||
"per_week": 5000, # $5,000 per week
|
|
||||||
"time_lock_threshold": 500, # Time lock over $500
|
|
||||||
"time_lock_delay": 48 # 48 hour delay
|
|
||||||
}
|
|
||||||
@@ -1,470 +0,0 @@
|
|||||||
"""
|
|
||||||
Persistent Spending Tracker - Database-Backed Security
|
|
||||||
Fixes the critical vulnerability where spending limits were lost on restart
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from sqlalchemy import create_engine, Column, String, Integer, Float, DateTime, Index
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
from sqlalchemy.orm import sessionmaker, Session
|
|
||||||
from eth_utils import to_checksum_address
|
|
||||||
import json
|
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
class SpendingRecord(Base):
|
|
||||||
"""Database model for spending tracking"""
|
|
||||||
__tablename__ = "spending_records"
|
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
|
||||||
agent_address = Column(String, index=True)
|
|
||||||
period_type = Column(String, index=True) # hour, day, week
|
|
||||||
period_key = Column(String, index=True)
|
|
||||||
amount = Column(Float)
|
|
||||||
transaction_hash = Column(String)
|
|
||||||
timestamp = Column(DateTime, default=datetime.utcnow)
|
|
||||||
|
|
||||||
# Composite indexes for performance
|
|
||||||
__table_args__ = (
|
|
||||||
Index('idx_agent_period', 'agent_address', 'period_type', 'period_key'),
|
|
||||||
Index('idx_timestamp', 'timestamp'),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SpendingLimit(Base):
|
|
||||||
"""Database model for spending limits"""
|
|
||||||
__tablename__ = "spending_limits"
|
|
||||||
|
|
||||||
agent_address = Column(String, primary_key=True)
|
|
||||||
per_transaction = Column(Float)
|
|
||||||
per_hour = Column(Float)
|
|
||||||
per_day = Column(Float)
|
|
||||||
per_week = Column(Float)
|
|
||||||
time_lock_threshold = Column(Float)
|
|
||||||
time_lock_delay_hours = Column(Integer)
|
|
||||||
updated_at = Column(DateTime, default=datetime.utcnow)
|
|
||||||
updated_by = Column(String) # Guardian who updated
|
|
||||||
|
|
||||||
|
|
||||||
class GuardianAuthorization(Base):
|
|
||||||
"""Database model for guardian authorizations"""
|
|
||||||
__tablename__ = "guardian_authorizations"
|
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
|
||||||
agent_address = Column(String, index=True)
|
|
||||||
guardian_address = Column(String, index=True)
|
|
||||||
is_active = Column(Boolean, default=True)
|
|
||||||
added_at = Column(DateTime, default=datetime.utcnow)
|
|
||||||
added_by = Column(String)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SpendingCheckResult:
|
|
||||||
"""Result of spending limit check"""
|
|
||||||
allowed: bool
|
|
||||||
reason: str
|
|
||||||
current_spent: Dict[str, float]
|
|
||||||
remaining: Dict[str, float]
|
|
||||||
requires_time_lock: bool
|
|
||||||
time_lock_until: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
class PersistentSpendingTracker:
|
|
||||||
"""
|
|
||||||
Database-backed spending tracker that survives restarts
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, database_url: str = "sqlite:///spending_tracker.db"):
|
|
||||||
self.engine = create_engine(database_url)
|
|
||||||
Base.metadata.create_all(self.engine)
|
|
||||||
self.SessionLocal = sessionmaker(bind=self.engine)
|
|
||||||
|
|
||||||
def get_session(self) -> Session:
|
|
||||||
"""Get database session"""
|
|
||||||
return self.SessionLocal()
|
|
||||||
|
|
||||||
def _get_period_key(self, timestamp: datetime, period: str) -> str:
|
|
||||||
"""Generate period key for spending tracking"""
|
|
||||||
if period == "hour":
|
|
||||||
return timestamp.strftime("%Y-%m-%d-%H")
|
|
||||||
elif period == "day":
|
|
||||||
return timestamp.strftime("%Y-%m-%d")
|
|
||||||
elif period == "week":
|
|
||||||
# Get week number (Monday as first day)
|
|
||||||
week_num = timestamp.isocalendar()[1]
|
|
||||||
return f"{timestamp.year}-W{week_num:02d}"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid period: {period}")
|
|
||||||
|
|
||||||
def get_spent_in_period(self, agent_address: str, period: str, timestamp: datetime = None) -> float:
|
|
||||||
"""
|
|
||||||
Get total spent in given period from database
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
period: Period type (hour, day, week)
|
|
||||||
timestamp: Timestamp to check (default: now)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Total amount spent in period
|
|
||||||
"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
period_key = self._get_period_key(timestamp, period)
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
with self.get_session() as session:
|
|
||||||
total = session.query(SpendingRecord).filter(
|
|
||||||
SpendingRecord.agent_address == agent_address,
|
|
||||||
SpendingRecord.period_type == period,
|
|
||||||
SpendingRecord.period_key == period_key
|
|
||||||
).with_entities(SpendingRecord.amount).all()
|
|
||||||
|
|
||||||
return sum(record.amount for record in total)
|
|
||||||
|
|
||||||
def record_spending(self, agent_address: str, amount: float, transaction_hash: str, timestamp: datetime = None) -> bool:
|
|
||||||
"""
|
|
||||||
Record a spending transaction in the database
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
amount: Amount spent
|
|
||||||
transaction_hash: Transaction hash
|
|
||||||
timestamp: Transaction timestamp (default: now)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if recorded successfully
|
|
||||||
"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.get_session() as session:
|
|
||||||
# Record for all periods
|
|
||||||
periods = ["hour", "day", "week"]
|
|
||||||
|
|
||||||
for period in periods:
|
|
||||||
period_key = self._get_period_key(timestamp, period)
|
|
||||||
|
|
||||||
record = SpendingRecord(
|
|
||||||
id=f"{transaction_hash}_{period}",
|
|
||||||
agent_address=agent_address,
|
|
||||||
period_type=period,
|
|
||||||
period_key=period_key,
|
|
||||||
amount=amount,
|
|
||||||
transaction_hash=transaction_hash,
|
|
||||||
timestamp=timestamp
|
|
||||||
)
|
|
||||||
|
|
||||||
session.add(record)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to record spending: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def check_spending_limits(self, agent_address: str, amount: float, timestamp: datetime = None) -> SpendingCheckResult:
|
|
||||||
"""
|
|
||||||
Check if amount exceeds spending limits using persistent data
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
amount: Amount to check
|
|
||||||
timestamp: Timestamp for check (default: now)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Spending check result
|
|
||||||
"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
# Get spending limits from database
|
|
||||||
with self.get_session() as session:
|
|
||||||
limits = session.query(SpendingLimit).filter(
|
|
||||||
SpendingLimit.agent_address == agent_address
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not limits:
|
|
||||||
# Default limits if not set
|
|
||||||
limits = SpendingLimit(
|
|
||||||
agent_address=agent_address,
|
|
||||||
per_transaction=1000.0,
|
|
||||||
per_hour=5000.0,
|
|
||||||
per_day=20000.0,
|
|
||||||
per_week=100000.0,
|
|
||||||
time_lock_threshold=5000.0,
|
|
||||||
time_lock_delay_hours=24
|
|
||||||
)
|
|
||||||
session.add(limits)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Check each limit
|
|
||||||
current_spent = {}
|
|
||||||
remaining = {}
|
|
||||||
|
|
||||||
# Per-transaction limit
|
|
||||||
if amount > limits.per_transaction:
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=False,
|
|
||||||
reason=f"Amount {amount} exceeds per-transaction limit {limits.per_transaction}",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Per-hour limit
|
|
||||||
spent_hour = self.get_spent_in_period(agent_address, "hour", timestamp)
|
|
||||||
current_spent["hour"] = spent_hour
|
|
||||||
remaining["hour"] = limits.per_hour - spent_hour
|
|
||||||
|
|
||||||
if spent_hour + amount > limits.per_hour:
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=False,
|
|
||||||
reason=f"Hourly spending {spent_hour + amount} would exceed limit {limits.per_hour}",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Per-day limit
|
|
||||||
spent_day = self.get_spent_in_period(agent_address, "day", timestamp)
|
|
||||||
current_spent["day"] = spent_day
|
|
||||||
remaining["day"] = limits.per_day - spent_day
|
|
||||||
|
|
||||||
if spent_day + amount > limits.per_day:
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=False,
|
|
||||||
reason=f"Daily spending {spent_day + amount} would exceed limit {limits.per_day}",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Per-week limit
|
|
||||||
spent_week = self.get_spent_in_period(agent_address, "week", timestamp)
|
|
||||||
current_spent["week"] = spent_week
|
|
||||||
remaining["week"] = limits.per_week - spent_week
|
|
||||||
|
|
||||||
if spent_week + amount > limits.per_week:
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=False,
|
|
||||||
reason=f"Weekly spending {spent_week + amount} would exceed limit {limits.per_week}",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check time lock requirement
|
|
||||||
requires_time_lock = amount >= limits.time_lock_threshold
|
|
||||||
time_lock_until = None
|
|
||||||
|
|
||||||
if requires_time_lock:
|
|
||||||
time_lock_until = timestamp + timedelta(hours=limits.time_lock_delay_hours)
|
|
||||||
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=True,
|
|
||||||
reason="Spending limits check passed",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=requires_time_lock,
|
|
||||||
time_lock_until=time_lock_until
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_spending_limits(self, agent_address: str, new_limits: Dict, guardian_address: str) -> bool:
|
|
||||||
"""
|
|
||||||
Update spending limits for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
new_limits: New spending limits
|
|
||||||
guardian_address: Guardian making the change
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if updated successfully
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
# Verify guardian authorization
|
|
||||||
if not self.is_guardian_authorized(agent_address, guardian_address):
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.get_session() as session:
|
|
||||||
limits = session.query(SpendingLimit).filter(
|
|
||||||
SpendingLimit.agent_address == agent_address
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if limits:
|
|
||||||
limits.per_transaction = new_limits.get("per_transaction", limits.per_transaction)
|
|
||||||
limits.per_hour = new_limits.get("per_hour", limits.per_hour)
|
|
||||||
limits.per_day = new_limits.get("per_day", limits.per_day)
|
|
||||||
limits.per_week = new_limits.get("per_week", limits.per_week)
|
|
||||||
limits.time_lock_threshold = new_limits.get("time_lock_threshold", limits.time_lock_threshold)
|
|
||||||
limits.time_lock_delay_hours = new_limits.get("time_lock_delay_hours", limits.time_lock_delay_hours)
|
|
||||||
limits.updated_at = datetime.utcnow()
|
|
||||||
limits.updated_by = guardian_address
|
|
||||||
else:
|
|
||||||
limits = SpendingLimit(
|
|
||||||
agent_address=agent_address,
|
|
||||||
per_transaction=new_limits.get("per_transaction", 1000.0),
|
|
||||||
per_hour=new_limits.get("per_hour", 5000.0),
|
|
||||||
per_day=new_limits.get("per_day", 20000.0),
|
|
||||||
per_week=new_limits.get("per_week", 100000.0),
|
|
||||||
time_lock_threshold=new_limits.get("time_lock_threshold", 5000.0),
|
|
||||||
time_lock_delay_hours=new_limits.get("time_lock_delay_hours", 24),
|
|
||||||
updated_at=datetime.utcnow(),
|
|
||||||
updated_by=guardian_address
|
|
||||||
)
|
|
||||||
session.add(limits)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to update spending limits: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def add_guardian(self, agent_address: str, guardian_address: str, added_by: str) -> bool:
|
|
||||||
"""
|
|
||||||
Add a guardian for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address
|
|
||||||
added_by: Who added this guardian
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if added successfully
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
added_by = to_checksum_address(added_by)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.get_session() as session:
|
|
||||||
# Check if already exists
|
|
||||||
existing = session.query(GuardianAuthorization).filter(
|
|
||||||
GuardianAuthorization.agent_address == agent_address,
|
|
||||||
GuardianAuthorization.guardian_address == guardian_address
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
existing.is_active = True
|
|
||||||
existing.added_at = datetime.utcnow()
|
|
||||||
existing.added_by = added_by
|
|
||||||
else:
|
|
||||||
auth = GuardianAuthorization(
|
|
||||||
id=f"{agent_address}_{guardian_address}",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address,
|
|
||||||
is_active=True,
|
|
||||||
added_at=datetime.utcnow(),
|
|
||||||
added_by=added_by
|
|
||||||
)
|
|
||||||
session.add(auth)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to add guardian: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_guardian_authorized(self, agent_address: str, guardian_address: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a guardian is authorized for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if authorized
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
with self.get_session() as session:
|
|
||||||
auth = session.query(GuardianAuthorization).filter(
|
|
||||||
GuardianAuthorization.agent_address == agent_address,
|
|
||||||
GuardianAuthorization.guardian_address == guardian_address,
|
|
||||||
GuardianAuthorization.is_active == True
|
|
||||||
).first()
|
|
||||||
|
|
||||||
return auth is not None
|
|
||||||
|
|
||||||
def get_spending_summary(self, agent_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Get comprehensive spending summary for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Spending summary
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
now = datetime.utcnow()
|
|
||||||
|
|
||||||
# Get current spending
|
|
||||||
current_spent = {
|
|
||||||
"hour": self.get_spent_in_period(agent_address, "hour", now),
|
|
||||||
"day": self.get_spent_in_period(agent_address, "day", now),
|
|
||||||
"week": self.get_spent_in_period(agent_address, "week", now)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get limits
|
|
||||||
with self.get_session() as session:
|
|
||||||
limits = session.query(SpendingLimit).filter(
|
|
||||||
SpendingLimit.agent_address == agent_address
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not limits:
|
|
||||||
return {"error": "No spending limits set"}
|
|
||||||
|
|
||||||
# Calculate remaining
|
|
||||||
remaining = {
|
|
||||||
"hour": limits.per_hour - current_spent["hour"],
|
|
||||||
"day": limits.per_day - current_spent["day"],
|
|
||||||
"week": limits.per_week - current_spent["week"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get authorized guardians
|
|
||||||
with self.get_session() as session:
|
|
||||||
guardians = session.query(GuardianAuthorization).filter(
|
|
||||||
GuardianAuthorization.agent_address == agent_address,
|
|
||||||
GuardianAuthorization.is_active == True
|
|
||||||
).all()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"current_spending": current_spent,
|
|
||||||
"remaining_spending": remaining,
|
|
||||||
"limits": {
|
|
||||||
"per_transaction": limits.per_transaction,
|
|
||||||
"per_hour": limits.per_hour,
|
|
||||||
"per_day": limits.per_day,
|
|
||||||
"per_week": limits.per_week
|
|
||||||
},
|
|
||||||
"time_lock": {
|
|
||||||
"threshold": limits.time_lock_threshold,
|
|
||||||
"delay_hours": limits.time_lock_delay_hours
|
|
||||||
},
|
|
||||||
"authorized_guardians": [g.guardian_address for g in guardians],
|
|
||||||
"last_updated": limits.updated_at.isoformat() if limits.updated_at else None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Global persistent tracker instance
|
|
||||||
persistent_tracker = PersistentSpendingTracker()
|
|
||||||
@@ -1,519 +0,0 @@
|
|||||||
"""
|
|
||||||
AITBC Agent Messaging Contract Implementation
|
|
||||||
|
|
||||||
This module implements on-chain messaging functionality for agents,
|
|
||||||
enabling forum-like communication between autonomous agents.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Any
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from enum import Enum
|
|
||||||
import json
|
|
||||||
import hashlib
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address
|
|
||||||
|
|
||||||
class MessageType(Enum):
|
|
||||||
"""Types of messages agents can send"""
|
|
||||||
POST = "post"
|
|
||||||
REPLY = "reply"
|
|
||||||
ANNOUNCEMENT = "announcement"
|
|
||||||
QUESTION = "question"
|
|
||||||
ANSWER = "answer"
|
|
||||||
MODERATION = "moderation"
|
|
||||||
|
|
||||||
class MessageStatus(Enum):
|
|
||||||
"""Status of messages in the forum"""
|
|
||||||
ACTIVE = "active"
|
|
||||||
HIDDEN = "hidden"
|
|
||||||
DELETED = "deleted"
|
|
||||||
PINNED = "pinned"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Message:
|
|
||||||
"""Represents a message in the agent forum"""
|
|
||||||
message_id: str
|
|
||||||
agent_id: str
|
|
||||||
agent_address: str
|
|
||||||
topic: str
|
|
||||||
content: str
|
|
||||||
message_type: MessageType
|
|
||||||
timestamp: datetime
|
|
||||||
parent_message_id: Optional[str] = None
|
|
||||||
reply_count: int = 0
|
|
||||||
upvotes: int = 0
|
|
||||||
downvotes: int = 0
|
|
||||||
status: MessageStatus = MessageStatus.ACTIVE
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Topic:
|
|
||||||
"""Represents a forum topic"""
|
|
||||||
topic_id: str
|
|
||||||
title: str
|
|
||||||
description: str
|
|
||||||
creator_agent_id: str
|
|
||||||
created_at: datetime
|
|
||||||
message_count: int = 0
|
|
||||||
last_activity: datetime = field(default_factory=datetime.now)
|
|
||||||
tags: List[str] = field(default_factory=list)
|
|
||||||
is_pinned: bool = False
|
|
||||||
is_locked: bool = False
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentReputation:
|
|
||||||
"""Reputation system for agents"""
|
|
||||||
agent_id: str
|
|
||||||
message_count: int = 0
|
|
||||||
upvotes_received: int = 0
|
|
||||||
downvotes_received: int = 0
|
|
||||||
reputation_score: float = 0.0
|
|
||||||
trust_level: int = 1 # 1-5 trust levels
|
|
||||||
is_moderator: bool = False
|
|
||||||
is_banned: bool = False
|
|
||||||
ban_reason: Optional[str] = None
|
|
||||||
ban_expires: Optional[datetime] = None
|
|
||||||
|
|
||||||
class AgentMessagingContract:
|
|
||||||
"""Main contract for agent messaging functionality"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.messages: Dict[str, Message] = {}
|
|
||||||
self.topics: Dict[str, Topic] = {}
|
|
||||||
self.agent_reputations: Dict[str, AgentReputation] = {}
|
|
||||||
self.moderation_log: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
def create_topic(self, agent_id: str, agent_address: str, title: str,
|
|
||||||
description: str, tags: List[str] = None) -> Dict[str, Any]:
|
|
||||||
"""Create a new forum topic"""
|
|
||||||
|
|
||||||
# Check if agent is banned
|
|
||||||
if self._is_agent_banned(agent_id):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Agent is banned from posting",
|
|
||||||
"error_code": "AGENT_BANNED"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate topic ID
|
|
||||||
topic_id = f"topic_{hashlib.sha256(f'{agent_id}_{title}_{datetime.now()}'.encode()).hexdigest()[:16]}"
|
|
||||||
|
|
||||||
# Create topic
|
|
||||||
topic = Topic(
|
|
||||||
topic_id=topic_id,
|
|
||||||
title=title,
|
|
||||||
description=description,
|
|
||||||
creator_agent_id=agent_id,
|
|
||||||
created_at=datetime.now(),
|
|
||||||
tags=tags or []
|
|
||||||
)
|
|
||||||
|
|
||||||
self.topics[topic_id] = topic
|
|
||||||
|
|
||||||
# Update agent reputation
|
|
||||||
self._update_agent_reputation(agent_id, message_count=1)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"topic_id": topic_id,
|
|
||||||
"topic": self._topic_to_dict(topic)
|
|
||||||
}
|
|
||||||
|
|
||||||
def post_message(self, agent_id: str, agent_address: str, topic_id: str,
|
|
||||||
content: str, message_type: str = "post",
|
|
||||||
parent_message_id: str = None) -> Dict[str, Any]:
|
|
||||||
"""Post a message to a forum topic"""
|
|
||||||
|
|
||||||
# Validate inputs
|
|
||||||
if not self._validate_agent(agent_id, agent_address):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid agent credentials",
|
|
||||||
"error_code": "INVALID_AGENT"
|
|
||||||
}
|
|
||||||
|
|
||||||
if self._is_agent_banned(agent_id):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Agent is banned from posting",
|
|
||||||
"error_code": "AGENT_BANNED"
|
|
||||||
}
|
|
||||||
|
|
||||||
if topic_id not in self.topics:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Topic not found",
|
|
||||||
"error_code": "TOPIC_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.topics[topic_id].is_locked:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Topic is locked",
|
|
||||||
"error_code": "TOPIC_LOCKED"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate message type
|
|
||||||
try:
|
|
||||||
msg_type = MessageType(message_type)
|
|
||||||
except ValueError:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid message type",
|
|
||||||
"error_code": "INVALID_MESSAGE_TYPE"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate message ID
|
|
||||||
message_id = f"msg_{hashlib.sha256(f'{agent_id}_{topic_id}_{content}_{datetime.now()}'.encode()).hexdigest()[:16]}"
|
|
||||||
|
|
||||||
# Create message
|
|
||||||
message = Message(
|
|
||||||
message_id=message_id,
|
|
||||||
agent_id=agent_id,
|
|
||||||
agent_address=agent_address,
|
|
||||||
topic=topic_id,
|
|
||||||
content=content,
|
|
||||||
message_type=msg_type,
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
parent_message_id=parent_message_id
|
|
||||||
)
|
|
||||||
|
|
||||||
self.messages[message_id] = message
|
|
||||||
|
|
||||||
# Update topic
|
|
||||||
self.topics[topic_id].message_count += 1
|
|
||||||
self.topics[topic_id].last_activity = datetime.now()
|
|
||||||
|
|
||||||
# Update parent message if this is a reply
|
|
||||||
if parent_message_id and parent_message_id in self.messages:
|
|
||||||
self.messages[parent_message_id].reply_count += 1
|
|
||||||
|
|
||||||
# Update agent reputation
|
|
||||||
self._update_agent_reputation(agent_id, message_count=1)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message_id": message_id,
|
|
||||||
"message": self._message_to_dict(message)
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_messages(self, topic_id: str, limit: int = 50, offset: int = 0,
|
|
||||||
sort_by: str = "timestamp") -> Dict[str, Any]:
|
|
||||||
"""Get messages from a topic"""
|
|
||||||
|
|
||||||
if topic_id not in self.topics:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Topic not found",
|
|
||||||
"error_code": "TOPIC_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get all messages for this topic
|
|
||||||
topic_messages = [
|
|
||||||
msg for msg in self.messages.values()
|
|
||||||
if msg.topic == topic_id and msg.status == MessageStatus.ACTIVE
|
|
||||||
]
|
|
||||||
|
|
||||||
# Sort messages
|
|
||||||
if sort_by == "timestamp":
|
|
||||||
topic_messages.sort(key=lambda x: x.timestamp, reverse=True)
|
|
||||||
elif sort_by == "upvotes":
|
|
||||||
topic_messages.sort(key=lambda x: x.upvotes, reverse=True)
|
|
||||||
elif sort_by == "replies":
|
|
||||||
topic_messages.sort(key=lambda x: x.reply_count, reverse=True)
|
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
total_messages = len(topic_messages)
|
|
||||||
paginated_messages = topic_messages[offset:offset + limit]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"messages": [self._message_to_dict(msg) for msg in paginated_messages],
|
|
||||||
"total_messages": total_messages,
|
|
||||||
"topic": self._topic_to_dict(self.topics[topic_id])
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_topics(self, limit: int = 50, offset: int = 0,
|
|
||||||
sort_by: str = "last_activity") -> Dict[str, Any]:
|
|
||||||
"""Get list of forum topics"""
|
|
||||||
|
|
||||||
# Sort topics
|
|
||||||
topic_list = list(self.topics.values())
|
|
||||||
|
|
||||||
if sort_by == "last_activity":
|
|
||||||
topic_list.sort(key=lambda x: x.last_activity, reverse=True)
|
|
||||||
elif sort_by == "created_at":
|
|
||||||
topic_list.sort(key=lambda x: x.created_at, reverse=True)
|
|
||||||
elif sort_by == "message_count":
|
|
||||||
topic_list.sort(key=lambda x: x.message_count, reverse=True)
|
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
total_topics = len(topic_list)
|
|
||||||
paginated_topics = topic_list[offset:offset + limit]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"topics": [self._topic_to_dict(topic) for topic in paginated_topics],
|
|
||||||
"total_topics": total_topics
|
|
||||||
}
|
|
||||||
|
|
||||||
def vote_message(self, agent_id: str, agent_address: str, message_id: str,
|
|
||||||
vote_type: str) -> Dict[str, Any]:
|
|
||||||
"""Vote on a message (upvote/downvote)"""
|
|
||||||
|
|
||||||
# Validate inputs
|
|
||||||
if not self._validate_agent(agent_id, agent_address):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid agent credentials",
|
|
||||||
"error_code": "INVALID_AGENT"
|
|
||||||
}
|
|
||||||
|
|
||||||
if message_id not in self.messages:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Message not found",
|
|
||||||
"error_code": "MESSAGE_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
if vote_type not in ["upvote", "downvote"]:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid vote type",
|
|
||||||
"error_code": "INVALID_VOTE_TYPE"
|
|
||||||
}
|
|
||||||
|
|
||||||
message = self.messages[message_id]
|
|
||||||
|
|
||||||
# Update vote counts
|
|
||||||
if vote_type == "upvote":
|
|
||||||
message.upvotes += 1
|
|
||||||
else:
|
|
||||||
message.downvotes += 1
|
|
||||||
|
|
||||||
# Update message author reputation
|
|
||||||
self._update_agent_reputation(
|
|
||||||
message.agent_id,
|
|
||||||
upvotes_received=message.upvotes,
|
|
||||||
downvotes_received=message.downvotes
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message_id": message_id,
|
|
||||||
"upvotes": message.upvotes,
|
|
||||||
"downvotes": message.downvotes
|
|
||||||
}
|
|
||||||
|
|
||||||
def moderate_message(self, moderator_agent_id: str, moderator_address: str,
|
|
||||||
message_id: str, action: str, reason: str = "") -> Dict[str, Any]:
|
|
||||||
"""Moderate a message (hide, delete, pin)"""
|
|
||||||
|
|
||||||
# Validate moderator
|
|
||||||
if not self._is_moderator(moderator_agent_id):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Insufficient permissions",
|
|
||||||
"error_code": "INSUFFICIENT_PERMISSIONS"
|
|
||||||
}
|
|
||||||
|
|
||||||
if message_id not in self.messages:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Message not found",
|
|
||||||
"error_code": "MESSAGE_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
message = self.messages[message_id]
|
|
||||||
|
|
||||||
# Apply moderation action
|
|
||||||
if action == "hide":
|
|
||||||
message.status = MessageStatus.HIDDEN
|
|
||||||
elif action == "delete":
|
|
||||||
message.status = MessageStatus.DELETED
|
|
||||||
elif action == "pin":
|
|
||||||
message.status = MessageStatus.PINNED
|
|
||||||
elif action == "unpin":
|
|
||||||
message.status = MessageStatus.ACTIVE
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid moderation action",
|
|
||||||
"error_code": "INVALID_ACTION"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log moderation action
|
|
||||||
self.moderation_log.append({
|
|
||||||
"timestamp": datetime.now(),
|
|
||||||
"moderator_agent_id": moderator_agent_id,
|
|
||||||
"message_id": message_id,
|
|
||||||
"action": action,
|
|
||||||
"reason": reason
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message_id": message_id,
|
|
||||||
"status": message.status.value
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_agent_reputation(self, agent_id: str) -> Dict[str, Any]:
|
|
||||||
"""Get an agent's reputation information"""
|
|
||||||
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Agent not found",
|
|
||||||
"error_code": "AGENT_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
reputation = self.agent_reputations[agent_id]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"agent_id": agent_id,
|
|
||||||
"reputation": self._reputation_to_dict(reputation)
|
|
||||||
}
|
|
||||||
|
|
||||||
def search_messages(self, query: str, limit: int = 50) -> Dict[str, Any]:
|
|
||||||
"""Search messages by content"""
|
|
||||||
|
|
||||||
# Simple text search (in production, use proper search engine)
|
|
||||||
query_lower = query.lower()
|
|
||||||
matching_messages = []
|
|
||||||
|
|
||||||
for message in self.messages.values():
|
|
||||||
if (message.status == MessageStatus.ACTIVE and
|
|
||||||
query_lower in message.content.lower()):
|
|
||||||
matching_messages.append(message)
|
|
||||||
|
|
||||||
# Sort by timestamp (most recent first)
|
|
||||||
matching_messages.sort(key=lambda x: x.timestamp, reverse=True)
|
|
||||||
|
|
||||||
# Limit results
|
|
||||||
limited_messages = matching_messages[:limit]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"query": query,
|
|
||||||
"messages": [self._message_to_dict(msg) for msg in limited_messages],
|
|
||||||
"total_matches": len(matching_messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
def _validate_agent(self, agent_id: str, agent_address: str) -> bool:
|
|
||||||
"""Validate agent credentials"""
|
|
||||||
# In a real implementation, this would verify the agent's signature
|
|
||||||
# For now, we'll do basic validation
|
|
||||||
return bool(agent_id and agent_address)
|
|
||||||
|
|
||||||
def _is_agent_banned(self, agent_id: str) -> bool:
|
|
||||||
"""Check if an agent is banned"""
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
return False
|
|
||||||
|
|
||||||
reputation = self.agent_reputations[agent_id]
|
|
||||||
|
|
||||||
if reputation.is_banned:
|
|
||||||
# Check if ban has expired
|
|
||||||
if reputation.ban_expires and datetime.now() > reputation.ban_expires:
|
|
||||||
reputation.is_banned = False
|
|
||||||
reputation.ban_expires = None
|
|
||||||
reputation.ban_reason = None
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _is_moderator(self, agent_id: str) -> bool:
|
|
||||||
"""Check if an agent is a moderator"""
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.agent_reputations[agent_id].is_moderator
|
|
||||||
|
|
||||||
def _update_agent_reputation(self, agent_id: str, message_count: int = 0,
|
|
||||||
upvotes_received: int = 0, downvotes_received: int = 0):
|
|
||||||
"""Update agent reputation"""
|
|
||||||
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
self.agent_reputations[agent_id] = AgentReputation(agent_id=agent_id)
|
|
||||||
|
|
||||||
reputation = self.agent_reputations[agent_id]
|
|
||||||
|
|
||||||
if message_count > 0:
|
|
||||||
reputation.message_count += message_count
|
|
||||||
|
|
||||||
if upvotes_received > 0:
|
|
||||||
reputation.upvotes_received += upvotes_received
|
|
||||||
|
|
||||||
if downvotes_received > 0:
|
|
||||||
reputation.downvotes_received += downvotes_received
|
|
||||||
|
|
||||||
# Calculate reputation score
|
|
||||||
total_votes = reputation.upvotes_received + reputation.downvotes_received
|
|
||||||
if total_votes > 0:
|
|
||||||
reputation.reputation_score = (reputation.upvotes_received - reputation.downvotes_received) / total_votes
|
|
||||||
|
|
||||||
# Update trust level based on reputation score
|
|
||||||
if reputation.reputation_score >= 0.8:
|
|
||||||
reputation.trust_level = 5
|
|
||||||
elif reputation.reputation_score >= 0.6:
|
|
||||||
reputation.trust_level = 4
|
|
||||||
elif reputation.reputation_score >= 0.4:
|
|
||||||
reputation.trust_level = 3
|
|
||||||
elif reputation.reputation_score >= 0.2:
|
|
||||||
reputation.trust_level = 2
|
|
||||||
else:
|
|
||||||
reputation.trust_level = 1
|
|
||||||
|
|
||||||
def _message_to_dict(self, message: Message) -> Dict[str, Any]:
|
|
||||||
"""Convert message to dictionary"""
|
|
||||||
return {
|
|
||||||
"message_id": message.message_id,
|
|
||||||
"agent_id": message.agent_id,
|
|
||||||
"agent_address": message.agent_address,
|
|
||||||
"topic": message.topic,
|
|
||||||
"content": message.content,
|
|
||||||
"message_type": message.message_type.value,
|
|
||||||
"timestamp": message.timestamp.isoformat(),
|
|
||||||
"parent_message_id": message.parent_message_id,
|
|
||||||
"reply_count": message.reply_count,
|
|
||||||
"upvotes": message.upvotes,
|
|
||||||
"downvotes": message.downvotes,
|
|
||||||
"status": message.status.value,
|
|
||||||
"metadata": message.metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
def _topic_to_dict(self, topic: Topic) -> Dict[str, Any]:
|
|
||||||
"""Convert topic to dictionary"""
|
|
||||||
return {
|
|
||||||
"topic_id": topic.topic_id,
|
|
||||||
"title": topic.title,
|
|
||||||
"description": topic.description,
|
|
||||||
"creator_agent_id": topic.creator_agent_id,
|
|
||||||
"created_at": topic.created_at.isoformat(),
|
|
||||||
"message_count": topic.message_count,
|
|
||||||
"last_activity": topic.last_activity.isoformat(),
|
|
||||||
"tags": topic.tags,
|
|
||||||
"is_pinned": topic.is_pinned,
|
|
||||||
"is_locked": topic.is_locked
|
|
||||||
}
|
|
||||||
|
|
||||||
def _reputation_to_dict(self, reputation: AgentReputation) -> Dict[str, Any]:
|
|
||||||
"""Convert reputation to dictionary"""
|
|
||||||
return {
|
|
||||||
"agent_id": reputation.agent_id,
|
|
||||||
"message_count": reputation.message_count,
|
|
||||||
"upvotes_received": reputation.upvotes_received,
|
|
||||||
"downvotes_received": reputation.downvotes_received,
|
|
||||||
"reputation_score": reputation.reputation_score,
|
|
||||||
"trust_level": reputation.trust_level,
|
|
||||||
"is_moderator": reputation.is_moderator,
|
|
||||||
"is_banned": reputation.is_banned,
|
|
||||||
"ban_reason": reputation.ban_reason,
|
|
||||||
"ban_expires": reputation.ban_expires.isoformat() if reputation.ban_expires else None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Global contract instance
|
|
||||||
messaging_contract = AgentMessagingContract()
|
|
||||||
@@ -1,584 +0,0 @@
|
|||||||
"""
|
|
||||||
AITBC Agent Wallet Security Implementation
|
|
||||||
|
|
||||||
This module implements the security layer for autonomous agent wallets,
|
|
||||||
integrating the guardian contract to prevent unlimited spending in case
|
|
||||||
of agent compromise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address
|
|
||||||
|
|
||||||
from .guardian_contract import (
|
|
||||||
GuardianContract,
|
|
||||||
SpendingLimit,
|
|
||||||
TimeLockConfig,
|
|
||||||
GuardianConfig,
|
|
||||||
create_guardian_contract,
|
|
||||||
CONSERVATIVE_CONFIG,
|
|
||||||
AGGRESSIVE_CONFIG,
|
|
||||||
HIGH_SECURITY_CONFIG
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentSecurityProfile:
|
|
||||||
"""Security profile for an agent"""
|
|
||||||
agent_address: str
|
|
||||||
security_level: str # "conservative", "aggressive", "high_security"
|
|
||||||
guardian_addresses: List[str]
|
|
||||||
custom_limits: Optional[Dict] = None
|
|
||||||
enabled: bool = True
|
|
||||||
created_at: datetime = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.created_at is None:
|
|
||||||
self.created_at = datetime.utcnow()
|
|
||||||
|
|
||||||
|
|
||||||
class AgentWalletSecurity:
|
|
||||||
"""
|
|
||||||
Security manager for autonomous agent wallets
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.agent_profiles: Dict[str, AgentSecurityProfile] = {}
|
|
||||||
self.guardian_contracts: Dict[str, GuardianContract] = {}
|
|
||||||
self.security_events: List[Dict] = []
|
|
||||||
|
|
||||||
# Default configurations
|
|
||||||
self.configurations = {
|
|
||||||
"conservative": CONSERVATIVE_CONFIG,
|
|
||||||
"aggressive": AGGRESSIVE_CONFIG,
|
|
||||||
"high_security": HIGH_SECURITY_CONFIG
|
|
||||||
}
|
|
||||||
|
|
||||||
def register_agent(self,
|
|
||||||
agent_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
guardian_addresses: List[str] = None,
|
|
||||||
custom_limits: Dict = None) -> Dict:
|
|
||||||
"""
|
|
||||||
Register an agent for security protection
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
security_level: Security level (conservative, aggressive, high_security)
|
|
||||||
guardian_addresses: List of guardian addresses for recovery
|
|
||||||
custom_limits: Custom spending limits (overrides security_level)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Registration result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
if agent_address in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent already registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate security level
|
|
||||||
if security_level not in self.configurations:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Invalid security level: {security_level}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Default guardians if none provided
|
|
||||||
if guardian_addresses is None:
|
|
||||||
guardian_addresses = [agent_address] # Self-guardian (should be overridden)
|
|
||||||
|
|
||||||
# Validate guardian addresses
|
|
||||||
guardian_addresses = [to_checksum_address(addr) for addr in guardian_addresses]
|
|
||||||
|
|
||||||
# Create security profile
|
|
||||||
profile = AgentSecurityProfile(
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
guardian_addresses=guardian_addresses,
|
|
||||||
custom_limits=custom_limits
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create guardian contract
|
|
||||||
config = self.configurations[security_level]
|
|
||||||
if custom_limits:
|
|
||||||
config.update(custom_limits)
|
|
||||||
|
|
||||||
guardian_contract = create_guardian_contract(
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardians=guardian_addresses,
|
|
||||||
**config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store profile and contract
|
|
||||||
self.agent_profiles[agent_address] = profile
|
|
||||||
self.guardian_contracts[agent_address] = guardian_contract
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="agent_registered",
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
guardian_count=len(guardian_addresses)
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "registered",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": security_level,
|
|
||||||
"guardian_addresses": guardian_addresses,
|
|
||||||
"limits": guardian_contract.config.limits,
|
|
||||||
"time_lock_threshold": guardian_contract.config.time_lock.threshold,
|
|
||||||
"registered_at": profile.created_at.isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Registration failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def protect_transaction(self,
|
|
||||||
agent_address: str,
|
|
||||||
to_address: str,
|
|
||||||
amount: int,
|
|
||||||
data: str = "") -> Dict:
|
|
||||||
"""
|
|
||||||
Protect a transaction with guardian contract
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
to_address: Recipient address
|
|
||||||
amount: Amount to transfer
|
|
||||||
data: Transaction data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Protection result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
# Check if agent is registered
|
|
||||||
if agent_address not in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "unprotected",
|
|
||||||
"reason": "Agent not registered for security protection",
|
|
||||||
"suggestion": "Register agent with register_agent() first"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check if protection is enabled
|
|
||||||
profile = self.agent_profiles[agent_address]
|
|
||||||
if not profile.enabled:
|
|
||||||
return {
|
|
||||||
"status": "unprotected",
|
|
||||||
"reason": "Security protection disabled for this agent"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get guardian contract
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
# Initiate transaction protection
|
|
||||||
result = guardian_contract.initiate_transaction(to_address, amount, data)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="transaction_protected",
|
|
||||||
agent_address=agent_address,
|
|
||||||
to_address=to_address,
|
|
||||||
amount=amount,
|
|
||||||
protection_status=result["status"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Transaction protection failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute_protected_transaction(self,
|
|
||||||
agent_address: str,
|
|
||||||
operation_id: str,
|
|
||||||
signature: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Execute a previously protected transaction
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
operation_id: Operation ID from protection
|
|
||||||
signature: Transaction signature
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Execution result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
result = guardian_contract.execute_transaction(operation_id, signature)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
if result["status"] == "executed":
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="transaction_executed",
|
|
||||||
agent_address=agent_address,
|
|
||||||
operation_id=operation_id,
|
|
||||||
transaction_hash=result.get("transaction_hash")
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Transaction execution failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def emergency_pause_agent(self, agent_address: str, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Emergency pause an agent's operations
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address initiating pause
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Pause result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
result = guardian_contract.emergency_pause(guardian_address)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
if result["status"] == "paused":
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="emergency_pause",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Emergency pause failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def update_agent_security(self,
|
|
||||||
agent_address: str,
|
|
||||||
new_limits: Dict,
|
|
||||||
guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Update security limits for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
new_limits: New spending limits
|
|
||||||
guardian_address: Guardian address making the change
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Update result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
# Create new spending limits
|
|
||||||
limits = SpendingLimit(
|
|
||||||
per_transaction=new_limits.get("per_transaction", 1000),
|
|
||||||
per_hour=new_limits.get("per_hour", 5000),
|
|
||||||
per_day=new_limits.get("per_day", 20000),
|
|
||||||
per_week=new_limits.get("per_week", 100000)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = guardian_contract.update_limits(limits, guardian_address)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
if result["status"] == "updated":
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="security_limits_updated",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address,
|
|
||||||
new_limits=new_limits
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Security update failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_agent_security_status(self, agent_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Get security status for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Security status
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
if agent_address not in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "not_registered",
|
|
||||||
"message": "Agent not registered for security protection"
|
|
||||||
}
|
|
||||||
|
|
||||||
profile = self.agent_profiles[agent_address]
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "protected",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": profile.security_level,
|
|
||||||
"enabled": profile.enabled,
|
|
||||||
"guardian_addresses": profile.guardian_addresses,
|
|
||||||
"registered_at": profile.created_at.isoformat(),
|
|
||||||
"spending_status": guardian_contract.get_spending_status(),
|
|
||||||
"pending_operations": guardian_contract.get_pending_operations(),
|
|
||||||
"recent_activity": guardian_contract.get_operation_history(10)
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Status check failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def list_protected_agents(self) -> List[Dict]:
|
|
||||||
"""List all protected agents"""
|
|
||||||
agents = []
|
|
||||||
|
|
||||||
for agent_address, profile in self.agent_profiles.items():
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
agents.append({
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": profile.security_level,
|
|
||||||
"enabled": profile.enabled,
|
|
||||||
"guardian_count": len(profile.guardian_addresses),
|
|
||||||
"pending_operations": len(guardian_contract.pending_operations),
|
|
||||||
"paused": guardian_contract.paused,
|
|
||||||
"emergency_mode": guardian_contract.emergency_mode,
|
|
||||||
"registered_at": profile.created_at.isoformat()
|
|
||||||
})
|
|
||||||
|
|
||||||
return sorted(agents, key=lambda x: x["registered_at"], reverse=True)
|
|
||||||
|
|
||||||
def get_security_events(self, agent_address: str = None, limit: int = 50) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Get security events
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Filter by agent address (optional)
|
|
||||||
limit: Maximum number of events
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Security events
|
|
||||||
"""
|
|
||||||
events = self.security_events
|
|
||||||
|
|
||||||
if agent_address:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
events = [e for e in events if e.get("agent_address") == agent_address]
|
|
||||||
|
|
||||||
return sorted(events, key=lambda x: x["timestamp"], reverse=True)[:limit]
|
|
||||||
|
|
||||||
def _log_security_event(self, **kwargs):
|
|
||||||
"""Log a security event"""
|
|
||||||
event = {
|
|
||||||
"timestamp": datetime.utcnow().isoformat(),
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
self.security_events.append(event)
|
|
||||||
|
|
||||||
def disable_agent_protection(self, agent_address: str, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Disable protection for an agent (guardian only)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Disable result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
if agent_address not in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
profile = self.agent_profiles[agent_address]
|
|
||||||
|
|
||||||
if guardian_address not in profile.guardian_addresses:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Not authorized: not a guardian"
|
|
||||||
}
|
|
||||||
|
|
||||||
profile.enabled = False
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="protection_disabled",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "disabled",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"disabled_at": datetime.utcnow().isoformat(),
|
|
||||||
"guardian": guardian_address
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Disable protection failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Global security manager instance
|
|
||||||
agent_wallet_security = AgentWalletSecurity()
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience functions for common operations
|
|
||||||
def register_agent_for_protection(agent_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
guardians: List[str] = None) -> Dict:
|
|
||||||
"""Register an agent for security protection"""
|
|
||||||
return agent_wallet_security.register_agent(
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
guardian_addresses=guardians
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def protect_agent_transaction(agent_address: str,
|
|
||||||
to_address: str,
|
|
||||||
amount: int,
|
|
||||||
data: str = "") -> Dict:
|
|
||||||
"""Protect a transaction for an agent"""
|
|
||||||
return agent_wallet_security.protect_transaction(
|
|
||||||
agent_address=agent_address,
|
|
||||||
to_address=to_address,
|
|
||||||
amount=amount,
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_security_summary(agent_address: str) -> Dict:
|
|
||||||
"""Get security summary for an agent"""
|
|
||||||
return agent_wallet_security.get_agent_security_status(agent_address)
|
|
||||||
|
|
||||||
|
|
||||||
# Security audit and monitoring functions
|
|
||||||
def generate_security_report() -> Dict:
|
|
||||||
"""Generate comprehensive security report"""
|
|
||||||
protected_agents = agent_wallet_security.list_protected_agents()
|
|
||||||
|
|
||||||
total_agents = len(protected_agents)
|
|
||||||
active_agents = len([a for a in protected_agents if a["enabled"]])
|
|
||||||
paused_agents = len([a for a in protected_agents if a["paused"]])
|
|
||||||
emergency_agents = len([a for a in protected_agents if a["emergency_mode"]])
|
|
||||||
|
|
||||||
recent_events = agent_wallet_security.get_security_events(limit=20)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"generated_at": datetime.utcnow().isoformat(),
|
|
||||||
"summary": {
|
|
||||||
"total_protected_agents": total_agents,
|
|
||||||
"active_agents": active_agents,
|
|
||||||
"paused_agents": paused_agents,
|
|
||||||
"emergency_mode_agents": emergency_agents,
|
|
||||||
"protection_coverage": f"{(active_agents / total_agents * 100):.1f}%" if total_agents > 0 else "0%"
|
|
||||||
},
|
|
||||||
"agents": protected_agents,
|
|
||||||
"recent_security_events": recent_events,
|
|
||||||
"security_levels": {
|
|
||||||
level: len([a for a in protected_agents if a["security_level"] == level])
|
|
||||||
for level in ["conservative", "aggressive", "high_security"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def detect_suspicious_activity(agent_address: str, hours: int = 24) -> Dict:
|
|
||||||
"""Detect suspicious activity for an agent"""
|
|
||||||
status = agent_wallet_security.get_agent_security_status(agent_address)
|
|
||||||
|
|
||||||
if status["status"] != "protected":
|
|
||||||
return {
|
|
||||||
"status": "not_protected",
|
|
||||||
"suspicious_activity": False
|
|
||||||
}
|
|
||||||
|
|
||||||
spending_status = status["spending_status"]
|
|
||||||
recent_events = agent_wallet_security.get_security_events(agent_address, limit=50)
|
|
||||||
|
|
||||||
# Suspicious patterns
|
|
||||||
suspicious_patterns = []
|
|
||||||
|
|
||||||
# Check for rapid spending
|
|
||||||
if spending_status["spent"]["current_hour"] > spending_status["current_limits"]["per_hour"] * 0.8:
|
|
||||||
suspicious_patterns.append("High hourly spending rate")
|
|
||||||
|
|
||||||
# Check for many small transactions (potential dust attack)
|
|
||||||
recent_tx_count = len([e for e in recent_events if e["event_type"] == "transaction_executed"])
|
|
||||||
if recent_tx_count > 20:
|
|
||||||
suspicious_patterns.append("High transaction frequency")
|
|
||||||
|
|
||||||
# Check for emergency pauses
|
|
||||||
recent_pauses = len([e for e in recent_events if e["event_type"] == "emergency_pause"])
|
|
||||||
if recent_pauses > 0:
|
|
||||||
suspicious_patterns.append("Recent emergency pauses detected")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "analyzed",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"suspicious_activity": len(suspicious_patterns) > 0,
|
|
||||||
"suspicious_patterns": suspicious_patterns,
|
|
||||||
"analysis_period_hours": hours,
|
|
||||||
"analyzed_at": datetime.utcnow().isoformat()
|
|
||||||
}
|
|
||||||
@@ -1,559 +0,0 @@
|
|||||||
"""
|
|
||||||
Smart Contract Escrow System
|
|
||||||
Handles automated payment holding and release for AI job marketplace
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from typing import Dict, List, Optional, Tuple, Set
|
|
||||||
from dataclasses import dataclass, asdict
|
|
||||||
from enum import Enum
|
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
class EscrowState(Enum):
|
|
||||||
CREATED = "created"
|
|
||||||
FUNDED = "funded"
|
|
||||||
JOB_STARTED = "job_started"
|
|
||||||
JOB_COMPLETED = "job_completed"
|
|
||||||
DISPUTED = "disputed"
|
|
||||||
RESOLVED = "resolved"
|
|
||||||
RELEASED = "released"
|
|
||||||
REFUNDED = "refunded"
|
|
||||||
EXPIRED = "expired"
|
|
||||||
|
|
||||||
class DisputeReason(Enum):
|
|
||||||
QUALITY_ISSUES = "quality_issues"
|
|
||||||
DELIVERY_LATE = "delivery_late"
|
|
||||||
INCOMPLETE_WORK = "incomplete_work"
|
|
||||||
TECHNICAL_ISSUES = "technical_issues"
|
|
||||||
PAYMENT_DISPUTE = "payment_dispute"
|
|
||||||
OTHER = "other"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EscrowContract:
|
|
||||||
contract_id: str
|
|
||||||
job_id: str
|
|
||||||
client_address: str
|
|
||||||
agent_address: str
|
|
||||||
amount: Decimal
|
|
||||||
fee_rate: Decimal # Platform fee rate
|
|
||||||
created_at: float
|
|
||||||
expires_at: float
|
|
||||||
state: EscrowState
|
|
||||||
milestones: List[Dict]
|
|
||||||
current_milestone: int
|
|
||||||
dispute_reason: Optional[DisputeReason]
|
|
||||||
dispute_evidence: List[Dict]
|
|
||||||
resolution: Optional[Dict]
|
|
||||||
released_amount: Decimal
|
|
||||||
refunded_amount: Decimal
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Milestone:
|
|
||||||
milestone_id: str
|
|
||||||
description: str
|
|
||||||
amount: Decimal
|
|
||||||
completed: bool
|
|
||||||
completed_at: Optional[float]
|
|
||||||
verified: bool
|
|
||||||
|
|
||||||
class EscrowManager:
|
|
||||||
"""Manages escrow contracts for AI job marketplace"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.escrow_contracts: Dict[str, EscrowContract] = {}
|
|
||||||
self.active_contracts: Set[str] = set()
|
|
||||||
self.disputed_contracts: Set[str] = set()
|
|
||||||
|
|
||||||
# Escrow parameters
|
|
||||||
self.default_fee_rate = Decimal('0.025') # 2.5% platform fee
|
|
||||||
self.max_contract_duration = 86400 * 30 # 30 days
|
|
||||||
self.dispute_timeout = 86400 * 7 # 7 days for dispute resolution
|
|
||||||
self.min_dispute_evidence = 1
|
|
||||||
self.max_dispute_evidence = 10
|
|
||||||
|
|
||||||
# Milestone parameters
|
|
||||||
self.min_milestone_amount = Decimal('0.01')
|
|
||||||
self.max_milestones = 10
|
|
||||||
self.verification_timeout = 86400 # 24 hours for milestone verification
|
|
||||||
|
|
||||||
async def create_contract(self, job_id: str, client_address: str, agent_address: str,
|
|
||||||
amount: Decimal, fee_rate: Optional[Decimal] = None,
|
|
||||||
milestones: Optional[List[Dict]] = None,
|
|
||||||
duration_days: int = 30) -> Tuple[bool, str, Optional[str]]:
|
|
||||||
"""Create new escrow contract"""
|
|
||||||
try:
|
|
||||||
# Validate inputs
|
|
||||||
if not self._validate_contract_inputs(job_id, client_address, agent_address, amount):
|
|
||||||
return False, "Invalid contract inputs", None
|
|
||||||
|
|
||||||
# Calculate fee
|
|
||||||
fee_rate = fee_rate or self.default_fee_rate
|
|
||||||
platform_fee = amount * fee_rate
|
|
||||||
total_amount = amount + platform_fee
|
|
||||||
|
|
||||||
# Validate milestones
|
|
||||||
validated_milestones = []
|
|
||||||
if milestones:
|
|
||||||
validated_milestones = await self._validate_milestones(milestones, amount)
|
|
||||||
if not validated_milestones:
|
|
||||||
return False, "Invalid milestones configuration", None
|
|
||||||
else:
|
|
||||||
# Create single milestone for full amount
|
|
||||||
validated_milestones = [{
|
|
||||||
'milestone_id': 'milestone_1',
|
|
||||||
'description': 'Complete job',
|
|
||||||
'amount': amount,
|
|
||||||
'completed': False
|
|
||||||
}]
|
|
||||||
|
|
||||||
# Create contract
|
|
||||||
contract_id = self._generate_contract_id(client_address, agent_address, job_id)
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
contract = EscrowContract(
|
|
||||||
contract_id=contract_id,
|
|
||||||
job_id=job_id,
|
|
||||||
client_address=client_address,
|
|
||||||
agent_address=agent_address,
|
|
||||||
amount=total_amount,
|
|
||||||
fee_rate=fee_rate,
|
|
||||||
created_at=current_time,
|
|
||||||
expires_at=current_time + (duration_days * 86400),
|
|
||||||
state=EscrowState.CREATED,
|
|
||||||
milestones=validated_milestones,
|
|
||||||
current_milestone=0,
|
|
||||||
dispute_reason=None,
|
|
||||||
dispute_evidence=[],
|
|
||||||
resolution=None,
|
|
||||||
released_amount=Decimal('0'),
|
|
||||||
refunded_amount=Decimal('0')
|
|
||||||
)
|
|
||||||
|
|
||||||
self.escrow_contracts[contract_id] = contract
|
|
||||||
|
|
||||||
log_info(f"Escrow contract created: {contract_id} for job {job_id}")
|
|
||||||
return True, "Contract created successfully", contract_id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return False, f"Contract creation failed: {str(e)}", None
|
|
||||||
|
|
||||||
def _validate_contract_inputs(self, job_id: str, client_address: str,
|
|
||||||
agent_address: str, amount: Decimal) -> bool:
|
|
||||||
"""Validate contract creation inputs"""
|
|
||||||
if not all([job_id, client_address, agent_address]):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Validate addresses (simplified)
|
|
||||||
if not (client_address.startswith('0x') and len(client_address) == 42):
|
|
||||||
return False
|
|
||||||
if not (agent_address.startswith('0x') and len(agent_address) == 42):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Validate amount
|
|
||||||
if amount <= 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check for existing contract
|
|
||||||
for contract in self.escrow_contracts.values():
|
|
||||||
if contract.job_id == job_id:
|
|
||||||
return False # Contract already exists for this job
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _validate_milestones(self, milestones: List[Dict], total_amount: Decimal) -> Optional[List[Dict]]:
|
|
||||||
"""Validate milestone configuration"""
|
|
||||||
if not milestones or len(milestones) > self.max_milestones:
|
|
||||||
return None
|
|
||||||
|
|
||||||
validated_milestones = []
|
|
||||||
milestone_total = Decimal('0')
|
|
||||||
|
|
||||||
for i, milestone_data in enumerate(milestones):
|
|
||||||
# Validate required fields
|
|
||||||
required_fields = ['milestone_id', 'description', 'amount']
|
|
||||||
if not all(field in milestone_data for field in required_fields):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Validate amount
|
|
||||||
amount = Decimal(str(milestone_data['amount']))
|
|
||||||
if amount < self.min_milestone_amount:
|
|
||||||
return None
|
|
||||||
|
|
||||||
milestone_total += amount
|
|
||||||
validated_milestones.append({
|
|
||||||
'milestone_id': milestone_data['milestone_id'],
|
|
||||||
'description': milestone_data['description'],
|
|
||||||
'amount': amount,
|
|
||||||
'completed': False
|
|
||||||
})
|
|
||||||
|
|
||||||
# Check if milestone amounts sum to total
|
|
||||||
if abs(milestone_total - total_amount) > Decimal('0.01'): # Allow small rounding difference
|
|
||||||
return None
|
|
||||||
|
|
||||||
return validated_milestones
|
|
||||||
|
|
||||||
def _generate_contract_id(self, client_address: str, agent_address: str, job_id: str) -> str:
|
|
||||||
"""Generate unique contract ID"""
|
|
||||||
import hashlib
|
|
||||||
content = f"{client_address}:{agent_address}:{job_id}:{time.time()}"
|
|
||||||
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
|
||||||
|
|
||||||
async def fund_contract(self, contract_id: str, payment_tx_hash: str) -> Tuple[bool, str]:
|
|
||||||
"""Fund escrow contract"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state != EscrowState.CREATED:
|
|
||||||
return False, f"Cannot fund contract in {contract.state.value} state"
|
|
||||||
|
|
||||||
# In real implementation, this would verify the payment transaction
|
|
||||||
# For now, assume payment is valid
|
|
||||||
|
|
||||||
contract.state = EscrowState.FUNDED
|
|
||||||
self.active_contracts.add(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Contract funded: {contract_id}")
|
|
||||||
return True, "Contract funded successfully"
|
|
||||||
|
|
||||||
async def start_job(self, contract_id: str) -> Tuple[bool, str]:
|
|
||||||
"""Mark job as started"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state != EscrowState.FUNDED:
|
|
||||||
return False, f"Cannot start job in {contract.state.value} state"
|
|
||||||
|
|
||||||
contract.state = EscrowState.JOB_STARTED
|
|
||||||
|
|
||||||
log_info(f"Job started for contract: {contract_id}")
|
|
||||||
return True, "Job started successfully"
|
|
||||||
|
|
||||||
async def complete_milestone(self, contract_id: str, milestone_id: str,
|
|
||||||
evidence: Dict = None) -> Tuple[bool, str]:
|
|
||||||
"""Mark milestone as completed"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state not in [EscrowState.JOB_STARTED, EscrowState.JOB_COMPLETED]:
|
|
||||||
return False, f"Cannot complete milestone in {contract.state.value} state"
|
|
||||||
|
|
||||||
# Find milestone
|
|
||||||
milestone = None
|
|
||||||
for ms in contract.milestones:
|
|
||||||
if ms['milestone_id'] == milestone_id:
|
|
||||||
milestone = ms
|
|
||||||
break
|
|
||||||
|
|
||||||
if not milestone:
|
|
||||||
return False, "Milestone not found"
|
|
||||||
|
|
||||||
if milestone['completed']:
|
|
||||||
return False, "Milestone already completed"
|
|
||||||
|
|
||||||
# Mark as completed
|
|
||||||
milestone['completed'] = True
|
|
||||||
milestone['completed_at'] = time.time()
|
|
||||||
|
|
||||||
# Add evidence if provided
|
|
||||||
if evidence:
|
|
||||||
milestone['evidence'] = evidence
|
|
||||||
|
|
||||||
# Check if all milestones are completed
|
|
||||||
all_completed = all(ms['completed'] for ms in contract.milestones)
|
|
||||||
if all_completed:
|
|
||||||
contract.state = EscrowState.JOB_COMPLETED
|
|
||||||
|
|
||||||
log_info(f"Milestone {milestone_id} completed for contract: {contract_id}")
|
|
||||||
return True, "Milestone completed successfully"
|
|
||||||
|
|
||||||
async def verify_milestone(self, contract_id: str, milestone_id: str,
|
|
||||||
verified: bool, feedback: str = "") -> Tuple[bool, str]:
|
|
||||||
"""Verify milestone completion"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
# Find milestone
|
|
||||||
milestone = None
|
|
||||||
for ms in contract.milestones:
|
|
||||||
if ms['milestone_id'] == milestone_id:
|
|
||||||
milestone = ms
|
|
||||||
break
|
|
||||||
|
|
||||||
if not milestone:
|
|
||||||
return False, "Milestone not found"
|
|
||||||
|
|
||||||
if not milestone['completed']:
|
|
||||||
return False, "Milestone not completed yet"
|
|
||||||
|
|
||||||
# Set verification status
|
|
||||||
milestone['verified'] = verified
|
|
||||||
milestone['verification_feedback'] = feedback
|
|
||||||
|
|
||||||
if verified:
|
|
||||||
# Release milestone payment
|
|
||||||
await self._release_milestone_payment(contract_id, milestone_id)
|
|
||||||
else:
|
|
||||||
# Create dispute if verification fails
|
|
||||||
await self._create_dispute(contract_id, DisputeReason.QUALITY_ISSUES,
|
|
||||||
f"Milestone {milestone_id} verification failed: {feedback}")
|
|
||||||
|
|
||||||
log_info(f"Milestone {milestone_id} verification: {verified} for contract: {contract_id}")
|
|
||||||
return True, "Milestone verification processed"
|
|
||||||
|
|
||||||
async def _release_milestone_payment(self, contract_id: str, milestone_id: str):
|
|
||||||
"""Release payment for verified milestone"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Find milestone
|
|
||||||
milestone = None
|
|
||||||
for ms in contract.milestones:
|
|
||||||
if ms['milestone_id'] == milestone_id:
|
|
||||||
milestone = ms
|
|
||||||
break
|
|
||||||
|
|
||||||
if not milestone:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Calculate payment amount (minus platform fee)
|
|
||||||
milestone_amount = Decimal(str(milestone['amount']))
|
|
||||||
platform_fee = milestone_amount * contract.fee_rate
|
|
||||||
payment_amount = milestone_amount - platform_fee
|
|
||||||
|
|
||||||
# Update released amount
|
|
||||||
contract.released_amount += payment_amount
|
|
||||||
|
|
||||||
# In real implementation, this would trigger actual payment transfer
|
|
||||||
log_info(f"Released {payment_amount} for milestone {milestone_id} in contract {contract_id}")
|
|
||||||
|
|
||||||
async def release_full_payment(self, contract_id: str) -> Tuple[bool, str]:
|
|
||||||
"""Release full payment to agent"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state != EscrowState.JOB_COMPLETED:
|
|
||||||
return False, f"Cannot release payment in {contract.state.value} state"
|
|
||||||
|
|
||||||
# Check if all milestones are verified
|
|
||||||
all_verified = all(ms.get('verified', False) for ms in contract.milestones)
|
|
||||||
if not all_verified:
|
|
||||||
return False, "Not all milestones are verified"
|
|
||||||
|
|
||||||
# Calculate remaining payment
|
|
||||||
total_milestone_amount = sum(Decimal(str(ms['amount'])) for ms in contract.milestones)
|
|
||||||
platform_fee_total = total_milestone_amount * contract.fee_rate
|
|
||||||
remaining_payment = total_milestone_amount - contract.released_amount - platform_fee_total
|
|
||||||
|
|
||||||
if remaining_payment > 0:
|
|
||||||
contract.released_amount += remaining_payment
|
|
||||||
|
|
||||||
contract.state = EscrowState.RELEASED
|
|
||||||
self.active_contracts.discard(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Full payment released for contract: {contract_id}")
|
|
||||||
return True, "Payment released successfully"
|
|
||||||
|
|
||||||
async def create_dispute(self, contract_id: str, reason: DisputeReason,
|
|
||||||
description: str, evidence: List[Dict] = None) -> Tuple[bool, str]:
|
|
||||||
"""Create dispute for contract"""
|
|
||||||
return await self._create_dispute(contract_id, reason, description, evidence)
|
|
||||||
|
|
||||||
async def _create_dispute(self, contract_id: str, reason: DisputeReason,
|
|
||||||
description: str, evidence: List[Dict] = None):
|
|
||||||
"""Internal dispute creation method"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state == EscrowState.DISPUTED:
|
|
||||||
return False, "Contract already disputed"
|
|
||||||
|
|
||||||
if contract.state not in [EscrowState.FUNDED, EscrowState.JOB_STARTED, EscrowState.JOB_COMPLETED]:
|
|
||||||
return False, f"Cannot dispute contract in {contract.state.value} state"
|
|
||||||
|
|
||||||
# Validate evidence
|
|
||||||
if evidence and (len(evidence) < self.min_dispute_evidence or len(evidence) > self.max_dispute_evidence):
|
|
||||||
return False, f"Invalid evidence count: {len(evidence)}"
|
|
||||||
|
|
||||||
# Create dispute
|
|
||||||
contract.state = EscrowState.DISPUTED
|
|
||||||
contract.dispute_reason = reason
|
|
||||||
contract.dispute_evidence = evidence or []
|
|
||||||
contract.dispute_created_at = time.time()
|
|
||||||
|
|
||||||
self.disputed_contracts.add(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Dispute created for contract: {contract_id} - {reason.value}")
|
|
||||||
return True, "Dispute created successfully"
|
|
||||||
|
|
||||||
async def resolve_dispute(self, contract_id: str, resolution: Dict) -> Tuple[bool, str]:
|
|
||||||
"""Resolve dispute with specified outcome"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state != EscrowState.DISPUTED:
|
|
||||||
return False, f"Contract not in disputed state: {contract.state.value}"
|
|
||||||
|
|
||||||
# Validate resolution
|
|
||||||
required_fields = ['winner', 'client_refund', 'agent_payment']
|
|
||||||
if not all(field in resolution for field in required_fields):
|
|
||||||
return False, "Invalid resolution format"
|
|
||||||
|
|
||||||
winner = resolution['winner']
|
|
||||||
client_refund = Decimal(str(resolution['client_refund']))
|
|
||||||
agent_payment = Decimal(str(resolution['agent_payment']))
|
|
||||||
|
|
||||||
# Validate amounts
|
|
||||||
total_refund = client_refund + agent_payment
|
|
||||||
if total_refund > contract.amount:
|
|
||||||
return False, "Refund amounts exceed contract amount"
|
|
||||||
|
|
||||||
# Apply resolution
|
|
||||||
contract.resolution = resolution
|
|
||||||
contract.state = EscrowState.RESOLVED
|
|
||||||
|
|
||||||
# Update amounts
|
|
||||||
contract.released_amount += agent_payment
|
|
||||||
contract.refunded_amount += client_refund
|
|
||||||
|
|
||||||
# Remove from disputed contracts
|
|
||||||
self.disputed_contracts.discard(contract_id)
|
|
||||||
self.active_contracts.discard(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Dispute resolved for contract: {contract_id} - Winner: {winner}")
|
|
||||||
return True, "Dispute resolved successfully"
|
|
||||||
|
|
||||||
async def refund_contract(self, contract_id: str, reason: str = "") -> Tuple[bool, str]:
|
|
||||||
"""Refund contract to client"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state in [EscrowState.RELEASED, EscrowState.REFUNDED, EscrowState.EXPIRED]:
|
|
||||||
return False, f"Cannot refund contract in {contract.state.value} state"
|
|
||||||
|
|
||||||
# Calculate refund amount (minus any released payments)
|
|
||||||
refund_amount = contract.amount - contract.released_amount
|
|
||||||
|
|
||||||
if refund_amount <= 0:
|
|
||||||
return False, "No amount available for refund"
|
|
||||||
|
|
||||||
contract.state = EscrowState.REFUNDED
|
|
||||||
contract.refunded_amount = refund_amount
|
|
||||||
|
|
||||||
self.active_contracts.discard(contract_id)
|
|
||||||
self.disputed_contracts.discard(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Contract refunded: {contract_id} - Amount: {refund_amount}")
|
|
||||||
return True, "Contract refunded successfully"
|
|
||||||
|
|
||||||
async def expire_contract(self, contract_id: str) -> Tuple[bool, str]:
|
|
||||||
"""Mark contract as expired"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if time.time() < contract.expires_at:
|
|
||||||
return False, "Contract has not expired yet"
|
|
||||||
|
|
||||||
if contract.state in [EscrowState.RELEASED, EscrowState.REFUNDED, EscrowState.EXPIRED]:
|
|
||||||
return False, f"Contract already in final state: {contract.state.value}"
|
|
||||||
|
|
||||||
# Auto-refund if no work has been done
|
|
||||||
if contract.state == EscrowState.FUNDED:
|
|
||||||
return await self.refund_contract(contract_id, "Contract expired")
|
|
||||||
|
|
||||||
# Handle other states based on work completion
|
|
||||||
contract.state = EscrowState.EXPIRED
|
|
||||||
self.active_contracts.discard(contract_id)
|
|
||||||
self.disputed_contracts.discard(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Contract expired: {contract_id}")
|
|
||||||
return True, "Contract expired successfully"
|
|
||||||
|
|
||||||
async def get_contract_info(self, contract_id: str) -> Optional[EscrowContract]:
|
|
||||||
"""Get contract information"""
|
|
||||||
return self.escrow_contracts.get(contract_id)
|
|
||||||
|
|
||||||
async def get_contracts_by_client(self, client_address: str) -> List[EscrowContract]:
|
|
||||||
"""Get contracts for specific client"""
|
|
||||||
return [
|
|
||||||
contract for contract in self.escrow_contracts.values()
|
|
||||||
if contract.client_address == client_address
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_contracts_by_agent(self, agent_address: str) -> List[EscrowContract]:
|
|
||||||
"""Get contracts for specific agent"""
|
|
||||||
return [
|
|
||||||
contract for contract in self.escrow_contracts.values()
|
|
||||||
if contract.agent_address == agent_address
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_active_contracts(self) -> List[EscrowContract]:
|
|
||||||
"""Get all active contracts"""
|
|
||||||
return [
|
|
||||||
self.escrow_contracts[contract_id]
|
|
||||||
for contract_id in self.active_contracts
|
|
||||||
if contract_id in self.escrow_contracts
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_disputed_contracts(self) -> List[EscrowContract]:
|
|
||||||
"""Get all disputed contracts"""
|
|
||||||
return [
|
|
||||||
self.escrow_contracts[contract_id]
|
|
||||||
for contract_id in self.disputed_contracts
|
|
||||||
if contract_id in self.escrow_contracts
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_escrow_statistics(self) -> Dict:
|
|
||||||
"""Get escrow system statistics"""
|
|
||||||
total_contracts = len(self.escrow_contracts)
|
|
||||||
active_count = len(self.active_contracts)
|
|
||||||
disputed_count = len(self.disputed_contracts)
|
|
||||||
|
|
||||||
# State distribution
|
|
||||||
state_counts = {}
|
|
||||||
for contract in self.escrow_contracts.values():
|
|
||||||
state = contract.state.value
|
|
||||||
state_counts[state] = state_counts.get(state, 0) + 1
|
|
||||||
|
|
||||||
# Financial statistics
|
|
||||||
total_amount = sum(contract.amount for contract in self.escrow_contracts.values())
|
|
||||||
total_released = sum(contract.released_amount for contract in self.escrow_contracts.values())
|
|
||||||
total_refunded = sum(contract.refunded_amount for contract in self.escrow_contracts.values())
|
|
||||||
total_fees = total_amount - total_released - total_refunded
|
|
||||||
|
|
||||||
return {
|
|
||||||
'total_contracts': total_contracts,
|
|
||||||
'active_contracts': active_count,
|
|
||||||
'disputed_contracts': disputed_count,
|
|
||||||
'state_distribution': state_counts,
|
|
||||||
'total_amount': float(total_amount),
|
|
||||||
'total_released': float(total_released),
|
|
||||||
'total_refunded': float(total_refunded),
|
|
||||||
'total_fees': float(total_fees),
|
|
||||||
'average_contract_value': float(total_amount / total_contracts) if total_contracts > 0 else 0
|
|
||||||
}
|
|
||||||
|
|
||||||
# Global escrow manager
|
|
||||||
escrow_manager: Optional[EscrowManager] = None
|
|
||||||
|
|
||||||
def get_escrow_manager() -> Optional[EscrowManager]:
|
|
||||||
"""Get global escrow manager"""
|
|
||||||
return escrow_manager
|
|
||||||
|
|
||||||
def create_escrow_manager() -> EscrowManager:
|
|
||||||
"""Create and set global escrow manager"""
|
|
||||||
global escrow_manager
|
|
||||||
escrow_manager = EscrowManager()
|
|
||||||
return escrow_manager
|
|
||||||
@@ -1,405 +0,0 @@
|
|||||||
"""
|
|
||||||
Fixed Guardian Configuration with Proper Guardian Setup
|
|
||||||
Addresses the critical vulnerability where guardian lists were empty
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address, keccak
|
|
||||||
|
|
||||||
from .guardian_contract import (
|
|
||||||
SpendingLimit,
|
|
||||||
TimeLockConfig,
|
|
||||||
GuardianConfig,
|
|
||||||
GuardianContract
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GuardianSetup:
|
|
||||||
"""Guardian setup configuration"""
|
|
||||||
primary_guardian: str # Main guardian address
|
|
||||||
backup_guardians: List[str] # Backup guardian addresses
|
|
||||||
multisig_threshold: int # Number of signatures required
|
|
||||||
emergency_contacts: List[str] # Additional emergency contacts
|
|
||||||
|
|
||||||
|
|
||||||
class SecureGuardianManager:
|
|
||||||
"""
|
|
||||||
Secure guardian management with proper initialization
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.guardian_registrations: Dict[str, GuardianSetup] = {}
|
|
||||||
self.guardian_contracts: Dict[str, GuardianContract] = {}
|
|
||||||
|
|
||||||
def create_guardian_setup(
|
|
||||||
self,
|
|
||||||
agent_address: str,
|
|
||||||
owner_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
custom_guardians: Optional[List[str]] = None
|
|
||||||
) -> GuardianSetup:
|
|
||||||
"""
|
|
||||||
Create a proper guardian setup for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
owner_address: Owner of the agent
|
|
||||||
security_level: Security level (conservative, aggressive, high_security)
|
|
||||||
custom_guardians: Optional custom guardian addresses
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Guardian setup configuration
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
owner_address = to_checksum_address(owner_address)
|
|
||||||
|
|
||||||
# Determine guardian requirements based on security level
|
|
||||||
if security_level == "conservative":
|
|
||||||
required_guardians = 3
|
|
||||||
multisig_threshold = 2
|
|
||||||
elif security_level == "aggressive":
|
|
||||||
required_guardians = 2
|
|
||||||
multisig_threshold = 2
|
|
||||||
elif security_level == "high_security":
|
|
||||||
required_guardians = 5
|
|
||||||
multisig_threshold = 3
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid security level: {security_level}")
|
|
||||||
|
|
||||||
# Build guardian list
|
|
||||||
guardians = []
|
|
||||||
|
|
||||||
# Always include the owner as primary guardian
|
|
||||||
guardians.append(owner_address)
|
|
||||||
|
|
||||||
# Add custom guardians if provided
|
|
||||||
if custom_guardians:
|
|
||||||
for guardian in custom_guardians:
|
|
||||||
guardian = to_checksum_address(guardian)
|
|
||||||
if guardian not in guardians:
|
|
||||||
guardians.append(guardian)
|
|
||||||
|
|
||||||
# Generate backup guardians if needed
|
|
||||||
while len(guardians) < required_guardians:
|
|
||||||
# Generate a deterministic backup guardian based on agent address
|
|
||||||
# In production, these would be trusted service addresses
|
|
||||||
backup_index = len(guardians) - 1 # -1 because owner is already included
|
|
||||||
backup_guardian = self._generate_backup_guardian(agent_address, backup_index)
|
|
||||||
|
|
||||||
if backup_guardian not in guardians:
|
|
||||||
guardians.append(backup_guardian)
|
|
||||||
|
|
||||||
# Create setup
|
|
||||||
setup = GuardianSetup(
|
|
||||||
primary_guardian=owner_address,
|
|
||||||
backup_guardians=[g for g in guardians if g != owner_address],
|
|
||||||
multisig_threshold=multisig_threshold,
|
|
||||||
emergency_contacts=guardians.copy()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.guardian_registrations[agent_address] = setup
|
|
||||||
|
|
||||||
return setup
|
|
||||||
|
|
||||||
def _generate_backup_guardian(self, agent_address: str, index: int) -> str:
|
|
||||||
"""
|
|
||||||
Generate deterministic backup guardian address
|
|
||||||
|
|
||||||
In production, these would be pre-registered trusted guardian addresses
|
|
||||||
"""
|
|
||||||
# Create a deterministic address based on agent address and index
|
|
||||||
seed = f"{agent_address}_{index}_backup_guardian"
|
|
||||||
hash_result = keccak(seed.encode())
|
|
||||||
|
|
||||||
# Use the hash to generate a valid address
|
|
||||||
address_bytes = hash_result[-20:] # Take last 20 bytes
|
|
||||||
address = "0x" + address_bytes.hex()
|
|
||||||
|
|
||||||
return to_checksum_address(address)
|
|
||||||
|
|
||||||
def create_secure_guardian_contract(
|
|
||||||
self,
|
|
||||||
agent_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
custom_guardians: Optional[List[str]] = None
|
|
||||||
) -> GuardianContract:
|
|
||||||
"""
|
|
||||||
Create a guardian contract with proper guardian configuration
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
security_level: Security level
|
|
||||||
custom_guardians: Optional custom guardian addresses
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured guardian contract
|
|
||||||
"""
|
|
||||||
# Create guardian setup
|
|
||||||
setup = self.create_guardian_setup(
|
|
||||||
agent_address=agent_address,
|
|
||||||
owner_address=agent_address, # Agent is its own owner initially
|
|
||||||
security_level=security_level,
|
|
||||||
custom_guardians=custom_guardians
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get security configuration
|
|
||||||
config = self._get_security_config(security_level, setup)
|
|
||||||
|
|
||||||
# Create contract
|
|
||||||
contract = GuardianContract(agent_address, config)
|
|
||||||
|
|
||||||
# Store contract
|
|
||||||
self.guardian_contracts[agent_address] = contract
|
|
||||||
|
|
||||||
return contract
|
|
||||||
|
|
||||||
def _get_security_config(self, security_level: str, setup: GuardianSetup) -> GuardianConfig:
|
|
||||||
"""Get security configuration with proper guardian list"""
|
|
||||||
|
|
||||||
# Build guardian list
|
|
||||||
all_guardians = [setup.primary_guardian] + setup.backup_guardians
|
|
||||||
|
|
||||||
if security_level == "conservative":
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=1000,
|
|
||||||
per_hour=5000,
|
|
||||||
per_day=20000,
|
|
||||||
per_week=100000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=5000,
|
|
||||||
delay_hours=24,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=all_guardians,
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False,
|
|
||||||
multisig_threshold=setup.multisig_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
elif security_level == "aggressive":
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=5000,
|
|
||||||
per_hour=25000,
|
|
||||||
per_day=100000,
|
|
||||||
per_week=500000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=20000,
|
|
||||||
delay_hours=12,
|
|
||||||
max_delay_hours=72
|
|
||||||
),
|
|
||||||
guardians=all_guardians,
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False,
|
|
||||||
multisig_threshold=setup.multisig_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
elif security_level == "high_security":
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=500,
|
|
||||||
per_hour=2000,
|
|
||||||
per_day=8000,
|
|
||||||
per_week=40000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=2000,
|
|
||||||
delay_hours=48,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=all_guardians,
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False,
|
|
||||||
multisig_threshold=setup.multisig_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid security level: {security_level}")
|
|
||||||
|
|
||||||
def test_emergency_pause(self, agent_address: str, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Test emergency pause functionality
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent address
|
|
||||||
guardian_address: Guardian attempting pause
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Test result
|
|
||||||
"""
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
contract = self.guardian_contracts[agent_address]
|
|
||||||
return contract.emergency_pause(guardian_address)
|
|
||||||
|
|
||||||
def verify_guardian_authorization(self, agent_address: str, guardian_address: str) -> bool:
|
|
||||||
"""
|
|
||||||
Verify if a guardian is authorized for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent address
|
|
||||||
guardian_address: Guardian address to verify
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if guardian is authorized
|
|
||||||
"""
|
|
||||||
if agent_address not in self.guardian_registrations:
|
|
||||||
return False
|
|
||||||
|
|
||||||
setup = self.guardian_registrations[agent_address]
|
|
||||||
all_guardians = [setup.primary_guardian] + setup.backup_guardians
|
|
||||||
|
|
||||||
return to_checksum_address(guardian_address) in [
|
|
||||||
to_checksum_address(g) for g in all_guardians
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_guardian_summary(self, agent_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Get guardian setup summary for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Guardian summary
|
|
||||||
"""
|
|
||||||
if agent_address not in self.guardian_registrations:
|
|
||||||
return {"error": "Agent not registered"}
|
|
||||||
|
|
||||||
setup = self.guardian_registrations[agent_address]
|
|
||||||
contract = self.guardian_contracts.get(agent_address)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"primary_guardian": setup.primary_guardian,
|
|
||||||
"backup_guardians": setup.backup_guardians,
|
|
||||||
"total_guardians": len(setup.backup_guardians) + 1,
|
|
||||||
"multisig_threshold": setup.multisig_threshold,
|
|
||||||
"emergency_contacts": setup.emergency_contacts,
|
|
||||||
"contract_status": contract.get_spending_status() if contract else None,
|
|
||||||
"pause_functional": contract is not None and len(setup.backup_guardians) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Fixed security configurations with proper guardians
|
|
||||||
def get_fixed_conservative_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
|
||||||
"""Get fixed conservative configuration with proper guardians"""
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=1000,
|
|
||||||
per_hour=5000,
|
|
||||||
per_day=20000,
|
|
||||||
per_week=100000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=5000,
|
|
||||||
delay_hours=24,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=[owner_address], # At least the owner
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_fixed_aggressive_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
|
||||||
"""Get fixed aggressive configuration with proper guardians"""
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=5000,
|
|
||||||
per_hour=25000,
|
|
||||||
per_day=100000,
|
|
||||||
per_week=500000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=20000,
|
|
||||||
delay_hours=12,
|
|
||||||
max_delay_hours=72
|
|
||||||
),
|
|
||||||
guardians=[owner_address], # At least the owner
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_fixed_high_security_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
|
||||||
"""Get fixed high security configuration with proper guardians"""
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=500,
|
|
||||||
per_hour=2000,
|
|
||||||
per_day=8000,
|
|
||||||
per_week=40000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=2000,
|
|
||||||
delay_hours=48,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=[owner_address], # At least the owner
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Global secure guardian manager
|
|
||||||
secure_guardian_manager = SecureGuardianManager()
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience function for secure agent registration
|
|
||||||
def register_agent_with_guardians(
|
|
||||||
agent_address: str,
|
|
||||||
owner_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
custom_guardians: Optional[List[str]] = None
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
|
||||||
Register an agent with proper guardian configuration
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
owner_address: Owner address
|
|
||||||
security_level: Security level
|
|
||||||
custom_guardians: Optional custom guardians
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Registration result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Create secure guardian contract
|
|
||||||
contract = secure_guardian_manager.create_secure_guardian_contract(
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
custom_guardians=custom_guardians
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get guardian summary
|
|
||||||
summary = secure_guardian_manager.get_guardian_summary(agent_address)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "registered",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": security_level,
|
|
||||||
"guardian_count": summary["total_guardians"],
|
|
||||||
"multisig_threshold": summary["multisig_threshold"],
|
|
||||||
"pause_functional": summary["pause_functional"],
|
|
||||||
"registered_at": datetime.utcnow().isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Registration failed: {str(e)}"
|
|
||||||
}
|
|
||||||
@@ -1,682 +0,0 @@
|
|||||||
"""
|
|
||||||
AITBC Guardian Contract - Spending Limit Protection for Agent Wallets
|
|
||||||
|
|
||||||
This contract implements a spending limit guardian that protects autonomous agent
|
|
||||||
wallets from unlimited spending in case of compromise. It provides:
|
|
||||||
- Per-transaction spending limits
|
|
||||||
- Per-period (daily/hourly) spending caps
|
|
||||||
- Time-lock for large withdrawals
|
|
||||||
- Emergency pause functionality
|
|
||||||
- Multi-signature recovery for critical operations
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sqlite3
|
|
||||||
from pathlib import Path
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address, keccak
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SpendingLimit:
|
|
||||||
"""Spending limit configuration"""
|
|
||||||
per_transaction: int # Maximum per transaction
|
|
||||||
per_hour: int # Maximum per hour
|
|
||||||
per_day: int # Maximum per day
|
|
||||||
per_week: int # Maximum per week
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TimeLockConfig:
|
|
||||||
"""Time lock configuration for large withdrawals"""
|
|
||||||
threshold: int # Amount that triggers time lock
|
|
||||||
delay_hours: int # Delay period in hours
|
|
||||||
max_delay_hours: int # Maximum delay period
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GuardianConfig:
|
|
||||||
"""Complete guardian configuration"""
|
|
||||||
limits: SpendingLimit
|
|
||||||
time_lock: TimeLockConfig
|
|
||||||
guardians: List[str] # Guardian addresses for recovery
|
|
||||||
pause_enabled: bool = True
|
|
||||||
emergency_mode: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class GuardianContract:
|
|
||||||
"""
|
|
||||||
Guardian contract implementation for agent wallet protection
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, agent_address: str, config: GuardianConfig, storage_path: str = None):
|
|
||||||
self.agent_address = to_checksum_address(agent_address)
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
# CRITICAL SECURITY FIX: Use persistent storage instead of in-memory
|
|
||||||
if storage_path is None:
|
|
||||||
storage_path = os.path.join(os.path.expanduser("~"), ".aitbc", "guardian_contracts")
|
|
||||||
|
|
||||||
self.storage_dir = Path(storage_path)
|
|
||||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Database file for this contract
|
|
||||||
self.db_path = self.storage_dir / f"guardian_{self.agent_address}.db"
|
|
||||||
|
|
||||||
# Initialize persistent storage
|
|
||||||
self._init_storage()
|
|
||||||
|
|
||||||
# Load state from storage
|
|
||||||
self._load_state()
|
|
||||||
|
|
||||||
# In-memory cache for performance (synced with storage)
|
|
||||||
self.spending_history: List[Dict] = []
|
|
||||||
self.pending_operations: Dict[str, Dict] = {}
|
|
||||||
self.paused = False
|
|
||||||
self.emergency_mode = False
|
|
||||||
|
|
||||||
# Contract state
|
|
||||||
self.nonce = 0
|
|
||||||
self.guardian_approvals: Dict[str, bool] = {}
|
|
||||||
|
|
||||||
# Load data from persistent storage
|
|
||||||
self._load_spending_history()
|
|
||||||
self._load_pending_operations()
|
|
||||||
|
|
||||||
def _init_storage(self):
|
|
||||||
"""Initialize SQLite database for persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute('''
|
|
||||||
CREATE TABLE IF NOT EXISTS spending_history (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
operation_id TEXT UNIQUE,
|
|
||||||
agent_address TEXT,
|
|
||||||
to_address TEXT,
|
|
||||||
amount INTEGER,
|
|
||||||
data TEXT,
|
|
||||||
timestamp TEXT,
|
|
||||||
executed_at TEXT,
|
|
||||||
status TEXT,
|
|
||||||
nonce INTEGER,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
''')
|
|
||||||
|
|
||||||
conn.execute('''
|
|
||||||
CREATE TABLE IF NOT EXISTS pending_operations (
|
|
||||||
operation_id TEXT PRIMARY KEY,
|
|
||||||
agent_address TEXT,
|
|
||||||
operation_data TEXT,
|
|
||||||
status TEXT,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
''')
|
|
||||||
|
|
||||||
conn.execute('''
|
|
||||||
CREATE TABLE IF NOT EXISTS contract_state (
|
|
||||||
agent_address TEXT PRIMARY KEY,
|
|
||||||
nonce INTEGER DEFAULT 0,
|
|
||||||
paused BOOLEAN DEFAULT 0,
|
|
||||||
emergency_mode BOOLEAN DEFAULT 0,
|
|
||||||
last_updated DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
''')
|
|
||||||
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _load_state(self):
|
|
||||||
"""Load contract state from persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
'SELECT nonce, paused, emergency_mode FROM contract_state WHERE agent_address = ?',
|
|
||||||
(self.agent_address,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
|
||||||
|
|
||||||
if row:
|
|
||||||
self.nonce, self.paused, self.emergency_mode = row
|
|
||||||
else:
|
|
||||||
# Initialize state for new contract
|
|
||||||
conn.execute(
|
|
||||||
'INSERT INTO contract_state (agent_address, nonce, paused, emergency_mode) VALUES (?, ?, ?, ?)',
|
|
||||||
(self.agent_address, 0, False, False)
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _save_state(self):
|
|
||||||
"""Save contract state to persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute(
|
|
||||||
'UPDATE contract_state SET nonce = ?, paused = ?, emergency_mode = ?, last_updated = CURRENT_TIMESTAMP WHERE agent_address = ?',
|
|
||||||
(self.nonce, self.paused, self.emergency_mode, self.agent_address)
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _load_spending_history(self):
|
|
||||||
"""Load spending history from persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
'SELECT operation_id, to_address, amount, data, timestamp, executed_at, status, nonce FROM spending_history WHERE agent_address = ? ORDER BY timestamp DESC',
|
|
||||||
(self.agent_address,)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.spending_history = []
|
|
||||||
for row in cursor:
|
|
||||||
self.spending_history.append({
|
|
||||||
"operation_id": row[0],
|
|
||||||
"to": row[1],
|
|
||||||
"amount": row[2],
|
|
||||||
"data": row[3],
|
|
||||||
"timestamp": row[4],
|
|
||||||
"executed_at": row[5],
|
|
||||||
"status": row[6],
|
|
||||||
"nonce": row[7]
|
|
||||||
})
|
|
||||||
|
|
||||||
def _save_spending_record(self, record: Dict):
|
|
||||||
"""Save spending record to persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute(
|
|
||||||
'''INSERT OR REPLACE INTO spending_history
|
|
||||||
(operation_id, agent_address, to_address, amount, data, timestamp, executed_at, status, nonce)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)''',
|
|
||||||
(
|
|
||||||
record["operation_id"],
|
|
||||||
self.agent_address,
|
|
||||||
record["to"],
|
|
||||||
record["amount"],
|
|
||||||
record.get("data", ""),
|
|
||||||
record["timestamp"],
|
|
||||||
record.get("executed_at", ""),
|
|
||||||
record["status"],
|
|
||||||
record["nonce"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _load_pending_operations(self):
|
|
||||||
"""Load pending operations from persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
'SELECT operation_id, operation_data, status FROM pending_operations WHERE agent_address = ?',
|
|
||||||
(self.agent_address,)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pending_operations = {}
|
|
||||||
for row in cursor:
|
|
||||||
operation_data = json.loads(row[1])
|
|
||||||
operation_data["status"] = row[2]
|
|
||||||
self.pending_operations[row[0]] = operation_data
|
|
||||||
|
|
||||||
def _save_pending_operation(self, operation_id: str, operation: Dict):
|
|
||||||
"""Save pending operation to persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute(
|
|
||||||
'''INSERT OR REPLACE INTO pending_operations
|
|
||||||
(operation_id, agent_address, operation_data, status, updated_at)
|
|
||||||
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)''',
|
|
||||||
(operation_id, self.agent_address, json.dumps(operation), operation["status"])
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _remove_pending_operation(self, operation_id: str):
|
|
||||||
"""Remove pending operation from persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute(
|
|
||||||
'DELETE FROM pending_operations WHERE operation_id = ? AND agent_address = ?',
|
|
||||||
(operation_id, self.agent_address)
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _get_period_key(self, timestamp: datetime, period: str) -> str:
|
|
||||||
"""Generate period key for spending tracking"""
|
|
||||||
if period == "hour":
|
|
||||||
return timestamp.strftime("%Y-%m-%d-%H")
|
|
||||||
elif period == "day":
|
|
||||||
return timestamp.strftime("%Y-%m-%d")
|
|
||||||
elif period == "week":
|
|
||||||
# Get week number (Monday as first day)
|
|
||||||
week_num = timestamp.isocalendar()[1]
|
|
||||||
return f"{timestamp.year}-W{week_num:02d}"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid period: {period}")
|
|
||||||
|
|
||||||
def _get_spent_in_period(self, period: str, timestamp: datetime = None) -> int:
|
|
||||||
"""Calculate total spent in given period"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
period_key = self._get_period_key(timestamp, period)
|
|
||||||
|
|
||||||
total = 0
|
|
||||||
for record in self.spending_history:
|
|
||||||
record_time = datetime.fromisoformat(record["timestamp"])
|
|
||||||
record_period = self._get_period_key(record_time, period)
|
|
||||||
|
|
||||||
if record_period == period_key and record["status"] == "completed":
|
|
||||||
total += record["amount"]
|
|
||||||
|
|
||||||
return total
|
|
||||||
|
|
||||||
def _check_spending_limits(self, amount: int, timestamp: datetime = None) -> Tuple[bool, str]:
|
|
||||||
"""Check if amount exceeds spending limits"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
# Check per-transaction limit
|
|
||||||
if amount > self.config.limits.per_transaction:
|
|
||||||
return False, f"Amount {amount} exceeds per-transaction limit {self.config.limits.per_transaction}"
|
|
||||||
|
|
||||||
# Check per-hour limit
|
|
||||||
spent_hour = self._get_spent_in_period("hour", timestamp)
|
|
||||||
if spent_hour + amount > self.config.limits.per_hour:
|
|
||||||
return False, f"Hourly spending {spent_hour + amount} would exceed limit {self.config.limits.per_hour}"
|
|
||||||
|
|
||||||
# Check per-day limit
|
|
||||||
spent_day = self._get_spent_in_period("day", timestamp)
|
|
||||||
if spent_day + amount > self.config.limits.per_day:
|
|
||||||
return False, f"Daily spending {spent_day + amount} would exceed limit {self.config.limits.per_day}"
|
|
||||||
|
|
||||||
# Check per-week limit
|
|
||||||
spent_week = self._get_spent_in_period("week", timestamp)
|
|
||||||
if spent_week + amount > self.config.limits.per_week:
|
|
||||||
return False, f"Weekly spending {spent_week + amount} would exceed limit {self.config.limits.per_week}"
|
|
||||||
|
|
||||||
return True, "Spending limits check passed"
|
|
||||||
|
|
||||||
def _requires_time_lock(self, amount: int) -> bool:
|
|
||||||
"""Check if amount requires time lock"""
|
|
||||||
return amount >= self.config.time_lock.threshold
|
|
||||||
|
|
||||||
def _create_operation_hash(self, operation: Dict) -> str:
|
|
||||||
"""Create hash for operation identification"""
|
|
||||||
operation_str = json.dumps(operation, sort_keys=True)
|
|
||||||
return keccak(operation_str.encode()).hex()
|
|
||||||
|
|
||||||
def initiate_transaction(self, to_address: str, amount: int, data: str = "") -> Dict:
|
|
||||||
"""
|
|
||||||
Initiate a transaction with guardian protection
|
|
||||||
|
|
||||||
Args:
|
|
||||||
to_address: Recipient address
|
|
||||||
amount: Amount to transfer
|
|
||||||
data: Transaction data (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Operation result with status and details
|
|
||||||
"""
|
|
||||||
# Check if paused
|
|
||||||
if self.paused:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Guardian contract is paused",
|
|
||||||
"operation_id": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check emergency mode
|
|
||||||
if self.emergency_mode:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Emergency mode activated",
|
|
||||||
"operation_id": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate address
|
|
||||||
try:
|
|
||||||
to_address = to_checksum_address(to_address)
|
|
||||||
except Exception:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Invalid recipient address",
|
|
||||||
"operation_id": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check spending limits
|
|
||||||
limits_ok, limits_reason = self._check_spending_limits(amount)
|
|
||||||
if not limits_ok:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": limits_reason,
|
|
||||||
"operation_id": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create operation
|
|
||||||
operation = {
|
|
||||||
"type": "transaction",
|
|
||||||
"to": to_address,
|
|
||||||
"amount": amount,
|
|
||||||
"data": data,
|
|
||||||
"timestamp": datetime.utcnow().isoformat(),
|
|
||||||
"nonce": self.nonce,
|
|
||||||
"status": "pending"
|
|
||||||
}
|
|
||||||
|
|
||||||
operation_id = self._create_operation_hash(operation)
|
|
||||||
operation["operation_id"] = operation_id
|
|
||||||
|
|
||||||
# Check if time lock is required
|
|
||||||
if self._requires_time_lock(amount):
|
|
||||||
unlock_time = datetime.utcnow() + timedelta(hours=self.config.time_lock.delay_hours)
|
|
||||||
operation["unlock_time"] = unlock_time.isoformat()
|
|
||||||
operation["status"] = "time_locked"
|
|
||||||
|
|
||||||
# Store for later execution
|
|
||||||
self.pending_operations[operation_id] = operation
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "time_locked",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"unlock_time": unlock_time.isoformat(),
|
|
||||||
"delay_hours": self.config.time_lock.delay_hours,
|
|
||||||
"message": f"Transaction requires {self.config.time_lock.delay_hours}h time lock"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Immediate execution for smaller amounts
|
|
||||||
self.pending_operations[operation_id] = operation
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "approved",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"message": "Transaction approved for execution"
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute_transaction(self, operation_id: str, signature: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Execute a previously approved transaction
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: Operation ID from initiate_transaction
|
|
||||||
signature: Transaction signature from agent
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Execution result
|
|
||||||
"""
|
|
||||||
if operation_id not in self.pending_operations:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Operation not found"
|
|
||||||
}
|
|
||||||
|
|
||||||
operation = self.pending_operations[operation_id]
|
|
||||||
|
|
||||||
# Check if operation is time locked
|
|
||||||
if operation["status"] == "time_locked":
|
|
||||||
unlock_time = datetime.fromisoformat(operation["unlock_time"])
|
|
||||||
if datetime.utcnow() < unlock_time:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Operation locked until {unlock_time.isoformat()}"
|
|
||||||
}
|
|
||||||
|
|
||||||
operation["status"] = "ready"
|
|
||||||
|
|
||||||
# Verify signature (simplified - in production, use proper verification)
|
|
||||||
try:
|
|
||||||
# In production, verify the signature matches the agent address
|
|
||||||
# For now, we'll assume signature is valid
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Invalid signature: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Record the transaction
|
|
||||||
record = {
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"to": operation["to"],
|
|
||||||
"amount": operation["amount"],
|
|
||||||
"data": operation.get("data", ""),
|
|
||||||
"timestamp": operation["timestamp"],
|
|
||||||
"executed_at": datetime.utcnow().isoformat(),
|
|
||||||
"status": "completed",
|
|
||||||
"nonce": operation["nonce"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# CRITICAL SECURITY FIX: Save to persistent storage
|
|
||||||
self._save_spending_record(record)
|
|
||||||
self.spending_history.append(record)
|
|
||||||
self.nonce += 1
|
|
||||||
self._save_state()
|
|
||||||
|
|
||||||
# Remove from pending storage
|
|
||||||
self._remove_pending_operation(operation_id)
|
|
||||||
if operation_id in self.pending_operations:
|
|
||||||
del self.pending_operations[operation_id]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "executed",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"transaction_hash": f"0x{keccak(f'{operation_id}{signature}'.encode()).hex()}",
|
|
||||||
"executed_at": record["executed_at"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def emergency_pause(self, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Emergency pause function (guardian only)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
guardian_address: Address of guardian initiating pause
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Pause result
|
|
||||||
"""
|
|
||||||
if guardian_address not in self.config.guardians:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Not authorized: guardian address not recognized"
|
|
||||||
}
|
|
||||||
|
|
||||||
self.paused = True
|
|
||||||
self.emergency_mode = True
|
|
||||||
|
|
||||||
# CRITICAL SECURITY FIX: Save state to persistent storage
|
|
||||||
self._save_state()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "paused",
|
|
||||||
"paused_at": datetime.utcnow().isoformat(),
|
|
||||||
"guardian": guardian_address,
|
|
||||||
"message": "Emergency pause activated - all operations halted"
|
|
||||||
}
|
|
||||||
|
|
||||||
def emergency_unpause(self, guardian_signatures: List[str]) -> Dict:
|
|
||||||
"""
|
|
||||||
Emergency unpause function (requires multiple guardian signatures)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
guardian_signatures: Signatures from required guardians
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Unpause result
|
|
||||||
"""
|
|
||||||
# In production, verify all guardian signatures
|
|
||||||
required_signatures = len(self.config.guardians)
|
|
||||||
if len(guardian_signatures) < required_signatures:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": f"Requires {required_signatures} guardian signatures, got {len(guardian_signatures)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Verify signatures (simplified)
|
|
||||||
# In production, verify each signature matches a guardian address
|
|
||||||
|
|
||||||
self.paused = False
|
|
||||||
self.emergency_mode = False
|
|
||||||
|
|
||||||
# CRITICAL SECURITY FIX: Save state to persistent storage
|
|
||||||
self._save_state()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "unpaused",
|
|
||||||
"unpaused_at": datetime.utcnow().isoformat(),
|
|
||||||
"message": "Emergency pause lifted - operations resumed"
|
|
||||||
}
|
|
||||||
|
|
||||||
def update_limits(self, new_limits: SpendingLimit, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Update spending limits (guardian only)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
new_limits: New spending limits
|
|
||||||
guardian_address: Address of guardian making the change
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Update result
|
|
||||||
"""
|
|
||||||
if guardian_address not in self.config.guardians:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Not authorized: guardian address not recognized"
|
|
||||||
}
|
|
||||||
|
|
||||||
old_limits = self.config.limits
|
|
||||||
self.config.limits = new_limits
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "updated",
|
|
||||||
"old_limits": old_limits,
|
|
||||||
"new_limits": new_limits,
|
|
||||||
"updated_at": datetime.utcnow().isoformat(),
|
|
||||||
"guardian": guardian_address
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_spending_status(self) -> Dict:
|
|
||||||
"""Get current spending status and limits"""
|
|
||||||
now = datetime.utcnow()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"agent_address": self.agent_address,
|
|
||||||
"current_limits": self.config.limits,
|
|
||||||
"spent": {
|
|
||||||
"current_hour": self._get_spent_in_period("hour", now),
|
|
||||||
"current_day": self._get_spent_in_period("day", now),
|
|
||||||
"current_week": self._get_spent_in_period("week", now)
|
|
||||||
},
|
|
||||||
"remaining": {
|
|
||||||
"current_hour": self.config.limits.per_hour - self._get_spent_in_period("hour", now),
|
|
||||||
"current_day": self.config.limits.per_day - self._get_spent_in_period("day", now),
|
|
||||||
"current_week": self.config.limits.per_week - self._get_spent_in_period("week", now)
|
|
||||||
},
|
|
||||||
"pending_operations": len(self.pending_operations),
|
|
||||||
"paused": self.paused,
|
|
||||||
"emergency_mode": self.emergency_mode,
|
|
||||||
"nonce": self.nonce
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_operation_history(self, limit: int = 50) -> List[Dict]:
|
|
||||||
"""Get operation history"""
|
|
||||||
return sorted(self.spending_history, key=lambda x: x["timestamp"], reverse=True)[:limit]
|
|
||||||
|
|
||||||
def get_pending_operations(self) -> List[Dict]:
|
|
||||||
"""Get all pending operations"""
|
|
||||||
return list(self.pending_operations.values())
|
|
||||||
|
|
||||||
|
|
||||||
# Factory function for creating guardian contracts
|
|
||||||
def create_guardian_contract(
|
|
||||||
agent_address: str,
|
|
||||||
per_transaction: int = 1000,
|
|
||||||
per_hour: int = 5000,
|
|
||||||
per_day: int = 20000,
|
|
||||||
per_week: int = 100000,
|
|
||||||
time_lock_threshold: int = 10000,
|
|
||||||
time_lock_delay: int = 24,
|
|
||||||
guardians: List[str] = None
|
|
||||||
) -> GuardianContract:
|
|
||||||
"""
|
|
||||||
Create a guardian contract with default security parameters
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: The agent wallet address to protect
|
|
||||||
per_transaction: Maximum amount per transaction
|
|
||||||
per_hour: Maximum amount per hour
|
|
||||||
per_day: Maximum amount per day
|
|
||||||
per_week: Maximum amount per week
|
|
||||||
time_lock_threshold: Amount that triggers time lock
|
|
||||||
time_lock_delay: Time lock delay in hours
|
|
||||||
guardians: List of guardian addresses (REQUIRED for security)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured GuardianContract instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If no guardians are provided or guardians list is insufficient
|
|
||||||
"""
|
|
||||||
# CRITICAL SECURITY FIX: Require proper guardians, never default to agent address
|
|
||||||
if guardians is None or not guardians:
|
|
||||||
raise ValueError(
|
|
||||||
"❌ CRITICAL: Guardians are required for security. "
|
|
||||||
"Provide at least 3 trusted guardian addresses different from the agent address."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate that guardians are different from agent address
|
|
||||||
agent_checksum = to_checksum_address(agent_address)
|
|
||||||
guardian_checksums = [to_checksum_address(g) for g in guardians]
|
|
||||||
|
|
||||||
if agent_checksum in guardian_checksums:
|
|
||||||
raise ValueError(
|
|
||||||
"❌ CRITICAL: Agent address cannot be used as guardian. "
|
|
||||||
"Guardians must be independent trusted addresses."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Require minimum number of guardians for security
|
|
||||||
if len(guardian_checksums) < 3:
|
|
||||||
raise ValueError(
|
|
||||||
f"❌ CRITICAL: At least 3 guardians required for security, got {len(guardian_checksums)}. "
|
|
||||||
"Consider using a multi-sig wallet or trusted service providers."
|
|
||||||
)
|
|
||||||
|
|
||||||
limits = SpendingLimit(
|
|
||||||
per_transaction=per_transaction,
|
|
||||||
per_hour=per_hour,
|
|
||||||
per_day=per_day,
|
|
||||||
per_week=per_week
|
|
||||||
)
|
|
||||||
|
|
||||||
time_lock = TimeLockConfig(
|
|
||||||
threshold=time_lock_threshold,
|
|
||||||
delay_hours=time_lock_delay,
|
|
||||||
max_delay_hours=168 # 1 week max
|
|
||||||
)
|
|
||||||
|
|
||||||
config = GuardianConfig(
|
|
||||||
limits=limits,
|
|
||||||
time_lock=time_lock,
|
|
||||||
guardians=[to_checksum_address(g) for g in guardians]
|
|
||||||
)
|
|
||||||
|
|
||||||
return GuardianContract(agent_address, config)
|
|
||||||
|
|
||||||
|
|
||||||
# Example usage and security configurations
|
|
||||||
CONSERVATIVE_CONFIG = {
|
|
||||||
"per_transaction": 100, # $100 per transaction
|
|
||||||
"per_hour": 500, # $500 per hour
|
|
||||||
"per_day": 2000, # $2,000 per day
|
|
||||||
"per_week": 10000, # $10,000 per week
|
|
||||||
"time_lock_threshold": 1000, # Time lock over $1,000
|
|
||||||
"time_lock_delay": 24 # 24 hour delay
|
|
||||||
}
|
|
||||||
|
|
||||||
AGGRESSIVE_CONFIG = {
|
|
||||||
"per_transaction": 1000, # $1,000 per transaction
|
|
||||||
"per_hour": 5000, # $5,000 per hour
|
|
||||||
"per_day": 20000, # $20,000 per day
|
|
||||||
"per_week": 100000, # $100,000 per week
|
|
||||||
"time_lock_threshold": 10000, # Time lock over $10,000
|
|
||||||
"time_lock_delay": 12 # 12 hour delay
|
|
||||||
}
|
|
||||||
|
|
||||||
HIGH_SECURITY_CONFIG = {
|
|
||||||
"per_transaction": 50, # $50 per transaction
|
|
||||||
"per_hour": 200, # $200 per hour
|
|
||||||
"per_day": 1000, # $1,000 per day
|
|
||||||
"per_week": 5000, # $5,000 per week
|
|
||||||
"time_lock_threshold": 500, # Time lock over $500
|
|
||||||
"time_lock_delay": 48 # 48 hour delay
|
|
||||||
}
|
|
||||||
@@ -1,351 +0,0 @@
|
|||||||
"""
|
|
||||||
Gas Optimization System
|
|
||||||
Optimizes gas usage and fee efficiency for smart contracts
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
class OptimizationStrategy(Enum):
|
|
||||||
BATCH_OPERATIONS = "batch_operations"
|
|
||||||
LAZY_EVALUATION = "lazy_evaluation"
|
|
||||||
STATE_COMPRESSION = "state_compression"
|
|
||||||
EVENT_FILTERING = "event_filtering"
|
|
||||||
STORAGE_OPTIMIZATION = "storage_optimization"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GasMetric:
|
|
||||||
contract_address: str
|
|
||||||
function_name: str
|
|
||||||
gas_used: int
|
|
||||||
gas_limit: int
|
|
||||||
execution_time: float
|
|
||||||
timestamp: float
|
|
||||||
optimization_applied: Optional[str]
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OptimizationResult:
|
|
||||||
strategy: OptimizationStrategy
|
|
||||||
original_gas: int
|
|
||||||
optimized_gas: int
|
|
||||||
gas_savings: int
|
|
||||||
savings_percentage: float
|
|
||||||
implementation_cost: Decimal
|
|
||||||
net_benefit: Decimal
|
|
||||||
|
|
||||||
class GasOptimizer:
|
|
||||||
"""Optimizes gas usage for smart contracts"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.gas_metrics: List[GasMetric] = []
|
|
||||||
self.optimization_results: List[OptimizationResult] = []
|
|
||||||
self.optimization_strategies = self._initialize_strategies()
|
|
||||||
|
|
||||||
# Optimization parameters
|
|
||||||
self.min_optimization_threshold = 1000 # Minimum gas to consider optimization
|
|
||||||
self.optimization_target_savings = 0.1 # 10% minimum savings
|
|
||||||
self.max_optimization_cost = Decimal('0.01') # Maximum cost per optimization
|
|
||||||
self.metric_retention_period = 86400 * 7 # 7 days
|
|
||||||
|
|
||||||
# Gas price tracking
|
|
||||||
self.gas_price_history: List[Dict] = []
|
|
||||||
self.current_gas_price = Decimal('0.001')
|
|
||||||
|
|
||||||
def _initialize_strategies(self) -> Dict[OptimizationStrategy, Dict]:
|
|
||||||
"""Initialize optimization strategies"""
|
|
||||||
return {
|
|
||||||
OptimizationStrategy.BATCH_OPERATIONS: {
|
|
||||||
'description': 'Batch multiple operations into single transaction',
|
|
||||||
'potential_savings': 0.3, # 30% potential savings
|
|
||||||
'implementation_cost': Decimal('0.005'),
|
|
||||||
'applicable_functions': ['transfer', 'approve', 'mint']
|
|
||||||
},
|
|
||||||
OptimizationStrategy.LAZY_EVALUATION: {
|
|
||||||
'description': 'Defer expensive computations until needed',
|
|
||||||
'potential_savings': 0.2, # 20% potential savings
|
|
||||||
'implementation_cost': Decimal('0.003'),
|
|
||||||
'applicable_functions': ['calculate', 'validate', 'process']
|
|
||||||
},
|
|
||||||
OptimizationStrategy.STATE_COMPRESSION: {
|
|
||||||
'description': 'Compress state data to reduce storage costs',
|
|
||||||
'potential_savings': 0.4, # 40% potential savings
|
|
||||||
'implementation_cost': Decimal('0.008'),
|
|
||||||
'applicable_functions': ['store', 'update', 'save']
|
|
||||||
},
|
|
||||||
OptimizationStrategy.EVENT_FILTERING: {
|
|
||||||
'description': 'Filter events to reduce emission costs',
|
|
||||||
'potential_savings': 0.15, # 15% potential savings
|
|
||||||
'implementation_cost': Decimal('0.002'),
|
|
||||||
'applicable_functions': ['emit', 'log', 'notify']
|
|
||||||
},
|
|
||||||
OptimizationStrategy.STORAGE_OPTIMIZATION: {
|
|
||||||
'description': 'Optimize storage patterns and data structures',
|
|
||||||
'potential_savings': 0.25, # 25% potential savings
|
|
||||||
'implementation_cost': Decimal('0.006'),
|
|
||||||
'applicable_functions': ['set', 'add', 'remove']
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async def record_gas_usage(self, contract_address: str, function_name: str,
|
|
||||||
gas_used: int, gas_limit: int, execution_time: float,
|
|
||||||
optimization_applied: Optional[str] = None):
|
|
||||||
"""Record gas usage metrics"""
|
|
||||||
metric = GasMetric(
|
|
||||||
contract_address=contract_address,
|
|
||||||
function_name=function_name,
|
|
||||||
gas_used=gas_used,
|
|
||||||
gas_limit=gas_limit,
|
|
||||||
execution_time=execution_time,
|
|
||||||
timestamp=time.time(),
|
|
||||||
optimization_applied=optimization_applied
|
|
||||||
)
|
|
||||||
|
|
||||||
self.gas_metrics.append(metric)
|
|
||||||
|
|
||||||
# Limit history size
|
|
||||||
if len(self.gas_metrics) > 10000:
|
|
||||||
self.gas_metrics = self.gas_metrics[-5000]
|
|
||||||
|
|
||||||
# Trigger optimization analysis if threshold met
|
|
||||||
if gas_used >= self.min_optimization_threshold:
|
|
||||||
asyncio.create_task(self._analyze_optimization_opportunity(metric))
|
|
||||||
|
|
||||||
async def _analyze_optimization_opportunity(self, metric: GasMetric):
|
|
||||||
"""Analyze if optimization is beneficial"""
|
|
||||||
# Get historical average for this function
|
|
||||||
historical_metrics = [
|
|
||||||
m for m in self.gas_metrics
|
|
||||||
if m.function_name == metric.function_name and
|
|
||||||
m.contract_address == metric.contract_address and
|
|
||||||
not m.optimization_applied
|
|
||||||
]
|
|
||||||
|
|
||||||
if len(historical_metrics) < 5: # Need sufficient history
|
|
||||||
return
|
|
||||||
|
|
||||||
avg_gas = sum(m.gas_used for m in historical_metrics) / len(historical_metrics)
|
|
||||||
|
|
||||||
# Test each optimization strategy
|
|
||||||
for strategy, config in self.optimization_strategies.items():
|
|
||||||
if self._is_strategy_applicable(strategy, metric.function_name):
|
|
||||||
potential_savings = avg_gas * config['potential_savings']
|
|
||||||
|
|
||||||
if potential_savings >= self.min_optimization_threshold:
|
|
||||||
# Calculate net benefit
|
|
||||||
gas_price = self.current_gas_price
|
|
||||||
gas_savings_value = potential_savings * gas_price
|
|
||||||
net_benefit = gas_savings_value - config['implementation_cost']
|
|
||||||
|
|
||||||
if net_benefit > 0:
|
|
||||||
# Create optimization result
|
|
||||||
result = OptimizationResult(
|
|
||||||
strategy=strategy,
|
|
||||||
original_gas=int(avg_gas),
|
|
||||||
optimized_gas=int(avg_gas - potential_savings),
|
|
||||||
gas_savings=int(potential_savings),
|
|
||||||
savings_percentage=config['potential_savings'],
|
|
||||||
implementation_cost=config['implementation_cost'],
|
|
||||||
net_benefit=net_benefit
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimization_results.append(result)
|
|
||||||
|
|
||||||
# Keep only recent results
|
|
||||||
if len(self.optimization_results) > 1000:
|
|
||||||
self.optimization_results = self.optimization_results[-500]
|
|
||||||
|
|
||||||
log_info(f"Optimization opportunity found: {strategy.value} for {metric.function_name} - Potential savings: {potential_savings} gas")
|
|
||||||
|
|
||||||
def _is_strategy_applicable(self, strategy: OptimizationStrategy, function_name: str) -> bool:
|
|
||||||
"""Check if optimization strategy is applicable to function"""
|
|
||||||
config = self.optimization_strategies.get(strategy, {})
|
|
||||||
applicable_functions = config.get('applicable_functions', [])
|
|
||||||
|
|
||||||
# Check if function name contains any applicable keywords
|
|
||||||
for applicable in applicable_functions:
|
|
||||||
if applicable.lower() in function_name.lower():
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def apply_optimization(self, contract_address: str, function_name: str,
|
|
||||||
strategy: OptimizationStrategy) -> Tuple[bool, str]:
|
|
||||||
"""Apply optimization strategy to contract function"""
|
|
||||||
try:
|
|
||||||
# Validate strategy
|
|
||||||
if strategy not in self.optimization_strategies:
|
|
||||||
return False, "Unknown optimization strategy"
|
|
||||||
|
|
||||||
# Check applicability
|
|
||||||
if not self._is_strategy_applicable(strategy, function_name):
|
|
||||||
return False, "Strategy not applicable to this function"
|
|
||||||
|
|
||||||
# Get optimization result
|
|
||||||
result = None
|
|
||||||
for res in self.optimization_results:
|
|
||||||
if (res.strategy == strategy and
|
|
||||||
res.strategy in self.optimization_strategies):
|
|
||||||
result = res
|
|
||||||
break
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
return False, "No optimization analysis available"
|
|
||||||
|
|
||||||
# Check if net benefit is positive
|
|
||||||
if result.net_benefit <= 0:
|
|
||||||
return False, "Optimization not cost-effective"
|
|
||||||
|
|
||||||
# Apply optimization (in real implementation, this would modify contract code)
|
|
||||||
success = await self._implement_optimization(contract_address, function_name, strategy)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
# Record optimization
|
|
||||||
await self.record_gas_usage(
|
|
||||||
contract_address, function_name, result.optimized_gas,
|
|
||||||
result.optimized_gas, 0.0, strategy.value
|
|
||||||
)
|
|
||||||
|
|
||||||
log_info(f"Optimization applied: {strategy.value} to {function_name}")
|
|
||||||
return True, f"Optimization applied successfully. Gas savings: {result.gas_savings}"
|
|
||||||
else:
|
|
||||||
return False, "Optimization implementation failed"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return False, f"Optimization error: {str(e)}"
|
|
||||||
|
|
||||||
async def _implement_optimization(self, contract_address: str, function_name: str,
|
|
||||||
strategy: OptimizationStrategy) -> bool:
|
|
||||||
"""Implement the optimization strategy"""
|
|
||||||
try:
|
|
||||||
# In real implementation, this would:
|
|
||||||
# 1. Analyze contract bytecode
|
|
||||||
# 2. Apply optimization patterns
|
|
||||||
# 3. Generate optimized bytecode
|
|
||||||
# 4. Deploy optimized version
|
|
||||||
# 5. Verify functionality
|
|
||||||
|
|
||||||
# Simulate implementation
|
|
||||||
await asyncio.sleep(2) # Simulate optimization time
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Optimization implementation error: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def update_gas_price(self, new_price: Decimal):
|
|
||||||
"""Update current gas price"""
|
|
||||||
self.current_gas_price = new_price
|
|
||||||
|
|
||||||
# Record price history
|
|
||||||
self.gas_price_history.append({
|
|
||||||
'price': float(new_price),
|
|
||||||
'timestamp': time.time()
|
|
||||||
})
|
|
||||||
|
|
||||||
# Limit history size
|
|
||||||
if len(self.gas_price_history) > 1000:
|
|
||||||
self.gas_price_history = self.gas_price_history[-500]
|
|
||||||
|
|
||||||
# Re-evaluate optimization opportunities with new price
|
|
||||||
asyncio.create_task(self._reevaluate_optimizations())
|
|
||||||
|
|
||||||
async def _reevaluate_optimizations(self):
|
|
||||||
"""Re-evaluate optimization opportunities with new gas price"""
|
|
||||||
# Clear old results and re-analyze
|
|
||||||
self.optimization_results.clear()
|
|
||||||
|
|
||||||
# Re-analyze recent metrics
|
|
||||||
recent_metrics = [
|
|
||||||
m for m in self.gas_metrics
|
|
||||||
if time.time() - m.timestamp < 3600 # Last hour
|
|
||||||
]
|
|
||||||
|
|
||||||
for metric in recent_metrics:
|
|
||||||
if metric.gas_used >= self.min_optimization_threshold:
|
|
||||||
await self._analyze_optimization_opportunity(metric)
|
|
||||||
|
|
||||||
async def get_optimization_recommendations(self, contract_address: Optional[str] = None,
|
|
||||||
limit: int = 10) -> List[Dict]:
|
|
||||||
"""Get optimization recommendations"""
|
|
||||||
recommendations = []
|
|
||||||
|
|
||||||
for result in self.optimization_results:
|
|
||||||
if contract_address and result.strategy.value not in self.optimization_strategies:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if result.net_benefit > 0:
|
|
||||||
recommendations.append({
|
|
||||||
'strategy': result.strategy.value,
|
|
||||||
'function': 'contract_function', # Would map to actual function
|
|
||||||
'original_gas': result.original_gas,
|
|
||||||
'optimized_gas': result.optimized_gas,
|
|
||||||
'gas_savings': result.gas_savings,
|
|
||||||
'savings_percentage': result.savings_percentage,
|
|
||||||
'net_benefit': float(result.net_benefit),
|
|
||||||
'implementation_cost': float(result.implementation_cost)
|
|
||||||
})
|
|
||||||
|
|
||||||
# Sort by net benefit
|
|
||||||
recommendations.sort(key=lambda x: x['net_benefit'], reverse=True)
|
|
||||||
|
|
||||||
return recommendations[:limit]
|
|
||||||
|
|
||||||
async def get_gas_statistics(self) -> Dict:
|
|
||||||
"""Get gas usage statistics"""
|
|
||||||
if not self.gas_metrics:
|
|
||||||
return {
|
|
||||||
'total_transactions': 0,
|
|
||||||
'average_gas_used': 0,
|
|
||||||
'total_gas_used': 0,
|
|
||||||
'gas_efficiency': 0,
|
|
||||||
'optimization_opportunities': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
total_transactions = len(self.gas_metrics)
|
|
||||||
total_gas_used = sum(m.gas_used for m in self.gas_metrics)
|
|
||||||
average_gas_used = total_gas_used / total_transactions
|
|
||||||
|
|
||||||
# Calculate efficiency (gas used vs gas limit)
|
|
||||||
efficiency_scores = [
|
|
||||||
m.gas_used / m.gas_limit for m in self.gas_metrics
|
|
||||||
if m.gas_limit > 0
|
|
||||||
]
|
|
||||||
avg_efficiency = sum(efficiency_scores) / len(efficiency_scores) if efficiency_scores else 0
|
|
||||||
|
|
||||||
# Optimization opportunities
|
|
||||||
optimization_count = len([
|
|
||||||
result for result in self.optimization_results
|
|
||||||
if result.net_benefit > 0
|
|
||||||
])
|
|
||||||
|
|
||||||
return {
|
|
||||||
'total_transactions': total_transactions,
|
|
||||||
'average_gas_used': average_gas_used,
|
|
||||||
'total_gas_used': total_gas_used,
|
|
||||||
'gas_efficiency': avg_efficiency,
|
|
||||||
'optimization_opportunities': optimization_count,
|
|
||||||
'current_gas_price': float(self.current_gas_price),
|
|
||||||
'total_optimizations_applied': len([
|
|
||||||
m for m in self.gas_metrics
|
|
||||||
if m.optimization_applied
|
|
||||||
])
|
|
||||||
}
|
|
||||||
|
|
||||||
# Global gas optimizer
|
|
||||||
gas_optimizer: Optional[GasOptimizer] = None
|
|
||||||
|
|
||||||
def get_gas_optimizer() -> Optional[GasOptimizer]:
|
|
||||||
"""Get global gas optimizer"""
|
|
||||||
return gas_optimizer
|
|
||||||
|
|
||||||
def create_gas_optimizer() -> GasOptimizer:
|
|
||||||
"""Create and set global gas optimizer"""
|
|
||||||
global gas_optimizer
|
|
||||||
gas_optimizer = GasOptimizer()
|
|
||||||
return gas_optimizer
|
|
||||||
@@ -1,470 +0,0 @@
|
|||||||
"""
|
|
||||||
Persistent Spending Tracker - Database-Backed Security
|
|
||||||
Fixes the critical vulnerability where spending limits were lost on restart
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from sqlalchemy import create_engine, Column, String, Integer, Float, DateTime, Index
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
from sqlalchemy.orm import sessionmaker, Session
|
|
||||||
from eth_utils import to_checksum_address
|
|
||||||
import json
|
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
class SpendingRecord(Base):
|
|
||||||
"""Database model for spending tracking"""
|
|
||||||
__tablename__ = "spending_records"
|
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
|
||||||
agent_address = Column(String, index=True)
|
|
||||||
period_type = Column(String, index=True) # hour, day, week
|
|
||||||
period_key = Column(String, index=True)
|
|
||||||
amount = Column(Float)
|
|
||||||
transaction_hash = Column(String)
|
|
||||||
timestamp = Column(DateTime, default=datetime.utcnow)
|
|
||||||
|
|
||||||
# Composite indexes for performance
|
|
||||||
__table_args__ = (
|
|
||||||
Index('idx_agent_period', 'agent_address', 'period_type', 'period_key'),
|
|
||||||
Index('idx_timestamp', 'timestamp'),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SpendingLimit(Base):
|
|
||||||
"""Database model for spending limits"""
|
|
||||||
__tablename__ = "spending_limits"
|
|
||||||
|
|
||||||
agent_address = Column(String, primary_key=True)
|
|
||||||
per_transaction = Column(Float)
|
|
||||||
per_hour = Column(Float)
|
|
||||||
per_day = Column(Float)
|
|
||||||
per_week = Column(Float)
|
|
||||||
time_lock_threshold = Column(Float)
|
|
||||||
time_lock_delay_hours = Column(Integer)
|
|
||||||
updated_at = Column(DateTime, default=datetime.utcnow)
|
|
||||||
updated_by = Column(String) # Guardian who updated
|
|
||||||
|
|
||||||
|
|
||||||
class GuardianAuthorization(Base):
|
|
||||||
"""Database model for guardian authorizations"""
|
|
||||||
__tablename__ = "guardian_authorizations"
|
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
|
||||||
agent_address = Column(String, index=True)
|
|
||||||
guardian_address = Column(String, index=True)
|
|
||||||
is_active = Column(Boolean, default=True)
|
|
||||||
added_at = Column(DateTime, default=datetime.utcnow)
|
|
||||||
added_by = Column(String)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SpendingCheckResult:
|
|
||||||
"""Result of spending limit check"""
|
|
||||||
allowed: bool
|
|
||||||
reason: str
|
|
||||||
current_spent: Dict[str, float]
|
|
||||||
remaining: Dict[str, float]
|
|
||||||
requires_time_lock: bool
|
|
||||||
time_lock_until: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
class PersistentSpendingTracker:
|
|
||||||
"""
|
|
||||||
Database-backed spending tracker that survives restarts
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, database_url: str = "sqlite:///spending_tracker.db"):
|
|
||||||
self.engine = create_engine(database_url)
|
|
||||||
Base.metadata.create_all(self.engine)
|
|
||||||
self.SessionLocal = sessionmaker(bind=self.engine)
|
|
||||||
|
|
||||||
def get_session(self) -> Session:
|
|
||||||
"""Get database session"""
|
|
||||||
return self.SessionLocal()
|
|
||||||
|
|
||||||
def _get_period_key(self, timestamp: datetime, period: str) -> str:
|
|
||||||
"""Generate period key for spending tracking"""
|
|
||||||
if period == "hour":
|
|
||||||
return timestamp.strftime("%Y-%m-%d-%H")
|
|
||||||
elif period == "day":
|
|
||||||
return timestamp.strftime("%Y-%m-%d")
|
|
||||||
elif period == "week":
|
|
||||||
# Get week number (Monday as first day)
|
|
||||||
week_num = timestamp.isocalendar()[1]
|
|
||||||
return f"{timestamp.year}-W{week_num:02d}"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid period: {period}")
|
|
||||||
|
|
||||||
def get_spent_in_period(self, agent_address: str, period: str, timestamp: datetime = None) -> float:
|
|
||||||
"""
|
|
||||||
Get total spent in given period from database
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
period: Period type (hour, day, week)
|
|
||||||
timestamp: Timestamp to check (default: now)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Total amount spent in period
|
|
||||||
"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
period_key = self._get_period_key(timestamp, period)
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
with self.get_session() as session:
|
|
||||||
total = session.query(SpendingRecord).filter(
|
|
||||||
SpendingRecord.agent_address == agent_address,
|
|
||||||
SpendingRecord.period_type == period,
|
|
||||||
SpendingRecord.period_key == period_key
|
|
||||||
).with_entities(SpendingRecord.amount).all()
|
|
||||||
|
|
||||||
return sum(record.amount for record in total)
|
|
||||||
|
|
||||||
def record_spending(self, agent_address: str, amount: float, transaction_hash: str, timestamp: datetime = None) -> bool:
|
|
||||||
"""
|
|
||||||
Record a spending transaction in the database
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
amount: Amount spent
|
|
||||||
transaction_hash: Transaction hash
|
|
||||||
timestamp: Transaction timestamp (default: now)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if recorded successfully
|
|
||||||
"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.get_session() as session:
|
|
||||||
# Record for all periods
|
|
||||||
periods = ["hour", "day", "week"]
|
|
||||||
|
|
||||||
for period in periods:
|
|
||||||
period_key = self._get_period_key(timestamp, period)
|
|
||||||
|
|
||||||
record = SpendingRecord(
|
|
||||||
id=f"{transaction_hash}_{period}",
|
|
||||||
agent_address=agent_address,
|
|
||||||
period_type=period,
|
|
||||||
period_key=period_key,
|
|
||||||
amount=amount,
|
|
||||||
transaction_hash=transaction_hash,
|
|
||||||
timestamp=timestamp
|
|
||||||
)
|
|
||||||
|
|
||||||
session.add(record)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to record spending: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def check_spending_limits(self, agent_address: str, amount: float, timestamp: datetime = None) -> SpendingCheckResult:
|
|
||||||
"""
|
|
||||||
Check if amount exceeds spending limits using persistent data
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
amount: Amount to check
|
|
||||||
timestamp: Timestamp for check (default: now)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Spending check result
|
|
||||||
"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
# Get spending limits from database
|
|
||||||
with self.get_session() as session:
|
|
||||||
limits = session.query(SpendingLimit).filter(
|
|
||||||
SpendingLimit.agent_address == agent_address
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not limits:
|
|
||||||
# Default limits if not set
|
|
||||||
limits = SpendingLimit(
|
|
||||||
agent_address=agent_address,
|
|
||||||
per_transaction=1000.0,
|
|
||||||
per_hour=5000.0,
|
|
||||||
per_day=20000.0,
|
|
||||||
per_week=100000.0,
|
|
||||||
time_lock_threshold=5000.0,
|
|
||||||
time_lock_delay_hours=24
|
|
||||||
)
|
|
||||||
session.add(limits)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Check each limit
|
|
||||||
current_spent = {}
|
|
||||||
remaining = {}
|
|
||||||
|
|
||||||
# Per-transaction limit
|
|
||||||
if amount > limits.per_transaction:
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=False,
|
|
||||||
reason=f"Amount {amount} exceeds per-transaction limit {limits.per_transaction}",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Per-hour limit
|
|
||||||
spent_hour = self.get_spent_in_period(agent_address, "hour", timestamp)
|
|
||||||
current_spent["hour"] = spent_hour
|
|
||||||
remaining["hour"] = limits.per_hour - spent_hour
|
|
||||||
|
|
||||||
if spent_hour + amount > limits.per_hour:
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=False,
|
|
||||||
reason=f"Hourly spending {spent_hour + amount} would exceed limit {limits.per_hour}",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Per-day limit
|
|
||||||
spent_day = self.get_spent_in_period(agent_address, "day", timestamp)
|
|
||||||
current_spent["day"] = spent_day
|
|
||||||
remaining["day"] = limits.per_day - spent_day
|
|
||||||
|
|
||||||
if spent_day + amount > limits.per_day:
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=False,
|
|
||||||
reason=f"Daily spending {spent_day + amount} would exceed limit {limits.per_day}",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Per-week limit
|
|
||||||
spent_week = self.get_spent_in_period(agent_address, "week", timestamp)
|
|
||||||
current_spent["week"] = spent_week
|
|
||||||
remaining["week"] = limits.per_week - spent_week
|
|
||||||
|
|
||||||
if spent_week + amount > limits.per_week:
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=False,
|
|
||||||
reason=f"Weekly spending {spent_week + amount} would exceed limit {limits.per_week}",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check time lock requirement
|
|
||||||
requires_time_lock = amount >= limits.time_lock_threshold
|
|
||||||
time_lock_until = None
|
|
||||||
|
|
||||||
if requires_time_lock:
|
|
||||||
time_lock_until = timestamp + timedelta(hours=limits.time_lock_delay_hours)
|
|
||||||
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=True,
|
|
||||||
reason="Spending limits check passed",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=requires_time_lock,
|
|
||||||
time_lock_until=time_lock_until
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_spending_limits(self, agent_address: str, new_limits: Dict, guardian_address: str) -> bool:
|
|
||||||
"""
|
|
||||||
Update spending limits for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
new_limits: New spending limits
|
|
||||||
guardian_address: Guardian making the change
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if updated successfully
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
# Verify guardian authorization
|
|
||||||
if not self.is_guardian_authorized(agent_address, guardian_address):
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.get_session() as session:
|
|
||||||
limits = session.query(SpendingLimit).filter(
|
|
||||||
SpendingLimit.agent_address == agent_address
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if limits:
|
|
||||||
limits.per_transaction = new_limits.get("per_transaction", limits.per_transaction)
|
|
||||||
limits.per_hour = new_limits.get("per_hour", limits.per_hour)
|
|
||||||
limits.per_day = new_limits.get("per_day", limits.per_day)
|
|
||||||
limits.per_week = new_limits.get("per_week", limits.per_week)
|
|
||||||
limits.time_lock_threshold = new_limits.get("time_lock_threshold", limits.time_lock_threshold)
|
|
||||||
limits.time_lock_delay_hours = new_limits.get("time_lock_delay_hours", limits.time_lock_delay_hours)
|
|
||||||
limits.updated_at = datetime.utcnow()
|
|
||||||
limits.updated_by = guardian_address
|
|
||||||
else:
|
|
||||||
limits = SpendingLimit(
|
|
||||||
agent_address=agent_address,
|
|
||||||
per_transaction=new_limits.get("per_transaction", 1000.0),
|
|
||||||
per_hour=new_limits.get("per_hour", 5000.0),
|
|
||||||
per_day=new_limits.get("per_day", 20000.0),
|
|
||||||
per_week=new_limits.get("per_week", 100000.0),
|
|
||||||
time_lock_threshold=new_limits.get("time_lock_threshold", 5000.0),
|
|
||||||
time_lock_delay_hours=new_limits.get("time_lock_delay_hours", 24),
|
|
||||||
updated_at=datetime.utcnow(),
|
|
||||||
updated_by=guardian_address
|
|
||||||
)
|
|
||||||
session.add(limits)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to update spending limits: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def add_guardian(self, agent_address: str, guardian_address: str, added_by: str) -> bool:
|
|
||||||
"""
|
|
||||||
Add a guardian for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address
|
|
||||||
added_by: Who added this guardian
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if added successfully
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
added_by = to_checksum_address(added_by)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.get_session() as session:
|
|
||||||
# Check if already exists
|
|
||||||
existing = session.query(GuardianAuthorization).filter(
|
|
||||||
GuardianAuthorization.agent_address == agent_address,
|
|
||||||
GuardianAuthorization.guardian_address == guardian_address
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
existing.is_active = True
|
|
||||||
existing.added_at = datetime.utcnow()
|
|
||||||
existing.added_by = added_by
|
|
||||||
else:
|
|
||||||
auth = GuardianAuthorization(
|
|
||||||
id=f"{agent_address}_{guardian_address}",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address,
|
|
||||||
is_active=True,
|
|
||||||
added_at=datetime.utcnow(),
|
|
||||||
added_by=added_by
|
|
||||||
)
|
|
||||||
session.add(auth)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to add guardian: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_guardian_authorized(self, agent_address: str, guardian_address: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a guardian is authorized for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if authorized
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
with self.get_session() as session:
|
|
||||||
auth = session.query(GuardianAuthorization).filter(
|
|
||||||
GuardianAuthorization.agent_address == agent_address,
|
|
||||||
GuardianAuthorization.guardian_address == guardian_address,
|
|
||||||
GuardianAuthorization.is_active == True
|
|
||||||
).first()
|
|
||||||
|
|
||||||
return auth is not None
|
|
||||||
|
|
||||||
def get_spending_summary(self, agent_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Get comprehensive spending summary for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Spending summary
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
now = datetime.utcnow()
|
|
||||||
|
|
||||||
# Get current spending
|
|
||||||
current_spent = {
|
|
||||||
"hour": self.get_spent_in_period(agent_address, "hour", now),
|
|
||||||
"day": self.get_spent_in_period(agent_address, "day", now),
|
|
||||||
"week": self.get_spent_in_period(agent_address, "week", now)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get limits
|
|
||||||
with self.get_session() as session:
|
|
||||||
limits = session.query(SpendingLimit).filter(
|
|
||||||
SpendingLimit.agent_address == agent_address
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not limits:
|
|
||||||
return {"error": "No spending limits set"}
|
|
||||||
|
|
||||||
# Calculate remaining
|
|
||||||
remaining = {
|
|
||||||
"hour": limits.per_hour - current_spent["hour"],
|
|
||||||
"day": limits.per_day - current_spent["day"],
|
|
||||||
"week": limits.per_week - current_spent["week"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get authorized guardians
|
|
||||||
with self.get_session() as session:
|
|
||||||
guardians = session.query(GuardianAuthorization).filter(
|
|
||||||
GuardianAuthorization.agent_address == agent_address,
|
|
||||||
GuardianAuthorization.is_active == True
|
|
||||||
).all()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"current_spending": current_spent,
|
|
||||||
"remaining_spending": remaining,
|
|
||||||
"limits": {
|
|
||||||
"per_transaction": limits.per_transaction,
|
|
||||||
"per_hour": limits.per_hour,
|
|
||||||
"per_day": limits.per_day,
|
|
||||||
"per_week": limits.per_week
|
|
||||||
},
|
|
||||||
"time_lock": {
|
|
||||||
"threshold": limits.time_lock_threshold,
|
|
||||||
"delay_hours": limits.time_lock_delay_hours
|
|
||||||
},
|
|
||||||
"authorized_guardians": [g.guardian_address for g in guardians],
|
|
||||||
"last_updated": limits.updated_at.isoformat() if limits.updated_at else None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Global persistent tracker instance
|
|
||||||
persistent_tracker = PersistentSpendingTracker()
|
|
||||||
@@ -1,542 +0,0 @@
|
|||||||
"""
|
|
||||||
Contract Upgrade System
|
|
||||||
Handles safe contract versioning and upgrade mechanisms
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from typing import Dict, List, Optional, Tuple, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
class UpgradeStatus(Enum):
|
|
||||||
PROPOSED = "proposed"
|
|
||||||
APPROVED = "approved"
|
|
||||||
REJECTED = "rejected"
|
|
||||||
EXECUTED = "executed"
|
|
||||||
FAILED = "failed"
|
|
||||||
ROLLED_BACK = "rolled_back"
|
|
||||||
|
|
||||||
class UpgradeType(Enum):
|
|
||||||
PARAMETER_CHANGE = "parameter_change"
|
|
||||||
LOGIC_UPDATE = "logic_update"
|
|
||||||
SECURITY_PATCH = "security_patch"
|
|
||||||
FEATURE_ADDITION = "feature_addition"
|
|
||||||
EMERGENCY_FIX = "emergency_fix"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ContractVersion:
|
|
||||||
version: str
|
|
||||||
address: str
|
|
||||||
deployed_at: float
|
|
||||||
total_contracts: int
|
|
||||||
total_value: Decimal
|
|
||||||
is_active: bool
|
|
||||||
metadata: Dict
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class UpgradeProposal:
|
|
||||||
proposal_id: str
|
|
||||||
contract_type: str
|
|
||||||
current_version: str
|
|
||||||
new_version: str
|
|
||||||
upgrade_type: UpgradeType
|
|
||||||
description: str
|
|
||||||
changes: Dict
|
|
||||||
voting_deadline: float
|
|
||||||
execution_deadline: float
|
|
||||||
status: UpgradeStatus
|
|
||||||
votes: Dict[str, bool]
|
|
||||||
total_votes: int
|
|
||||||
yes_votes: int
|
|
||||||
no_votes: int
|
|
||||||
required_approval: float
|
|
||||||
created_at: float
|
|
||||||
proposer: str
|
|
||||||
executed_at: Optional[float]
|
|
||||||
rollback_data: Optional[Dict]
|
|
||||||
|
|
||||||
class ContractUpgradeManager:
|
|
||||||
"""Manages contract upgrades and versioning"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.contract_versions: Dict[str, List[ContractVersion]] = {} # contract_type -> versions
|
|
||||||
self.active_versions: Dict[str, str] = {} # contract_type -> active version
|
|
||||||
self.upgrade_proposals: Dict[str, UpgradeProposal] = {}
|
|
||||||
self.upgrade_history: List[Dict] = []
|
|
||||||
|
|
||||||
# Upgrade parameters
|
|
||||||
self.min_voting_period = 86400 * 3 # 3 days
|
|
||||||
self.max_voting_period = 86400 * 7 # 7 days
|
|
||||||
self.required_approval_rate = 0.6 # 60% approval required
|
|
||||||
self.min_participation_rate = 0.3 # 30% minimum participation
|
|
||||||
self.emergency_upgrade_threshold = 0.8 # 80% for emergency upgrades
|
|
||||||
self.rollback_timeout = 86400 * 7 # 7 days to rollback
|
|
||||||
|
|
||||||
# Governance
|
|
||||||
self.governance_addresses: Set[str] = set()
|
|
||||||
self.stake_weights: Dict[str, Decimal] = {}
|
|
||||||
|
|
||||||
# Initialize governance
|
|
||||||
self._initialize_governance()
|
|
||||||
|
|
||||||
def _initialize_governance(self):
|
|
||||||
"""Initialize governance addresses"""
|
|
||||||
# In real implementation, this would load from blockchain state
|
|
||||||
# For now, use default governance addresses
|
|
||||||
governance_addresses = [
|
|
||||||
"0xgovernance1111111111111111111111111111111111111",
|
|
||||||
"0xgovernance2222222222222222222222222222222222222",
|
|
||||||
"0xgovernance3333333333333333333333333333333333333"
|
|
||||||
]
|
|
||||||
|
|
||||||
for address in governance_addresses:
|
|
||||||
self.governance_addresses.add(address)
|
|
||||||
self.stake_weights[address] = Decimal('1000') # Equal stake weights initially
|
|
||||||
|
|
||||||
async def propose_upgrade(self, contract_type: str, current_version: str, new_version: str,
|
|
||||||
upgrade_type: UpgradeType, description: str, changes: Dict,
|
|
||||||
proposer: str, emergency: bool = False) -> Tuple[bool, str, Optional[str]]:
|
|
||||||
"""Propose contract upgrade"""
|
|
||||||
try:
|
|
||||||
# Validate inputs
|
|
||||||
if not all([contract_type, current_version, new_version, description, changes, proposer]):
|
|
||||||
return False, "Missing required fields", None
|
|
||||||
|
|
||||||
# Check proposer authority
|
|
||||||
if proposer not in self.governance_addresses:
|
|
||||||
return False, "Proposer not authorized", None
|
|
||||||
|
|
||||||
# Check current version
|
|
||||||
active_version = self.active_versions.get(contract_type)
|
|
||||||
if active_version != current_version:
|
|
||||||
return False, f"Current version mismatch. Active: {active_version}, Proposed: {current_version}", None
|
|
||||||
|
|
||||||
# Validate new version format
|
|
||||||
if not self._validate_version_format(new_version):
|
|
||||||
return False, "Invalid version format", None
|
|
||||||
|
|
||||||
# Check for existing proposal
|
|
||||||
for proposal in self.upgrade_proposals.values():
|
|
||||||
if (proposal.contract_type == contract_type and
|
|
||||||
proposal.new_version == new_version and
|
|
||||||
proposal.status in [UpgradeStatus.PROPOSED, UpgradeStatus.APPROVED]):
|
|
||||||
return False, "Proposal for this version already exists", None
|
|
||||||
|
|
||||||
# Generate proposal ID
|
|
||||||
proposal_id = self._generate_proposal_id(contract_type, new_version)
|
|
||||||
|
|
||||||
# Set voting deadlines
|
|
||||||
current_time = time.time()
|
|
||||||
voting_period = self.min_voting_period if not emergency else self.min_voting_period // 2
|
|
||||||
voting_deadline = current_time + voting_period
|
|
||||||
execution_deadline = voting_deadline + 86400 # 1 day after voting
|
|
||||||
|
|
||||||
# Set required approval rate
|
|
||||||
required_approval = self.emergency_upgrade_threshold if emergency else self.required_approval_rate
|
|
||||||
|
|
||||||
# Create proposal
|
|
||||||
proposal = UpgradeProposal(
|
|
||||||
proposal_id=proposal_id,
|
|
||||||
contract_type=contract_type,
|
|
||||||
current_version=current_version,
|
|
||||||
new_version=new_version,
|
|
||||||
upgrade_type=upgrade_type,
|
|
||||||
description=description,
|
|
||||||
changes=changes,
|
|
||||||
voting_deadline=voting_deadline,
|
|
||||||
execution_deadline=execution_deadline,
|
|
||||||
status=UpgradeStatus.PROPOSED,
|
|
||||||
votes={},
|
|
||||||
total_votes=0,
|
|
||||||
yes_votes=0,
|
|
||||||
no_votes=0,
|
|
||||||
required_approval=required_approval,
|
|
||||||
created_at=current_time,
|
|
||||||
proposer=proposer,
|
|
||||||
executed_at=None,
|
|
||||||
rollback_data=None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.upgrade_proposals[proposal_id] = proposal
|
|
||||||
|
|
||||||
# Start voting process
|
|
||||||
asyncio.create_task(self._manage_voting_process(proposal_id))
|
|
||||||
|
|
||||||
log_info(f"Upgrade proposal created: {proposal_id} - {contract_type} {current_version} -> {new_version}")
|
|
||||||
return True, "Upgrade proposal created successfully", proposal_id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return False, f"Failed to create proposal: {str(e)}", None
|
|
||||||
|
|
||||||
def _validate_version_format(self, version: str) -> bool:
|
|
||||||
"""Validate semantic version format"""
|
|
||||||
try:
|
|
||||||
parts = version.split('.')
|
|
||||||
if len(parts) != 3:
|
|
||||||
return False
|
|
||||||
|
|
||||||
major, minor, patch = parts
|
|
||||||
int(major) and int(minor) and int(patch)
|
|
||||||
return True
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _generate_proposal_id(self, contract_type: str, new_version: str) -> str:
|
|
||||||
"""Generate unique proposal ID"""
|
|
||||||
import hashlib
|
|
||||||
content = f"{contract_type}:{new_version}:{time.time()}"
|
|
||||||
return hashlib.sha256(content.encode()).hexdigest()[:12]
|
|
||||||
|
|
||||||
async def _manage_voting_process(self, proposal_id: str):
|
|
||||||
"""Manage voting process for proposal"""
|
|
||||||
proposal = self.upgrade_proposals.get(proposal_id)
|
|
||||||
if not proposal:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Wait for voting deadline
|
|
||||||
await asyncio.sleep(proposal.voting_deadline - time.time())
|
|
||||||
|
|
||||||
# Check voting results
|
|
||||||
await self._finalize_voting(proposal_id)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Error in voting process for {proposal_id}: {e}")
|
|
||||||
proposal.status = UpgradeStatus.FAILED
|
|
||||||
|
|
||||||
async def _finalize_voting(self, proposal_id: str):
|
|
||||||
"""Finalize voting and determine outcome"""
|
|
||||||
proposal = self.upgrade_proposals[proposal_id]
|
|
||||||
|
|
||||||
# Calculate voting results
|
|
||||||
total_stake = sum(self.stake_weights.get(voter, Decimal('0')) for voter in proposal.votes.keys())
|
|
||||||
yes_stake = sum(self.stake_weights.get(voter, Decimal('0')) for voter, vote in proposal.votes.items() if vote)
|
|
||||||
|
|
||||||
# Check minimum participation
|
|
||||||
total_governance_stake = sum(self.stake_weights.values())
|
|
||||||
participation_rate = float(total_stake / total_governance_stake) if total_governance_stake > 0 else 0
|
|
||||||
|
|
||||||
if participation_rate < self.min_participation_rate:
|
|
||||||
proposal.status = UpgradeStatus.REJECTED
|
|
||||||
log_info(f"Proposal {proposal_id} rejected due to low participation: {participation_rate:.2%}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check approval rate
|
|
||||||
approval_rate = float(yes_stake / total_stake) if total_stake > 0 else 0
|
|
||||||
|
|
||||||
if approval_rate >= proposal.required_approval:
|
|
||||||
proposal.status = UpgradeStatus.APPROVED
|
|
||||||
log_info(f"Proposal {proposal_id} approved with {approval_rate:.2%} approval")
|
|
||||||
|
|
||||||
# Schedule execution
|
|
||||||
asyncio.create_task(self._execute_upgrade(proposal_id))
|
|
||||||
else:
|
|
||||||
proposal.status = UpgradeStatus.REJECTED
|
|
||||||
log_info(f"Proposal {proposal_id} rejected with {approval_rate:.2%} approval")
|
|
||||||
|
|
||||||
async def vote_on_proposal(self, proposal_id: str, voter_address: str, vote: bool) -> Tuple[bool, str]:
|
|
||||||
"""Cast vote on upgrade proposal"""
|
|
||||||
proposal = self.upgrade_proposals.get(proposal_id)
|
|
||||||
if not proposal:
|
|
||||||
return False, "Proposal not found"
|
|
||||||
|
|
||||||
# Check voting authority
|
|
||||||
if voter_address not in self.governance_addresses:
|
|
||||||
return False, "Not authorized to vote"
|
|
||||||
|
|
||||||
# Check voting period
|
|
||||||
if time.time() > proposal.voting_deadline:
|
|
||||||
return False, "Voting period has ended"
|
|
||||||
|
|
||||||
# Check if already voted
|
|
||||||
if voter_address in proposal.votes:
|
|
||||||
return False, "Already voted"
|
|
||||||
|
|
||||||
# Cast vote
|
|
||||||
proposal.votes[voter_address] = vote
|
|
||||||
proposal.total_votes += 1
|
|
||||||
|
|
||||||
if vote:
|
|
||||||
proposal.yes_votes += 1
|
|
||||||
else:
|
|
||||||
proposal.no_votes += 1
|
|
||||||
|
|
||||||
log_info(f"Vote cast on proposal {proposal_id} by {voter_address}: {'YES' if vote else 'NO'}")
|
|
||||||
return True, "Vote cast successfully"
|
|
||||||
|
|
||||||
async def _execute_upgrade(self, proposal_id: str):
|
|
||||||
"""Execute approved upgrade"""
|
|
||||||
proposal = self.upgrade_proposals[proposal_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Wait for execution deadline
|
|
||||||
await asyncio.sleep(proposal.execution_deadline - time.time())
|
|
||||||
|
|
||||||
# Check if still approved
|
|
||||||
if proposal.status != UpgradeStatus.APPROVED:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Prepare rollback data
|
|
||||||
rollback_data = await self._prepare_rollback_data(proposal)
|
|
||||||
|
|
||||||
# Execute upgrade
|
|
||||||
success = await self._perform_upgrade(proposal)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
proposal.status = UpgradeStatus.EXECUTED
|
|
||||||
proposal.executed_at = time.time()
|
|
||||||
proposal.rollback_data = rollback_data
|
|
||||||
|
|
||||||
# Update active version
|
|
||||||
self.active_versions[proposal.contract_type] = proposal.new_version
|
|
||||||
|
|
||||||
# Record in history
|
|
||||||
self.upgrade_history.append({
|
|
||||||
'proposal_id': proposal_id,
|
|
||||||
'contract_type': proposal.contract_type,
|
|
||||||
'from_version': proposal.current_version,
|
|
||||||
'to_version': proposal.new_version,
|
|
||||||
'executed_at': proposal.executed_at,
|
|
||||||
'upgrade_type': proposal.upgrade_type.value
|
|
||||||
})
|
|
||||||
|
|
||||||
log_info(f"Upgrade executed: {proposal_id} - {proposal.contract_type} {proposal.current_version} -> {proposal.new_version}")
|
|
||||||
|
|
||||||
# Start rollback window
|
|
||||||
asyncio.create_task(self._manage_rollback_window(proposal_id))
|
|
||||||
else:
|
|
||||||
proposal.status = UpgradeStatus.FAILED
|
|
||||||
log_error(f"Upgrade execution failed: {proposal_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
proposal.status = UpgradeStatus.FAILED
|
|
||||||
log_error(f"Error executing upgrade {proposal_id}: {e}")
|
|
||||||
|
|
||||||
async def _prepare_rollback_data(self, proposal: UpgradeProposal) -> Dict:
|
|
||||||
"""Prepare data for potential rollback"""
|
|
||||||
return {
|
|
||||||
'previous_version': proposal.current_version,
|
|
||||||
'contract_state': {}, # Would capture current contract state
|
|
||||||
'migration_data': {}, # Would store migration data
|
|
||||||
'timestamp': time.time()
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _perform_upgrade(self, proposal: UpgradeProposal) -> bool:
|
|
||||||
"""Perform the actual upgrade"""
|
|
||||||
try:
|
|
||||||
# In real implementation, this would:
|
|
||||||
# 1. Deploy new contract version
|
|
||||||
# 2. Migrate state from old contract
|
|
||||||
# 3. Update contract references
|
|
||||||
# 4. Verify upgrade integrity
|
|
||||||
|
|
||||||
# Simulate upgrade process
|
|
||||||
await asyncio.sleep(10) # Simulate upgrade time
|
|
||||||
|
|
||||||
# Create new version record
|
|
||||||
new_version = ContractVersion(
|
|
||||||
version=proposal.new_version,
|
|
||||||
address=f"0x{proposal.contract_type}_{proposal.new_version}", # New address
|
|
||||||
deployed_at=time.time(),
|
|
||||||
total_contracts=0,
|
|
||||||
total_value=Decimal('0'),
|
|
||||||
is_active=True,
|
|
||||||
metadata={
|
|
||||||
'upgrade_type': proposal.upgrade_type.value,
|
|
||||||
'proposal_id': proposal.proposal_id,
|
|
||||||
'changes': proposal.changes
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add to version history
|
|
||||||
if proposal.contract_type not in self.contract_versions:
|
|
||||||
self.contract_versions[proposal.contract_type] = []
|
|
||||||
|
|
||||||
# Deactivate old version
|
|
||||||
for version in self.contract_versions[proposal.contract_type]:
|
|
||||||
if version.version == proposal.current_version:
|
|
||||||
version.is_active = False
|
|
||||||
break
|
|
||||||
|
|
||||||
# Add new version
|
|
||||||
self.contract_versions[proposal.contract_type].append(new_version)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Upgrade execution error: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _manage_rollback_window(self, proposal_id: str):
|
|
||||||
"""Manage rollback window after upgrade"""
|
|
||||||
proposal = self.upgrade_proposals[proposal_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Wait for rollback timeout
|
|
||||||
await asyncio.sleep(self.rollback_timeout)
|
|
||||||
|
|
||||||
# Check if rollback was requested
|
|
||||||
if proposal.status == UpgradeStatus.EXECUTED:
|
|
||||||
# No rollback requested, finalize upgrade
|
|
||||||
await self._finalize_upgrade(proposal_id)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Error in rollback window for {proposal_id}: {e}")
|
|
||||||
|
|
||||||
async def _finalize_upgrade(self, proposal_id: str):
|
|
||||||
"""Finalize upgrade after rollback window"""
|
|
||||||
proposal = self.upgrade_proposals[proposal_id]
|
|
||||||
|
|
||||||
# Clear rollback data to save space
|
|
||||||
proposal.rollback_data = None
|
|
||||||
|
|
||||||
log_info(f"Upgrade finalized: {proposal_id}")
|
|
||||||
|
|
||||||
async def rollback_upgrade(self, proposal_id: str, reason: str) -> Tuple[bool, str]:
|
|
||||||
"""Rollback upgrade to previous version"""
|
|
||||||
proposal = self.upgrade_proposals.get(proposal_id)
|
|
||||||
if not proposal:
|
|
||||||
return False, "Proposal not found"
|
|
||||||
|
|
||||||
if proposal.status != UpgradeStatus.EXECUTED:
|
|
||||||
return False, "Can only rollback executed upgrades"
|
|
||||||
|
|
||||||
if not proposal.rollback_data:
|
|
||||||
return False, "Rollback data not available"
|
|
||||||
|
|
||||||
# Check rollback window
|
|
||||||
if time.time() - proposal.executed_at > self.rollback_timeout:
|
|
||||||
return False, "Rollback window has expired"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Perform rollback
|
|
||||||
success = await self._perform_rollback(proposal)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
proposal.status = UpgradeStatus.ROLLED_BACK
|
|
||||||
|
|
||||||
# Restore previous version
|
|
||||||
self.active_versions[proposal.contract_type] = proposal.current_version
|
|
||||||
|
|
||||||
# Update version records
|
|
||||||
for version in self.contract_versions[proposal.contract_type]:
|
|
||||||
if version.version == proposal.new_version:
|
|
||||||
version.is_active = False
|
|
||||||
elif version.version == proposal.current_version:
|
|
||||||
version.is_active = True
|
|
||||||
|
|
||||||
log_info(f"Upgrade rolled back: {proposal_id} - Reason: {reason}")
|
|
||||||
return True, "Rollback successful"
|
|
||||||
else:
|
|
||||||
return False, "Rollback execution failed"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Rollback error for {proposal_id}: {e}")
|
|
||||||
return False, f"Rollback failed: {str(e)}"
|
|
||||||
|
|
||||||
async def _perform_rollback(self, proposal: UpgradeProposal) -> bool:
|
|
||||||
"""Perform the actual rollback"""
|
|
||||||
try:
|
|
||||||
# In real implementation, this would:
|
|
||||||
# 1. Restore previous contract state
|
|
||||||
# 2. Update contract references back
|
|
||||||
# 3. Verify rollback integrity
|
|
||||||
|
|
||||||
# Simulate rollback process
|
|
||||||
await asyncio.sleep(5) # Simulate rollback time
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Rollback execution error: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def get_proposal(self, proposal_id: str) -> Optional[UpgradeProposal]:
|
|
||||||
"""Get upgrade proposal"""
|
|
||||||
return self.upgrade_proposals.get(proposal_id)
|
|
||||||
|
|
||||||
async def get_proposals_by_status(self, status: UpgradeStatus) -> List[UpgradeProposal]:
|
|
||||||
"""Get proposals by status"""
|
|
||||||
return [
|
|
||||||
proposal for proposal in self.upgrade_proposals.values()
|
|
||||||
if proposal.status == status
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_contract_versions(self, contract_type: str) -> List[ContractVersion]:
|
|
||||||
"""Get all versions for a contract type"""
|
|
||||||
return self.contract_versions.get(contract_type, [])
|
|
||||||
|
|
||||||
async def get_active_version(self, contract_type: str) -> Optional[str]:
|
|
||||||
"""Get active version for contract type"""
|
|
||||||
return self.active_versions.get(contract_type)
|
|
||||||
|
|
||||||
async def get_upgrade_statistics(self) -> Dict:
|
|
||||||
"""Get upgrade system statistics"""
|
|
||||||
total_proposals = len(self.upgrade_proposals)
|
|
||||||
|
|
||||||
if total_proposals == 0:
|
|
||||||
return {
|
|
||||||
'total_proposals': 0,
|
|
||||||
'status_distribution': {},
|
|
||||||
'upgrade_types': {},
|
|
||||||
'average_execution_time': 0,
|
|
||||||
'success_rate': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
# Status distribution
|
|
||||||
status_counts = {}
|
|
||||||
for proposal in self.upgrade_proposals.values():
|
|
||||||
status = proposal.status.value
|
|
||||||
status_counts[status] = status_counts.get(status, 0) + 1
|
|
||||||
|
|
||||||
# Upgrade type distribution
|
|
||||||
type_counts = {}
|
|
||||||
for proposal in self.upgrade_proposals.values():
|
|
||||||
up_type = proposal.upgrade_type.value
|
|
||||||
type_counts[up_type] = type_counts.get(up_type, 0) + 1
|
|
||||||
|
|
||||||
# Execution statistics
|
|
||||||
executed_proposals = [
|
|
||||||
proposal for proposal in self.upgrade_proposals.values()
|
|
||||||
if proposal.status == UpgradeStatus.EXECUTED
|
|
||||||
]
|
|
||||||
|
|
||||||
if executed_proposals:
|
|
||||||
execution_times = [
|
|
||||||
proposal.executed_at - proposal.created_at
|
|
||||||
for proposal in executed_proposals
|
|
||||||
if proposal.executed_at
|
|
||||||
]
|
|
||||||
avg_execution_time = sum(execution_times) / len(execution_times) if execution_times else 0
|
|
||||||
else:
|
|
||||||
avg_execution_time = 0
|
|
||||||
|
|
||||||
# Success rate
|
|
||||||
successful_upgrades = len(executed_proposals)
|
|
||||||
success_rate = successful_upgrades / total_proposals if total_proposals > 0 else 0
|
|
||||||
|
|
||||||
return {
|
|
||||||
'total_proposals': total_proposals,
|
|
||||||
'status_distribution': status_counts,
|
|
||||||
'upgrade_types': type_counts,
|
|
||||||
'average_execution_time': avg_execution_time,
|
|
||||||
'success_rate': success_rate,
|
|
||||||
'total_governance_addresses': len(self.governance_addresses),
|
|
||||||
'contract_types': len(self.contract_versions)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Global upgrade manager
|
|
||||||
upgrade_manager: Optional[ContractUpgradeManager] = None
|
|
||||||
|
|
||||||
def get_upgrade_manager() -> Optional[ContractUpgradeManager]:
|
|
||||||
"""Get global upgrade manager"""
|
|
||||||
return upgrade_manager
|
|
||||||
|
|
||||||
def create_upgrade_manager() -> ContractUpgradeManager:
|
|
||||||
"""Create and set global upgrade manager"""
|
|
||||||
global upgrade_manager
|
|
||||||
upgrade_manager = ContractUpgradeManager()
|
|
||||||
return upgrade_manager
|
|
||||||
@@ -1,519 +0,0 @@
|
|||||||
"""
|
|
||||||
AITBC Agent Messaging Contract Implementation
|
|
||||||
|
|
||||||
This module implements on-chain messaging functionality for agents,
|
|
||||||
enabling forum-like communication between autonomous agents.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Any
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from enum import Enum
|
|
||||||
import json
|
|
||||||
import hashlib
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address
|
|
||||||
|
|
||||||
class MessageType(Enum):
|
|
||||||
"""Types of messages agents can send"""
|
|
||||||
POST = "post"
|
|
||||||
REPLY = "reply"
|
|
||||||
ANNOUNCEMENT = "announcement"
|
|
||||||
QUESTION = "question"
|
|
||||||
ANSWER = "answer"
|
|
||||||
MODERATION = "moderation"
|
|
||||||
|
|
||||||
class MessageStatus(Enum):
|
|
||||||
"""Status of messages in the forum"""
|
|
||||||
ACTIVE = "active"
|
|
||||||
HIDDEN = "hidden"
|
|
||||||
DELETED = "deleted"
|
|
||||||
PINNED = "pinned"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Message:
|
|
||||||
"""Represents a message in the agent forum"""
|
|
||||||
message_id: str
|
|
||||||
agent_id: str
|
|
||||||
agent_address: str
|
|
||||||
topic: str
|
|
||||||
content: str
|
|
||||||
message_type: MessageType
|
|
||||||
timestamp: datetime
|
|
||||||
parent_message_id: Optional[str] = None
|
|
||||||
reply_count: int = 0
|
|
||||||
upvotes: int = 0
|
|
||||||
downvotes: int = 0
|
|
||||||
status: MessageStatus = MessageStatus.ACTIVE
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Topic:
|
|
||||||
"""Represents a forum topic"""
|
|
||||||
topic_id: str
|
|
||||||
title: str
|
|
||||||
description: str
|
|
||||||
creator_agent_id: str
|
|
||||||
created_at: datetime
|
|
||||||
message_count: int = 0
|
|
||||||
last_activity: datetime = field(default_factory=datetime.now)
|
|
||||||
tags: List[str] = field(default_factory=list)
|
|
||||||
is_pinned: bool = False
|
|
||||||
is_locked: bool = False
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentReputation:
|
|
||||||
"""Reputation system for agents"""
|
|
||||||
agent_id: str
|
|
||||||
message_count: int = 0
|
|
||||||
upvotes_received: int = 0
|
|
||||||
downvotes_received: int = 0
|
|
||||||
reputation_score: float = 0.0
|
|
||||||
trust_level: int = 1 # 1-5 trust levels
|
|
||||||
is_moderator: bool = False
|
|
||||||
is_banned: bool = False
|
|
||||||
ban_reason: Optional[str] = None
|
|
||||||
ban_expires: Optional[datetime] = None
|
|
||||||
|
|
||||||
class AgentMessagingContract:
|
|
||||||
"""Main contract for agent messaging functionality"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.messages: Dict[str, Message] = {}
|
|
||||||
self.topics: Dict[str, Topic] = {}
|
|
||||||
self.agent_reputations: Dict[str, AgentReputation] = {}
|
|
||||||
self.moderation_log: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
def create_topic(self, agent_id: str, agent_address: str, title: str,
|
|
||||||
description: str, tags: List[str] = None) -> Dict[str, Any]:
|
|
||||||
"""Create a new forum topic"""
|
|
||||||
|
|
||||||
# Check if agent is banned
|
|
||||||
if self._is_agent_banned(agent_id):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Agent is banned from posting",
|
|
||||||
"error_code": "AGENT_BANNED"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate topic ID
|
|
||||||
topic_id = f"topic_{hashlib.sha256(f'{agent_id}_{title}_{datetime.now()}'.encode()).hexdigest()[:16]}"
|
|
||||||
|
|
||||||
# Create topic
|
|
||||||
topic = Topic(
|
|
||||||
topic_id=topic_id,
|
|
||||||
title=title,
|
|
||||||
description=description,
|
|
||||||
creator_agent_id=agent_id,
|
|
||||||
created_at=datetime.now(),
|
|
||||||
tags=tags or []
|
|
||||||
)
|
|
||||||
|
|
||||||
self.topics[topic_id] = topic
|
|
||||||
|
|
||||||
# Update agent reputation
|
|
||||||
self._update_agent_reputation(agent_id, message_count=1)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"topic_id": topic_id,
|
|
||||||
"topic": self._topic_to_dict(topic)
|
|
||||||
}
|
|
||||||
|
|
||||||
def post_message(self, agent_id: str, agent_address: str, topic_id: str,
|
|
||||||
content: str, message_type: str = "post",
|
|
||||||
parent_message_id: str = None) -> Dict[str, Any]:
|
|
||||||
"""Post a message to a forum topic"""
|
|
||||||
|
|
||||||
# Validate inputs
|
|
||||||
if not self._validate_agent(agent_id, agent_address):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid agent credentials",
|
|
||||||
"error_code": "INVALID_AGENT"
|
|
||||||
}
|
|
||||||
|
|
||||||
if self._is_agent_banned(agent_id):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Agent is banned from posting",
|
|
||||||
"error_code": "AGENT_BANNED"
|
|
||||||
}
|
|
||||||
|
|
||||||
if topic_id not in self.topics:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Topic not found",
|
|
||||||
"error_code": "TOPIC_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.topics[topic_id].is_locked:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Topic is locked",
|
|
||||||
"error_code": "TOPIC_LOCKED"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate message type
|
|
||||||
try:
|
|
||||||
msg_type = MessageType(message_type)
|
|
||||||
except ValueError:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid message type",
|
|
||||||
"error_code": "INVALID_MESSAGE_TYPE"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate message ID
|
|
||||||
message_id = f"msg_{hashlib.sha256(f'{agent_id}_{topic_id}_{content}_{datetime.now()}'.encode()).hexdigest()[:16]}"
|
|
||||||
|
|
||||||
# Create message
|
|
||||||
message = Message(
|
|
||||||
message_id=message_id,
|
|
||||||
agent_id=agent_id,
|
|
||||||
agent_address=agent_address,
|
|
||||||
topic=topic_id,
|
|
||||||
content=content,
|
|
||||||
message_type=msg_type,
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
parent_message_id=parent_message_id
|
|
||||||
)
|
|
||||||
|
|
||||||
self.messages[message_id] = message
|
|
||||||
|
|
||||||
# Update topic
|
|
||||||
self.topics[topic_id].message_count += 1
|
|
||||||
self.topics[topic_id].last_activity = datetime.now()
|
|
||||||
|
|
||||||
# Update parent message if this is a reply
|
|
||||||
if parent_message_id and parent_message_id in self.messages:
|
|
||||||
self.messages[parent_message_id].reply_count += 1
|
|
||||||
|
|
||||||
# Update agent reputation
|
|
||||||
self._update_agent_reputation(agent_id, message_count=1)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message_id": message_id,
|
|
||||||
"message": self._message_to_dict(message)
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_messages(self, topic_id: str, limit: int = 50, offset: int = 0,
|
|
||||||
sort_by: str = "timestamp") -> Dict[str, Any]:
|
|
||||||
"""Get messages from a topic"""
|
|
||||||
|
|
||||||
if topic_id not in self.topics:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Topic not found",
|
|
||||||
"error_code": "TOPIC_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get all messages for this topic
|
|
||||||
topic_messages = [
|
|
||||||
msg for msg in self.messages.values()
|
|
||||||
if msg.topic == topic_id and msg.status == MessageStatus.ACTIVE
|
|
||||||
]
|
|
||||||
|
|
||||||
# Sort messages
|
|
||||||
if sort_by == "timestamp":
|
|
||||||
topic_messages.sort(key=lambda x: x.timestamp, reverse=True)
|
|
||||||
elif sort_by == "upvotes":
|
|
||||||
topic_messages.sort(key=lambda x: x.upvotes, reverse=True)
|
|
||||||
elif sort_by == "replies":
|
|
||||||
topic_messages.sort(key=lambda x: x.reply_count, reverse=True)
|
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
total_messages = len(topic_messages)
|
|
||||||
paginated_messages = topic_messages[offset:offset + limit]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"messages": [self._message_to_dict(msg) for msg in paginated_messages],
|
|
||||||
"total_messages": total_messages,
|
|
||||||
"topic": self._topic_to_dict(self.topics[topic_id])
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_topics(self, limit: int = 50, offset: int = 0,
|
|
||||||
sort_by: str = "last_activity") -> Dict[str, Any]:
|
|
||||||
"""Get list of forum topics"""
|
|
||||||
|
|
||||||
# Sort topics
|
|
||||||
topic_list = list(self.topics.values())
|
|
||||||
|
|
||||||
if sort_by == "last_activity":
|
|
||||||
topic_list.sort(key=lambda x: x.last_activity, reverse=True)
|
|
||||||
elif sort_by == "created_at":
|
|
||||||
topic_list.sort(key=lambda x: x.created_at, reverse=True)
|
|
||||||
elif sort_by == "message_count":
|
|
||||||
topic_list.sort(key=lambda x: x.message_count, reverse=True)
|
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
total_topics = len(topic_list)
|
|
||||||
paginated_topics = topic_list[offset:offset + limit]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"topics": [self._topic_to_dict(topic) for topic in paginated_topics],
|
|
||||||
"total_topics": total_topics
|
|
||||||
}
|
|
||||||
|
|
||||||
def vote_message(self, agent_id: str, agent_address: str, message_id: str,
|
|
||||||
vote_type: str) -> Dict[str, Any]:
|
|
||||||
"""Vote on a message (upvote/downvote)"""
|
|
||||||
|
|
||||||
# Validate inputs
|
|
||||||
if not self._validate_agent(agent_id, agent_address):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid agent credentials",
|
|
||||||
"error_code": "INVALID_AGENT"
|
|
||||||
}
|
|
||||||
|
|
||||||
if message_id not in self.messages:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Message not found",
|
|
||||||
"error_code": "MESSAGE_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
if vote_type not in ["upvote", "downvote"]:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid vote type",
|
|
||||||
"error_code": "INVALID_VOTE_TYPE"
|
|
||||||
}
|
|
||||||
|
|
||||||
message = self.messages[message_id]
|
|
||||||
|
|
||||||
# Update vote counts
|
|
||||||
if vote_type == "upvote":
|
|
||||||
message.upvotes += 1
|
|
||||||
else:
|
|
||||||
message.downvotes += 1
|
|
||||||
|
|
||||||
# Update message author reputation
|
|
||||||
self._update_agent_reputation(
|
|
||||||
message.agent_id,
|
|
||||||
upvotes_received=message.upvotes,
|
|
||||||
downvotes_received=message.downvotes
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message_id": message_id,
|
|
||||||
"upvotes": message.upvotes,
|
|
||||||
"downvotes": message.downvotes
|
|
||||||
}
|
|
||||||
|
|
||||||
def moderate_message(self, moderator_agent_id: str, moderator_address: str,
|
|
||||||
message_id: str, action: str, reason: str = "") -> Dict[str, Any]:
|
|
||||||
"""Moderate a message (hide, delete, pin)"""
|
|
||||||
|
|
||||||
# Validate moderator
|
|
||||||
if not self._is_moderator(moderator_agent_id):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Insufficient permissions",
|
|
||||||
"error_code": "INSUFFICIENT_PERMISSIONS"
|
|
||||||
}
|
|
||||||
|
|
||||||
if message_id not in self.messages:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Message not found",
|
|
||||||
"error_code": "MESSAGE_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
message = self.messages[message_id]
|
|
||||||
|
|
||||||
# Apply moderation action
|
|
||||||
if action == "hide":
|
|
||||||
message.status = MessageStatus.HIDDEN
|
|
||||||
elif action == "delete":
|
|
||||||
message.status = MessageStatus.DELETED
|
|
||||||
elif action == "pin":
|
|
||||||
message.status = MessageStatus.PINNED
|
|
||||||
elif action == "unpin":
|
|
||||||
message.status = MessageStatus.ACTIVE
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid moderation action",
|
|
||||||
"error_code": "INVALID_ACTION"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log moderation action
|
|
||||||
self.moderation_log.append({
|
|
||||||
"timestamp": datetime.now(),
|
|
||||||
"moderator_agent_id": moderator_agent_id,
|
|
||||||
"message_id": message_id,
|
|
||||||
"action": action,
|
|
||||||
"reason": reason
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message_id": message_id,
|
|
||||||
"status": message.status.value
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_agent_reputation(self, agent_id: str) -> Dict[str, Any]:
|
|
||||||
"""Get an agent's reputation information"""
|
|
||||||
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Agent not found",
|
|
||||||
"error_code": "AGENT_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
reputation = self.agent_reputations[agent_id]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"agent_id": agent_id,
|
|
||||||
"reputation": self._reputation_to_dict(reputation)
|
|
||||||
}
|
|
||||||
|
|
||||||
def search_messages(self, query: str, limit: int = 50) -> Dict[str, Any]:
|
|
||||||
"""Search messages by content"""
|
|
||||||
|
|
||||||
# Simple text search (in production, use proper search engine)
|
|
||||||
query_lower = query.lower()
|
|
||||||
matching_messages = []
|
|
||||||
|
|
||||||
for message in self.messages.values():
|
|
||||||
if (message.status == MessageStatus.ACTIVE and
|
|
||||||
query_lower in message.content.lower()):
|
|
||||||
matching_messages.append(message)
|
|
||||||
|
|
||||||
# Sort by timestamp (most recent first)
|
|
||||||
matching_messages.sort(key=lambda x: x.timestamp, reverse=True)
|
|
||||||
|
|
||||||
# Limit results
|
|
||||||
limited_messages = matching_messages[:limit]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"query": query,
|
|
||||||
"messages": [self._message_to_dict(msg) for msg in limited_messages],
|
|
||||||
"total_matches": len(matching_messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
def _validate_agent(self, agent_id: str, agent_address: str) -> bool:
|
|
||||||
"""Validate agent credentials"""
|
|
||||||
# In a real implementation, this would verify the agent's signature
|
|
||||||
# For now, we'll do basic validation
|
|
||||||
return bool(agent_id and agent_address)
|
|
||||||
|
|
||||||
def _is_agent_banned(self, agent_id: str) -> bool:
|
|
||||||
"""Check if an agent is banned"""
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
return False
|
|
||||||
|
|
||||||
reputation = self.agent_reputations[agent_id]
|
|
||||||
|
|
||||||
if reputation.is_banned:
|
|
||||||
# Check if ban has expired
|
|
||||||
if reputation.ban_expires and datetime.now() > reputation.ban_expires:
|
|
||||||
reputation.is_banned = False
|
|
||||||
reputation.ban_expires = None
|
|
||||||
reputation.ban_reason = None
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _is_moderator(self, agent_id: str) -> bool:
|
|
||||||
"""Check if an agent is a moderator"""
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.agent_reputations[agent_id].is_moderator
|
|
||||||
|
|
||||||
def _update_agent_reputation(self, agent_id: str, message_count: int = 0,
|
|
||||||
upvotes_received: int = 0, downvotes_received: int = 0):
|
|
||||||
"""Update agent reputation"""
|
|
||||||
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
self.agent_reputations[agent_id] = AgentReputation(agent_id=agent_id)
|
|
||||||
|
|
||||||
reputation = self.agent_reputations[agent_id]
|
|
||||||
|
|
||||||
if message_count > 0:
|
|
||||||
reputation.message_count += message_count
|
|
||||||
|
|
||||||
if upvotes_received > 0:
|
|
||||||
reputation.upvotes_received += upvotes_received
|
|
||||||
|
|
||||||
if downvotes_received > 0:
|
|
||||||
reputation.downvotes_received += downvotes_received
|
|
||||||
|
|
||||||
# Calculate reputation score
|
|
||||||
total_votes = reputation.upvotes_received + reputation.downvotes_received
|
|
||||||
if total_votes > 0:
|
|
||||||
reputation.reputation_score = (reputation.upvotes_received - reputation.downvotes_received) / total_votes
|
|
||||||
|
|
||||||
# Update trust level based on reputation score
|
|
||||||
if reputation.reputation_score >= 0.8:
|
|
||||||
reputation.trust_level = 5
|
|
||||||
elif reputation.reputation_score >= 0.6:
|
|
||||||
reputation.trust_level = 4
|
|
||||||
elif reputation.reputation_score >= 0.4:
|
|
||||||
reputation.trust_level = 3
|
|
||||||
elif reputation.reputation_score >= 0.2:
|
|
||||||
reputation.trust_level = 2
|
|
||||||
else:
|
|
||||||
reputation.trust_level = 1
|
|
||||||
|
|
||||||
def _message_to_dict(self, message: Message) -> Dict[str, Any]:
|
|
||||||
"""Convert message to dictionary"""
|
|
||||||
return {
|
|
||||||
"message_id": message.message_id,
|
|
||||||
"agent_id": message.agent_id,
|
|
||||||
"agent_address": message.agent_address,
|
|
||||||
"topic": message.topic,
|
|
||||||
"content": message.content,
|
|
||||||
"message_type": message.message_type.value,
|
|
||||||
"timestamp": message.timestamp.isoformat(),
|
|
||||||
"parent_message_id": message.parent_message_id,
|
|
||||||
"reply_count": message.reply_count,
|
|
||||||
"upvotes": message.upvotes,
|
|
||||||
"downvotes": message.downvotes,
|
|
||||||
"status": message.status.value,
|
|
||||||
"metadata": message.metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
def _topic_to_dict(self, topic: Topic) -> Dict[str, Any]:
|
|
||||||
"""Convert topic to dictionary"""
|
|
||||||
return {
|
|
||||||
"topic_id": topic.topic_id,
|
|
||||||
"title": topic.title,
|
|
||||||
"description": topic.description,
|
|
||||||
"creator_agent_id": topic.creator_agent_id,
|
|
||||||
"created_at": topic.created_at.isoformat(),
|
|
||||||
"message_count": topic.message_count,
|
|
||||||
"last_activity": topic.last_activity.isoformat(),
|
|
||||||
"tags": topic.tags,
|
|
||||||
"is_pinned": topic.is_pinned,
|
|
||||||
"is_locked": topic.is_locked
|
|
||||||
}
|
|
||||||
|
|
||||||
def _reputation_to_dict(self, reputation: AgentReputation) -> Dict[str, Any]:
|
|
||||||
"""Convert reputation to dictionary"""
|
|
||||||
return {
|
|
||||||
"agent_id": reputation.agent_id,
|
|
||||||
"message_count": reputation.message_count,
|
|
||||||
"upvotes_received": reputation.upvotes_received,
|
|
||||||
"downvotes_received": reputation.downvotes_received,
|
|
||||||
"reputation_score": reputation.reputation_score,
|
|
||||||
"trust_level": reputation.trust_level,
|
|
||||||
"is_moderator": reputation.is_moderator,
|
|
||||||
"is_banned": reputation.is_banned,
|
|
||||||
"ban_reason": reputation.ban_reason,
|
|
||||||
"ban_expires": reputation.ban_expires.isoformat() if reputation.ban_expires else None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Global contract instance
|
|
||||||
messaging_contract = AgentMessagingContract()
|
|
||||||
@@ -1,584 +0,0 @@
|
|||||||
"""
|
|
||||||
AITBC Agent Wallet Security Implementation
|
|
||||||
|
|
||||||
This module implements the security layer for autonomous agent wallets,
|
|
||||||
integrating the guardian contract to prevent unlimited spending in case
|
|
||||||
of agent compromise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address
|
|
||||||
|
|
||||||
from .guardian_contract import (
|
|
||||||
GuardianContract,
|
|
||||||
SpendingLimit,
|
|
||||||
TimeLockConfig,
|
|
||||||
GuardianConfig,
|
|
||||||
create_guardian_contract,
|
|
||||||
CONSERVATIVE_CONFIG,
|
|
||||||
AGGRESSIVE_CONFIG,
|
|
||||||
HIGH_SECURITY_CONFIG
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentSecurityProfile:
|
|
||||||
"""Security profile for an agent"""
|
|
||||||
agent_address: str
|
|
||||||
security_level: str # "conservative", "aggressive", "high_security"
|
|
||||||
guardian_addresses: List[str]
|
|
||||||
custom_limits: Optional[Dict] = None
|
|
||||||
enabled: bool = True
|
|
||||||
created_at: datetime = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.created_at is None:
|
|
||||||
self.created_at = datetime.utcnow()
|
|
||||||
|
|
||||||
|
|
||||||
class AgentWalletSecurity:
|
|
||||||
"""
|
|
||||||
Security manager for autonomous agent wallets
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.agent_profiles: Dict[str, AgentSecurityProfile] = {}
|
|
||||||
self.guardian_contracts: Dict[str, GuardianContract] = {}
|
|
||||||
self.security_events: List[Dict] = []
|
|
||||||
|
|
||||||
# Default configurations
|
|
||||||
self.configurations = {
|
|
||||||
"conservative": CONSERVATIVE_CONFIG,
|
|
||||||
"aggressive": AGGRESSIVE_CONFIG,
|
|
||||||
"high_security": HIGH_SECURITY_CONFIG
|
|
||||||
}
|
|
||||||
|
|
||||||
def register_agent(self,
|
|
||||||
agent_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
guardian_addresses: List[str] = None,
|
|
||||||
custom_limits: Dict = None) -> Dict:
|
|
||||||
"""
|
|
||||||
Register an agent for security protection
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
security_level: Security level (conservative, aggressive, high_security)
|
|
||||||
guardian_addresses: List of guardian addresses for recovery
|
|
||||||
custom_limits: Custom spending limits (overrides security_level)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Registration result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
if agent_address in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent already registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate security level
|
|
||||||
if security_level not in self.configurations:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Invalid security level: {security_level}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Default guardians if none provided
|
|
||||||
if guardian_addresses is None:
|
|
||||||
guardian_addresses = [agent_address] # Self-guardian (should be overridden)
|
|
||||||
|
|
||||||
# Validate guardian addresses
|
|
||||||
guardian_addresses = [to_checksum_address(addr) for addr in guardian_addresses]
|
|
||||||
|
|
||||||
# Create security profile
|
|
||||||
profile = AgentSecurityProfile(
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
guardian_addresses=guardian_addresses,
|
|
||||||
custom_limits=custom_limits
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create guardian contract
|
|
||||||
config = self.configurations[security_level]
|
|
||||||
if custom_limits:
|
|
||||||
config.update(custom_limits)
|
|
||||||
|
|
||||||
guardian_contract = create_guardian_contract(
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardians=guardian_addresses,
|
|
||||||
**config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store profile and contract
|
|
||||||
self.agent_profiles[agent_address] = profile
|
|
||||||
self.guardian_contracts[agent_address] = guardian_contract
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="agent_registered",
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
guardian_count=len(guardian_addresses)
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "registered",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": security_level,
|
|
||||||
"guardian_addresses": guardian_addresses,
|
|
||||||
"limits": guardian_contract.config.limits,
|
|
||||||
"time_lock_threshold": guardian_contract.config.time_lock.threshold,
|
|
||||||
"registered_at": profile.created_at.isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Registration failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def protect_transaction(self,
|
|
||||||
agent_address: str,
|
|
||||||
to_address: str,
|
|
||||||
amount: int,
|
|
||||||
data: str = "") -> Dict:
|
|
||||||
"""
|
|
||||||
Protect a transaction with guardian contract
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
to_address: Recipient address
|
|
||||||
amount: Amount to transfer
|
|
||||||
data: Transaction data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Protection result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
# Check if agent is registered
|
|
||||||
if agent_address not in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "unprotected",
|
|
||||||
"reason": "Agent not registered for security protection",
|
|
||||||
"suggestion": "Register agent with register_agent() first"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check if protection is enabled
|
|
||||||
profile = self.agent_profiles[agent_address]
|
|
||||||
if not profile.enabled:
|
|
||||||
return {
|
|
||||||
"status": "unprotected",
|
|
||||||
"reason": "Security protection disabled for this agent"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get guardian contract
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
# Initiate transaction protection
|
|
||||||
result = guardian_contract.initiate_transaction(to_address, amount, data)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="transaction_protected",
|
|
||||||
agent_address=agent_address,
|
|
||||||
to_address=to_address,
|
|
||||||
amount=amount,
|
|
||||||
protection_status=result["status"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Transaction protection failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute_protected_transaction(self,
|
|
||||||
agent_address: str,
|
|
||||||
operation_id: str,
|
|
||||||
signature: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Execute a previously protected transaction
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
operation_id: Operation ID from protection
|
|
||||||
signature: Transaction signature
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Execution result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
result = guardian_contract.execute_transaction(operation_id, signature)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
if result["status"] == "executed":
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="transaction_executed",
|
|
||||||
agent_address=agent_address,
|
|
||||||
operation_id=operation_id,
|
|
||||||
transaction_hash=result.get("transaction_hash")
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Transaction execution failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def emergency_pause_agent(self, agent_address: str, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Emergency pause an agent's operations
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address initiating pause
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Pause result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
result = guardian_contract.emergency_pause(guardian_address)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
if result["status"] == "paused":
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="emergency_pause",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Emergency pause failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def update_agent_security(self,
|
|
||||||
agent_address: str,
|
|
||||||
new_limits: Dict,
|
|
||||||
guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Update security limits for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
new_limits: New spending limits
|
|
||||||
guardian_address: Guardian address making the change
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Update result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
# Create new spending limits
|
|
||||||
limits = SpendingLimit(
|
|
||||||
per_transaction=new_limits.get("per_transaction", 1000),
|
|
||||||
per_hour=new_limits.get("per_hour", 5000),
|
|
||||||
per_day=new_limits.get("per_day", 20000),
|
|
||||||
per_week=new_limits.get("per_week", 100000)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = guardian_contract.update_limits(limits, guardian_address)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
if result["status"] == "updated":
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="security_limits_updated",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address,
|
|
||||||
new_limits=new_limits
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Security update failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_agent_security_status(self, agent_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Get security status for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Security status
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
if agent_address not in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "not_registered",
|
|
||||||
"message": "Agent not registered for security protection"
|
|
||||||
}
|
|
||||||
|
|
||||||
profile = self.agent_profiles[agent_address]
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "protected",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": profile.security_level,
|
|
||||||
"enabled": profile.enabled,
|
|
||||||
"guardian_addresses": profile.guardian_addresses,
|
|
||||||
"registered_at": profile.created_at.isoformat(),
|
|
||||||
"spending_status": guardian_contract.get_spending_status(),
|
|
||||||
"pending_operations": guardian_contract.get_pending_operations(),
|
|
||||||
"recent_activity": guardian_contract.get_operation_history(10)
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Status check failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def list_protected_agents(self) -> List[Dict]:
|
|
||||||
"""List all protected agents"""
|
|
||||||
agents = []
|
|
||||||
|
|
||||||
for agent_address, profile in self.agent_profiles.items():
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
agents.append({
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": profile.security_level,
|
|
||||||
"enabled": profile.enabled,
|
|
||||||
"guardian_count": len(profile.guardian_addresses),
|
|
||||||
"pending_operations": len(guardian_contract.pending_operations),
|
|
||||||
"paused": guardian_contract.paused,
|
|
||||||
"emergency_mode": guardian_contract.emergency_mode,
|
|
||||||
"registered_at": profile.created_at.isoformat()
|
|
||||||
})
|
|
||||||
|
|
||||||
return sorted(agents, key=lambda x: x["registered_at"], reverse=True)
|
|
||||||
|
|
||||||
def get_security_events(self, agent_address: str = None, limit: int = 50) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Get security events
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Filter by agent address (optional)
|
|
||||||
limit: Maximum number of events
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Security events
|
|
||||||
"""
|
|
||||||
events = self.security_events
|
|
||||||
|
|
||||||
if agent_address:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
events = [e for e in events if e.get("agent_address") == agent_address]
|
|
||||||
|
|
||||||
return sorted(events, key=lambda x: x["timestamp"], reverse=True)[:limit]
|
|
||||||
|
|
||||||
def _log_security_event(self, **kwargs):
|
|
||||||
"""Log a security event"""
|
|
||||||
event = {
|
|
||||||
"timestamp": datetime.utcnow().isoformat(),
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
self.security_events.append(event)
|
|
||||||
|
|
||||||
def disable_agent_protection(self, agent_address: str, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Disable protection for an agent (guardian only)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Disable result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
if agent_address not in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
profile = self.agent_profiles[agent_address]
|
|
||||||
|
|
||||||
if guardian_address not in profile.guardian_addresses:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Not authorized: not a guardian"
|
|
||||||
}
|
|
||||||
|
|
||||||
profile.enabled = False
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="protection_disabled",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "disabled",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"disabled_at": datetime.utcnow().isoformat(),
|
|
||||||
"guardian": guardian_address
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Disable protection failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Global security manager instance
|
|
||||||
agent_wallet_security = AgentWalletSecurity()
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience functions for common operations
|
|
||||||
def register_agent_for_protection(agent_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
guardians: List[str] = None) -> Dict:
|
|
||||||
"""Register an agent for security protection"""
|
|
||||||
return agent_wallet_security.register_agent(
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
guardian_addresses=guardians
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def protect_agent_transaction(agent_address: str,
|
|
||||||
to_address: str,
|
|
||||||
amount: int,
|
|
||||||
data: str = "") -> Dict:
|
|
||||||
"""Protect a transaction for an agent"""
|
|
||||||
return agent_wallet_security.protect_transaction(
|
|
||||||
agent_address=agent_address,
|
|
||||||
to_address=to_address,
|
|
||||||
amount=amount,
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_security_summary(agent_address: str) -> Dict:
|
|
||||||
"""Get security summary for an agent"""
|
|
||||||
return agent_wallet_security.get_agent_security_status(agent_address)
|
|
||||||
|
|
||||||
|
|
||||||
# Security audit and monitoring functions
|
|
||||||
def generate_security_report() -> Dict:
|
|
||||||
"""Generate comprehensive security report"""
|
|
||||||
protected_agents = agent_wallet_security.list_protected_agents()
|
|
||||||
|
|
||||||
total_agents = len(protected_agents)
|
|
||||||
active_agents = len([a for a in protected_agents if a["enabled"]])
|
|
||||||
paused_agents = len([a for a in protected_agents if a["paused"]])
|
|
||||||
emergency_agents = len([a for a in protected_agents if a["emergency_mode"]])
|
|
||||||
|
|
||||||
recent_events = agent_wallet_security.get_security_events(limit=20)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"generated_at": datetime.utcnow().isoformat(),
|
|
||||||
"summary": {
|
|
||||||
"total_protected_agents": total_agents,
|
|
||||||
"active_agents": active_agents,
|
|
||||||
"paused_agents": paused_agents,
|
|
||||||
"emergency_mode_agents": emergency_agents,
|
|
||||||
"protection_coverage": f"{(active_agents / total_agents * 100):.1f}%" if total_agents > 0 else "0%"
|
|
||||||
},
|
|
||||||
"agents": protected_agents,
|
|
||||||
"recent_security_events": recent_events,
|
|
||||||
"security_levels": {
|
|
||||||
level: len([a for a in protected_agents if a["security_level"] == level])
|
|
||||||
for level in ["conservative", "aggressive", "high_security"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def detect_suspicious_activity(agent_address: str, hours: int = 24) -> Dict:
|
|
||||||
"""Detect suspicious activity for an agent"""
|
|
||||||
status = agent_wallet_security.get_agent_security_status(agent_address)
|
|
||||||
|
|
||||||
if status["status"] != "protected":
|
|
||||||
return {
|
|
||||||
"status": "not_protected",
|
|
||||||
"suspicious_activity": False
|
|
||||||
}
|
|
||||||
|
|
||||||
spending_status = status["spending_status"]
|
|
||||||
recent_events = agent_wallet_security.get_security_events(agent_address, limit=50)
|
|
||||||
|
|
||||||
# Suspicious patterns
|
|
||||||
suspicious_patterns = []
|
|
||||||
|
|
||||||
# Check for rapid spending
|
|
||||||
if spending_status["spent"]["current_hour"] > spending_status["current_limits"]["per_hour"] * 0.8:
|
|
||||||
suspicious_patterns.append("High hourly spending rate")
|
|
||||||
|
|
||||||
# Check for many small transactions (potential dust attack)
|
|
||||||
recent_tx_count = len([e for e in recent_events if e["event_type"] == "transaction_executed"])
|
|
||||||
if recent_tx_count > 20:
|
|
||||||
suspicious_patterns.append("High transaction frequency")
|
|
||||||
|
|
||||||
# Check for emergency pauses
|
|
||||||
recent_pauses = len([e for e in recent_events if e["event_type"] == "emergency_pause"])
|
|
||||||
if recent_pauses > 0:
|
|
||||||
suspicious_patterns.append("Recent emergency pauses detected")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "analyzed",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"suspicious_activity": len(suspicious_patterns) > 0,
|
|
||||||
"suspicious_patterns": suspicious_patterns,
|
|
||||||
"analysis_period_hours": hours,
|
|
||||||
"analyzed_at": datetime.utcnow().isoformat()
|
|
||||||
}
|
|
||||||
@@ -1,559 +0,0 @@
|
|||||||
"""
|
|
||||||
Smart Contract Escrow System
|
|
||||||
Handles automated payment holding and release for AI job marketplace
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from typing import Dict, List, Optional, Tuple, Set
|
|
||||||
from dataclasses import dataclass, asdict
|
|
||||||
from enum import Enum
|
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
class EscrowState(Enum):
|
|
||||||
CREATED = "created"
|
|
||||||
FUNDED = "funded"
|
|
||||||
JOB_STARTED = "job_started"
|
|
||||||
JOB_COMPLETED = "job_completed"
|
|
||||||
DISPUTED = "disputed"
|
|
||||||
RESOLVED = "resolved"
|
|
||||||
RELEASED = "released"
|
|
||||||
REFUNDED = "refunded"
|
|
||||||
EXPIRED = "expired"
|
|
||||||
|
|
||||||
class DisputeReason(Enum):
|
|
||||||
QUALITY_ISSUES = "quality_issues"
|
|
||||||
DELIVERY_LATE = "delivery_late"
|
|
||||||
INCOMPLETE_WORK = "incomplete_work"
|
|
||||||
TECHNICAL_ISSUES = "technical_issues"
|
|
||||||
PAYMENT_DISPUTE = "payment_dispute"
|
|
||||||
OTHER = "other"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EscrowContract:
|
|
||||||
contract_id: str
|
|
||||||
job_id: str
|
|
||||||
client_address: str
|
|
||||||
agent_address: str
|
|
||||||
amount: Decimal
|
|
||||||
fee_rate: Decimal # Platform fee rate
|
|
||||||
created_at: float
|
|
||||||
expires_at: float
|
|
||||||
state: EscrowState
|
|
||||||
milestones: List[Dict]
|
|
||||||
current_milestone: int
|
|
||||||
dispute_reason: Optional[DisputeReason]
|
|
||||||
dispute_evidence: List[Dict]
|
|
||||||
resolution: Optional[Dict]
|
|
||||||
released_amount: Decimal
|
|
||||||
refunded_amount: Decimal
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Milestone:
|
|
||||||
milestone_id: str
|
|
||||||
description: str
|
|
||||||
amount: Decimal
|
|
||||||
completed: bool
|
|
||||||
completed_at: Optional[float]
|
|
||||||
verified: bool
|
|
||||||
|
|
||||||
class EscrowManager:
|
|
||||||
"""Manages escrow contracts for AI job marketplace"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.escrow_contracts: Dict[str, EscrowContract] = {}
|
|
||||||
self.active_contracts: Set[str] = set()
|
|
||||||
self.disputed_contracts: Set[str] = set()
|
|
||||||
|
|
||||||
# Escrow parameters
|
|
||||||
self.default_fee_rate = Decimal('0.025') # 2.5% platform fee
|
|
||||||
self.max_contract_duration = 86400 * 30 # 30 days
|
|
||||||
self.dispute_timeout = 86400 * 7 # 7 days for dispute resolution
|
|
||||||
self.min_dispute_evidence = 1
|
|
||||||
self.max_dispute_evidence = 10
|
|
||||||
|
|
||||||
# Milestone parameters
|
|
||||||
self.min_milestone_amount = Decimal('0.01')
|
|
||||||
self.max_milestones = 10
|
|
||||||
self.verification_timeout = 86400 # 24 hours for milestone verification
|
|
||||||
|
|
||||||
async def create_contract(self, job_id: str, client_address: str, agent_address: str,
|
|
||||||
amount: Decimal, fee_rate: Optional[Decimal] = None,
|
|
||||||
milestones: Optional[List[Dict]] = None,
|
|
||||||
duration_days: int = 30) -> Tuple[bool, str, Optional[str]]:
|
|
||||||
"""Create new escrow contract"""
|
|
||||||
try:
|
|
||||||
# Validate inputs
|
|
||||||
if not self._validate_contract_inputs(job_id, client_address, agent_address, amount):
|
|
||||||
return False, "Invalid contract inputs", None
|
|
||||||
|
|
||||||
# Calculate fee
|
|
||||||
fee_rate = fee_rate or self.default_fee_rate
|
|
||||||
platform_fee = amount * fee_rate
|
|
||||||
total_amount = amount + platform_fee
|
|
||||||
|
|
||||||
# Validate milestones
|
|
||||||
validated_milestones = []
|
|
||||||
if milestones:
|
|
||||||
validated_milestones = await self._validate_milestones(milestones, amount)
|
|
||||||
if not validated_milestones:
|
|
||||||
return False, "Invalid milestones configuration", None
|
|
||||||
else:
|
|
||||||
# Create single milestone for full amount
|
|
||||||
validated_milestones = [{
|
|
||||||
'milestone_id': 'milestone_1',
|
|
||||||
'description': 'Complete job',
|
|
||||||
'amount': amount,
|
|
||||||
'completed': False
|
|
||||||
}]
|
|
||||||
|
|
||||||
# Create contract
|
|
||||||
contract_id = self._generate_contract_id(client_address, agent_address, job_id)
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
contract = EscrowContract(
|
|
||||||
contract_id=contract_id,
|
|
||||||
job_id=job_id,
|
|
||||||
client_address=client_address,
|
|
||||||
agent_address=agent_address,
|
|
||||||
amount=total_amount,
|
|
||||||
fee_rate=fee_rate,
|
|
||||||
created_at=current_time,
|
|
||||||
expires_at=current_time + (duration_days * 86400),
|
|
||||||
state=EscrowState.CREATED,
|
|
||||||
milestones=validated_milestones,
|
|
||||||
current_milestone=0,
|
|
||||||
dispute_reason=None,
|
|
||||||
dispute_evidence=[],
|
|
||||||
resolution=None,
|
|
||||||
released_amount=Decimal('0'),
|
|
||||||
refunded_amount=Decimal('0')
|
|
||||||
)
|
|
||||||
|
|
||||||
self.escrow_contracts[contract_id] = contract
|
|
||||||
|
|
||||||
log_info(f"Escrow contract created: {contract_id} for job {job_id}")
|
|
||||||
return True, "Contract created successfully", contract_id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return False, f"Contract creation failed: {str(e)}", None
|
|
||||||
|
|
||||||
def _validate_contract_inputs(self, job_id: str, client_address: str,
|
|
||||||
agent_address: str, amount: Decimal) -> bool:
|
|
||||||
"""Validate contract creation inputs"""
|
|
||||||
if not all([job_id, client_address, agent_address]):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Validate addresses (simplified)
|
|
||||||
if not (client_address.startswith('0x') and len(client_address) == 42):
|
|
||||||
return False
|
|
||||||
if not (agent_address.startswith('0x') and len(agent_address) == 42):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Validate amount
|
|
||||||
if amount <= 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check for existing contract
|
|
||||||
for contract in self.escrow_contracts.values():
|
|
||||||
if contract.job_id == job_id:
|
|
||||||
return False # Contract already exists for this job
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _validate_milestones(self, milestones: List[Dict], total_amount: Decimal) -> Optional[List[Dict]]:
|
|
||||||
"""Validate milestone configuration"""
|
|
||||||
if not milestones or len(milestones) > self.max_milestones:
|
|
||||||
return None
|
|
||||||
|
|
||||||
validated_milestones = []
|
|
||||||
milestone_total = Decimal('0')
|
|
||||||
|
|
||||||
for i, milestone_data in enumerate(milestones):
|
|
||||||
# Validate required fields
|
|
||||||
required_fields = ['milestone_id', 'description', 'amount']
|
|
||||||
if not all(field in milestone_data for field in required_fields):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Validate amount
|
|
||||||
amount = Decimal(str(milestone_data['amount']))
|
|
||||||
if amount < self.min_milestone_amount:
|
|
||||||
return None
|
|
||||||
|
|
||||||
milestone_total += amount
|
|
||||||
validated_milestones.append({
|
|
||||||
'milestone_id': milestone_data['milestone_id'],
|
|
||||||
'description': milestone_data['description'],
|
|
||||||
'amount': amount,
|
|
||||||
'completed': False
|
|
||||||
})
|
|
||||||
|
|
||||||
# Check if milestone amounts sum to total
|
|
||||||
if abs(milestone_total - total_amount) > Decimal('0.01'): # Allow small rounding difference
|
|
||||||
return None
|
|
||||||
|
|
||||||
return validated_milestones
|
|
||||||
|
|
||||||
def _generate_contract_id(self, client_address: str, agent_address: str, job_id: str) -> str:
|
|
||||||
"""Generate unique contract ID"""
|
|
||||||
import hashlib
|
|
||||||
content = f"{client_address}:{agent_address}:{job_id}:{time.time()}"
|
|
||||||
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
|
||||||
|
|
||||||
async def fund_contract(self, contract_id: str, payment_tx_hash: str) -> Tuple[bool, str]:
|
|
||||||
"""Fund escrow contract"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state != EscrowState.CREATED:
|
|
||||||
return False, f"Cannot fund contract in {contract.state.value} state"
|
|
||||||
|
|
||||||
# In real implementation, this would verify the payment transaction
|
|
||||||
# For now, assume payment is valid
|
|
||||||
|
|
||||||
contract.state = EscrowState.FUNDED
|
|
||||||
self.active_contracts.add(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Contract funded: {contract_id}")
|
|
||||||
return True, "Contract funded successfully"
|
|
||||||
|
|
||||||
async def start_job(self, contract_id: str) -> Tuple[bool, str]:
|
|
||||||
"""Mark job as started"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state != EscrowState.FUNDED:
|
|
||||||
return False, f"Cannot start job in {contract.state.value} state"
|
|
||||||
|
|
||||||
contract.state = EscrowState.JOB_STARTED
|
|
||||||
|
|
||||||
log_info(f"Job started for contract: {contract_id}")
|
|
||||||
return True, "Job started successfully"
|
|
||||||
|
|
||||||
async def complete_milestone(self, contract_id: str, milestone_id: str,
|
|
||||||
evidence: Dict = None) -> Tuple[bool, str]:
|
|
||||||
"""Mark milestone as completed"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state not in [EscrowState.JOB_STARTED, EscrowState.JOB_COMPLETED]:
|
|
||||||
return False, f"Cannot complete milestone in {contract.state.value} state"
|
|
||||||
|
|
||||||
# Find milestone
|
|
||||||
milestone = None
|
|
||||||
for ms in contract.milestones:
|
|
||||||
if ms['milestone_id'] == milestone_id:
|
|
||||||
milestone = ms
|
|
||||||
break
|
|
||||||
|
|
||||||
if not milestone:
|
|
||||||
return False, "Milestone not found"
|
|
||||||
|
|
||||||
if milestone['completed']:
|
|
||||||
return False, "Milestone already completed"
|
|
||||||
|
|
||||||
# Mark as completed
|
|
||||||
milestone['completed'] = True
|
|
||||||
milestone['completed_at'] = time.time()
|
|
||||||
|
|
||||||
# Add evidence if provided
|
|
||||||
if evidence:
|
|
||||||
milestone['evidence'] = evidence
|
|
||||||
|
|
||||||
# Check if all milestones are completed
|
|
||||||
all_completed = all(ms['completed'] for ms in contract.milestones)
|
|
||||||
if all_completed:
|
|
||||||
contract.state = EscrowState.JOB_COMPLETED
|
|
||||||
|
|
||||||
log_info(f"Milestone {milestone_id} completed for contract: {contract_id}")
|
|
||||||
return True, "Milestone completed successfully"
|
|
||||||
|
|
||||||
async def verify_milestone(self, contract_id: str, milestone_id: str,
|
|
||||||
verified: bool, feedback: str = "") -> Tuple[bool, str]:
|
|
||||||
"""Verify milestone completion"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
# Find milestone
|
|
||||||
milestone = None
|
|
||||||
for ms in contract.milestones:
|
|
||||||
if ms['milestone_id'] == milestone_id:
|
|
||||||
milestone = ms
|
|
||||||
break
|
|
||||||
|
|
||||||
if not milestone:
|
|
||||||
return False, "Milestone not found"
|
|
||||||
|
|
||||||
if not milestone['completed']:
|
|
||||||
return False, "Milestone not completed yet"
|
|
||||||
|
|
||||||
# Set verification status
|
|
||||||
milestone['verified'] = verified
|
|
||||||
milestone['verification_feedback'] = feedback
|
|
||||||
|
|
||||||
if verified:
|
|
||||||
# Release milestone payment
|
|
||||||
await self._release_milestone_payment(contract_id, milestone_id)
|
|
||||||
else:
|
|
||||||
# Create dispute if verification fails
|
|
||||||
await self._create_dispute(contract_id, DisputeReason.QUALITY_ISSUES,
|
|
||||||
f"Milestone {milestone_id} verification failed: {feedback}")
|
|
||||||
|
|
||||||
log_info(f"Milestone {milestone_id} verification: {verified} for contract: {contract_id}")
|
|
||||||
return True, "Milestone verification processed"
|
|
||||||
|
|
||||||
async def _release_milestone_payment(self, contract_id: str, milestone_id: str):
|
|
||||||
"""Release payment for verified milestone"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Find milestone
|
|
||||||
milestone = None
|
|
||||||
for ms in contract.milestones:
|
|
||||||
if ms['milestone_id'] == milestone_id:
|
|
||||||
milestone = ms
|
|
||||||
break
|
|
||||||
|
|
||||||
if not milestone:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Calculate payment amount (minus platform fee)
|
|
||||||
milestone_amount = Decimal(str(milestone['amount']))
|
|
||||||
platform_fee = milestone_amount * contract.fee_rate
|
|
||||||
payment_amount = milestone_amount - platform_fee
|
|
||||||
|
|
||||||
# Update released amount
|
|
||||||
contract.released_amount += payment_amount
|
|
||||||
|
|
||||||
# In real implementation, this would trigger actual payment transfer
|
|
||||||
log_info(f"Released {payment_amount} for milestone {milestone_id} in contract {contract_id}")
|
|
||||||
|
|
||||||
async def release_full_payment(self, contract_id: str) -> Tuple[bool, str]:
|
|
||||||
"""Release full payment to agent"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state != EscrowState.JOB_COMPLETED:
|
|
||||||
return False, f"Cannot release payment in {contract.state.value} state"
|
|
||||||
|
|
||||||
# Check if all milestones are verified
|
|
||||||
all_verified = all(ms.get('verified', False) for ms in contract.milestones)
|
|
||||||
if not all_verified:
|
|
||||||
return False, "Not all milestones are verified"
|
|
||||||
|
|
||||||
# Calculate remaining payment
|
|
||||||
total_milestone_amount = sum(Decimal(str(ms['amount'])) for ms in contract.milestones)
|
|
||||||
platform_fee_total = total_milestone_amount * contract.fee_rate
|
|
||||||
remaining_payment = total_milestone_amount - contract.released_amount - platform_fee_total
|
|
||||||
|
|
||||||
if remaining_payment > 0:
|
|
||||||
contract.released_amount += remaining_payment
|
|
||||||
|
|
||||||
contract.state = EscrowState.RELEASED
|
|
||||||
self.active_contracts.discard(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Full payment released for contract: {contract_id}")
|
|
||||||
return True, "Payment released successfully"
|
|
||||||
|
|
||||||
async def create_dispute(self, contract_id: str, reason: DisputeReason,
|
|
||||||
description: str, evidence: List[Dict] = None) -> Tuple[bool, str]:
|
|
||||||
"""Create dispute for contract"""
|
|
||||||
return await self._create_dispute(contract_id, reason, description, evidence)
|
|
||||||
|
|
||||||
async def _create_dispute(self, contract_id: str, reason: DisputeReason,
|
|
||||||
description: str, evidence: List[Dict] = None):
|
|
||||||
"""Internal dispute creation method"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state == EscrowState.DISPUTED:
|
|
||||||
return False, "Contract already disputed"
|
|
||||||
|
|
||||||
if contract.state not in [EscrowState.FUNDED, EscrowState.JOB_STARTED, EscrowState.JOB_COMPLETED]:
|
|
||||||
return False, f"Cannot dispute contract in {contract.state.value} state"
|
|
||||||
|
|
||||||
# Validate evidence
|
|
||||||
if evidence and (len(evidence) < self.min_dispute_evidence or len(evidence) > self.max_dispute_evidence):
|
|
||||||
return False, f"Invalid evidence count: {len(evidence)}"
|
|
||||||
|
|
||||||
# Create dispute
|
|
||||||
contract.state = EscrowState.DISPUTED
|
|
||||||
contract.dispute_reason = reason
|
|
||||||
contract.dispute_evidence = evidence or []
|
|
||||||
contract.dispute_created_at = time.time()
|
|
||||||
|
|
||||||
self.disputed_contracts.add(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Dispute created for contract: {contract_id} - {reason.value}")
|
|
||||||
return True, "Dispute created successfully"
|
|
||||||
|
|
||||||
async def resolve_dispute(self, contract_id: str, resolution: Dict) -> Tuple[bool, str]:
|
|
||||||
"""Resolve dispute with specified outcome"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state != EscrowState.DISPUTED:
|
|
||||||
return False, f"Contract not in disputed state: {contract.state.value}"
|
|
||||||
|
|
||||||
# Validate resolution
|
|
||||||
required_fields = ['winner', 'client_refund', 'agent_payment']
|
|
||||||
if not all(field in resolution for field in required_fields):
|
|
||||||
return False, "Invalid resolution format"
|
|
||||||
|
|
||||||
winner = resolution['winner']
|
|
||||||
client_refund = Decimal(str(resolution['client_refund']))
|
|
||||||
agent_payment = Decimal(str(resolution['agent_payment']))
|
|
||||||
|
|
||||||
# Validate amounts
|
|
||||||
total_refund = client_refund + agent_payment
|
|
||||||
if total_refund > contract.amount:
|
|
||||||
return False, "Refund amounts exceed contract amount"
|
|
||||||
|
|
||||||
# Apply resolution
|
|
||||||
contract.resolution = resolution
|
|
||||||
contract.state = EscrowState.RESOLVED
|
|
||||||
|
|
||||||
# Update amounts
|
|
||||||
contract.released_amount += agent_payment
|
|
||||||
contract.refunded_amount += client_refund
|
|
||||||
|
|
||||||
# Remove from disputed contracts
|
|
||||||
self.disputed_contracts.discard(contract_id)
|
|
||||||
self.active_contracts.discard(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Dispute resolved for contract: {contract_id} - Winner: {winner}")
|
|
||||||
return True, "Dispute resolved successfully"
|
|
||||||
|
|
||||||
async def refund_contract(self, contract_id: str, reason: str = "") -> Tuple[bool, str]:
|
|
||||||
"""Refund contract to client"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state in [EscrowState.RELEASED, EscrowState.REFUNDED, EscrowState.EXPIRED]:
|
|
||||||
return False, f"Cannot refund contract in {contract.state.value} state"
|
|
||||||
|
|
||||||
# Calculate refund amount (minus any released payments)
|
|
||||||
refund_amount = contract.amount - contract.released_amount
|
|
||||||
|
|
||||||
if refund_amount <= 0:
|
|
||||||
return False, "No amount available for refund"
|
|
||||||
|
|
||||||
contract.state = EscrowState.REFUNDED
|
|
||||||
contract.refunded_amount = refund_amount
|
|
||||||
|
|
||||||
self.active_contracts.discard(contract_id)
|
|
||||||
self.disputed_contracts.discard(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Contract refunded: {contract_id} - Amount: {refund_amount}")
|
|
||||||
return True, "Contract refunded successfully"
|
|
||||||
|
|
||||||
async def expire_contract(self, contract_id: str) -> Tuple[bool, str]:
|
|
||||||
"""Mark contract as expired"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if time.time() < contract.expires_at:
|
|
||||||
return False, "Contract has not expired yet"
|
|
||||||
|
|
||||||
if contract.state in [EscrowState.RELEASED, EscrowState.REFUNDED, EscrowState.EXPIRED]:
|
|
||||||
return False, f"Contract already in final state: {contract.state.value}"
|
|
||||||
|
|
||||||
# Auto-refund if no work has been done
|
|
||||||
if contract.state == EscrowState.FUNDED:
|
|
||||||
return await self.refund_contract(contract_id, "Contract expired")
|
|
||||||
|
|
||||||
# Handle other states based on work completion
|
|
||||||
contract.state = EscrowState.EXPIRED
|
|
||||||
self.active_contracts.discard(contract_id)
|
|
||||||
self.disputed_contracts.discard(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Contract expired: {contract_id}")
|
|
||||||
return True, "Contract expired successfully"
|
|
||||||
|
|
||||||
async def get_contract_info(self, contract_id: str) -> Optional[EscrowContract]:
|
|
||||||
"""Get contract information"""
|
|
||||||
return self.escrow_contracts.get(contract_id)
|
|
||||||
|
|
||||||
async def get_contracts_by_client(self, client_address: str) -> List[EscrowContract]:
|
|
||||||
"""Get contracts for specific client"""
|
|
||||||
return [
|
|
||||||
contract for contract in self.escrow_contracts.values()
|
|
||||||
if contract.client_address == client_address
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_contracts_by_agent(self, agent_address: str) -> List[EscrowContract]:
|
|
||||||
"""Get contracts for specific agent"""
|
|
||||||
return [
|
|
||||||
contract for contract in self.escrow_contracts.values()
|
|
||||||
if contract.agent_address == agent_address
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_active_contracts(self) -> List[EscrowContract]:
|
|
||||||
"""Get all active contracts"""
|
|
||||||
return [
|
|
||||||
self.escrow_contracts[contract_id]
|
|
||||||
for contract_id in self.active_contracts
|
|
||||||
if contract_id in self.escrow_contracts
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_disputed_contracts(self) -> List[EscrowContract]:
|
|
||||||
"""Get all disputed contracts"""
|
|
||||||
return [
|
|
||||||
self.escrow_contracts[contract_id]
|
|
||||||
for contract_id in self.disputed_contracts
|
|
||||||
if contract_id in self.escrow_contracts
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_escrow_statistics(self) -> Dict:
|
|
||||||
"""Get escrow system statistics"""
|
|
||||||
total_contracts = len(self.escrow_contracts)
|
|
||||||
active_count = len(self.active_contracts)
|
|
||||||
disputed_count = len(self.disputed_contracts)
|
|
||||||
|
|
||||||
# State distribution
|
|
||||||
state_counts = {}
|
|
||||||
for contract in self.escrow_contracts.values():
|
|
||||||
state = contract.state.value
|
|
||||||
state_counts[state] = state_counts.get(state, 0) + 1
|
|
||||||
|
|
||||||
# Financial statistics
|
|
||||||
total_amount = sum(contract.amount for contract in self.escrow_contracts.values())
|
|
||||||
total_released = sum(contract.released_amount for contract in self.escrow_contracts.values())
|
|
||||||
total_refunded = sum(contract.refunded_amount for contract in self.escrow_contracts.values())
|
|
||||||
total_fees = total_amount - total_released - total_refunded
|
|
||||||
|
|
||||||
return {
|
|
||||||
'total_contracts': total_contracts,
|
|
||||||
'active_contracts': active_count,
|
|
||||||
'disputed_contracts': disputed_count,
|
|
||||||
'state_distribution': state_counts,
|
|
||||||
'total_amount': float(total_amount),
|
|
||||||
'total_released': float(total_released),
|
|
||||||
'total_refunded': float(total_refunded),
|
|
||||||
'total_fees': float(total_fees),
|
|
||||||
'average_contract_value': float(total_amount / total_contracts) if total_contracts > 0 else 0
|
|
||||||
}
|
|
||||||
|
|
||||||
# Global escrow manager
|
|
||||||
escrow_manager: Optional[EscrowManager] = None
|
|
||||||
|
|
||||||
def get_escrow_manager() -> Optional[EscrowManager]:
|
|
||||||
"""Get global escrow manager"""
|
|
||||||
return escrow_manager
|
|
||||||
|
|
||||||
def create_escrow_manager() -> EscrowManager:
|
|
||||||
"""Create and set global escrow manager"""
|
|
||||||
global escrow_manager
|
|
||||||
escrow_manager = EscrowManager()
|
|
||||||
return escrow_manager
|
|
||||||
@@ -1,405 +0,0 @@
|
|||||||
"""
|
|
||||||
Fixed Guardian Configuration with Proper Guardian Setup
|
|
||||||
Addresses the critical vulnerability where guardian lists were empty
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address, keccak
|
|
||||||
|
|
||||||
from .guardian_contract import (
|
|
||||||
SpendingLimit,
|
|
||||||
TimeLockConfig,
|
|
||||||
GuardianConfig,
|
|
||||||
GuardianContract
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GuardianSetup:
|
|
||||||
"""Guardian setup configuration"""
|
|
||||||
primary_guardian: str # Main guardian address
|
|
||||||
backup_guardians: List[str] # Backup guardian addresses
|
|
||||||
multisig_threshold: int # Number of signatures required
|
|
||||||
emergency_contacts: List[str] # Additional emergency contacts
|
|
||||||
|
|
||||||
|
|
||||||
class SecureGuardianManager:
|
|
||||||
"""
|
|
||||||
Secure guardian management with proper initialization
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.guardian_registrations: Dict[str, GuardianSetup] = {}
|
|
||||||
self.guardian_contracts: Dict[str, GuardianContract] = {}
|
|
||||||
|
|
||||||
def create_guardian_setup(
|
|
||||||
self,
|
|
||||||
agent_address: str,
|
|
||||||
owner_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
custom_guardians: Optional[List[str]] = None
|
|
||||||
) -> GuardianSetup:
|
|
||||||
"""
|
|
||||||
Create a proper guardian setup for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
owner_address: Owner of the agent
|
|
||||||
security_level: Security level (conservative, aggressive, high_security)
|
|
||||||
custom_guardians: Optional custom guardian addresses
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Guardian setup configuration
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
owner_address = to_checksum_address(owner_address)
|
|
||||||
|
|
||||||
# Determine guardian requirements based on security level
|
|
||||||
if security_level == "conservative":
|
|
||||||
required_guardians = 3
|
|
||||||
multisig_threshold = 2
|
|
||||||
elif security_level == "aggressive":
|
|
||||||
required_guardians = 2
|
|
||||||
multisig_threshold = 2
|
|
||||||
elif security_level == "high_security":
|
|
||||||
required_guardians = 5
|
|
||||||
multisig_threshold = 3
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid security level: {security_level}")
|
|
||||||
|
|
||||||
# Build guardian list
|
|
||||||
guardians = []
|
|
||||||
|
|
||||||
# Always include the owner as primary guardian
|
|
||||||
guardians.append(owner_address)
|
|
||||||
|
|
||||||
# Add custom guardians if provided
|
|
||||||
if custom_guardians:
|
|
||||||
for guardian in custom_guardians:
|
|
||||||
guardian = to_checksum_address(guardian)
|
|
||||||
if guardian not in guardians:
|
|
||||||
guardians.append(guardian)
|
|
||||||
|
|
||||||
# Generate backup guardians if needed
|
|
||||||
while len(guardians) < required_guardians:
|
|
||||||
# Generate a deterministic backup guardian based on agent address
|
|
||||||
# In production, these would be trusted service addresses
|
|
||||||
backup_index = len(guardians) - 1 # -1 because owner is already included
|
|
||||||
backup_guardian = self._generate_backup_guardian(agent_address, backup_index)
|
|
||||||
|
|
||||||
if backup_guardian not in guardians:
|
|
||||||
guardians.append(backup_guardian)
|
|
||||||
|
|
||||||
# Create setup
|
|
||||||
setup = GuardianSetup(
|
|
||||||
primary_guardian=owner_address,
|
|
||||||
backup_guardians=[g for g in guardians if g != owner_address],
|
|
||||||
multisig_threshold=multisig_threshold,
|
|
||||||
emergency_contacts=guardians.copy()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.guardian_registrations[agent_address] = setup
|
|
||||||
|
|
||||||
return setup
|
|
||||||
|
|
||||||
def _generate_backup_guardian(self, agent_address: str, index: int) -> str:
|
|
||||||
"""
|
|
||||||
Generate deterministic backup guardian address
|
|
||||||
|
|
||||||
In production, these would be pre-registered trusted guardian addresses
|
|
||||||
"""
|
|
||||||
# Create a deterministic address based on agent address and index
|
|
||||||
seed = f"{agent_address}_{index}_backup_guardian"
|
|
||||||
hash_result = keccak(seed.encode())
|
|
||||||
|
|
||||||
# Use the hash to generate a valid address
|
|
||||||
address_bytes = hash_result[-20:] # Take last 20 bytes
|
|
||||||
address = "0x" + address_bytes.hex()
|
|
||||||
|
|
||||||
return to_checksum_address(address)
|
|
||||||
|
|
||||||
def create_secure_guardian_contract(
|
|
||||||
self,
|
|
||||||
agent_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
custom_guardians: Optional[List[str]] = None
|
|
||||||
) -> GuardianContract:
|
|
||||||
"""
|
|
||||||
Create a guardian contract with proper guardian configuration
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
security_level: Security level
|
|
||||||
custom_guardians: Optional custom guardian addresses
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured guardian contract
|
|
||||||
"""
|
|
||||||
# Create guardian setup
|
|
||||||
setup = self.create_guardian_setup(
|
|
||||||
agent_address=agent_address,
|
|
||||||
owner_address=agent_address, # Agent is its own owner initially
|
|
||||||
security_level=security_level,
|
|
||||||
custom_guardians=custom_guardians
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get security configuration
|
|
||||||
config = self._get_security_config(security_level, setup)
|
|
||||||
|
|
||||||
# Create contract
|
|
||||||
contract = GuardianContract(agent_address, config)
|
|
||||||
|
|
||||||
# Store contract
|
|
||||||
self.guardian_contracts[agent_address] = contract
|
|
||||||
|
|
||||||
return contract
|
|
||||||
|
|
||||||
def _get_security_config(self, security_level: str, setup: GuardianSetup) -> GuardianConfig:
|
|
||||||
"""Get security configuration with proper guardian list"""
|
|
||||||
|
|
||||||
# Build guardian list
|
|
||||||
all_guardians = [setup.primary_guardian] + setup.backup_guardians
|
|
||||||
|
|
||||||
if security_level == "conservative":
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=1000,
|
|
||||||
per_hour=5000,
|
|
||||||
per_day=20000,
|
|
||||||
per_week=100000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=5000,
|
|
||||||
delay_hours=24,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=all_guardians,
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False,
|
|
||||||
multisig_threshold=setup.multisig_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
elif security_level == "aggressive":
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=5000,
|
|
||||||
per_hour=25000,
|
|
||||||
per_day=100000,
|
|
||||||
per_week=500000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=20000,
|
|
||||||
delay_hours=12,
|
|
||||||
max_delay_hours=72
|
|
||||||
),
|
|
||||||
guardians=all_guardians,
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False,
|
|
||||||
multisig_threshold=setup.multisig_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
elif security_level == "high_security":
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=500,
|
|
||||||
per_hour=2000,
|
|
||||||
per_day=8000,
|
|
||||||
per_week=40000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=2000,
|
|
||||||
delay_hours=48,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=all_guardians,
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False,
|
|
||||||
multisig_threshold=setup.multisig_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid security level: {security_level}")
|
|
||||||
|
|
||||||
def test_emergency_pause(self, agent_address: str, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Test emergency pause functionality
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent address
|
|
||||||
guardian_address: Guardian attempting pause
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Test result
|
|
||||||
"""
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
contract = self.guardian_contracts[agent_address]
|
|
||||||
return contract.emergency_pause(guardian_address)
|
|
||||||
|
|
||||||
def verify_guardian_authorization(self, agent_address: str, guardian_address: str) -> bool:
|
|
||||||
"""
|
|
||||||
Verify if a guardian is authorized for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent address
|
|
||||||
guardian_address: Guardian address to verify
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if guardian is authorized
|
|
||||||
"""
|
|
||||||
if agent_address not in self.guardian_registrations:
|
|
||||||
return False
|
|
||||||
|
|
||||||
setup = self.guardian_registrations[agent_address]
|
|
||||||
all_guardians = [setup.primary_guardian] + setup.backup_guardians
|
|
||||||
|
|
||||||
return to_checksum_address(guardian_address) in [
|
|
||||||
to_checksum_address(g) for g in all_guardians
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_guardian_summary(self, agent_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Get guardian setup summary for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Guardian summary
|
|
||||||
"""
|
|
||||||
if agent_address not in self.guardian_registrations:
|
|
||||||
return {"error": "Agent not registered"}
|
|
||||||
|
|
||||||
setup = self.guardian_registrations[agent_address]
|
|
||||||
contract = self.guardian_contracts.get(agent_address)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"primary_guardian": setup.primary_guardian,
|
|
||||||
"backup_guardians": setup.backup_guardians,
|
|
||||||
"total_guardians": len(setup.backup_guardians) + 1,
|
|
||||||
"multisig_threshold": setup.multisig_threshold,
|
|
||||||
"emergency_contacts": setup.emergency_contacts,
|
|
||||||
"contract_status": contract.get_spending_status() if contract else None,
|
|
||||||
"pause_functional": contract is not None and len(setup.backup_guardians) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Fixed security configurations with proper guardians
|
|
||||||
def get_fixed_conservative_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
|
||||||
"""Get fixed conservative configuration with proper guardians"""
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=1000,
|
|
||||||
per_hour=5000,
|
|
||||||
per_day=20000,
|
|
||||||
per_week=100000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=5000,
|
|
||||||
delay_hours=24,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=[owner_address], # At least the owner
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_fixed_aggressive_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
|
||||||
"""Get fixed aggressive configuration with proper guardians"""
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=5000,
|
|
||||||
per_hour=25000,
|
|
||||||
per_day=100000,
|
|
||||||
per_week=500000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=20000,
|
|
||||||
delay_hours=12,
|
|
||||||
max_delay_hours=72
|
|
||||||
),
|
|
||||||
guardians=[owner_address], # At least the owner
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_fixed_high_security_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
|
||||||
"""Get fixed high security configuration with proper guardians"""
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=500,
|
|
||||||
per_hour=2000,
|
|
||||||
per_day=8000,
|
|
||||||
per_week=40000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=2000,
|
|
||||||
delay_hours=48,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=[owner_address], # At least the owner
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Global secure guardian manager
|
|
||||||
secure_guardian_manager = SecureGuardianManager()
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience function for secure agent registration
|
|
||||||
def register_agent_with_guardians(
|
|
||||||
agent_address: str,
|
|
||||||
owner_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
custom_guardians: Optional[List[str]] = None
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
|
||||||
Register an agent with proper guardian configuration
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
owner_address: Owner address
|
|
||||||
security_level: Security level
|
|
||||||
custom_guardians: Optional custom guardians
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Registration result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Create secure guardian contract
|
|
||||||
contract = secure_guardian_manager.create_secure_guardian_contract(
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
custom_guardians=custom_guardians
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get guardian summary
|
|
||||||
summary = secure_guardian_manager.get_guardian_summary(agent_address)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "registered",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": security_level,
|
|
||||||
"guardian_count": summary["total_guardians"],
|
|
||||||
"multisig_threshold": summary["multisig_threshold"],
|
|
||||||
"pause_functional": summary["pause_functional"],
|
|
||||||
"registered_at": datetime.utcnow().isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Registration failed: {str(e)}"
|
|
||||||
}
|
|
||||||
@@ -1,682 +0,0 @@
|
|||||||
"""
|
|
||||||
AITBC Guardian Contract - Spending Limit Protection for Agent Wallets
|
|
||||||
|
|
||||||
This contract implements a spending limit guardian that protects autonomous agent
|
|
||||||
wallets from unlimited spending in case of compromise. It provides:
|
|
||||||
- Per-transaction spending limits
|
|
||||||
- Per-period (daily/hourly) spending caps
|
|
||||||
- Time-lock for large withdrawals
|
|
||||||
- Emergency pause functionality
|
|
||||||
- Multi-signature recovery for critical operations
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sqlite3
|
|
||||||
from pathlib import Path
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address, keccak
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SpendingLimit:
|
|
||||||
"""Spending limit configuration"""
|
|
||||||
per_transaction: int # Maximum per transaction
|
|
||||||
per_hour: int # Maximum per hour
|
|
||||||
per_day: int # Maximum per day
|
|
||||||
per_week: int # Maximum per week
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TimeLockConfig:
|
|
||||||
"""Time lock configuration for large withdrawals"""
|
|
||||||
threshold: int # Amount that triggers time lock
|
|
||||||
delay_hours: int # Delay period in hours
|
|
||||||
max_delay_hours: int # Maximum delay period
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GuardianConfig:
|
|
||||||
"""Complete guardian configuration"""
|
|
||||||
limits: SpendingLimit
|
|
||||||
time_lock: TimeLockConfig
|
|
||||||
guardians: List[str] # Guardian addresses for recovery
|
|
||||||
pause_enabled: bool = True
|
|
||||||
emergency_mode: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class GuardianContract:
|
|
||||||
"""
|
|
||||||
Guardian contract implementation for agent wallet protection
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, agent_address: str, config: GuardianConfig, storage_path: str = None):
|
|
||||||
self.agent_address = to_checksum_address(agent_address)
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
# CRITICAL SECURITY FIX: Use persistent storage instead of in-memory
|
|
||||||
if storage_path is None:
|
|
||||||
storage_path = os.path.join(os.path.expanduser("~"), ".aitbc", "guardian_contracts")
|
|
||||||
|
|
||||||
self.storage_dir = Path(storage_path)
|
|
||||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Database file for this contract
|
|
||||||
self.db_path = self.storage_dir / f"guardian_{self.agent_address}.db"
|
|
||||||
|
|
||||||
# Initialize persistent storage
|
|
||||||
self._init_storage()
|
|
||||||
|
|
||||||
# Load state from storage
|
|
||||||
self._load_state()
|
|
||||||
|
|
||||||
# In-memory cache for performance (synced with storage)
|
|
||||||
self.spending_history: List[Dict] = []
|
|
||||||
self.pending_operations: Dict[str, Dict] = {}
|
|
||||||
self.paused = False
|
|
||||||
self.emergency_mode = False
|
|
||||||
|
|
||||||
# Contract state
|
|
||||||
self.nonce = 0
|
|
||||||
self.guardian_approvals: Dict[str, bool] = {}
|
|
||||||
|
|
||||||
# Load data from persistent storage
|
|
||||||
self._load_spending_history()
|
|
||||||
self._load_pending_operations()
|
|
||||||
|
|
||||||
def _init_storage(self):
|
|
||||||
"""Initialize SQLite database for persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute('''
|
|
||||||
CREATE TABLE IF NOT EXISTS spending_history (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
operation_id TEXT UNIQUE,
|
|
||||||
agent_address TEXT,
|
|
||||||
to_address TEXT,
|
|
||||||
amount INTEGER,
|
|
||||||
data TEXT,
|
|
||||||
timestamp TEXT,
|
|
||||||
executed_at TEXT,
|
|
||||||
status TEXT,
|
|
||||||
nonce INTEGER,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
''')
|
|
||||||
|
|
||||||
conn.execute('''
|
|
||||||
CREATE TABLE IF NOT EXISTS pending_operations (
|
|
||||||
operation_id TEXT PRIMARY KEY,
|
|
||||||
agent_address TEXT,
|
|
||||||
operation_data TEXT,
|
|
||||||
status TEXT,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
''')
|
|
||||||
|
|
||||||
conn.execute('''
|
|
||||||
CREATE TABLE IF NOT EXISTS contract_state (
|
|
||||||
agent_address TEXT PRIMARY KEY,
|
|
||||||
nonce INTEGER DEFAULT 0,
|
|
||||||
paused BOOLEAN DEFAULT 0,
|
|
||||||
emergency_mode BOOLEAN DEFAULT 0,
|
|
||||||
last_updated DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
''')
|
|
||||||
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _load_state(self):
|
|
||||||
"""Load contract state from persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
'SELECT nonce, paused, emergency_mode FROM contract_state WHERE agent_address = ?',
|
|
||||||
(self.agent_address,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
|
||||||
|
|
||||||
if row:
|
|
||||||
self.nonce, self.paused, self.emergency_mode = row
|
|
||||||
else:
|
|
||||||
# Initialize state for new contract
|
|
||||||
conn.execute(
|
|
||||||
'INSERT INTO contract_state (agent_address, nonce, paused, emergency_mode) VALUES (?, ?, ?, ?)',
|
|
||||||
(self.agent_address, 0, False, False)
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _save_state(self):
|
|
||||||
"""Save contract state to persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute(
|
|
||||||
'UPDATE contract_state SET nonce = ?, paused = ?, emergency_mode = ?, last_updated = CURRENT_TIMESTAMP WHERE agent_address = ?',
|
|
||||||
(self.nonce, self.paused, self.emergency_mode, self.agent_address)
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _load_spending_history(self):
|
|
||||||
"""Load spending history from persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
'SELECT operation_id, to_address, amount, data, timestamp, executed_at, status, nonce FROM spending_history WHERE agent_address = ? ORDER BY timestamp DESC',
|
|
||||||
(self.agent_address,)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.spending_history = []
|
|
||||||
for row in cursor:
|
|
||||||
self.spending_history.append({
|
|
||||||
"operation_id": row[0],
|
|
||||||
"to": row[1],
|
|
||||||
"amount": row[2],
|
|
||||||
"data": row[3],
|
|
||||||
"timestamp": row[4],
|
|
||||||
"executed_at": row[5],
|
|
||||||
"status": row[6],
|
|
||||||
"nonce": row[7]
|
|
||||||
})
|
|
||||||
|
|
||||||
def _save_spending_record(self, record: Dict):
|
|
||||||
"""Save spending record to persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute(
|
|
||||||
'''INSERT OR REPLACE INTO spending_history
|
|
||||||
(operation_id, agent_address, to_address, amount, data, timestamp, executed_at, status, nonce)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)''',
|
|
||||||
(
|
|
||||||
record["operation_id"],
|
|
||||||
self.agent_address,
|
|
||||||
record["to"],
|
|
||||||
record["amount"],
|
|
||||||
record.get("data", ""),
|
|
||||||
record["timestamp"],
|
|
||||||
record.get("executed_at", ""),
|
|
||||||
record["status"],
|
|
||||||
record["nonce"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _load_pending_operations(self):
|
|
||||||
"""Load pending operations from persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
'SELECT operation_id, operation_data, status FROM pending_operations WHERE agent_address = ?',
|
|
||||||
(self.agent_address,)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pending_operations = {}
|
|
||||||
for row in cursor:
|
|
||||||
operation_data = json.loads(row[1])
|
|
||||||
operation_data["status"] = row[2]
|
|
||||||
self.pending_operations[row[0]] = operation_data
|
|
||||||
|
|
||||||
def _save_pending_operation(self, operation_id: str, operation: Dict):
|
|
||||||
"""Save pending operation to persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute(
|
|
||||||
'''INSERT OR REPLACE INTO pending_operations
|
|
||||||
(operation_id, agent_address, operation_data, status, updated_at)
|
|
||||||
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)''',
|
|
||||||
(operation_id, self.agent_address, json.dumps(operation), operation["status"])
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _remove_pending_operation(self, operation_id: str):
|
|
||||||
"""Remove pending operation from persistent storage"""
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
conn.execute(
|
|
||||||
'DELETE FROM pending_operations WHERE operation_id = ? AND agent_address = ?',
|
|
||||||
(operation_id, self.agent_address)
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def _get_period_key(self, timestamp: datetime, period: str) -> str:
|
|
||||||
"""Generate period key for spending tracking"""
|
|
||||||
if period == "hour":
|
|
||||||
return timestamp.strftime("%Y-%m-%d-%H")
|
|
||||||
elif period == "day":
|
|
||||||
return timestamp.strftime("%Y-%m-%d")
|
|
||||||
elif period == "week":
|
|
||||||
# Get week number (Monday as first day)
|
|
||||||
week_num = timestamp.isocalendar()[1]
|
|
||||||
return f"{timestamp.year}-W{week_num:02d}"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid period: {period}")
|
|
||||||
|
|
||||||
def _get_spent_in_period(self, period: str, timestamp: datetime = None) -> int:
|
|
||||||
"""Calculate total spent in given period"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
period_key = self._get_period_key(timestamp, period)
|
|
||||||
|
|
||||||
total = 0
|
|
||||||
for record in self.spending_history:
|
|
||||||
record_time = datetime.fromisoformat(record["timestamp"])
|
|
||||||
record_period = self._get_period_key(record_time, period)
|
|
||||||
|
|
||||||
if record_period == period_key and record["status"] == "completed":
|
|
||||||
total += record["amount"]
|
|
||||||
|
|
||||||
return total
|
|
||||||
|
|
||||||
def _check_spending_limits(self, amount: int, timestamp: datetime = None) -> Tuple[bool, str]:
|
|
||||||
"""Check if amount exceeds spending limits"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
# Check per-transaction limit
|
|
||||||
if amount > self.config.limits.per_transaction:
|
|
||||||
return False, f"Amount {amount} exceeds per-transaction limit {self.config.limits.per_transaction}"
|
|
||||||
|
|
||||||
# Check per-hour limit
|
|
||||||
spent_hour = self._get_spent_in_period("hour", timestamp)
|
|
||||||
if spent_hour + amount > self.config.limits.per_hour:
|
|
||||||
return False, f"Hourly spending {spent_hour + amount} would exceed limit {self.config.limits.per_hour}"
|
|
||||||
|
|
||||||
# Check per-day limit
|
|
||||||
spent_day = self._get_spent_in_period("day", timestamp)
|
|
||||||
if spent_day + amount > self.config.limits.per_day:
|
|
||||||
return False, f"Daily spending {spent_day + amount} would exceed limit {self.config.limits.per_day}"
|
|
||||||
|
|
||||||
# Check per-week limit
|
|
||||||
spent_week = self._get_spent_in_period("week", timestamp)
|
|
||||||
if spent_week + amount > self.config.limits.per_week:
|
|
||||||
return False, f"Weekly spending {spent_week + amount} would exceed limit {self.config.limits.per_week}"
|
|
||||||
|
|
||||||
return True, "Spending limits check passed"
|
|
||||||
|
|
||||||
def _requires_time_lock(self, amount: int) -> bool:
|
|
||||||
"""Check if amount requires time lock"""
|
|
||||||
return amount >= self.config.time_lock.threshold
|
|
||||||
|
|
||||||
def _create_operation_hash(self, operation: Dict) -> str:
|
|
||||||
"""Create hash for operation identification"""
|
|
||||||
operation_str = json.dumps(operation, sort_keys=True)
|
|
||||||
return keccak(operation_str.encode()).hex()
|
|
||||||
|
|
||||||
def initiate_transaction(self, to_address: str, amount: int, data: str = "") -> Dict:
|
|
||||||
"""
|
|
||||||
Initiate a transaction with guardian protection
|
|
||||||
|
|
||||||
Args:
|
|
||||||
to_address: Recipient address
|
|
||||||
amount: Amount to transfer
|
|
||||||
data: Transaction data (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Operation result with status and details
|
|
||||||
"""
|
|
||||||
# Check if paused
|
|
||||||
if self.paused:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Guardian contract is paused",
|
|
||||||
"operation_id": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check emergency mode
|
|
||||||
if self.emergency_mode:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Emergency mode activated",
|
|
||||||
"operation_id": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate address
|
|
||||||
try:
|
|
||||||
to_address = to_checksum_address(to_address)
|
|
||||||
except Exception:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Invalid recipient address",
|
|
||||||
"operation_id": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check spending limits
|
|
||||||
limits_ok, limits_reason = self._check_spending_limits(amount)
|
|
||||||
if not limits_ok:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": limits_reason,
|
|
||||||
"operation_id": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create operation
|
|
||||||
operation = {
|
|
||||||
"type": "transaction",
|
|
||||||
"to": to_address,
|
|
||||||
"amount": amount,
|
|
||||||
"data": data,
|
|
||||||
"timestamp": datetime.utcnow().isoformat(),
|
|
||||||
"nonce": self.nonce,
|
|
||||||
"status": "pending"
|
|
||||||
}
|
|
||||||
|
|
||||||
operation_id = self._create_operation_hash(operation)
|
|
||||||
operation["operation_id"] = operation_id
|
|
||||||
|
|
||||||
# Check if time lock is required
|
|
||||||
if self._requires_time_lock(amount):
|
|
||||||
unlock_time = datetime.utcnow() + timedelta(hours=self.config.time_lock.delay_hours)
|
|
||||||
operation["unlock_time"] = unlock_time.isoformat()
|
|
||||||
operation["status"] = "time_locked"
|
|
||||||
|
|
||||||
# Store for later execution
|
|
||||||
self.pending_operations[operation_id] = operation
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "time_locked",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"unlock_time": unlock_time.isoformat(),
|
|
||||||
"delay_hours": self.config.time_lock.delay_hours,
|
|
||||||
"message": f"Transaction requires {self.config.time_lock.delay_hours}h time lock"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Immediate execution for smaller amounts
|
|
||||||
self.pending_operations[operation_id] = operation
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "approved",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"message": "Transaction approved for execution"
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute_transaction(self, operation_id: str, signature: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Execute a previously approved transaction
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation_id: Operation ID from initiate_transaction
|
|
||||||
signature: Transaction signature from agent
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Execution result
|
|
||||||
"""
|
|
||||||
if operation_id not in self.pending_operations:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Operation not found"
|
|
||||||
}
|
|
||||||
|
|
||||||
operation = self.pending_operations[operation_id]
|
|
||||||
|
|
||||||
# Check if operation is time locked
|
|
||||||
if operation["status"] == "time_locked":
|
|
||||||
unlock_time = datetime.fromisoformat(operation["unlock_time"])
|
|
||||||
if datetime.utcnow() < unlock_time:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Operation locked until {unlock_time.isoformat()}"
|
|
||||||
}
|
|
||||||
|
|
||||||
operation["status"] = "ready"
|
|
||||||
|
|
||||||
# Verify signature (simplified - in production, use proper verification)
|
|
||||||
try:
|
|
||||||
# In production, verify the signature matches the agent address
|
|
||||||
# For now, we'll assume signature is valid
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Invalid signature: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Record the transaction
|
|
||||||
record = {
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"to": operation["to"],
|
|
||||||
"amount": operation["amount"],
|
|
||||||
"data": operation.get("data", ""),
|
|
||||||
"timestamp": operation["timestamp"],
|
|
||||||
"executed_at": datetime.utcnow().isoformat(),
|
|
||||||
"status": "completed",
|
|
||||||
"nonce": operation["nonce"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# CRITICAL SECURITY FIX: Save to persistent storage
|
|
||||||
self._save_spending_record(record)
|
|
||||||
self.spending_history.append(record)
|
|
||||||
self.nonce += 1
|
|
||||||
self._save_state()
|
|
||||||
|
|
||||||
# Remove from pending storage
|
|
||||||
self._remove_pending_operation(operation_id)
|
|
||||||
if operation_id in self.pending_operations:
|
|
||||||
del self.pending_operations[operation_id]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "executed",
|
|
||||||
"operation_id": operation_id,
|
|
||||||
"transaction_hash": f"0x{keccak(f'{operation_id}{signature}'.encode()).hex()}",
|
|
||||||
"executed_at": record["executed_at"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def emergency_pause(self, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Emergency pause function (guardian only)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
guardian_address: Address of guardian initiating pause
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Pause result
|
|
||||||
"""
|
|
||||||
if guardian_address not in self.config.guardians:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Not authorized: guardian address not recognized"
|
|
||||||
}
|
|
||||||
|
|
||||||
self.paused = True
|
|
||||||
self.emergency_mode = True
|
|
||||||
|
|
||||||
# CRITICAL SECURITY FIX: Save state to persistent storage
|
|
||||||
self._save_state()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "paused",
|
|
||||||
"paused_at": datetime.utcnow().isoformat(),
|
|
||||||
"guardian": guardian_address,
|
|
||||||
"message": "Emergency pause activated - all operations halted"
|
|
||||||
}
|
|
||||||
|
|
||||||
def emergency_unpause(self, guardian_signatures: List[str]) -> Dict:
|
|
||||||
"""
|
|
||||||
Emergency unpause function (requires multiple guardian signatures)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
guardian_signatures: Signatures from required guardians
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Unpause result
|
|
||||||
"""
|
|
||||||
# In production, verify all guardian signatures
|
|
||||||
required_signatures = len(self.config.guardians)
|
|
||||||
if len(guardian_signatures) < required_signatures:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": f"Requires {required_signatures} guardian signatures, got {len(guardian_signatures)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Verify signatures (simplified)
|
|
||||||
# In production, verify each signature matches a guardian address
|
|
||||||
|
|
||||||
self.paused = False
|
|
||||||
self.emergency_mode = False
|
|
||||||
|
|
||||||
# CRITICAL SECURITY FIX: Save state to persistent storage
|
|
||||||
self._save_state()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "unpaused",
|
|
||||||
"unpaused_at": datetime.utcnow().isoformat(),
|
|
||||||
"message": "Emergency pause lifted - operations resumed"
|
|
||||||
}
|
|
||||||
|
|
||||||
def update_limits(self, new_limits: SpendingLimit, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Update spending limits (guardian only)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
new_limits: New spending limits
|
|
||||||
guardian_address: Address of guardian making the change
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Update result
|
|
||||||
"""
|
|
||||||
if guardian_address not in self.config.guardians:
|
|
||||||
return {
|
|
||||||
"status": "rejected",
|
|
||||||
"reason": "Not authorized: guardian address not recognized"
|
|
||||||
}
|
|
||||||
|
|
||||||
old_limits = self.config.limits
|
|
||||||
self.config.limits = new_limits
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "updated",
|
|
||||||
"old_limits": old_limits,
|
|
||||||
"new_limits": new_limits,
|
|
||||||
"updated_at": datetime.utcnow().isoformat(),
|
|
||||||
"guardian": guardian_address
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_spending_status(self) -> Dict:
|
|
||||||
"""Get current spending status and limits"""
|
|
||||||
now = datetime.utcnow()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"agent_address": self.agent_address,
|
|
||||||
"current_limits": self.config.limits,
|
|
||||||
"spent": {
|
|
||||||
"current_hour": self._get_spent_in_period("hour", now),
|
|
||||||
"current_day": self._get_spent_in_period("day", now),
|
|
||||||
"current_week": self._get_spent_in_period("week", now)
|
|
||||||
},
|
|
||||||
"remaining": {
|
|
||||||
"current_hour": self.config.limits.per_hour - self._get_spent_in_period("hour", now),
|
|
||||||
"current_day": self.config.limits.per_day - self._get_spent_in_period("day", now),
|
|
||||||
"current_week": self.config.limits.per_week - self._get_spent_in_period("week", now)
|
|
||||||
},
|
|
||||||
"pending_operations": len(self.pending_operations),
|
|
||||||
"paused": self.paused,
|
|
||||||
"emergency_mode": self.emergency_mode,
|
|
||||||
"nonce": self.nonce
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_operation_history(self, limit: int = 50) -> List[Dict]:
|
|
||||||
"""Get operation history"""
|
|
||||||
return sorted(self.spending_history, key=lambda x: x["timestamp"], reverse=True)[:limit]
|
|
||||||
|
|
||||||
def get_pending_operations(self) -> List[Dict]:
|
|
||||||
"""Get all pending operations"""
|
|
||||||
return list(self.pending_operations.values())
|
|
||||||
|
|
||||||
|
|
||||||
# Factory function for creating guardian contracts
|
|
||||||
def create_guardian_contract(
|
|
||||||
agent_address: str,
|
|
||||||
per_transaction: int = 1000,
|
|
||||||
per_hour: int = 5000,
|
|
||||||
per_day: int = 20000,
|
|
||||||
per_week: int = 100000,
|
|
||||||
time_lock_threshold: int = 10000,
|
|
||||||
time_lock_delay: int = 24,
|
|
||||||
guardians: List[str] = None
|
|
||||||
) -> GuardianContract:
|
|
||||||
"""
|
|
||||||
Create a guardian contract with default security parameters
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: The agent wallet address to protect
|
|
||||||
per_transaction: Maximum amount per transaction
|
|
||||||
per_hour: Maximum amount per hour
|
|
||||||
per_day: Maximum amount per day
|
|
||||||
per_week: Maximum amount per week
|
|
||||||
time_lock_threshold: Amount that triggers time lock
|
|
||||||
time_lock_delay: Time lock delay in hours
|
|
||||||
guardians: List of guardian addresses (REQUIRED for security)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured GuardianContract instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If no guardians are provided or guardians list is insufficient
|
|
||||||
"""
|
|
||||||
# CRITICAL SECURITY FIX: Require proper guardians, never default to agent address
|
|
||||||
if guardians is None or not guardians:
|
|
||||||
raise ValueError(
|
|
||||||
"❌ CRITICAL: Guardians are required for security. "
|
|
||||||
"Provide at least 3 trusted guardian addresses different from the agent address."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate that guardians are different from agent address
|
|
||||||
agent_checksum = to_checksum_address(agent_address)
|
|
||||||
guardian_checksums = [to_checksum_address(g) for g in guardians]
|
|
||||||
|
|
||||||
if agent_checksum in guardian_checksums:
|
|
||||||
raise ValueError(
|
|
||||||
"❌ CRITICAL: Agent address cannot be used as guardian. "
|
|
||||||
"Guardians must be independent trusted addresses."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Require minimum number of guardians for security
|
|
||||||
if len(guardian_checksums) < 3:
|
|
||||||
raise ValueError(
|
|
||||||
f"❌ CRITICAL: At least 3 guardians required for security, got {len(guardian_checksums)}. "
|
|
||||||
"Consider using a multi-sig wallet or trusted service providers."
|
|
||||||
)
|
|
||||||
|
|
||||||
limits = SpendingLimit(
|
|
||||||
per_transaction=per_transaction,
|
|
||||||
per_hour=per_hour,
|
|
||||||
per_day=per_day,
|
|
||||||
per_week=per_week
|
|
||||||
)
|
|
||||||
|
|
||||||
time_lock = TimeLockConfig(
|
|
||||||
threshold=time_lock_threshold,
|
|
||||||
delay_hours=time_lock_delay,
|
|
||||||
max_delay_hours=168 # 1 week max
|
|
||||||
)
|
|
||||||
|
|
||||||
config = GuardianConfig(
|
|
||||||
limits=limits,
|
|
||||||
time_lock=time_lock,
|
|
||||||
guardians=[to_checksum_address(g) for g in guardians]
|
|
||||||
)
|
|
||||||
|
|
||||||
return GuardianContract(agent_address, config)
|
|
||||||
|
|
||||||
|
|
||||||
# Example usage and security configurations
|
|
||||||
CONSERVATIVE_CONFIG = {
|
|
||||||
"per_transaction": 100, # $100 per transaction
|
|
||||||
"per_hour": 500, # $500 per hour
|
|
||||||
"per_day": 2000, # $2,000 per day
|
|
||||||
"per_week": 10000, # $10,000 per week
|
|
||||||
"time_lock_threshold": 1000, # Time lock over $1,000
|
|
||||||
"time_lock_delay": 24 # 24 hour delay
|
|
||||||
}
|
|
||||||
|
|
||||||
AGGRESSIVE_CONFIG = {
|
|
||||||
"per_transaction": 1000, # $1,000 per transaction
|
|
||||||
"per_hour": 5000, # $5,000 per hour
|
|
||||||
"per_day": 20000, # $20,000 per day
|
|
||||||
"per_week": 100000, # $100,000 per week
|
|
||||||
"time_lock_threshold": 10000, # Time lock over $10,000
|
|
||||||
"time_lock_delay": 12 # 12 hour delay
|
|
||||||
}
|
|
||||||
|
|
||||||
HIGH_SECURITY_CONFIG = {
|
|
||||||
"per_transaction": 50, # $50 per transaction
|
|
||||||
"per_hour": 200, # $200 per hour
|
|
||||||
"per_day": 1000, # $1,000 per day
|
|
||||||
"per_week": 5000, # $5,000 per week
|
|
||||||
"time_lock_threshold": 500, # Time lock over $500
|
|
||||||
"time_lock_delay": 48 # 48 hour delay
|
|
||||||
}
|
|
||||||
@@ -1,351 +0,0 @@
|
|||||||
"""
|
|
||||||
Gas Optimization System
|
|
||||||
Optimizes gas usage and fee efficiency for smart contracts
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
class OptimizationStrategy(Enum):
|
|
||||||
BATCH_OPERATIONS = "batch_operations"
|
|
||||||
LAZY_EVALUATION = "lazy_evaluation"
|
|
||||||
STATE_COMPRESSION = "state_compression"
|
|
||||||
EVENT_FILTERING = "event_filtering"
|
|
||||||
STORAGE_OPTIMIZATION = "storage_optimization"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GasMetric:
|
|
||||||
contract_address: str
|
|
||||||
function_name: str
|
|
||||||
gas_used: int
|
|
||||||
gas_limit: int
|
|
||||||
execution_time: float
|
|
||||||
timestamp: float
|
|
||||||
optimization_applied: Optional[str]
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OptimizationResult:
|
|
||||||
strategy: OptimizationStrategy
|
|
||||||
original_gas: int
|
|
||||||
optimized_gas: int
|
|
||||||
gas_savings: int
|
|
||||||
savings_percentage: float
|
|
||||||
implementation_cost: Decimal
|
|
||||||
net_benefit: Decimal
|
|
||||||
|
|
||||||
class GasOptimizer:
|
|
||||||
"""Optimizes gas usage for smart contracts"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.gas_metrics: List[GasMetric] = []
|
|
||||||
self.optimization_results: List[OptimizationResult] = []
|
|
||||||
self.optimization_strategies = self._initialize_strategies()
|
|
||||||
|
|
||||||
# Optimization parameters
|
|
||||||
self.min_optimization_threshold = 1000 # Minimum gas to consider optimization
|
|
||||||
self.optimization_target_savings = 0.1 # 10% minimum savings
|
|
||||||
self.max_optimization_cost = Decimal('0.01') # Maximum cost per optimization
|
|
||||||
self.metric_retention_period = 86400 * 7 # 7 days
|
|
||||||
|
|
||||||
# Gas price tracking
|
|
||||||
self.gas_price_history: List[Dict] = []
|
|
||||||
self.current_gas_price = Decimal('0.001')
|
|
||||||
|
|
||||||
def _initialize_strategies(self) -> Dict[OptimizationStrategy, Dict]:
|
|
||||||
"""Initialize optimization strategies"""
|
|
||||||
return {
|
|
||||||
OptimizationStrategy.BATCH_OPERATIONS: {
|
|
||||||
'description': 'Batch multiple operations into single transaction',
|
|
||||||
'potential_savings': 0.3, # 30% potential savings
|
|
||||||
'implementation_cost': Decimal('0.005'),
|
|
||||||
'applicable_functions': ['transfer', 'approve', 'mint']
|
|
||||||
},
|
|
||||||
OptimizationStrategy.LAZY_EVALUATION: {
|
|
||||||
'description': 'Defer expensive computations until needed',
|
|
||||||
'potential_savings': 0.2, # 20% potential savings
|
|
||||||
'implementation_cost': Decimal('0.003'),
|
|
||||||
'applicable_functions': ['calculate', 'validate', 'process']
|
|
||||||
},
|
|
||||||
OptimizationStrategy.STATE_COMPRESSION: {
|
|
||||||
'description': 'Compress state data to reduce storage costs',
|
|
||||||
'potential_savings': 0.4, # 40% potential savings
|
|
||||||
'implementation_cost': Decimal('0.008'),
|
|
||||||
'applicable_functions': ['store', 'update', 'save']
|
|
||||||
},
|
|
||||||
OptimizationStrategy.EVENT_FILTERING: {
|
|
||||||
'description': 'Filter events to reduce emission costs',
|
|
||||||
'potential_savings': 0.15, # 15% potential savings
|
|
||||||
'implementation_cost': Decimal('0.002'),
|
|
||||||
'applicable_functions': ['emit', 'log', 'notify']
|
|
||||||
},
|
|
||||||
OptimizationStrategy.STORAGE_OPTIMIZATION: {
|
|
||||||
'description': 'Optimize storage patterns and data structures',
|
|
||||||
'potential_savings': 0.25, # 25% potential savings
|
|
||||||
'implementation_cost': Decimal('0.006'),
|
|
||||||
'applicable_functions': ['set', 'add', 'remove']
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async def record_gas_usage(self, contract_address: str, function_name: str,
|
|
||||||
gas_used: int, gas_limit: int, execution_time: float,
|
|
||||||
optimization_applied: Optional[str] = None):
|
|
||||||
"""Record gas usage metrics"""
|
|
||||||
metric = GasMetric(
|
|
||||||
contract_address=contract_address,
|
|
||||||
function_name=function_name,
|
|
||||||
gas_used=gas_used,
|
|
||||||
gas_limit=gas_limit,
|
|
||||||
execution_time=execution_time,
|
|
||||||
timestamp=time.time(),
|
|
||||||
optimization_applied=optimization_applied
|
|
||||||
)
|
|
||||||
|
|
||||||
self.gas_metrics.append(metric)
|
|
||||||
|
|
||||||
# Limit history size
|
|
||||||
if len(self.gas_metrics) > 10000:
|
|
||||||
self.gas_metrics = self.gas_metrics[-5000]
|
|
||||||
|
|
||||||
# Trigger optimization analysis if threshold met
|
|
||||||
if gas_used >= self.min_optimization_threshold:
|
|
||||||
asyncio.create_task(self._analyze_optimization_opportunity(metric))
|
|
||||||
|
|
||||||
async def _analyze_optimization_opportunity(self, metric: GasMetric):
|
|
||||||
"""Analyze if optimization is beneficial"""
|
|
||||||
# Get historical average for this function
|
|
||||||
historical_metrics = [
|
|
||||||
m for m in self.gas_metrics
|
|
||||||
if m.function_name == metric.function_name and
|
|
||||||
m.contract_address == metric.contract_address and
|
|
||||||
not m.optimization_applied
|
|
||||||
]
|
|
||||||
|
|
||||||
if len(historical_metrics) < 5: # Need sufficient history
|
|
||||||
return
|
|
||||||
|
|
||||||
avg_gas = sum(m.gas_used for m in historical_metrics) / len(historical_metrics)
|
|
||||||
|
|
||||||
# Test each optimization strategy
|
|
||||||
for strategy, config in self.optimization_strategies.items():
|
|
||||||
if self._is_strategy_applicable(strategy, metric.function_name):
|
|
||||||
potential_savings = avg_gas * config['potential_savings']
|
|
||||||
|
|
||||||
if potential_savings >= self.min_optimization_threshold:
|
|
||||||
# Calculate net benefit
|
|
||||||
gas_price = self.current_gas_price
|
|
||||||
gas_savings_value = potential_savings * gas_price
|
|
||||||
net_benefit = gas_savings_value - config['implementation_cost']
|
|
||||||
|
|
||||||
if net_benefit > 0:
|
|
||||||
# Create optimization result
|
|
||||||
result = OptimizationResult(
|
|
||||||
strategy=strategy,
|
|
||||||
original_gas=int(avg_gas),
|
|
||||||
optimized_gas=int(avg_gas - potential_savings),
|
|
||||||
gas_savings=int(potential_savings),
|
|
||||||
savings_percentage=config['potential_savings'],
|
|
||||||
implementation_cost=config['implementation_cost'],
|
|
||||||
net_benefit=net_benefit
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimization_results.append(result)
|
|
||||||
|
|
||||||
# Keep only recent results
|
|
||||||
if len(self.optimization_results) > 1000:
|
|
||||||
self.optimization_results = self.optimization_results[-500]
|
|
||||||
|
|
||||||
log_info(f"Optimization opportunity found: {strategy.value} for {metric.function_name} - Potential savings: {potential_savings} gas")
|
|
||||||
|
|
||||||
def _is_strategy_applicable(self, strategy: OptimizationStrategy, function_name: str) -> bool:
|
|
||||||
"""Check if optimization strategy is applicable to function"""
|
|
||||||
config = self.optimization_strategies.get(strategy, {})
|
|
||||||
applicable_functions = config.get('applicable_functions', [])
|
|
||||||
|
|
||||||
# Check if function name contains any applicable keywords
|
|
||||||
for applicable in applicable_functions:
|
|
||||||
if applicable.lower() in function_name.lower():
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def apply_optimization(self, contract_address: str, function_name: str,
|
|
||||||
strategy: OptimizationStrategy) -> Tuple[bool, str]:
|
|
||||||
"""Apply optimization strategy to contract function"""
|
|
||||||
try:
|
|
||||||
# Validate strategy
|
|
||||||
if strategy not in self.optimization_strategies:
|
|
||||||
return False, "Unknown optimization strategy"
|
|
||||||
|
|
||||||
# Check applicability
|
|
||||||
if not self._is_strategy_applicable(strategy, function_name):
|
|
||||||
return False, "Strategy not applicable to this function"
|
|
||||||
|
|
||||||
# Get optimization result
|
|
||||||
result = None
|
|
||||||
for res in self.optimization_results:
|
|
||||||
if (res.strategy == strategy and
|
|
||||||
res.strategy in self.optimization_strategies):
|
|
||||||
result = res
|
|
||||||
break
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
return False, "No optimization analysis available"
|
|
||||||
|
|
||||||
# Check if net benefit is positive
|
|
||||||
if result.net_benefit <= 0:
|
|
||||||
return False, "Optimization not cost-effective"
|
|
||||||
|
|
||||||
# Apply optimization (in real implementation, this would modify contract code)
|
|
||||||
success = await self._implement_optimization(contract_address, function_name, strategy)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
# Record optimization
|
|
||||||
await self.record_gas_usage(
|
|
||||||
contract_address, function_name, result.optimized_gas,
|
|
||||||
result.optimized_gas, 0.0, strategy.value
|
|
||||||
)
|
|
||||||
|
|
||||||
log_info(f"Optimization applied: {strategy.value} to {function_name}")
|
|
||||||
return True, f"Optimization applied successfully. Gas savings: {result.gas_savings}"
|
|
||||||
else:
|
|
||||||
return False, "Optimization implementation failed"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return False, f"Optimization error: {str(e)}"
|
|
||||||
|
|
||||||
async def _implement_optimization(self, contract_address: str, function_name: str,
|
|
||||||
strategy: OptimizationStrategy) -> bool:
|
|
||||||
"""Implement the optimization strategy"""
|
|
||||||
try:
|
|
||||||
# In real implementation, this would:
|
|
||||||
# 1. Analyze contract bytecode
|
|
||||||
# 2. Apply optimization patterns
|
|
||||||
# 3. Generate optimized bytecode
|
|
||||||
# 4. Deploy optimized version
|
|
||||||
# 5. Verify functionality
|
|
||||||
|
|
||||||
# Simulate implementation
|
|
||||||
await asyncio.sleep(2) # Simulate optimization time
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Optimization implementation error: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def update_gas_price(self, new_price: Decimal):
|
|
||||||
"""Update current gas price"""
|
|
||||||
self.current_gas_price = new_price
|
|
||||||
|
|
||||||
# Record price history
|
|
||||||
self.gas_price_history.append({
|
|
||||||
'price': float(new_price),
|
|
||||||
'timestamp': time.time()
|
|
||||||
})
|
|
||||||
|
|
||||||
# Limit history size
|
|
||||||
if len(self.gas_price_history) > 1000:
|
|
||||||
self.gas_price_history = self.gas_price_history[-500]
|
|
||||||
|
|
||||||
# Re-evaluate optimization opportunities with new price
|
|
||||||
asyncio.create_task(self._reevaluate_optimizations())
|
|
||||||
|
|
||||||
async def _reevaluate_optimizations(self):
|
|
||||||
"""Re-evaluate optimization opportunities with new gas price"""
|
|
||||||
# Clear old results and re-analyze
|
|
||||||
self.optimization_results.clear()
|
|
||||||
|
|
||||||
# Re-analyze recent metrics
|
|
||||||
recent_metrics = [
|
|
||||||
m for m in self.gas_metrics
|
|
||||||
if time.time() - m.timestamp < 3600 # Last hour
|
|
||||||
]
|
|
||||||
|
|
||||||
for metric in recent_metrics:
|
|
||||||
if metric.gas_used >= self.min_optimization_threshold:
|
|
||||||
await self._analyze_optimization_opportunity(metric)
|
|
||||||
|
|
||||||
async def get_optimization_recommendations(self, contract_address: Optional[str] = None,
|
|
||||||
limit: int = 10) -> List[Dict]:
|
|
||||||
"""Get optimization recommendations"""
|
|
||||||
recommendations = []
|
|
||||||
|
|
||||||
for result in self.optimization_results:
|
|
||||||
if contract_address and result.strategy.value not in self.optimization_strategies:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if result.net_benefit > 0:
|
|
||||||
recommendations.append({
|
|
||||||
'strategy': result.strategy.value,
|
|
||||||
'function': 'contract_function', # Would map to actual function
|
|
||||||
'original_gas': result.original_gas,
|
|
||||||
'optimized_gas': result.optimized_gas,
|
|
||||||
'gas_savings': result.gas_savings,
|
|
||||||
'savings_percentage': result.savings_percentage,
|
|
||||||
'net_benefit': float(result.net_benefit),
|
|
||||||
'implementation_cost': float(result.implementation_cost)
|
|
||||||
})
|
|
||||||
|
|
||||||
# Sort by net benefit
|
|
||||||
recommendations.sort(key=lambda x: x['net_benefit'], reverse=True)
|
|
||||||
|
|
||||||
return recommendations[:limit]
|
|
||||||
|
|
||||||
async def get_gas_statistics(self) -> Dict:
|
|
||||||
"""Get gas usage statistics"""
|
|
||||||
if not self.gas_metrics:
|
|
||||||
return {
|
|
||||||
'total_transactions': 0,
|
|
||||||
'average_gas_used': 0,
|
|
||||||
'total_gas_used': 0,
|
|
||||||
'gas_efficiency': 0,
|
|
||||||
'optimization_opportunities': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
total_transactions = len(self.gas_metrics)
|
|
||||||
total_gas_used = sum(m.gas_used for m in self.gas_metrics)
|
|
||||||
average_gas_used = total_gas_used / total_transactions
|
|
||||||
|
|
||||||
# Calculate efficiency (gas used vs gas limit)
|
|
||||||
efficiency_scores = [
|
|
||||||
m.gas_used / m.gas_limit for m in self.gas_metrics
|
|
||||||
if m.gas_limit > 0
|
|
||||||
]
|
|
||||||
avg_efficiency = sum(efficiency_scores) / len(efficiency_scores) if efficiency_scores else 0
|
|
||||||
|
|
||||||
# Optimization opportunities
|
|
||||||
optimization_count = len([
|
|
||||||
result for result in self.optimization_results
|
|
||||||
if result.net_benefit > 0
|
|
||||||
])
|
|
||||||
|
|
||||||
return {
|
|
||||||
'total_transactions': total_transactions,
|
|
||||||
'average_gas_used': average_gas_used,
|
|
||||||
'total_gas_used': total_gas_used,
|
|
||||||
'gas_efficiency': avg_efficiency,
|
|
||||||
'optimization_opportunities': optimization_count,
|
|
||||||
'current_gas_price': float(self.current_gas_price),
|
|
||||||
'total_optimizations_applied': len([
|
|
||||||
m for m in self.gas_metrics
|
|
||||||
if m.optimization_applied
|
|
||||||
])
|
|
||||||
}
|
|
||||||
|
|
||||||
# Global gas optimizer
|
|
||||||
gas_optimizer: Optional[GasOptimizer] = None
|
|
||||||
|
|
||||||
def get_gas_optimizer() -> Optional[GasOptimizer]:
|
|
||||||
"""Get global gas optimizer"""
|
|
||||||
return gas_optimizer
|
|
||||||
|
|
||||||
def create_gas_optimizer() -> GasOptimizer:
|
|
||||||
"""Create and set global gas optimizer"""
|
|
||||||
global gas_optimizer
|
|
||||||
gas_optimizer = GasOptimizer()
|
|
||||||
return gas_optimizer
|
|
||||||
@@ -1,470 +0,0 @@
|
|||||||
"""
|
|
||||||
Persistent Spending Tracker - Database-Backed Security
|
|
||||||
Fixes the critical vulnerability where spending limits were lost on restart
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from sqlalchemy import create_engine, Column, String, Integer, Float, DateTime, Index
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
from sqlalchemy.orm import sessionmaker, Session
|
|
||||||
from eth_utils import to_checksum_address
|
|
||||||
import json
|
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
class SpendingRecord(Base):
|
|
||||||
"""Database model for spending tracking"""
|
|
||||||
__tablename__ = "spending_records"
|
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
|
||||||
agent_address = Column(String, index=True)
|
|
||||||
period_type = Column(String, index=True) # hour, day, week
|
|
||||||
period_key = Column(String, index=True)
|
|
||||||
amount = Column(Float)
|
|
||||||
transaction_hash = Column(String)
|
|
||||||
timestamp = Column(DateTime, default=datetime.utcnow)
|
|
||||||
|
|
||||||
# Composite indexes for performance
|
|
||||||
__table_args__ = (
|
|
||||||
Index('idx_agent_period', 'agent_address', 'period_type', 'period_key'),
|
|
||||||
Index('idx_timestamp', 'timestamp'),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SpendingLimit(Base):
|
|
||||||
"""Database model for spending limits"""
|
|
||||||
__tablename__ = "spending_limits"
|
|
||||||
|
|
||||||
agent_address = Column(String, primary_key=True)
|
|
||||||
per_transaction = Column(Float)
|
|
||||||
per_hour = Column(Float)
|
|
||||||
per_day = Column(Float)
|
|
||||||
per_week = Column(Float)
|
|
||||||
time_lock_threshold = Column(Float)
|
|
||||||
time_lock_delay_hours = Column(Integer)
|
|
||||||
updated_at = Column(DateTime, default=datetime.utcnow)
|
|
||||||
updated_by = Column(String) # Guardian who updated
|
|
||||||
|
|
||||||
|
|
||||||
class GuardianAuthorization(Base):
|
|
||||||
"""Database model for guardian authorizations"""
|
|
||||||
__tablename__ = "guardian_authorizations"
|
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
|
||||||
agent_address = Column(String, index=True)
|
|
||||||
guardian_address = Column(String, index=True)
|
|
||||||
is_active = Column(Boolean, default=True)
|
|
||||||
added_at = Column(DateTime, default=datetime.utcnow)
|
|
||||||
added_by = Column(String)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SpendingCheckResult:
|
|
||||||
"""Result of spending limit check"""
|
|
||||||
allowed: bool
|
|
||||||
reason: str
|
|
||||||
current_spent: Dict[str, float]
|
|
||||||
remaining: Dict[str, float]
|
|
||||||
requires_time_lock: bool
|
|
||||||
time_lock_until: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
class PersistentSpendingTracker:
|
|
||||||
"""
|
|
||||||
Database-backed spending tracker that survives restarts
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, database_url: str = "sqlite:///spending_tracker.db"):
|
|
||||||
self.engine = create_engine(database_url)
|
|
||||||
Base.metadata.create_all(self.engine)
|
|
||||||
self.SessionLocal = sessionmaker(bind=self.engine)
|
|
||||||
|
|
||||||
def get_session(self) -> Session:
|
|
||||||
"""Get database session"""
|
|
||||||
return self.SessionLocal()
|
|
||||||
|
|
||||||
def _get_period_key(self, timestamp: datetime, period: str) -> str:
|
|
||||||
"""Generate period key for spending tracking"""
|
|
||||||
if period == "hour":
|
|
||||||
return timestamp.strftime("%Y-%m-%d-%H")
|
|
||||||
elif period == "day":
|
|
||||||
return timestamp.strftime("%Y-%m-%d")
|
|
||||||
elif period == "week":
|
|
||||||
# Get week number (Monday as first day)
|
|
||||||
week_num = timestamp.isocalendar()[1]
|
|
||||||
return f"{timestamp.year}-W{week_num:02d}"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid period: {period}")
|
|
||||||
|
|
||||||
def get_spent_in_period(self, agent_address: str, period: str, timestamp: datetime = None) -> float:
|
|
||||||
"""
|
|
||||||
Get total spent in given period from database
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
period: Period type (hour, day, week)
|
|
||||||
timestamp: Timestamp to check (default: now)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Total amount spent in period
|
|
||||||
"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
period_key = self._get_period_key(timestamp, period)
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
with self.get_session() as session:
|
|
||||||
total = session.query(SpendingRecord).filter(
|
|
||||||
SpendingRecord.agent_address == agent_address,
|
|
||||||
SpendingRecord.period_type == period,
|
|
||||||
SpendingRecord.period_key == period_key
|
|
||||||
).with_entities(SpendingRecord.amount).all()
|
|
||||||
|
|
||||||
return sum(record.amount for record in total)
|
|
||||||
|
|
||||||
def record_spending(self, agent_address: str, amount: float, transaction_hash: str, timestamp: datetime = None) -> bool:
|
|
||||||
"""
|
|
||||||
Record a spending transaction in the database
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
amount: Amount spent
|
|
||||||
transaction_hash: Transaction hash
|
|
||||||
timestamp: Transaction timestamp (default: now)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if recorded successfully
|
|
||||||
"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.get_session() as session:
|
|
||||||
# Record for all periods
|
|
||||||
periods = ["hour", "day", "week"]
|
|
||||||
|
|
||||||
for period in periods:
|
|
||||||
period_key = self._get_period_key(timestamp, period)
|
|
||||||
|
|
||||||
record = SpendingRecord(
|
|
||||||
id=f"{transaction_hash}_{period}",
|
|
||||||
agent_address=agent_address,
|
|
||||||
period_type=period,
|
|
||||||
period_key=period_key,
|
|
||||||
amount=amount,
|
|
||||||
transaction_hash=transaction_hash,
|
|
||||||
timestamp=timestamp
|
|
||||||
)
|
|
||||||
|
|
||||||
session.add(record)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to record spending: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def check_spending_limits(self, agent_address: str, amount: float, timestamp: datetime = None) -> SpendingCheckResult:
|
|
||||||
"""
|
|
||||||
Check if amount exceeds spending limits using persistent data
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
amount: Amount to check
|
|
||||||
timestamp: Timestamp for check (default: now)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Spending check result
|
|
||||||
"""
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
# Get spending limits from database
|
|
||||||
with self.get_session() as session:
|
|
||||||
limits = session.query(SpendingLimit).filter(
|
|
||||||
SpendingLimit.agent_address == agent_address
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not limits:
|
|
||||||
# Default limits if not set
|
|
||||||
limits = SpendingLimit(
|
|
||||||
agent_address=agent_address,
|
|
||||||
per_transaction=1000.0,
|
|
||||||
per_hour=5000.0,
|
|
||||||
per_day=20000.0,
|
|
||||||
per_week=100000.0,
|
|
||||||
time_lock_threshold=5000.0,
|
|
||||||
time_lock_delay_hours=24
|
|
||||||
)
|
|
||||||
session.add(limits)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# Check each limit
|
|
||||||
current_spent = {}
|
|
||||||
remaining = {}
|
|
||||||
|
|
||||||
# Per-transaction limit
|
|
||||||
if amount > limits.per_transaction:
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=False,
|
|
||||||
reason=f"Amount {amount} exceeds per-transaction limit {limits.per_transaction}",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Per-hour limit
|
|
||||||
spent_hour = self.get_spent_in_period(agent_address, "hour", timestamp)
|
|
||||||
current_spent["hour"] = spent_hour
|
|
||||||
remaining["hour"] = limits.per_hour - spent_hour
|
|
||||||
|
|
||||||
if spent_hour + amount > limits.per_hour:
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=False,
|
|
||||||
reason=f"Hourly spending {spent_hour + amount} would exceed limit {limits.per_hour}",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Per-day limit
|
|
||||||
spent_day = self.get_spent_in_period(agent_address, "day", timestamp)
|
|
||||||
current_spent["day"] = spent_day
|
|
||||||
remaining["day"] = limits.per_day - spent_day
|
|
||||||
|
|
||||||
if spent_day + amount > limits.per_day:
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=False,
|
|
||||||
reason=f"Daily spending {spent_day + amount} would exceed limit {limits.per_day}",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Per-week limit
|
|
||||||
spent_week = self.get_spent_in_period(agent_address, "week", timestamp)
|
|
||||||
current_spent["week"] = spent_week
|
|
||||||
remaining["week"] = limits.per_week - spent_week
|
|
||||||
|
|
||||||
if spent_week + amount > limits.per_week:
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=False,
|
|
||||||
reason=f"Weekly spending {spent_week + amount} would exceed limit {limits.per_week}",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check time lock requirement
|
|
||||||
requires_time_lock = amount >= limits.time_lock_threshold
|
|
||||||
time_lock_until = None
|
|
||||||
|
|
||||||
if requires_time_lock:
|
|
||||||
time_lock_until = timestamp + timedelta(hours=limits.time_lock_delay_hours)
|
|
||||||
|
|
||||||
return SpendingCheckResult(
|
|
||||||
allowed=True,
|
|
||||||
reason="Spending limits check passed",
|
|
||||||
current_spent=current_spent,
|
|
||||||
remaining=remaining,
|
|
||||||
requires_time_lock=requires_time_lock,
|
|
||||||
time_lock_until=time_lock_until
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_spending_limits(self, agent_address: str, new_limits: Dict, guardian_address: str) -> bool:
|
|
||||||
"""
|
|
||||||
Update spending limits for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
new_limits: New spending limits
|
|
||||||
guardian_address: Guardian making the change
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if updated successfully
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
# Verify guardian authorization
|
|
||||||
if not self.is_guardian_authorized(agent_address, guardian_address):
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.get_session() as session:
|
|
||||||
limits = session.query(SpendingLimit).filter(
|
|
||||||
SpendingLimit.agent_address == agent_address
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if limits:
|
|
||||||
limits.per_transaction = new_limits.get("per_transaction", limits.per_transaction)
|
|
||||||
limits.per_hour = new_limits.get("per_hour", limits.per_hour)
|
|
||||||
limits.per_day = new_limits.get("per_day", limits.per_day)
|
|
||||||
limits.per_week = new_limits.get("per_week", limits.per_week)
|
|
||||||
limits.time_lock_threshold = new_limits.get("time_lock_threshold", limits.time_lock_threshold)
|
|
||||||
limits.time_lock_delay_hours = new_limits.get("time_lock_delay_hours", limits.time_lock_delay_hours)
|
|
||||||
limits.updated_at = datetime.utcnow()
|
|
||||||
limits.updated_by = guardian_address
|
|
||||||
else:
|
|
||||||
limits = SpendingLimit(
|
|
||||||
agent_address=agent_address,
|
|
||||||
per_transaction=new_limits.get("per_transaction", 1000.0),
|
|
||||||
per_hour=new_limits.get("per_hour", 5000.0),
|
|
||||||
per_day=new_limits.get("per_day", 20000.0),
|
|
||||||
per_week=new_limits.get("per_week", 100000.0),
|
|
||||||
time_lock_threshold=new_limits.get("time_lock_threshold", 5000.0),
|
|
||||||
time_lock_delay_hours=new_limits.get("time_lock_delay_hours", 24),
|
|
||||||
updated_at=datetime.utcnow(),
|
|
||||||
updated_by=guardian_address
|
|
||||||
)
|
|
||||||
session.add(limits)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to update spending limits: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def add_guardian(self, agent_address: str, guardian_address: str, added_by: str) -> bool:
|
|
||||||
"""
|
|
||||||
Add a guardian for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address
|
|
||||||
added_by: Who added this guardian
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if added successfully
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
added_by = to_checksum_address(added_by)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.get_session() as session:
|
|
||||||
# Check if already exists
|
|
||||||
existing = session.query(GuardianAuthorization).filter(
|
|
||||||
GuardianAuthorization.agent_address == agent_address,
|
|
||||||
GuardianAuthorization.guardian_address == guardian_address
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
existing.is_active = True
|
|
||||||
existing.added_at = datetime.utcnow()
|
|
||||||
existing.added_by = added_by
|
|
||||||
else:
|
|
||||||
auth = GuardianAuthorization(
|
|
||||||
id=f"{agent_address}_{guardian_address}",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address,
|
|
||||||
is_active=True,
|
|
||||||
added_at=datetime.utcnow(),
|
|
||||||
added_by=added_by
|
|
||||||
)
|
|
||||||
session.add(auth)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to add guardian: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_guardian_authorized(self, agent_address: str, guardian_address: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a guardian is authorized for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if authorized
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
with self.get_session() as session:
|
|
||||||
auth = session.query(GuardianAuthorization).filter(
|
|
||||||
GuardianAuthorization.agent_address == agent_address,
|
|
||||||
GuardianAuthorization.guardian_address == guardian_address,
|
|
||||||
GuardianAuthorization.is_active == True
|
|
||||||
).first()
|
|
||||||
|
|
||||||
return auth is not None
|
|
||||||
|
|
||||||
def get_spending_summary(self, agent_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Get comprehensive spending summary for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Spending summary
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
now = datetime.utcnow()
|
|
||||||
|
|
||||||
# Get current spending
|
|
||||||
current_spent = {
|
|
||||||
"hour": self.get_spent_in_period(agent_address, "hour", now),
|
|
||||||
"day": self.get_spent_in_period(agent_address, "day", now),
|
|
||||||
"week": self.get_spent_in_period(agent_address, "week", now)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get limits
|
|
||||||
with self.get_session() as session:
|
|
||||||
limits = session.query(SpendingLimit).filter(
|
|
||||||
SpendingLimit.agent_address == agent_address
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not limits:
|
|
||||||
return {"error": "No spending limits set"}
|
|
||||||
|
|
||||||
# Calculate remaining
|
|
||||||
remaining = {
|
|
||||||
"hour": limits.per_hour - current_spent["hour"],
|
|
||||||
"day": limits.per_day - current_spent["day"],
|
|
||||||
"week": limits.per_week - current_spent["week"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get authorized guardians
|
|
||||||
with self.get_session() as session:
|
|
||||||
guardians = session.query(GuardianAuthorization).filter(
|
|
||||||
GuardianAuthorization.agent_address == agent_address,
|
|
||||||
GuardianAuthorization.is_active == True
|
|
||||||
).all()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"current_spending": current_spent,
|
|
||||||
"remaining_spending": remaining,
|
|
||||||
"limits": {
|
|
||||||
"per_transaction": limits.per_transaction,
|
|
||||||
"per_hour": limits.per_hour,
|
|
||||||
"per_day": limits.per_day,
|
|
||||||
"per_week": limits.per_week
|
|
||||||
},
|
|
||||||
"time_lock": {
|
|
||||||
"threshold": limits.time_lock_threshold,
|
|
||||||
"delay_hours": limits.time_lock_delay_hours
|
|
||||||
},
|
|
||||||
"authorized_guardians": [g.guardian_address for g in guardians],
|
|
||||||
"last_updated": limits.updated_at.isoformat() if limits.updated_at else None
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Global persistent tracker instance
|
|
||||||
persistent_tracker = PersistentSpendingTracker()
|
|
||||||
@@ -1,542 +0,0 @@
|
|||||||
"""
|
|
||||||
Contract Upgrade System
|
|
||||||
Handles safe contract versioning and upgrade mechanisms
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from typing import Dict, List, Optional, Tuple, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
class UpgradeStatus(Enum):
|
|
||||||
PROPOSED = "proposed"
|
|
||||||
APPROVED = "approved"
|
|
||||||
REJECTED = "rejected"
|
|
||||||
EXECUTED = "executed"
|
|
||||||
FAILED = "failed"
|
|
||||||
ROLLED_BACK = "rolled_back"
|
|
||||||
|
|
||||||
class UpgradeType(Enum):
|
|
||||||
PARAMETER_CHANGE = "parameter_change"
|
|
||||||
LOGIC_UPDATE = "logic_update"
|
|
||||||
SECURITY_PATCH = "security_patch"
|
|
||||||
FEATURE_ADDITION = "feature_addition"
|
|
||||||
EMERGENCY_FIX = "emergency_fix"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ContractVersion:
|
|
||||||
version: str
|
|
||||||
address: str
|
|
||||||
deployed_at: float
|
|
||||||
total_contracts: int
|
|
||||||
total_value: Decimal
|
|
||||||
is_active: bool
|
|
||||||
metadata: Dict
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class UpgradeProposal:
|
|
||||||
proposal_id: str
|
|
||||||
contract_type: str
|
|
||||||
current_version: str
|
|
||||||
new_version: str
|
|
||||||
upgrade_type: UpgradeType
|
|
||||||
description: str
|
|
||||||
changes: Dict
|
|
||||||
voting_deadline: float
|
|
||||||
execution_deadline: float
|
|
||||||
status: UpgradeStatus
|
|
||||||
votes: Dict[str, bool]
|
|
||||||
total_votes: int
|
|
||||||
yes_votes: int
|
|
||||||
no_votes: int
|
|
||||||
required_approval: float
|
|
||||||
created_at: float
|
|
||||||
proposer: str
|
|
||||||
executed_at: Optional[float]
|
|
||||||
rollback_data: Optional[Dict]
|
|
||||||
|
|
||||||
class ContractUpgradeManager:
|
|
||||||
"""Manages contract upgrades and versioning"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.contract_versions: Dict[str, List[ContractVersion]] = {} # contract_type -> versions
|
|
||||||
self.active_versions: Dict[str, str] = {} # contract_type -> active version
|
|
||||||
self.upgrade_proposals: Dict[str, UpgradeProposal] = {}
|
|
||||||
self.upgrade_history: List[Dict] = []
|
|
||||||
|
|
||||||
# Upgrade parameters
|
|
||||||
self.min_voting_period = 86400 * 3 # 3 days
|
|
||||||
self.max_voting_period = 86400 * 7 # 7 days
|
|
||||||
self.required_approval_rate = 0.6 # 60% approval required
|
|
||||||
self.min_participation_rate = 0.3 # 30% minimum participation
|
|
||||||
self.emergency_upgrade_threshold = 0.8 # 80% for emergency upgrades
|
|
||||||
self.rollback_timeout = 86400 * 7 # 7 days to rollback
|
|
||||||
|
|
||||||
# Governance
|
|
||||||
self.governance_addresses: Set[str] = set()
|
|
||||||
self.stake_weights: Dict[str, Decimal] = {}
|
|
||||||
|
|
||||||
# Initialize governance
|
|
||||||
self._initialize_governance()
|
|
||||||
|
|
||||||
def _initialize_governance(self):
|
|
||||||
"""Initialize governance addresses"""
|
|
||||||
# In real implementation, this would load from blockchain state
|
|
||||||
# For now, use default governance addresses
|
|
||||||
governance_addresses = [
|
|
||||||
"0xgovernance1111111111111111111111111111111111111",
|
|
||||||
"0xgovernance2222222222222222222222222222222222222",
|
|
||||||
"0xgovernance3333333333333333333333333333333333333"
|
|
||||||
]
|
|
||||||
|
|
||||||
for address in governance_addresses:
|
|
||||||
self.governance_addresses.add(address)
|
|
||||||
self.stake_weights[address] = Decimal('1000') # Equal stake weights initially
|
|
||||||
|
|
||||||
async def propose_upgrade(self, contract_type: str, current_version: str, new_version: str,
|
|
||||||
upgrade_type: UpgradeType, description: str, changes: Dict,
|
|
||||||
proposer: str, emergency: bool = False) -> Tuple[bool, str, Optional[str]]:
|
|
||||||
"""Propose contract upgrade"""
|
|
||||||
try:
|
|
||||||
# Validate inputs
|
|
||||||
if not all([contract_type, current_version, new_version, description, changes, proposer]):
|
|
||||||
return False, "Missing required fields", None
|
|
||||||
|
|
||||||
# Check proposer authority
|
|
||||||
if proposer not in self.governance_addresses:
|
|
||||||
return False, "Proposer not authorized", None
|
|
||||||
|
|
||||||
# Check current version
|
|
||||||
active_version = self.active_versions.get(contract_type)
|
|
||||||
if active_version != current_version:
|
|
||||||
return False, f"Current version mismatch. Active: {active_version}, Proposed: {current_version}", None
|
|
||||||
|
|
||||||
# Validate new version format
|
|
||||||
if not self._validate_version_format(new_version):
|
|
||||||
return False, "Invalid version format", None
|
|
||||||
|
|
||||||
# Check for existing proposal
|
|
||||||
for proposal in self.upgrade_proposals.values():
|
|
||||||
if (proposal.contract_type == contract_type and
|
|
||||||
proposal.new_version == new_version and
|
|
||||||
proposal.status in [UpgradeStatus.PROPOSED, UpgradeStatus.APPROVED]):
|
|
||||||
return False, "Proposal for this version already exists", None
|
|
||||||
|
|
||||||
# Generate proposal ID
|
|
||||||
proposal_id = self._generate_proposal_id(contract_type, new_version)
|
|
||||||
|
|
||||||
# Set voting deadlines
|
|
||||||
current_time = time.time()
|
|
||||||
voting_period = self.min_voting_period if not emergency else self.min_voting_period // 2
|
|
||||||
voting_deadline = current_time + voting_period
|
|
||||||
execution_deadline = voting_deadline + 86400 # 1 day after voting
|
|
||||||
|
|
||||||
# Set required approval rate
|
|
||||||
required_approval = self.emergency_upgrade_threshold if emergency else self.required_approval_rate
|
|
||||||
|
|
||||||
# Create proposal
|
|
||||||
proposal = UpgradeProposal(
|
|
||||||
proposal_id=proposal_id,
|
|
||||||
contract_type=contract_type,
|
|
||||||
current_version=current_version,
|
|
||||||
new_version=new_version,
|
|
||||||
upgrade_type=upgrade_type,
|
|
||||||
description=description,
|
|
||||||
changes=changes,
|
|
||||||
voting_deadline=voting_deadline,
|
|
||||||
execution_deadline=execution_deadline,
|
|
||||||
status=UpgradeStatus.PROPOSED,
|
|
||||||
votes={},
|
|
||||||
total_votes=0,
|
|
||||||
yes_votes=0,
|
|
||||||
no_votes=0,
|
|
||||||
required_approval=required_approval,
|
|
||||||
created_at=current_time,
|
|
||||||
proposer=proposer,
|
|
||||||
executed_at=None,
|
|
||||||
rollback_data=None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.upgrade_proposals[proposal_id] = proposal
|
|
||||||
|
|
||||||
# Start voting process
|
|
||||||
asyncio.create_task(self._manage_voting_process(proposal_id))
|
|
||||||
|
|
||||||
log_info(f"Upgrade proposal created: {proposal_id} - {contract_type} {current_version} -> {new_version}")
|
|
||||||
return True, "Upgrade proposal created successfully", proposal_id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return False, f"Failed to create proposal: {str(e)}", None
|
|
||||||
|
|
||||||
def _validate_version_format(self, version: str) -> bool:
|
|
||||||
"""Validate semantic version format"""
|
|
||||||
try:
|
|
||||||
parts = version.split('.')
|
|
||||||
if len(parts) != 3:
|
|
||||||
return False
|
|
||||||
|
|
||||||
major, minor, patch = parts
|
|
||||||
int(major) and int(minor) and int(patch)
|
|
||||||
return True
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _generate_proposal_id(self, contract_type: str, new_version: str) -> str:
|
|
||||||
"""Generate unique proposal ID"""
|
|
||||||
import hashlib
|
|
||||||
content = f"{contract_type}:{new_version}:{time.time()}"
|
|
||||||
return hashlib.sha256(content.encode()).hexdigest()[:12]
|
|
||||||
|
|
||||||
async def _manage_voting_process(self, proposal_id: str):
|
|
||||||
"""Manage voting process for proposal"""
|
|
||||||
proposal = self.upgrade_proposals.get(proposal_id)
|
|
||||||
if not proposal:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Wait for voting deadline
|
|
||||||
await asyncio.sleep(proposal.voting_deadline - time.time())
|
|
||||||
|
|
||||||
# Check voting results
|
|
||||||
await self._finalize_voting(proposal_id)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Error in voting process for {proposal_id}: {e}")
|
|
||||||
proposal.status = UpgradeStatus.FAILED
|
|
||||||
|
|
||||||
async def _finalize_voting(self, proposal_id: str):
|
|
||||||
"""Finalize voting and determine outcome"""
|
|
||||||
proposal = self.upgrade_proposals[proposal_id]
|
|
||||||
|
|
||||||
# Calculate voting results
|
|
||||||
total_stake = sum(self.stake_weights.get(voter, Decimal('0')) for voter in proposal.votes.keys())
|
|
||||||
yes_stake = sum(self.stake_weights.get(voter, Decimal('0')) for voter, vote in proposal.votes.items() if vote)
|
|
||||||
|
|
||||||
# Check minimum participation
|
|
||||||
total_governance_stake = sum(self.stake_weights.values())
|
|
||||||
participation_rate = float(total_stake / total_governance_stake) if total_governance_stake > 0 else 0
|
|
||||||
|
|
||||||
if participation_rate < self.min_participation_rate:
|
|
||||||
proposal.status = UpgradeStatus.REJECTED
|
|
||||||
log_info(f"Proposal {proposal_id} rejected due to low participation: {participation_rate:.2%}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check approval rate
|
|
||||||
approval_rate = float(yes_stake / total_stake) if total_stake > 0 else 0
|
|
||||||
|
|
||||||
if approval_rate >= proposal.required_approval:
|
|
||||||
proposal.status = UpgradeStatus.APPROVED
|
|
||||||
log_info(f"Proposal {proposal_id} approved with {approval_rate:.2%} approval")
|
|
||||||
|
|
||||||
# Schedule execution
|
|
||||||
asyncio.create_task(self._execute_upgrade(proposal_id))
|
|
||||||
else:
|
|
||||||
proposal.status = UpgradeStatus.REJECTED
|
|
||||||
log_info(f"Proposal {proposal_id} rejected with {approval_rate:.2%} approval")
|
|
||||||
|
|
||||||
async def vote_on_proposal(self, proposal_id: str, voter_address: str, vote: bool) -> Tuple[bool, str]:
|
|
||||||
"""Cast vote on upgrade proposal"""
|
|
||||||
proposal = self.upgrade_proposals.get(proposal_id)
|
|
||||||
if not proposal:
|
|
||||||
return False, "Proposal not found"
|
|
||||||
|
|
||||||
# Check voting authority
|
|
||||||
if voter_address not in self.governance_addresses:
|
|
||||||
return False, "Not authorized to vote"
|
|
||||||
|
|
||||||
# Check voting period
|
|
||||||
if time.time() > proposal.voting_deadline:
|
|
||||||
return False, "Voting period has ended"
|
|
||||||
|
|
||||||
# Check if already voted
|
|
||||||
if voter_address in proposal.votes:
|
|
||||||
return False, "Already voted"
|
|
||||||
|
|
||||||
# Cast vote
|
|
||||||
proposal.votes[voter_address] = vote
|
|
||||||
proposal.total_votes += 1
|
|
||||||
|
|
||||||
if vote:
|
|
||||||
proposal.yes_votes += 1
|
|
||||||
else:
|
|
||||||
proposal.no_votes += 1
|
|
||||||
|
|
||||||
log_info(f"Vote cast on proposal {proposal_id} by {voter_address}: {'YES' if vote else 'NO'}")
|
|
||||||
return True, "Vote cast successfully"
|
|
||||||
|
|
||||||
async def _execute_upgrade(self, proposal_id: str):
|
|
||||||
"""Execute approved upgrade"""
|
|
||||||
proposal = self.upgrade_proposals[proposal_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Wait for execution deadline
|
|
||||||
await asyncio.sleep(proposal.execution_deadline - time.time())
|
|
||||||
|
|
||||||
# Check if still approved
|
|
||||||
if proposal.status != UpgradeStatus.APPROVED:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Prepare rollback data
|
|
||||||
rollback_data = await self._prepare_rollback_data(proposal)
|
|
||||||
|
|
||||||
# Execute upgrade
|
|
||||||
success = await self._perform_upgrade(proposal)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
proposal.status = UpgradeStatus.EXECUTED
|
|
||||||
proposal.executed_at = time.time()
|
|
||||||
proposal.rollback_data = rollback_data
|
|
||||||
|
|
||||||
# Update active version
|
|
||||||
self.active_versions[proposal.contract_type] = proposal.new_version
|
|
||||||
|
|
||||||
# Record in history
|
|
||||||
self.upgrade_history.append({
|
|
||||||
'proposal_id': proposal_id,
|
|
||||||
'contract_type': proposal.contract_type,
|
|
||||||
'from_version': proposal.current_version,
|
|
||||||
'to_version': proposal.new_version,
|
|
||||||
'executed_at': proposal.executed_at,
|
|
||||||
'upgrade_type': proposal.upgrade_type.value
|
|
||||||
})
|
|
||||||
|
|
||||||
log_info(f"Upgrade executed: {proposal_id} - {proposal.contract_type} {proposal.current_version} -> {proposal.new_version}")
|
|
||||||
|
|
||||||
# Start rollback window
|
|
||||||
asyncio.create_task(self._manage_rollback_window(proposal_id))
|
|
||||||
else:
|
|
||||||
proposal.status = UpgradeStatus.FAILED
|
|
||||||
log_error(f"Upgrade execution failed: {proposal_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
proposal.status = UpgradeStatus.FAILED
|
|
||||||
log_error(f"Error executing upgrade {proposal_id}: {e}")
|
|
||||||
|
|
||||||
async def _prepare_rollback_data(self, proposal: UpgradeProposal) -> Dict:
|
|
||||||
"""Prepare data for potential rollback"""
|
|
||||||
return {
|
|
||||||
'previous_version': proposal.current_version,
|
|
||||||
'contract_state': {}, # Would capture current contract state
|
|
||||||
'migration_data': {}, # Would store migration data
|
|
||||||
'timestamp': time.time()
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _perform_upgrade(self, proposal: UpgradeProposal) -> bool:
|
|
||||||
"""Perform the actual upgrade"""
|
|
||||||
try:
|
|
||||||
# In real implementation, this would:
|
|
||||||
# 1. Deploy new contract version
|
|
||||||
# 2. Migrate state from old contract
|
|
||||||
# 3. Update contract references
|
|
||||||
# 4. Verify upgrade integrity
|
|
||||||
|
|
||||||
# Simulate upgrade process
|
|
||||||
await asyncio.sleep(10) # Simulate upgrade time
|
|
||||||
|
|
||||||
# Create new version record
|
|
||||||
new_version = ContractVersion(
|
|
||||||
version=proposal.new_version,
|
|
||||||
address=f"0x{proposal.contract_type}_{proposal.new_version}", # New address
|
|
||||||
deployed_at=time.time(),
|
|
||||||
total_contracts=0,
|
|
||||||
total_value=Decimal('0'),
|
|
||||||
is_active=True,
|
|
||||||
metadata={
|
|
||||||
'upgrade_type': proposal.upgrade_type.value,
|
|
||||||
'proposal_id': proposal.proposal_id,
|
|
||||||
'changes': proposal.changes
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add to version history
|
|
||||||
if proposal.contract_type not in self.contract_versions:
|
|
||||||
self.contract_versions[proposal.contract_type] = []
|
|
||||||
|
|
||||||
# Deactivate old version
|
|
||||||
for version in self.contract_versions[proposal.contract_type]:
|
|
||||||
if version.version == proposal.current_version:
|
|
||||||
version.is_active = False
|
|
||||||
break
|
|
||||||
|
|
||||||
# Add new version
|
|
||||||
self.contract_versions[proposal.contract_type].append(new_version)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Upgrade execution error: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _manage_rollback_window(self, proposal_id: str):
|
|
||||||
"""Manage rollback window after upgrade"""
|
|
||||||
proposal = self.upgrade_proposals[proposal_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Wait for rollback timeout
|
|
||||||
await asyncio.sleep(self.rollback_timeout)
|
|
||||||
|
|
||||||
# Check if rollback was requested
|
|
||||||
if proposal.status == UpgradeStatus.EXECUTED:
|
|
||||||
# No rollback requested, finalize upgrade
|
|
||||||
await self._finalize_upgrade(proposal_id)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Error in rollback window for {proposal_id}: {e}")
|
|
||||||
|
|
||||||
async def _finalize_upgrade(self, proposal_id: str):
|
|
||||||
"""Finalize upgrade after rollback window"""
|
|
||||||
proposal = self.upgrade_proposals[proposal_id]
|
|
||||||
|
|
||||||
# Clear rollback data to save space
|
|
||||||
proposal.rollback_data = None
|
|
||||||
|
|
||||||
log_info(f"Upgrade finalized: {proposal_id}")
|
|
||||||
|
|
||||||
async def rollback_upgrade(self, proposal_id: str, reason: str) -> Tuple[bool, str]:
|
|
||||||
"""Rollback upgrade to previous version"""
|
|
||||||
proposal = self.upgrade_proposals.get(proposal_id)
|
|
||||||
if not proposal:
|
|
||||||
return False, "Proposal not found"
|
|
||||||
|
|
||||||
if proposal.status != UpgradeStatus.EXECUTED:
|
|
||||||
return False, "Can only rollback executed upgrades"
|
|
||||||
|
|
||||||
if not proposal.rollback_data:
|
|
||||||
return False, "Rollback data not available"
|
|
||||||
|
|
||||||
# Check rollback window
|
|
||||||
if time.time() - proposal.executed_at > self.rollback_timeout:
|
|
||||||
return False, "Rollback window has expired"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Perform rollback
|
|
||||||
success = await self._perform_rollback(proposal)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
proposal.status = UpgradeStatus.ROLLED_BACK
|
|
||||||
|
|
||||||
# Restore previous version
|
|
||||||
self.active_versions[proposal.contract_type] = proposal.current_version
|
|
||||||
|
|
||||||
# Update version records
|
|
||||||
for version in self.contract_versions[proposal.contract_type]:
|
|
||||||
if version.version == proposal.new_version:
|
|
||||||
version.is_active = False
|
|
||||||
elif version.version == proposal.current_version:
|
|
||||||
version.is_active = True
|
|
||||||
|
|
||||||
log_info(f"Upgrade rolled back: {proposal_id} - Reason: {reason}")
|
|
||||||
return True, "Rollback successful"
|
|
||||||
else:
|
|
||||||
return False, "Rollback execution failed"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Rollback error for {proposal_id}: {e}")
|
|
||||||
return False, f"Rollback failed: {str(e)}"
|
|
||||||
|
|
||||||
async def _perform_rollback(self, proposal: UpgradeProposal) -> bool:
|
|
||||||
"""Perform the actual rollback"""
|
|
||||||
try:
|
|
||||||
# In real implementation, this would:
|
|
||||||
# 1. Restore previous contract state
|
|
||||||
# 2. Update contract references back
|
|
||||||
# 3. Verify rollback integrity
|
|
||||||
|
|
||||||
# Simulate rollback process
|
|
||||||
await asyncio.sleep(5) # Simulate rollback time
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_error(f"Rollback execution error: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def get_proposal(self, proposal_id: str) -> Optional[UpgradeProposal]:
|
|
||||||
"""Get upgrade proposal"""
|
|
||||||
return self.upgrade_proposals.get(proposal_id)
|
|
||||||
|
|
||||||
async def get_proposals_by_status(self, status: UpgradeStatus) -> List[UpgradeProposal]:
|
|
||||||
"""Get proposals by status"""
|
|
||||||
return [
|
|
||||||
proposal for proposal in self.upgrade_proposals.values()
|
|
||||||
if proposal.status == status
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_contract_versions(self, contract_type: str) -> List[ContractVersion]:
|
|
||||||
"""Get all versions for a contract type"""
|
|
||||||
return self.contract_versions.get(contract_type, [])
|
|
||||||
|
|
||||||
async def get_active_version(self, contract_type: str) -> Optional[str]:
|
|
||||||
"""Get active version for contract type"""
|
|
||||||
return self.active_versions.get(contract_type)
|
|
||||||
|
|
||||||
async def get_upgrade_statistics(self) -> Dict:
|
|
||||||
"""Get upgrade system statistics"""
|
|
||||||
total_proposals = len(self.upgrade_proposals)
|
|
||||||
|
|
||||||
if total_proposals == 0:
|
|
||||||
return {
|
|
||||||
'total_proposals': 0,
|
|
||||||
'status_distribution': {},
|
|
||||||
'upgrade_types': {},
|
|
||||||
'average_execution_time': 0,
|
|
||||||
'success_rate': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
# Status distribution
|
|
||||||
status_counts = {}
|
|
||||||
for proposal in self.upgrade_proposals.values():
|
|
||||||
status = proposal.status.value
|
|
||||||
status_counts[status] = status_counts.get(status, 0) + 1
|
|
||||||
|
|
||||||
# Upgrade type distribution
|
|
||||||
type_counts = {}
|
|
||||||
for proposal in self.upgrade_proposals.values():
|
|
||||||
up_type = proposal.upgrade_type.value
|
|
||||||
type_counts[up_type] = type_counts.get(up_type, 0) + 1
|
|
||||||
|
|
||||||
# Execution statistics
|
|
||||||
executed_proposals = [
|
|
||||||
proposal for proposal in self.upgrade_proposals.values()
|
|
||||||
if proposal.status == UpgradeStatus.EXECUTED
|
|
||||||
]
|
|
||||||
|
|
||||||
if executed_proposals:
|
|
||||||
execution_times = [
|
|
||||||
proposal.executed_at - proposal.created_at
|
|
||||||
for proposal in executed_proposals
|
|
||||||
if proposal.executed_at
|
|
||||||
]
|
|
||||||
avg_execution_time = sum(execution_times) / len(execution_times) if execution_times else 0
|
|
||||||
else:
|
|
||||||
avg_execution_time = 0
|
|
||||||
|
|
||||||
# Success rate
|
|
||||||
successful_upgrades = len(executed_proposals)
|
|
||||||
success_rate = successful_upgrades / total_proposals if total_proposals > 0 else 0
|
|
||||||
|
|
||||||
return {
|
|
||||||
'total_proposals': total_proposals,
|
|
||||||
'status_distribution': status_counts,
|
|
||||||
'upgrade_types': type_counts,
|
|
||||||
'average_execution_time': avg_execution_time,
|
|
||||||
'success_rate': success_rate,
|
|
||||||
'total_governance_addresses': len(self.governance_addresses),
|
|
||||||
'contract_types': len(self.contract_versions)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Global upgrade manager
|
|
||||||
upgrade_manager: Optional[ContractUpgradeManager] = None
|
|
||||||
|
|
||||||
def get_upgrade_manager() -> Optional[ContractUpgradeManager]:
|
|
||||||
"""Get global upgrade manager"""
|
|
||||||
return upgrade_manager
|
|
||||||
|
|
||||||
def create_upgrade_manager() -> ContractUpgradeManager:
|
|
||||||
"""Create and set global upgrade manager"""
|
|
||||||
global upgrade_manager
|
|
||||||
upgrade_manager = ContractUpgradeManager()
|
|
||||||
return upgrade_manager
|
|
||||||
@@ -1,519 +0,0 @@
|
|||||||
"""
|
|
||||||
AITBC Agent Messaging Contract Implementation
|
|
||||||
|
|
||||||
This module implements on-chain messaging functionality for agents,
|
|
||||||
enabling forum-like communication between autonomous agents.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Any
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from enum import Enum
|
|
||||||
import json
|
|
||||||
import hashlib
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address
|
|
||||||
|
|
||||||
class MessageType(Enum):
|
|
||||||
"""Types of messages agents can send"""
|
|
||||||
POST = "post"
|
|
||||||
REPLY = "reply"
|
|
||||||
ANNOUNCEMENT = "announcement"
|
|
||||||
QUESTION = "question"
|
|
||||||
ANSWER = "answer"
|
|
||||||
MODERATION = "moderation"
|
|
||||||
|
|
||||||
class MessageStatus(Enum):
|
|
||||||
"""Status of messages in the forum"""
|
|
||||||
ACTIVE = "active"
|
|
||||||
HIDDEN = "hidden"
|
|
||||||
DELETED = "deleted"
|
|
||||||
PINNED = "pinned"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Message:
|
|
||||||
"""Represents a message in the agent forum"""
|
|
||||||
message_id: str
|
|
||||||
agent_id: str
|
|
||||||
agent_address: str
|
|
||||||
topic: str
|
|
||||||
content: str
|
|
||||||
message_type: MessageType
|
|
||||||
timestamp: datetime
|
|
||||||
parent_message_id: Optional[str] = None
|
|
||||||
reply_count: int = 0
|
|
||||||
upvotes: int = 0
|
|
||||||
downvotes: int = 0
|
|
||||||
status: MessageStatus = MessageStatus.ACTIVE
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Topic:
|
|
||||||
"""Represents a forum topic"""
|
|
||||||
topic_id: str
|
|
||||||
title: str
|
|
||||||
description: str
|
|
||||||
creator_agent_id: str
|
|
||||||
created_at: datetime
|
|
||||||
message_count: int = 0
|
|
||||||
last_activity: datetime = field(default_factory=datetime.now)
|
|
||||||
tags: List[str] = field(default_factory=list)
|
|
||||||
is_pinned: bool = False
|
|
||||||
is_locked: bool = False
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentReputation:
|
|
||||||
"""Reputation system for agents"""
|
|
||||||
agent_id: str
|
|
||||||
message_count: int = 0
|
|
||||||
upvotes_received: int = 0
|
|
||||||
downvotes_received: int = 0
|
|
||||||
reputation_score: float = 0.0
|
|
||||||
trust_level: int = 1 # 1-5 trust levels
|
|
||||||
is_moderator: bool = False
|
|
||||||
is_banned: bool = False
|
|
||||||
ban_reason: Optional[str] = None
|
|
||||||
ban_expires: Optional[datetime] = None
|
|
||||||
|
|
||||||
class AgentMessagingContract:
|
|
||||||
"""Main contract for agent messaging functionality"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.messages: Dict[str, Message] = {}
|
|
||||||
self.topics: Dict[str, Topic] = {}
|
|
||||||
self.agent_reputations: Dict[str, AgentReputation] = {}
|
|
||||||
self.moderation_log: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
def create_topic(self, agent_id: str, agent_address: str, title: str,
|
|
||||||
description: str, tags: List[str] = None) -> Dict[str, Any]:
|
|
||||||
"""Create a new forum topic"""
|
|
||||||
|
|
||||||
# Check if agent is banned
|
|
||||||
if self._is_agent_banned(agent_id):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Agent is banned from posting",
|
|
||||||
"error_code": "AGENT_BANNED"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate topic ID
|
|
||||||
topic_id = f"topic_{hashlib.sha256(f'{agent_id}_{title}_{datetime.now()}'.encode()).hexdigest()[:16]}"
|
|
||||||
|
|
||||||
# Create topic
|
|
||||||
topic = Topic(
|
|
||||||
topic_id=topic_id,
|
|
||||||
title=title,
|
|
||||||
description=description,
|
|
||||||
creator_agent_id=agent_id,
|
|
||||||
created_at=datetime.now(),
|
|
||||||
tags=tags or []
|
|
||||||
)
|
|
||||||
|
|
||||||
self.topics[topic_id] = topic
|
|
||||||
|
|
||||||
# Update agent reputation
|
|
||||||
self._update_agent_reputation(agent_id, message_count=1)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"topic_id": topic_id,
|
|
||||||
"topic": self._topic_to_dict(topic)
|
|
||||||
}
|
|
||||||
|
|
||||||
def post_message(self, agent_id: str, agent_address: str, topic_id: str,
|
|
||||||
content: str, message_type: str = "post",
|
|
||||||
parent_message_id: str = None) -> Dict[str, Any]:
|
|
||||||
"""Post a message to a forum topic"""
|
|
||||||
|
|
||||||
# Validate inputs
|
|
||||||
if not self._validate_agent(agent_id, agent_address):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid agent credentials",
|
|
||||||
"error_code": "INVALID_AGENT"
|
|
||||||
}
|
|
||||||
|
|
||||||
if self._is_agent_banned(agent_id):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Agent is banned from posting",
|
|
||||||
"error_code": "AGENT_BANNED"
|
|
||||||
}
|
|
||||||
|
|
||||||
if topic_id not in self.topics:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Topic not found",
|
|
||||||
"error_code": "TOPIC_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.topics[topic_id].is_locked:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Topic is locked",
|
|
||||||
"error_code": "TOPIC_LOCKED"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate message type
|
|
||||||
try:
|
|
||||||
msg_type = MessageType(message_type)
|
|
||||||
except ValueError:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid message type",
|
|
||||||
"error_code": "INVALID_MESSAGE_TYPE"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate message ID
|
|
||||||
message_id = f"msg_{hashlib.sha256(f'{agent_id}_{topic_id}_{content}_{datetime.now()}'.encode()).hexdigest()[:16]}"
|
|
||||||
|
|
||||||
# Create message
|
|
||||||
message = Message(
|
|
||||||
message_id=message_id,
|
|
||||||
agent_id=agent_id,
|
|
||||||
agent_address=agent_address,
|
|
||||||
topic=topic_id,
|
|
||||||
content=content,
|
|
||||||
message_type=msg_type,
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
parent_message_id=parent_message_id
|
|
||||||
)
|
|
||||||
|
|
||||||
self.messages[message_id] = message
|
|
||||||
|
|
||||||
# Update topic
|
|
||||||
self.topics[topic_id].message_count += 1
|
|
||||||
self.topics[topic_id].last_activity = datetime.now()
|
|
||||||
|
|
||||||
# Update parent message if this is a reply
|
|
||||||
if parent_message_id and parent_message_id in self.messages:
|
|
||||||
self.messages[parent_message_id].reply_count += 1
|
|
||||||
|
|
||||||
# Update agent reputation
|
|
||||||
self._update_agent_reputation(agent_id, message_count=1)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message_id": message_id,
|
|
||||||
"message": self._message_to_dict(message)
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_messages(self, topic_id: str, limit: int = 50, offset: int = 0,
|
|
||||||
sort_by: str = "timestamp") -> Dict[str, Any]:
|
|
||||||
"""Get messages from a topic"""
|
|
||||||
|
|
||||||
if topic_id not in self.topics:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Topic not found",
|
|
||||||
"error_code": "TOPIC_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get all messages for this topic
|
|
||||||
topic_messages = [
|
|
||||||
msg for msg in self.messages.values()
|
|
||||||
if msg.topic == topic_id and msg.status == MessageStatus.ACTIVE
|
|
||||||
]
|
|
||||||
|
|
||||||
# Sort messages
|
|
||||||
if sort_by == "timestamp":
|
|
||||||
topic_messages.sort(key=lambda x: x.timestamp, reverse=True)
|
|
||||||
elif sort_by == "upvotes":
|
|
||||||
topic_messages.sort(key=lambda x: x.upvotes, reverse=True)
|
|
||||||
elif sort_by == "replies":
|
|
||||||
topic_messages.sort(key=lambda x: x.reply_count, reverse=True)
|
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
total_messages = len(topic_messages)
|
|
||||||
paginated_messages = topic_messages[offset:offset + limit]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"messages": [self._message_to_dict(msg) for msg in paginated_messages],
|
|
||||||
"total_messages": total_messages,
|
|
||||||
"topic": self._topic_to_dict(self.topics[topic_id])
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_topics(self, limit: int = 50, offset: int = 0,
|
|
||||||
sort_by: str = "last_activity") -> Dict[str, Any]:
|
|
||||||
"""Get list of forum topics"""
|
|
||||||
|
|
||||||
# Sort topics
|
|
||||||
topic_list = list(self.topics.values())
|
|
||||||
|
|
||||||
if sort_by == "last_activity":
|
|
||||||
topic_list.sort(key=lambda x: x.last_activity, reverse=True)
|
|
||||||
elif sort_by == "created_at":
|
|
||||||
topic_list.sort(key=lambda x: x.created_at, reverse=True)
|
|
||||||
elif sort_by == "message_count":
|
|
||||||
topic_list.sort(key=lambda x: x.message_count, reverse=True)
|
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
total_topics = len(topic_list)
|
|
||||||
paginated_topics = topic_list[offset:offset + limit]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"topics": [self._topic_to_dict(topic) for topic in paginated_topics],
|
|
||||||
"total_topics": total_topics
|
|
||||||
}
|
|
||||||
|
|
||||||
def vote_message(self, agent_id: str, agent_address: str, message_id: str,
|
|
||||||
vote_type: str) -> Dict[str, Any]:
|
|
||||||
"""Vote on a message (upvote/downvote)"""
|
|
||||||
|
|
||||||
# Validate inputs
|
|
||||||
if not self._validate_agent(agent_id, agent_address):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid agent credentials",
|
|
||||||
"error_code": "INVALID_AGENT"
|
|
||||||
}
|
|
||||||
|
|
||||||
if message_id not in self.messages:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Message not found",
|
|
||||||
"error_code": "MESSAGE_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
if vote_type not in ["upvote", "downvote"]:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid vote type",
|
|
||||||
"error_code": "INVALID_VOTE_TYPE"
|
|
||||||
}
|
|
||||||
|
|
||||||
message = self.messages[message_id]
|
|
||||||
|
|
||||||
# Update vote counts
|
|
||||||
if vote_type == "upvote":
|
|
||||||
message.upvotes += 1
|
|
||||||
else:
|
|
||||||
message.downvotes += 1
|
|
||||||
|
|
||||||
# Update message author reputation
|
|
||||||
self._update_agent_reputation(
|
|
||||||
message.agent_id,
|
|
||||||
upvotes_received=message.upvotes,
|
|
||||||
downvotes_received=message.downvotes
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message_id": message_id,
|
|
||||||
"upvotes": message.upvotes,
|
|
||||||
"downvotes": message.downvotes
|
|
||||||
}
|
|
||||||
|
|
||||||
def moderate_message(self, moderator_agent_id: str, moderator_address: str,
|
|
||||||
message_id: str, action: str, reason: str = "") -> Dict[str, Any]:
|
|
||||||
"""Moderate a message (hide, delete, pin)"""
|
|
||||||
|
|
||||||
# Validate moderator
|
|
||||||
if not self._is_moderator(moderator_agent_id):
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Insufficient permissions",
|
|
||||||
"error_code": "INSUFFICIENT_PERMISSIONS"
|
|
||||||
}
|
|
||||||
|
|
||||||
if message_id not in self.messages:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Message not found",
|
|
||||||
"error_code": "MESSAGE_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
message = self.messages[message_id]
|
|
||||||
|
|
||||||
# Apply moderation action
|
|
||||||
if action == "hide":
|
|
||||||
message.status = MessageStatus.HIDDEN
|
|
||||||
elif action == "delete":
|
|
||||||
message.status = MessageStatus.DELETED
|
|
||||||
elif action == "pin":
|
|
||||||
message.status = MessageStatus.PINNED
|
|
||||||
elif action == "unpin":
|
|
||||||
message.status = MessageStatus.ACTIVE
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Invalid moderation action",
|
|
||||||
"error_code": "INVALID_ACTION"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log moderation action
|
|
||||||
self.moderation_log.append({
|
|
||||||
"timestamp": datetime.now(),
|
|
||||||
"moderator_agent_id": moderator_agent_id,
|
|
||||||
"message_id": message_id,
|
|
||||||
"action": action,
|
|
||||||
"reason": reason
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message_id": message_id,
|
|
||||||
"status": message.status.value
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_agent_reputation(self, agent_id: str) -> Dict[str, Any]:
|
|
||||||
"""Get an agent's reputation information"""
|
|
||||||
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Agent not found",
|
|
||||||
"error_code": "AGENT_NOT_FOUND"
|
|
||||||
}
|
|
||||||
|
|
||||||
reputation = self.agent_reputations[agent_id]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"agent_id": agent_id,
|
|
||||||
"reputation": self._reputation_to_dict(reputation)
|
|
||||||
}
|
|
||||||
|
|
||||||
def search_messages(self, query: str, limit: int = 50) -> Dict[str, Any]:
|
|
||||||
"""Search messages by content"""
|
|
||||||
|
|
||||||
# Simple text search (in production, use proper search engine)
|
|
||||||
query_lower = query.lower()
|
|
||||||
matching_messages = []
|
|
||||||
|
|
||||||
for message in self.messages.values():
|
|
||||||
if (message.status == MessageStatus.ACTIVE and
|
|
||||||
query_lower in message.content.lower()):
|
|
||||||
matching_messages.append(message)
|
|
||||||
|
|
||||||
# Sort by timestamp (most recent first)
|
|
||||||
matching_messages.sort(key=lambda x: x.timestamp, reverse=True)
|
|
||||||
|
|
||||||
# Limit results
|
|
||||||
limited_messages = matching_messages[:limit]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"query": query,
|
|
||||||
"messages": [self._message_to_dict(msg) for msg in limited_messages],
|
|
||||||
"total_matches": len(matching_messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
def _validate_agent(self, agent_id: str, agent_address: str) -> bool:
|
|
||||||
"""Validate agent credentials"""
|
|
||||||
# In a real implementation, this would verify the agent's signature
|
|
||||||
# For now, we'll do basic validation
|
|
||||||
return bool(agent_id and agent_address)
|
|
||||||
|
|
||||||
def _is_agent_banned(self, agent_id: str) -> bool:
|
|
||||||
"""Check if an agent is banned"""
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
return False
|
|
||||||
|
|
||||||
reputation = self.agent_reputations[agent_id]
|
|
||||||
|
|
||||||
if reputation.is_banned:
|
|
||||||
# Check if ban has expired
|
|
||||||
if reputation.ban_expires and datetime.now() > reputation.ban_expires:
|
|
||||||
reputation.is_banned = False
|
|
||||||
reputation.ban_expires = None
|
|
||||||
reputation.ban_reason = None
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _is_moderator(self, agent_id: str) -> bool:
|
|
||||||
"""Check if an agent is a moderator"""
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.agent_reputations[agent_id].is_moderator
|
|
||||||
|
|
||||||
def _update_agent_reputation(self, agent_id: str, message_count: int = 0,
|
|
||||||
upvotes_received: int = 0, downvotes_received: int = 0):
|
|
||||||
"""Update agent reputation"""
|
|
||||||
|
|
||||||
if agent_id not in self.agent_reputations:
|
|
||||||
self.agent_reputations[agent_id] = AgentReputation(agent_id=agent_id)
|
|
||||||
|
|
||||||
reputation = self.agent_reputations[agent_id]
|
|
||||||
|
|
||||||
if message_count > 0:
|
|
||||||
reputation.message_count += message_count
|
|
||||||
|
|
||||||
if upvotes_received > 0:
|
|
||||||
reputation.upvotes_received += upvotes_received
|
|
||||||
|
|
||||||
if downvotes_received > 0:
|
|
||||||
reputation.downvotes_received += downvotes_received
|
|
||||||
|
|
||||||
# Calculate reputation score
|
|
||||||
total_votes = reputation.upvotes_received + reputation.downvotes_received
|
|
||||||
if total_votes > 0:
|
|
||||||
reputation.reputation_score = (reputation.upvotes_received - reputation.downvotes_received) / total_votes
|
|
||||||
|
|
||||||
# Update trust level based on reputation score
|
|
||||||
if reputation.reputation_score >= 0.8:
|
|
||||||
reputation.trust_level = 5
|
|
||||||
elif reputation.reputation_score >= 0.6:
|
|
||||||
reputation.trust_level = 4
|
|
||||||
elif reputation.reputation_score >= 0.4:
|
|
||||||
reputation.trust_level = 3
|
|
||||||
elif reputation.reputation_score >= 0.2:
|
|
||||||
reputation.trust_level = 2
|
|
||||||
else:
|
|
||||||
reputation.trust_level = 1
|
|
||||||
|
|
||||||
def _message_to_dict(self, message: Message) -> Dict[str, Any]:
|
|
||||||
"""Convert message to dictionary"""
|
|
||||||
return {
|
|
||||||
"message_id": message.message_id,
|
|
||||||
"agent_id": message.agent_id,
|
|
||||||
"agent_address": message.agent_address,
|
|
||||||
"topic": message.topic,
|
|
||||||
"content": message.content,
|
|
||||||
"message_type": message.message_type.value,
|
|
||||||
"timestamp": message.timestamp.isoformat(),
|
|
||||||
"parent_message_id": message.parent_message_id,
|
|
||||||
"reply_count": message.reply_count,
|
|
||||||
"upvotes": message.upvotes,
|
|
||||||
"downvotes": message.downvotes,
|
|
||||||
"status": message.status.value,
|
|
||||||
"metadata": message.metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
def _topic_to_dict(self, topic: Topic) -> Dict[str, Any]:
|
|
||||||
"""Convert topic to dictionary"""
|
|
||||||
return {
|
|
||||||
"topic_id": topic.topic_id,
|
|
||||||
"title": topic.title,
|
|
||||||
"description": topic.description,
|
|
||||||
"creator_agent_id": topic.creator_agent_id,
|
|
||||||
"created_at": topic.created_at.isoformat(),
|
|
||||||
"message_count": topic.message_count,
|
|
||||||
"last_activity": topic.last_activity.isoformat(),
|
|
||||||
"tags": topic.tags,
|
|
||||||
"is_pinned": topic.is_pinned,
|
|
||||||
"is_locked": topic.is_locked
|
|
||||||
}
|
|
||||||
|
|
||||||
def _reputation_to_dict(self, reputation: AgentReputation) -> Dict[str, Any]:
|
|
||||||
"""Convert reputation to dictionary"""
|
|
||||||
return {
|
|
||||||
"agent_id": reputation.agent_id,
|
|
||||||
"message_count": reputation.message_count,
|
|
||||||
"upvotes_received": reputation.upvotes_received,
|
|
||||||
"downvotes_received": reputation.downvotes_received,
|
|
||||||
"reputation_score": reputation.reputation_score,
|
|
||||||
"trust_level": reputation.trust_level,
|
|
||||||
"is_moderator": reputation.is_moderator,
|
|
||||||
"is_banned": reputation.is_banned,
|
|
||||||
"ban_reason": reputation.ban_reason,
|
|
||||||
"ban_expires": reputation.ban_expires.isoformat() if reputation.ban_expires else None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Global contract instance
|
|
||||||
messaging_contract = AgentMessagingContract()
|
|
||||||
@@ -1,584 +0,0 @@
|
|||||||
"""
|
|
||||||
AITBC Agent Wallet Security Implementation
|
|
||||||
|
|
||||||
This module implements the security layer for autonomous agent wallets,
|
|
||||||
integrating the guardian contract to prevent unlimited spending in case
|
|
||||||
of agent compromise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address
|
|
||||||
|
|
||||||
from .guardian_contract import (
|
|
||||||
GuardianContract,
|
|
||||||
SpendingLimit,
|
|
||||||
TimeLockConfig,
|
|
||||||
GuardianConfig,
|
|
||||||
create_guardian_contract,
|
|
||||||
CONSERVATIVE_CONFIG,
|
|
||||||
AGGRESSIVE_CONFIG,
|
|
||||||
HIGH_SECURITY_CONFIG
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentSecurityProfile:
|
|
||||||
"""Security profile for an agent"""
|
|
||||||
agent_address: str
|
|
||||||
security_level: str # "conservative", "aggressive", "high_security"
|
|
||||||
guardian_addresses: List[str]
|
|
||||||
custom_limits: Optional[Dict] = None
|
|
||||||
enabled: bool = True
|
|
||||||
created_at: datetime = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.created_at is None:
|
|
||||||
self.created_at = datetime.utcnow()
|
|
||||||
|
|
||||||
|
|
||||||
class AgentWalletSecurity:
|
|
||||||
"""
|
|
||||||
Security manager for autonomous agent wallets
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.agent_profiles: Dict[str, AgentSecurityProfile] = {}
|
|
||||||
self.guardian_contracts: Dict[str, GuardianContract] = {}
|
|
||||||
self.security_events: List[Dict] = []
|
|
||||||
|
|
||||||
# Default configurations
|
|
||||||
self.configurations = {
|
|
||||||
"conservative": CONSERVATIVE_CONFIG,
|
|
||||||
"aggressive": AGGRESSIVE_CONFIG,
|
|
||||||
"high_security": HIGH_SECURITY_CONFIG
|
|
||||||
}
|
|
||||||
|
|
||||||
def register_agent(self,
|
|
||||||
agent_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
guardian_addresses: List[str] = None,
|
|
||||||
custom_limits: Dict = None) -> Dict:
|
|
||||||
"""
|
|
||||||
Register an agent for security protection
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
security_level: Security level (conservative, aggressive, high_security)
|
|
||||||
guardian_addresses: List of guardian addresses for recovery
|
|
||||||
custom_limits: Custom spending limits (overrides security_level)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Registration result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
if agent_address in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent already registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Validate security level
|
|
||||||
if security_level not in self.configurations:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Invalid security level: {security_level}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Default guardians if none provided
|
|
||||||
if guardian_addresses is None:
|
|
||||||
guardian_addresses = [agent_address] # Self-guardian (should be overridden)
|
|
||||||
|
|
||||||
# Validate guardian addresses
|
|
||||||
guardian_addresses = [to_checksum_address(addr) for addr in guardian_addresses]
|
|
||||||
|
|
||||||
# Create security profile
|
|
||||||
profile = AgentSecurityProfile(
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
guardian_addresses=guardian_addresses,
|
|
||||||
custom_limits=custom_limits
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create guardian contract
|
|
||||||
config = self.configurations[security_level]
|
|
||||||
if custom_limits:
|
|
||||||
config.update(custom_limits)
|
|
||||||
|
|
||||||
guardian_contract = create_guardian_contract(
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardians=guardian_addresses,
|
|
||||||
**config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store profile and contract
|
|
||||||
self.agent_profiles[agent_address] = profile
|
|
||||||
self.guardian_contracts[agent_address] = guardian_contract
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="agent_registered",
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
guardian_count=len(guardian_addresses)
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "registered",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": security_level,
|
|
||||||
"guardian_addresses": guardian_addresses,
|
|
||||||
"limits": guardian_contract.config.limits,
|
|
||||||
"time_lock_threshold": guardian_contract.config.time_lock.threshold,
|
|
||||||
"registered_at": profile.created_at.isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Registration failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def protect_transaction(self,
|
|
||||||
agent_address: str,
|
|
||||||
to_address: str,
|
|
||||||
amount: int,
|
|
||||||
data: str = "") -> Dict:
|
|
||||||
"""
|
|
||||||
Protect a transaction with guardian contract
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
to_address: Recipient address
|
|
||||||
amount: Amount to transfer
|
|
||||||
data: Transaction data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Protection result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
# Check if agent is registered
|
|
||||||
if agent_address not in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "unprotected",
|
|
||||||
"reason": "Agent not registered for security protection",
|
|
||||||
"suggestion": "Register agent with register_agent() first"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check if protection is enabled
|
|
||||||
profile = self.agent_profiles[agent_address]
|
|
||||||
if not profile.enabled:
|
|
||||||
return {
|
|
||||||
"status": "unprotected",
|
|
||||||
"reason": "Security protection disabled for this agent"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get guardian contract
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
# Initiate transaction protection
|
|
||||||
result = guardian_contract.initiate_transaction(to_address, amount, data)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="transaction_protected",
|
|
||||||
agent_address=agent_address,
|
|
||||||
to_address=to_address,
|
|
||||||
amount=amount,
|
|
||||||
protection_status=result["status"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Transaction protection failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def execute_protected_transaction(self,
|
|
||||||
agent_address: str,
|
|
||||||
operation_id: str,
|
|
||||||
signature: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Execute a previously protected transaction
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
operation_id: Operation ID from protection
|
|
||||||
signature: Transaction signature
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Execution result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
result = guardian_contract.execute_transaction(operation_id, signature)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
if result["status"] == "executed":
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="transaction_executed",
|
|
||||||
agent_address=agent_address,
|
|
||||||
operation_id=operation_id,
|
|
||||||
transaction_hash=result.get("transaction_hash")
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Transaction execution failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def emergency_pause_agent(self, agent_address: str, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Emergency pause an agent's operations
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address initiating pause
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Pause result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
result = guardian_contract.emergency_pause(guardian_address)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
if result["status"] == "paused":
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="emergency_pause",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Emergency pause failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def update_agent_security(self,
|
|
||||||
agent_address: str,
|
|
||||||
new_limits: Dict,
|
|
||||||
guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Update security limits for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
new_limits: New spending limits
|
|
||||||
guardian_address: Guardian address making the change
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Update result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
# Create new spending limits
|
|
||||||
limits = SpendingLimit(
|
|
||||||
per_transaction=new_limits.get("per_transaction", 1000),
|
|
||||||
per_hour=new_limits.get("per_hour", 5000),
|
|
||||||
per_day=new_limits.get("per_day", 20000),
|
|
||||||
per_week=new_limits.get("per_week", 100000)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = guardian_contract.update_limits(limits, guardian_address)
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
if result["status"] == "updated":
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="security_limits_updated",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address,
|
|
||||||
new_limits=new_limits
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Security update failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_agent_security_status(self, agent_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Get security status for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Security status
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
|
|
||||||
if agent_address not in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "not_registered",
|
|
||||||
"message": "Agent not registered for security protection"
|
|
||||||
}
|
|
||||||
|
|
||||||
profile = self.agent_profiles[agent_address]
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "protected",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": profile.security_level,
|
|
||||||
"enabled": profile.enabled,
|
|
||||||
"guardian_addresses": profile.guardian_addresses,
|
|
||||||
"registered_at": profile.created_at.isoformat(),
|
|
||||||
"spending_status": guardian_contract.get_spending_status(),
|
|
||||||
"pending_operations": guardian_contract.get_pending_operations(),
|
|
||||||
"recent_activity": guardian_contract.get_operation_history(10)
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Status check failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
def list_protected_agents(self) -> List[Dict]:
|
|
||||||
"""List all protected agents"""
|
|
||||||
agents = []
|
|
||||||
|
|
||||||
for agent_address, profile in self.agent_profiles.items():
|
|
||||||
guardian_contract = self.guardian_contracts[agent_address]
|
|
||||||
|
|
||||||
agents.append({
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": profile.security_level,
|
|
||||||
"enabled": profile.enabled,
|
|
||||||
"guardian_count": len(profile.guardian_addresses),
|
|
||||||
"pending_operations": len(guardian_contract.pending_operations),
|
|
||||||
"paused": guardian_contract.paused,
|
|
||||||
"emergency_mode": guardian_contract.emergency_mode,
|
|
||||||
"registered_at": profile.created_at.isoformat()
|
|
||||||
})
|
|
||||||
|
|
||||||
return sorted(agents, key=lambda x: x["registered_at"], reverse=True)
|
|
||||||
|
|
||||||
def get_security_events(self, agent_address: str = None, limit: int = 50) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Get security events
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Filter by agent address (optional)
|
|
||||||
limit: Maximum number of events
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Security events
|
|
||||||
"""
|
|
||||||
events = self.security_events
|
|
||||||
|
|
||||||
if agent_address:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
events = [e for e in events if e.get("agent_address") == agent_address]
|
|
||||||
|
|
||||||
return sorted(events, key=lambda x: x["timestamp"], reverse=True)[:limit]
|
|
||||||
|
|
||||||
def _log_security_event(self, **kwargs):
|
|
||||||
"""Log a security event"""
|
|
||||||
event = {
|
|
||||||
"timestamp": datetime.utcnow().isoformat(),
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
self.security_events.append(event)
|
|
||||||
|
|
||||||
def disable_agent_protection(self, agent_address: str, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Disable protection for an agent (guardian only)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
guardian_address: Guardian address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Disable result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
guardian_address = to_checksum_address(guardian_address)
|
|
||||||
|
|
||||||
if agent_address not in self.agent_profiles:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
profile = self.agent_profiles[agent_address]
|
|
||||||
|
|
||||||
if guardian_address not in profile.guardian_addresses:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Not authorized: not a guardian"
|
|
||||||
}
|
|
||||||
|
|
||||||
profile.enabled = False
|
|
||||||
|
|
||||||
# Log security event
|
|
||||||
self._log_security_event(
|
|
||||||
event_type="protection_disabled",
|
|
||||||
agent_address=agent_address,
|
|
||||||
guardian_address=guardian_address
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "disabled",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"disabled_at": datetime.utcnow().isoformat(),
|
|
||||||
"guardian": guardian_address
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Disable protection failed: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Global security manager instance
|
|
||||||
agent_wallet_security = AgentWalletSecurity()
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience functions for common operations
|
|
||||||
def register_agent_for_protection(agent_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
guardians: List[str] = None) -> Dict:
|
|
||||||
"""Register an agent for security protection"""
|
|
||||||
return agent_wallet_security.register_agent(
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
guardian_addresses=guardians
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def protect_agent_transaction(agent_address: str,
|
|
||||||
to_address: str,
|
|
||||||
amount: int,
|
|
||||||
data: str = "") -> Dict:
|
|
||||||
"""Protect a transaction for an agent"""
|
|
||||||
return agent_wallet_security.protect_transaction(
|
|
||||||
agent_address=agent_address,
|
|
||||||
to_address=to_address,
|
|
||||||
amount=amount,
|
|
||||||
data=data
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_security_summary(agent_address: str) -> Dict:
|
|
||||||
"""Get security summary for an agent"""
|
|
||||||
return agent_wallet_security.get_agent_security_status(agent_address)
|
|
||||||
|
|
||||||
|
|
||||||
# Security audit and monitoring functions
|
|
||||||
def generate_security_report() -> Dict:
|
|
||||||
"""Generate comprehensive security report"""
|
|
||||||
protected_agents = agent_wallet_security.list_protected_agents()
|
|
||||||
|
|
||||||
total_agents = len(protected_agents)
|
|
||||||
active_agents = len([a for a in protected_agents if a["enabled"]])
|
|
||||||
paused_agents = len([a for a in protected_agents if a["paused"]])
|
|
||||||
emergency_agents = len([a for a in protected_agents if a["emergency_mode"]])
|
|
||||||
|
|
||||||
recent_events = agent_wallet_security.get_security_events(limit=20)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"generated_at": datetime.utcnow().isoformat(),
|
|
||||||
"summary": {
|
|
||||||
"total_protected_agents": total_agents,
|
|
||||||
"active_agents": active_agents,
|
|
||||||
"paused_agents": paused_agents,
|
|
||||||
"emergency_mode_agents": emergency_agents,
|
|
||||||
"protection_coverage": f"{(active_agents / total_agents * 100):.1f}%" if total_agents > 0 else "0%"
|
|
||||||
},
|
|
||||||
"agents": protected_agents,
|
|
||||||
"recent_security_events": recent_events,
|
|
||||||
"security_levels": {
|
|
||||||
level: len([a for a in protected_agents if a["security_level"] == level])
|
|
||||||
for level in ["conservative", "aggressive", "high_security"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def detect_suspicious_activity(agent_address: str, hours: int = 24) -> Dict:
|
|
||||||
"""Detect suspicious activity for an agent"""
|
|
||||||
status = agent_wallet_security.get_agent_security_status(agent_address)
|
|
||||||
|
|
||||||
if status["status"] != "protected":
|
|
||||||
return {
|
|
||||||
"status": "not_protected",
|
|
||||||
"suspicious_activity": False
|
|
||||||
}
|
|
||||||
|
|
||||||
spending_status = status["spending_status"]
|
|
||||||
recent_events = agent_wallet_security.get_security_events(agent_address, limit=50)
|
|
||||||
|
|
||||||
# Suspicious patterns
|
|
||||||
suspicious_patterns = []
|
|
||||||
|
|
||||||
# Check for rapid spending
|
|
||||||
if spending_status["spent"]["current_hour"] > spending_status["current_limits"]["per_hour"] * 0.8:
|
|
||||||
suspicious_patterns.append("High hourly spending rate")
|
|
||||||
|
|
||||||
# Check for many small transactions (potential dust attack)
|
|
||||||
recent_tx_count = len([e for e in recent_events if e["event_type"] == "transaction_executed"])
|
|
||||||
if recent_tx_count > 20:
|
|
||||||
suspicious_patterns.append("High transaction frequency")
|
|
||||||
|
|
||||||
# Check for emergency pauses
|
|
||||||
recent_pauses = len([e for e in recent_events if e["event_type"] == "emergency_pause"])
|
|
||||||
if recent_pauses > 0:
|
|
||||||
suspicious_patterns.append("Recent emergency pauses detected")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "analyzed",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"suspicious_activity": len(suspicious_patterns) > 0,
|
|
||||||
"suspicious_patterns": suspicious_patterns,
|
|
||||||
"analysis_period_hours": hours,
|
|
||||||
"analyzed_at": datetime.utcnow().isoformat()
|
|
||||||
}
|
|
||||||
@@ -1,559 +0,0 @@
|
|||||||
"""
|
|
||||||
Smart Contract Escrow System
|
|
||||||
Handles automated payment holding and release for AI job marketplace
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from typing import Dict, List, Optional, Tuple, Set
|
|
||||||
from dataclasses import dataclass, asdict
|
|
||||||
from enum import Enum
|
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
class EscrowState(Enum):
|
|
||||||
CREATED = "created"
|
|
||||||
FUNDED = "funded"
|
|
||||||
JOB_STARTED = "job_started"
|
|
||||||
JOB_COMPLETED = "job_completed"
|
|
||||||
DISPUTED = "disputed"
|
|
||||||
RESOLVED = "resolved"
|
|
||||||
RELEASED = "released"
|
|
||||||
REFUNDED = "refunded"
|
|
||||||
EXPIRED = "expired"
|
|
||||||
|
|
||||||
class DisputeReason(Enum):
|
|
||||||
QUALITY_ISSUES = "quality_issues"
|
|
||||||
DELIVERY_LATE = "delivery_late"
|
|
||||||
INCOMPLETE_WORK = "incomplete_work"
|
|
||||||
TECHNICAL_ISSUES = "technical_issues"
|
|
||||||
PAYMENT_DISPUTE = "payment_dispute"
|
|
||||||
OTHER = "other"
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EscrowContract:
|
|
||||||
contract_id: str
|
|
||||||
job_id: str
|
|
||||||
client_address: str
|
|
||||||
agent_address: str
|
|
||||||
amount: Decimal
|
|
||||||
fee_rate: Decimal # Platform fee rate
|
|
||||||
created_at: float
|
|
||||||
expires_at: float
|
|
||||||
state: EscrowState
|
|
||||||
milestones: List[Dict]
|
|
||||||
current_milestone: int
|
|
||||||
dispute_reason: Optional[DisputeReason]
|
|
||||||
dispute_evidence: List[Dict]
|
|
||||||
resolution: Optional[Dict]
|
|
||||||
released_amount: Decimal
|
|
||||||
refunded_amount: Decimal
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Milestone:
|
|
||||||
milestone_id: str
|
|
||||||
description: str
|
|
||||||
amount: Decimal
|
|
||||||
completed: bool
|
|
||||||
completed_at: Optional[float]
|
|
||||||
verified: bool
|
|
||||||
|
|
||||||
class EscrowManager:
|
|
||||||
"""Manages escrow contracts for AI job marketplace"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.escrow_contracts: Dict[str, EscrowContract] = {}
|
|
||||||
self.active_contracts: Set[str] = set()
|
|
||||||
self.disputed_contracts: Set[str] = set()
|
|
||||||
|
|
||||||
# Escrow parameters
|
|
||||||
self.default_fee_rate = Decimal('0.025') # 2.5% platform fee
|
|
||||||
self.max_contract_duration = 86400 * 30 # 30 days
|
|
||||||
self.dispute_timeout = 86400 * 7 # 7 days for dispute resolution
|
|
||||||
self.min_dispute_evidence = 1
|
|
||||||
self.max_dispute_evidence = 10
|
|
||||||
|
|
||||||
# Milestone parameters
|
|
||||||
self.min_milestone_amount = Decimal('0.01')
|
|
||||||
self.max_milestones = 10
|
|
||||||
self.verification_timeout = 86400 # 24 hours for milestone verification
|
|
||||||
|
|
||||||
async def create_contract(self, job_id: str, client_address: str, agent_address: str,
|
|
||||||
amount: Decimal, fee_rate: Optional[Decimal] = None,
|
|
||||||
milestones: Optional[List[Dict]] = None,
|
|
||||||
duration_days: int = 30) -> Tuple[bool, str, Optional[str]]:
|
|
||||||
"""Create new escrow contract"""
|
|
||||||
try:
|
|
||||||
# Validate inputs
|
|
||||||
if not self._validate_contract_inputs(job_id, client_address, agent_address, amount):
|
|
||||||
return False, "Invalid contract inputs", None
|
|
||||||
|
|
||||||
# Calculate fee
|
|
||||||
fee_rate = fee_rate or self.default_fee_rate
|
|
||||||
platform_fee = amount * fee_rate
|
|
||||||
total_amount = amount + platform_fee
|
|
||||||
|
|
||||||
# Validate milestones
|
|
||||||
validated_milestones = []
|
|
||||||
if milestones:
|
|
||||||
validated_milestones = await self._validate_milestones(milestones, amount)
|
|
||||||
if not validated_milestones:
|
|
||||||
return False, "Invalid milestones configuration", None
|
|
||||||
else:
|
|
||||||
# Create single milestone for full amount
|
|
||||||
validated_milestones = [{
|
|
||||||
'milestone_id': 'milestone_1',
|
|
||||||
'description': 'Complete job',
|
|
||||||
'amount': amount,
|
|
||||||
'completed': False
|
|
||||||
}]
|
|
||||||
|
|
||||||
# Create contract
|
|
||||||
contract_id = self._generate_contract_id(client_address, agent_address, job_id)
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
contract = EscrowContract(
|
|
||||||
contract_id=contract_id,
|
|
||||||
job_id=job_id,
|
|
||||||
client_address=client_address,
|
|
||||||
agent_address=agent_address,
|
|
||||||
amount=total_amount,
|
|
||||||
fee_rate=fee_rate,
|
|
||||||
created_at=current_time,
|
|
||||||
expires_at=current_time + (duration_days * 86400),
|
|
||||||
state=EscrowState.CREATED,
|
|
||||||
milestones=validated_milestones,
|
|
||||||
current_milestone=0,
|
|
||||||
dispute_reason=None,
|
|
||||||
dispute_evidence=[],
|
|
||||||
resolution=None,
|
|
||||||
released_amount=Decimal('0'),
|
|
||||||
refunded_amount=Decimal('0')
|
|
||||||
)
|
|
||||||
|
|
||||||
self.escrow_contracts[contract_id] = contract
|
|
||||||
|
|
||||||
log_info(f"Escrow contract created: {contract_id} for job {job_id}")
|
|
||||||
return True, "Contract created successfully", contract_id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return False, f"Contract creation failed: {str(e)}", None
|
|
||||||
|
|
||||||
def _validate_contract_inputs(self, job_id: str, client_address: str,
|
|
||||||
agent_address: str, amount: Decimal) -> bool:
|
|
||||||
"""Validate contract creation inputs"""
|
|
||||||
if not all([job_id, client_address, agent_address]):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Validate addresses (simplified)
|
|
||||||
if not (client_address.startswith('0x') and len(client_address) == 42):
|
|
||||||
return False
|
|
||||||
if not (agent_address.startswith('0x') and len(agent_address) == 42):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Validate amount
|
|
||||||
if amount <= 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check for existing contract
|
|
||||||
for contract in self.escrow_contracts.values():
|
|
||||||
if contract.job_id == job_id:
|
|
||||||
return False # Contract already exists for this job
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _validate_milestones(self, milestones: List[Dict], total_amount: Decimal) -> Optional[List[Dict]]:
|
|
||||||
"""Validate milestone configuration"""
|
|
||||||
if not milestones or len(milestones) > self.max_milestones:
|
|
||||||
return None
|
|
||||||
|
|
||||||
validated_milestones = []
|
|
||||||
milestone_total = Decimal('0')
|
|
||||||
|
|
||||||
for i, milestone_data in enumerate(milestones):
|
|
||||||
# Validate required fields
|
|
||||||
required_fields = ['milestone_id', 'description', 'amount']
|
|
||||||
if not all(field in milestone_data for field in required_fields):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Validate amount
|
|
||||||
amount = Decimal(str(milestone_data['amount']))
|
|
||||||
if amount < self.min_milestone_amount:
|
|
||||||
return None
|
|
||||||
|
|
||||||
milestone_total += amount
|
|
||||||
validated_milestones.append({
|
|
||||||
'milestone_id': milestone_data['milestone_id'],
|
|
||||||
'description': milestone_data['description'],
|
|
||||||
'amount': amount,
|
|
||||||
'completed': False
|
|
||||||
})
|
|
||||||
|
|
||||||
# Check if milestone amounts sum to total
|
|
||||||
if abs(milestone_total - total_amount) > Decimal('0.01'): # Allow small rounding difference
|
|
||||||
return None
|
|
||||||
|
|
||||||
return validated_milestones
|
|
||||||
|
|
||||||
def _generate_contract_id(self, client_address: str, agent_address: str, job_id: str) -> str:
|
|
||||||
"""Generate unique contract ID"""
|
|
||||||
import hashlib
|
|
||||||
content = f"{client_address}:{agent_address}:{job_id}:{time.time()}"
|
|
||||||
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
|
||||||
|
|
||||||
async def fund_contract(self, contract_id: str, payment_tx_hash: str) -> Tuple[bool, str]:
|
|
||||||
"""Fund escrow contract"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state != EscrowState.CREATED:
|
|
||||||
return False, f"Cannot fund contract in {contract.state.value} state"
|
|
||||||
|
|
||||||
# In real implementation, this would verify the payment transaction
|
|
||||||
# For now, assume payment is valid
|
|
||||||
|
|
||||||
contract.state = EscrowState.FUNDED
|
|
||||||
self.active_contracts.add(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Contract funded: {contract_id}")
|
|
||||||
return True, "Contract funded successfully"
|
|
||||||
|
|
||||||
async def start_job(self, contract_id: str) -> Tuple[bool, str]:
|
|
||||||
"""Mark job as started"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state != EscrowState.FUNDED:
|
|
||||||
return False, f"Cannot start job in {contract.state.value} state"
|
|
||||||
|
|
||||||
contract.state = EscrowState.JOB_STARTED
|
|
||||||
|
|
||||||
log_info(f"Job started for contract: {contract_id}")
|
|
||||||
return True, "Job started successfully"
|
|
||||||
|
|
||||||
async def complete_milestone(self, contract_id: str, milestone_id: str,
|
|
||||||
evidence: Dict = None) -> Tuple[bool, str]:
|
|
||||||
"""Mark milestone as completed"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state not in [EscrowState.JOB_STARTED, EscrowState.JOB_COMPLETED]:
|
|
||||||
return False, f"Cannot complete milestone in {contract.state.value} state"
|
|
||||||
|
|
||||||
# Find milestone
|
|
||||||
milestone = None
|
|
||||||
for ms in contract.milestones:
|
|
||||||
if ms['milestone_id'] == milestone_id:
|
|
||||||
milestone = ms
|
|
||||||
break
|
|
||||||
|
|
||||||
if not milestone:
|
|
||||||
return False, "Milestone not found"
|
|
||||||
|
|
||||||
if milestone['completed']:
|
|
||||||
return False, "Milestone already completed"
|
|
||||||
|
|
||||||
# Mark as completed
|
|
||||||
milestone['completed'] = True
|
|
||||||
milestone['completed_at'] = time.time()
|
|
||||||
|
|
||||||
# Add evidence if provided
|
|
||||||
if evidence:
|
|
||||||
milestone['evidence'] = evidence
|
|
||||||
|
|
||||||
# Check if all milestones are completed
|
|
||||||
all_completed = all(ms['completed'] for ms in contract.milestones)
|
|
||||||
if all_completed:
|
|
||||||
contract.state = EscrowState.JOB_COMPLETED
|
|
||||||
|
|
||||||
log_info(f"Milestone {milestone_id} completed for contract: {contract_id}")
|
|
||||||
return True, "Milestone completed successfully"
|
|
||||||
|
|
||||||
async def verify_milestone(self, contract_id: str, milestone_id: str,
|
|
||||||
verified: bool, feedback: str = "") -> Tuple[bool, str]:
|
|
||||||
"""Verify milestone completion"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
# Find milestone
|
|
||||||
milestone = None
|
|
||||||
for ms in contract.milestones:
|
|
||||||
if ms['milestone_id'] == milestone_id:
|
|
||||||
milestone = ms
|
|
||||||
break
|
|
||||||
|
|
||||||
if not milestone:
|
|
||||||
return False, "Milestone not found"
|
|
||||||
|
|
||||||
if not milestone['completed']:
|
|
||||||
return False, "Milestone not completed yet"
|
|
||||||
|
|
||||||
# Set verification status
|
|
||||||
milestone['verified'] = verified
|
|
||||||
milestone['verification_feedback'] = feedback
|
|
||||||
|
|
||||||
if verified:
|
|
||||||
# Release milestone payment
|
|
||||||
await self._release_milestone_payment(contract_id, milestone_id)
|
|
||||||
else:
|
|
||||||
# Create dispute if verification fails
|
|
||||||
await self._create_dispute(contract_id, DisputeReason.QUALITY_ISSUES,
|
|
||||||
f"Milestone {milestone_id} verification failed: {feedback}")
|
|
||||||
|
|
||||||
log_info(f"Milestone {milestone_id} verification: {verified} for contract: {contract_id}")
|
|
||||||
return True, "Milestone verification processed"
|
|
||||||
|
|
||||||
async def _release_milestone_payment(self, contract_id: str, milestone_id: str):
|
|
||||||
"""Release payment for verified milestone"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Find milestone
|
|
||||||
milestone = None
|
|
||||||
for ms in contract.milestones:
|
|
||||||
if ms['milestone_id'] == milestone_id:
|
|
||||||
milestone = ms
|
|
||||||
break
|
|
||||||
|
|
||||||
if not milestone:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Calculate payment amount (minus platform fee)
|
|
||||||
milestone_amount = Decimal(str(milestone['amount']))
|
|
||||||
platform_fee = milestone_amount * contract.fee_rate
|
|
||||||
payment_amount = milestone_amount - platform_fee
|
|
||||||
|
|
||||||
# Update released amount
|
|
||||||
contract.released_amount += payment_amount
|
|
||||||
|
|
||||||
# In real implementation, this would trigger actual payment transfer
|
|
||||||
log_info(f"Released {payment_amount} for milestone {milestone_id} in contract {contract_id}")
|
|
||||||
|
|
||||||
async def release_full_payment(self, contract_id: str) -> Tuple[bool, str]:
|
|
||||||
"""Release full payment to agent"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state != EscrowState.JOB_COMPLETED:
|
|
||||||
return False, f"Cannot release payment in {contract.state.value} state"
|
|
||||||
|
|
||||||
# Check if all milestones are verified
|
|
||||||
all_verified = all(ms.get('verified', False) for ms in contract.milestones)
|
|
||||||
if not all_verified:
|
|
||||||
return False, "Not all milestones are verified"
|
|
||||||
|
|
||||||
# Calculate remaining payment
|
|
||||||
total_milestone_amount = sum(Decimal(str(ms['amount'])) for ms in contract.milestones)
|
|
||||||
platform_fee_total = total_milestone_amount * contract.fee_rate
|
|
||||||
remaining_payment = total_milestone_amount - contract.released_amount - platform_fee_total
|
|
||||||
|
|
||||||
if remaining_payment > 0:
|
|
||||||
contract.released_amount += remaining_payment
|
|
||||||
|
|
||||||
contract.state = EscrowState.RELEASED
|
|
||||||
self.active_contracts.discard(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Full payment released for contract: {contract_id}")
|
|
||||||
return True, "Payment released successfully"
|
|
||||||
|
|
||||||
async def create_dispute(self, contract_id: str, reason: DisputeReason,
|
|
||||||
description: str, evidence: List[Dict] = None) -> Tuple[bool, str]:
|
|
||||||
"""Create dispute for contract"""
|
|
||||||
return await self._create_dispute(contract_id, reason, description, evidence)
|
|
||||||
|
|
||||||
async def _create_dispute(self, contract_id: str, reason: DisputeReason,
|
|
||||||
description: str, evidence: List[Dict] = None):
|
|
||||||
"""Internal dispute creation method"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state == EscrowState.DISPUTED:
|
|
||||||
return False, "Contract already disputed"
|
|
||||||
|
|
||||||
if contract.state not in [EscrowState.FUNDED, EscrowState.JOB_STARTED, EscrowState.JOB_COMPLETED]:
|
|
||||||
return False, f"Cannot dispute contract in {contract.state.value} state"
|
|
||||||
|
|
||||||
# Validate evidence
|
|
||||||
if evidence and (len(evidence) < self.min_dispute_evidence or len(evidence) > self.max_dispute_evidence):
|
|
||||||
return False, f"Invalid evidence count: {len(evidence)}"
|
|
||||||
|
|
||||||
# Create dispute
|
|
||||||
contract.state = EscrowState.DISPUTED
|
|
||||||
contract.dispute_reason = reason
|
|
||||||
contract.dispute_evidence = evidence or []
|
|
||||||
contract.dispute_created_at = time.time()
|
|
||||||
|
|
||||||
self.disputed_contracts.add(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Dispute created for contract: {contract_id} - {reason.value}")
|
|
||||||
return True, "Dispute created successfully"
|
|
||||||
|
|
||||||
async def resolve_dispute(self, contract_id: str, resolution: Dict) -> Tuple[bool, str]:
|
|
||||||
"""Resolve dispute with specified outcome"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state != EscrowState.DISPUTED:
|
|
||||||
return False, f"Contract not in disputed state: {contract.state.value}"
|
|
||||||
|
|
||||||
# Validate resolution
|
|
||||||
required_fields = ['winner', 'client_refund', 'agent_payment']
|
|
||||||
if not all(field in resolution for field in required_fields):
|
|
||||||
return False, "Invalid resolution format"
|
|
||||||
|
|
||||||
winner = resolution['winner']
|
|
||||||
client_refund = Decimal(str(resolution['client_refund']))
|
|
||||||
agent_payment = Decimal(str(resolution['agent_payment']))
|
|
||||||
|
|
||||||
# Validate amounts
|
|
||||||
total_refund = client_refund + agent_payment
|
|
||||||
if total_refund > contract.amount:
|
|
||||||
return False, "Refund amounts exceed contract amount"
|
|
||||||
|
|
||||||
# Apply resolution
|
|
||||||
contract.resolution = resolution
|
|
||||||
contract.state = EscrowState.RESOLVED
|
|
||||||
|
|
||||||
# Update amounts
|
|
||||||
contract.released_amount += agent_payment
|
|
||||||
contract.refunded_amount += client_refund
|
|
||||||
|
|
||||||
# Remove from disputed contracts
|
|
||||||
self.disputed_contracts.discard(contract_id)
|
|
||||||
self.active_contracts.discard(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Dispute resolved for contract: {contract_id} - Winner: {winner}")
|
|
||||||
return True, "Dispute resolved successfully"
|
|
||||||
|
|
||||||
async def refund_contract(self, contract_id: str, reason: str = "") -> Tuple[bool, str]:
|
|
||||||
"""Refund contract to client"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if contract.state in [EscrowState.RELEASED, EscrowState.REFUNDED, EscrowState.EXPIRED]:
|
|
||||||
return False, f"Cannot refund contract in {contract.state.value} state"
|
|
||||||
|
|
||||||
# Calculate refund amount (minus any released payments)
|
|
||||||
refund_amount = contract.amount - contract.released_amount
|
|
||||||
|
|
||||||
if refund_amount <= 0:
|
|
||||||
return False, "No amount available for refund"
|
|
||||||
|
|
||||||
contract.state = EscrowState.REFUNDED
|
|
||||||
contract.refunded_amount = refund_amount
|
|
||||||
|
|
||||||
self.active_contracts.discard(contract_id)
|
|
||||||
self.disputed_contracts.discard(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Contract refunded: {contract_id} - Amount: {refund_amount}")
|
|
||||||
return True, "Contract refunded successfully"
|
|
||||||
|
|
||||||
async def expire_contract(self, contract_id: str) -> Tuple[bool, str]:
|
|
||||||
"""Mark contract as expired"""
|
|
||||||
contract = self.escrow_contracts.get(contract_id)
|
|
||||||
if not contract:
|
|
||||||
return False, "Contract not found"
|
|
||||||
|
|
||||||
if time.time() < contract.expires_at:
|
|
||||||
return False, "Contract has not expired yet"
|
|
||||||
|
|
||||||
if contract.state in [EscrowState.RELEASED, EscrowState.REFUNDED, EscrowState.EXPIRED]:
|
|
||||||
return False, f"Contract already in final state: {contract.state.value}"
|
|
||||||
|
|
||||||
# Auto-refund if no work has been done
|
|
||||||
if contract.state == EscrowState.FUNDED:
|
|
||||||
return await self.refund_contract(contract_id, "Contract expired")
|
|
||||||
|
|
||||||
# Handle other states based on work completion
|
|
||||||
contract.state = EscrowState.EXPIRED
|
|
||||||
self.active_contracts.discard(contract_id)
|
|
||||||
self.disputed_contracts.discard(contract_id)
|
|
||||||
|
|
||||||
log_info(f"Contract expired: {contract_id}")
|
|
||||||
return True, "Contract expired successfully"
|
|
||||||
|
|
||||||
async def get_contract_info(self, contract_id: str) -> Optional[EscrowContract]:
|
|
||||||
"""Get contract information"""
|
|
||||||
return self.escrow_contracts.get(contract_id)
|
|
||||||
|
|
||||||
async def get_contracts_by_client(self, client_address: str) -> List[EscrowContract]:
|
|
||||||
"""Get contracts for specific client"""
|
|
||||||
return [
|
|
||||||
contract for contract in self.escrow_contracts.values()
|
|
||||||
if contract.client_address == client_address
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_contracts_by_agent(self, agent_address: str) -> List[EscrowContract]:
|
|
||||||
"""Get contracts for specific agent"""
|
|
||||||
return [
|
|
||||||
contract for contract in self.escrow_contracts.values()
|
|
||||||
if contract.agent_address == agent_address
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_active_contracts(self) -> List[EscrowContract]:
|
|
||||||
"""Get all active contracts"""
|
|
||||||
return [
|
|
||||||
self.escrow_contracts[contract_id]
|
|
||||||
for contract_id in self.active_contracts
|
|
||||||
if contract_id in self.escrow_contracts
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_disputed_contracts(self) -> List[EscrowContract]:
|
|
||||||
"""Get all disputed contracts"""
|
|
||||||
return [
|
|
||||||
self.escrow_contracts[contract_id]
|
|
||||||
for contract_id in self.disputed_contracts
|
|
||||||
if contract_id in self.escrow_contracts
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_escrow_statistics(self) -> Dict:
|
|
||||||
"""Get escrow system statistics"""
|
|
||||||
total_contracts = len(self.escrow_contracts)
|
|
||||||
active_count = len(self.active_contracts)
|
|
||||||
disputed_count = len(self.disputed_contracts)
|
|
||||||
|
|
||||||
# State distribution
|
|
||||||
state_counts = {}
|
|
||||||
for contract in self.escrow_contracts.values():
|
|
||||||
state = contract.state.value
|
|
||||||
state_counts[state] = state_counts.get(state, 0) + 1
|
|
||||||
|
|
||||||
# Financial statistics
|
|
||||||
total_amount = sum(contract.amount for contract in self.escrow_contracts.values())
|
|
||||||
total_released = sum(contract.released_amount for contract in self.escrow_contracts.values())
|
|
||||||
total_refunded = sum(contract.refunded_amount for contract in self.escrow_contracts.values())
|
|
||||||
total_fees = total_amount - total_released - total_refunded
|
|
||||||
|
|
||||||
return {
|
|
||||||
'total_contracts': total_contracts,
|
|
||||||
'active_contracts': active_count,
|
|
||||||
'disputed_contracts': disputed_count,
|
|
||||||
'state_distribution': state_counts,
|
|
||||||
'total_amount': float(total_amount),
|
|
||||||
'total_released': float(total_released),
|
|
||||||
'total_refunded': float(total_refunded),
|
|
||||||
'total_fees': float(total_fees),
|
|
||||||
'average_contract_value': float(total_amount / total_contracts) if total_contracts > 0 else 0
|
|
||||||
}
|
|
||||||
|
|
||||||
# Global escrow manager
|
|
||||||
escrow_manager: Optional[EscrowManager] = None
|
|
||||||
|
|
||||||
def get_escrow_manager() -> Optional[EscrowManager]:
|
|
||||||
"""Get global escrow manager"""
|
|
||||||
return escrow_manager
|
|
||||||
|
|
||||||
def create_escrow_manager() -> EscrowManager:
|
|
||||||
"""Create and set global escrow manager"""
|
|
||||||
global escrow_manager
|
|
||||||
escrow_manager = EscrowManager()
|
|
||||||
return escrow_manager
|
|
||||||
@@ -1,405 +0,0 @@
|
|||||||
"""
|
|
||||||
Fixed Guardian Configuration with Proper Guardian Setup
|
|
||||||
Addresses the critical vulnerability where guardian lists were empty
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
import json
|
|
||||||
from eth_account import Account
|
|
||||||
from eth_utils import to_checksum_address, keccak
|
|
||||||
|
|
||||||
from .guardian_contract import (
|
|
||||||
SpendingLimit,
|
|
||||||
TimeLockConfig,
|
|
||||||
GuardianConfig,
|
|
||||||
GuardianContract
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GuardianSetup:
|
|
||||||
"""Guardian setup configuration"""
|
|
||||||
primary_guardian: str # Main guardian address
|
|
||||||
backup_guardians: List[str] # Backup guardian addresses
|
|
||||||
multisig_threshold: int # Number of signatures required
|
|
||||||
emergency_contacts: List[str] # Additional emergency contacts
|
|
||||||
|
|
||||||
|
|
||||||
class SecureGuardianManager:
|
|
||||||
"""
|
|
||||||
Secure guardian management with proper initialization
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.guardian_registrations: Dict[str, GuardianSetup] = {}
|
|
||||||
self.guardian_contracts: Dict[str, GuardianContract] = {}
|
|
||||||
|
|
||||||
def create_guardian_setup(
|
|
||||||
self,
|
|
||||||
agent_address: str,
|
|
||||||
owner_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
custom_guardians: Optional[List[str]] = None
|
|
||||||
) -> GuardianSetup:
|
|
||||||
"""
|
|
||||||
Create a proper guardian setup for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
owner_address: Owner of the agent
|
|
||||||
security_level: Security level (conservative, aggressive, high_security)
|
|
||||||
custom_guardians: Optional custom guardian addresses
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Guardian setup configuration
|
|
||||||
"""
|
|
||||||
agent_address = to_checksum_address(agent_address)
|
|
||||||
owner_address = to_checksum_address(owner_address)
|
|
||||||
|
|
||||||
# Determine guardian requirements based on security level
|
|
||||||
if security_level == "conservative":
|
|
||||||
required_guardians = 3
|
|
||||||
multisig_threshold = 2
|
|
||||||
elif security_level == "aggressive":
|
|
||||||
required_guardians = 2
|
|
||||||
multisig_threshold = 2
|
|
||||||
elif security_level == "high_security":
|
|
||||||
required_guardians = 5
|
|
||||||
multisig_threshold = 3
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid security level: {security_level}")
|
|
||||||
|
|
||||||
# Build guardian list
|
|
||||||
guardians = []
|
|
||||||
|
|
||||||
# Always include the owner as primary guardian
|
|
||||||
guardians.append(owner_address)
|
|
||||||
|
|
||||||
# Add custom guardians if provided
|
|
||||||
if custom_guardians:
|
|
||||||
for guardian in custom_guardians:
|
|
||||||
guardian = to_checksum_address(guardian)
|
|
||||||
if guardian not in guardians:
|
|
||||||
guardians.append(guardian)
|
|
||||||
|
|
||||||
# Generate backup guardians if needed
|
|
||||||
while len(guardians) < required_guardians:
|
|
||||||
# Generate a deterministic backup guardian based on agent address
|
|
||||||
# In production, these would be trusted service addresses
|
|
||||||
backup_index = len(guardians) - 1 # -1 because owner is already included
|
|
||||||
backup_guardian = self._generate_backup_guardian(agent_address, backup_index)
|
|
||||||
|
|
||||||
if backup_guardian not in guardians:
|
|
||||||
guardians.append(backup_guardian)
|
|
||||||
|
|
||||||
# Create setup
|
|
||||||
setup = GuardianSetup(
|
|
||||||
primary_guardian=owner_address,
|
|
||||||
backup_guardians=[g for g in guardians if g != owner_address],
|
|
||||||
multisig_threshold=multisig_threshold,
|
|
||||||
emergency_contacts=guardians.copy()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.guardian_registrations[agent_address] = setup
|
|
||||||
|
|
||||||
return setup
|
|
||||||
|
|
||||||
def _generate_backup_guardian(self, agent_address: str, index: int) -> str:
|
|
||||||
"""
|
|
||||||
Generate deterministic backup guardian address
|
|
||||||
|
|
||||||
In production, these would be pre-registered trusted guardian addresses
|
|
||||||
"""
|
|
||||||
# Create a deterministic address based on agent address and index
|
|
||||||
seed = f"{agent_address}_{index}_backup_guardian"
|
|
||||||
hash_result = keccak(seed.encode())
|
|
||||||
|
|
||||||
# Use the hash to generate a valid address
|
|
||||||
address_bytes = hash_result[-20:] # Take last 20 bytes
|
|
||||||
address = "0x" + address_bytes.hex()
|
|
||||||
|
|
||||||
return to_checksum_address(address)
|
|
||||||
|
|
||||||
def create_secure_guardian_contract(
|
|
||||||
self,
|
|
||||||
agent_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
custom_guardians: Optional[List[str]] = None
|
|
||||||
) -> GuardianContract:
|
|
||||||
"""
|
|
||||||
Create a guardian contract with proper guardian configuration
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
security_level: Security level
|
|
||||||
custom_guardians: Optional custom guardian addresses
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured guardian contract
|
|
||||||
"""
|
|
||||||
# Create guardian setup
|
|
||||||
setup = self.create_guardian_setup(
|
|
||||||
agent_address=agent_address,
|
|
||||||
owner_address=agent_address, # Agent is its own owner initially
|
|
||||||
security_level=security_level,
|
|
||||||
custom_guardians=custom_guardians
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get security configuration
|
|
||||||
config = self._get_security_config(security_level, setup)
|
|
||||||
|
|
||||||
# Create contract
|
|
||||||
contract = GuardianContract(agent_address, config)
|
|
||||||
|
|
||||||
# Store contract
|
|
||||||
self.guardian_contracts[agent_address] = contract
|
|
||||||
|
|
||||||
return contract
|
|
||||||
|
|
||||||
def _get_security_config(self, security_level: str, setup: GuardianSetup) -> GuardianConfig:
|
|
||||||
"""Get security configuration with proper guardian list"""
|
|
||||||
|
|
||||||
# Build guardian list
|
|
||||||
all_guardians = [setup.primary_guardian] + setup.backup_guardians
|
|
||||||
|
|
||||||
if security_level == "conservative":
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=1000,
|
|
||||||
per_hour=5000,
|
|
||||||
per_day=20000,
|
|
||||||
per_week=100000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=5000,
|
|
||||||
delay_hours=24,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=all_guardians,
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False,
|
|
||||||
multisig_threshold=setup.multisig_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
elif security_level == "aggressive":
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=5000,
|
|
||||||
per_hour=25000,
|
|
||||||
per_day=100000,
|
|
||||||
per_week=500000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=20000,
|
|
||||||
delay_hours=12,
|
|
||||||
max_delay_hours=72
|
|
||||||
),
|
|
||||||
guardians=all_guardians,
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False,
|
|
||||||
multisig_threshold=setup.multisig_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
elif security_level == "high_security":
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=500,
|
|
||||||
per_hour=2000,
|
|
||||||
per_day=8000,
|
|
||||||
per_week=40000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=2000,
|
|
||||||
delay_hours=48,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=all_guardians,
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False,
|
|
||||||
multisig_threshold=setup.multisig_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid security level: {security_level}")
|
|
||||||
|
|
||||||
def test_emergency_pause(self, agent_address: str, guardian_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Test emergency pause functionality
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent address
|
|
||||||
guardian_address: Guardian attempting pause
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Test result
|
|
||||||
"""
|
|
||||||
if agent_address not in self.guardian_contracts:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": "Agent not registered"
|
|
||||||
}
|
|
||||||
|
|
||||||
contract = self.guardian_contracts[agent_address]
|
|
||||||
return contract.emergency_pause(guardian_address)
|
|
||||||
|
|
||||||
def verify_guardian_authorization(self, agent_address: str, guardian_address: str) -> bool:
|
|
||||||
"""
|
|
||||||
Verify if a guardian is authorized for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent address
|
|
||||||
guardian_address: Guardian address to verify
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if guardian is authorized
|
|
||||||
"""
|
|
||||||
if agent_address not in self.guardian_registrations:
|
|
||||||
return False
|
|
||||||
|
|
||||||
setup = self.guardian_registrations[agent_address]
|
|
||||||
all_guardians = [setup.primary_guardian] + setup.backup_guardians
|
|
||||||
|
|
||||||
return to_checksum_address(guardian_address) in [
|
|
||||||
to_checksum_address(g) for g in all_guardians
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_guardian_summary(self, agent_address: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Get guardian setup summary for an agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Guardian summary
|
|
||||||
"""
|
|
||||||
if agent_address not in self.guardian_registrations:
|
|
||||||
return {"error": "Agent not registered"}
|
|
||||||
|
|
||||||
setup = self.guardian_registrations[agent_address]
|
|
||||||
contract = self.guardian_contracts.get(agent_address)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"primary_guardian": setup.primary_guardian,
|
|
||||||
"backup_guardians": setup.backup_guardians,
|
|
||||||
"total_guardians": len(setup.backup_guardians) + 1,
|
|
||||||
"multisig_threshold": setup.multisig_threshold,
|
|
||||||
"emergency_contacts": setup.emergency_contacts,
|
|
||||||
"contract_status": contract.get_spending_status() if contract else None,
|
|
||||||
"pause_functional": contract is not None and len(setup.backup_guardians) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Fixed security configurations with proper guardians
|
|
||||||
def get_fixed_conservative_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
|
||||||
"""Get fixed conservative configuration with proper guardians"""
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=1000,
|
|
||||||
per_hour=5000,
|
|
||||||
per_day=20000,
|
|
||||||
per_week=100000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=5000,
|
|
||||||
delay_hours=24,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=[owner_address], # At least the owner
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_fixed_aggressive_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
|
||||||
"""Get fixed aggressive configuration with proper guardians"""
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=5000,
|
|
||||||
per_hour=25000,
|
|
||||||
per_day=100000,
|
|
||||||
per_week=500000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=20000,
|
|
||||||
delay_hours=12,
|
|
||||||
max_delay_hours=72
|
|
||||||
),
|
|
||||||
guardians=[owner_address], # At least the owner
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_fixed_high_security_config(agent_address: str, owner_address: str) -> GuardianConfig:
|
|
||||||
"""Get fixed high security configuration with proper guardians"""
|
|
||||||
return GuardianConfig(
|
|
||||||
limits=SpendingLimit(
|
|
||||||
per_transaction=500,
|
|
||||||
per_hour=2000,
|
|
||||||
per_day=8000,
|
|
||||||
per_week=40000
|
|
||||||
),
|
|
||||||
time_lock=TimeLockConfig(
|
|
||||||
threshold=2000,
|
|
||||||
delay_hours=48,
|
|
||||||
max_delay_hours=168
|
|
||||||
),
|
|
||||||
guardians=[owner_address], # At least the owner
|
|
||||||
pause_enabled=True,
|
|
||||||
emergency_mode=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Global secure guardian manager
|
|
||||||
secure_guardian_manager = SecureGuardianManager()
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience function for secure agent registration
|
|
||||||
def register_agent_with_guardians(
|
|
||||||
agent_address: str,
|
|
||||||
owner_address: str,
|
|
||||||
security_level: str = "conservative",
|
|
||||||
custom_guardians: Optional[List[str]] = None
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
|
||||||
Register an agent with proper guardian configuration
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_address: Agent wallet address
|
|
||||||
owner_address: Owner address
|
|
||||||
security_level: Security level
|
|
||||||
custom_guardians: Optional custom guardians
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Registration result
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Create secure guardian contract
|
|
||||||
contract = secure_guardian_manager.create_secure_guardian_contract(
|
|
||||||
agent_address=agent_address,
|
|
||||||
security_level=security_level,
|
|
||||||
custom_guardians=custom_guardians
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get guardian summary
|
|
||||||
summary = secure_guardian_manager.get_guardian_summary(agent_address)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "registered",
|
|
||||||
"agent_address": agent_address,
|
|
||||||
"security_level": security_level,
|
|
||||||
"guardian_count": summary["total_guardians"],
|
|
||||||
"multisig_threshold": summary["multisig_threshold"],
|
|
||||||
"pause_functional": summary["pause_functional"],
|
|
||||||
"registered_at": datetime.utcnow().isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status": "error",
|
|
||||||
"reason": f"Registration failed: {str(e)}"
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user