Files
aitbc/apps/zk-circuits/zk_cache.py
oib 15427c96c0 chore: update file permissions to executable across repository
- Change file mode from 644 to 755 for all project files
- Add chain_id parameter to get_balance RPC endpoint with default "ait-devnet"
- Rename Miner.extra_meta_data to extra_metadata for consistency
2026-03-06 22:17:54 +01:00

220 lines
7.5 KiB
Python
Executable File

#!/usr/bin/env python3
"""
ZK Circuit Compilation Cache System
Caches compiled circuit artifacts to speed up iterative development.
Tracks file dependencies and invalidates cache when source files change.
"""
import hashlib
import json
import os
from pathlib import Path
from typing import Dict, List, Optional
import time
class ZKCircuitCache:
"""Cache system for ZK circuit compilation artifacts"""
def __init__(self, cache_dir: Path = Path(".zk_cache")):
self.cache_dir = cache_dir
self.cache_dir.mkdir(exist_ok=True)
self.cache_manifest = self.cache_dir / "manifest.json"
def _calculate_file_hash(self, file_path: Path) -> str:
"""Calculate SHA256 hash of a file"""
if not file_path.exists():
return ""
with open(file_path, 'rb') as f:
return hashlib.sha256(f.read()).hexdigest()
def _get_cache_key(self, circuit_file: Path, output_dir: Path) -> str:
"""Generate cache key based on circuit file and dependencies"""
circuit_hash = self._calculate_file_hash(circuit_file)
# Include any imported files in hash calculation
dependencies = self._find_dependencies(circuit_file)
dep_hashes = [self._calculate_file_hash(dep) for dep in dependencies]
# Create composite hash
composite = f"{circuit_hash}|{'|'.join(dep_hashes)}|{output_dir}"
return hashlib.sha256(composite.encode()).hexdigest()[:16]
def _find_dependencies(self, circuit_file: Path) -> List[Path]:
"""Find Circom include dependencies"""
dependencies = []
try:
with open(circuit_file, 'r') as f:
content = f.read()
# Find include statements
import re
includes = re.findall(r'include\s+["\']([^"\']+)["\']', content)
circuit_dir = circuit_file.parent
for include in includes:
dep_path = circuit_dir / include
if dep_path.exists():
dependencies.append(dep_path)
# Recursively find dependencies
dependencies.extend(self._find_dependencies(dep_path))
except Exception:
pass
return list(set(dependencies)) # Remove duplicates
def is_cache_valid(self, circuit_file: Path, output_dir: Path) -> bool:
"""Check if cached artifacts are still valid"""
cache_key = self._get_cache_key(circuit_file, output_dir)
cache_entry = self._load_cache_entry(cache_key)
if not cache_entry:
return False
# Check if source files have changed
circuit_hash = self._calculate_file_hash(circuit_file)
if circuit_hash != cache_entry.get('circuit_hash'):
return False
# Check dependencies
dependencies = self._find_dependencies(circuit_file)
cached_deps = cache_entry.get('dependencies', {})
if len(dependencies) != len(cached_deps):
return False
for dep in dependencies:
dep_hash = self._calculate_file_hash(dep)
if dep_hash != cached_deps.get(str(dep)):
return False
# Check if output files exist
expected_files = cache_entry.get('output_files', [])
for file_path in expected_files:
if not Path(file_path).exists():
return False
return True
def _load_cache_entry(self, cache_key: str) -> Optional[Dict]:
"""Load cache entry from manifest"""
try:
if self.cache_manifest.exists():
with open(self.cache_manifest, 'r') as f:
manifest = json.load(f)
return manifest.get(cache_key)
except Exception:
pass
return None
def _save_cache_entry(self, cache_key: str, entry: Dict):
"""Save cache entry to manifest"""
try:
manifest = {}
if self.cache_manifest.exists():
with open(self.cache_manifest, 'r') as f:
manifest = json.load(f)
manifest[cache_key] = entry
with open(self.cache_manifest, 'w') as f:
json.dump(manifest, f, indent=2)
except Exception as e:
print(f"Warning: Failed to save cache entry: {e}")
def get_cached_artifacts(self, circuit_file: Path, output_dir: Path) -> Optional[Dict]:
"""Retrieve cached artifacts if valid"""
if self.is_cache_valid(circuit_file, output_dir):
cache_key = self._get_cache_key(circuit_file, output_dir)
cache_entry = self._load_cache_entry(cache_key)
return cache_entry
return None
def cache_artifacts(self, circuit_file: Path, output_dir: Path, compilation_time: float):
"""Cache successful compilation artifacts"""
cache_key = self._get_cache_key(circuit_file, output_dir)
# Find all output files
output_files = []
if output_dir.exists():
for ext in ['.r1cs', '.wasm', '.sym', '.c', '.dat']:
for file_path in output_dir.rglob(f'*{ext}'):
output_files.append(str(file_path))
# Calculate dependency hashes
dependencies = self._find_dependencies(circuit_file)
dep_hashes = {str(dep): self._calculate_file_hash(dep) for dep in dependencies}
entry = {
'circuit_file': str(circuit_file),
'output_dir': str(output_dir),
'circuit_hash': self._calculate_file_hash(circuit_file),
'dependencies': dep_hashes,
'output_files': output_files,
'compilation_time': compilation_time,
'cached_at': time.time()
}
self._save_cache_entry(cache_key, entry)
def clear_cache(self):
"""Clear all cached artifacts"""
import shutil
if self.cache_dir.exists():
shutil.rmtree(self.cache_dir)
self.cache_dir.mkdir(exist_ok=True)
def get_cache_stats(self) -> Dict:
"""Get cache statistics"""
try:
if self.cache_manifest.exists():
with open(self.cache_manifest, 'r') as f:
manifest = json.load(f)
total_entries = len(manifest)
total_size = 0
for entry in manifest.values():
for file_path in entry.get('output_files', []):
try:
total_size += Path(file_path).stat().st_size
except:
pass
return {
'entries': total_entries,
'total_size_mb': total_size / (1024 * 1024),
'cache_dir': str(self.cache_dir)
}
except Exception:
pass
return {'entries': 0, 'total_size_mb': 0, 'cache_dir': str(self.cache_dir)}
def main():
"""CLI interface for cache management"""
import argparse
parser = argparse.ArgumentParser(description='ZK Circuit Compilation Cache')
parser.add_argument('action', choices=['stats', 'clear'], help='Action to perform')
args = parser.parse_args()
cache = ZKCircuitCache()
if args.action == 'stats':
stats = cache.get_cache_stats()
print(f"Cache Statistics:")
print(f" Entries: {stats['entries']}")
print(f" Total Size: {stats['total_size_mb']:.2f} MB")
print(f" Cache Directory: {stats['cache_dir']}")
elif args.action == 'clear':
cache.clear_cache()
print("Cache cleared successfully")
if __name__ == "__main__":
main()