diff --git a/aitbc/dependency_scanner.py b/aitbc/dependency_scanner.py new file mode 100644 index 00000000..a647b95a --- /dev/null +++ b/aitbc/dependency_scanner.py @@ -0,0 +1,262 @@ +""" +Dependency vulnerability scanning utilities for AITBC +Provides automated vulnerability scanning for Python dependencies +""" + +import subprocess +import json +import re +from typing import List, Dict, Any, Optional +from dataclasses import dataclass +from pathlib import Path +from datetime import datetime + +from .aitbc_logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class VulnerabilityReport: + """Vulnerability scan report""" + package: str + version: str + vulnerability_id: str + severity: str + description: str + fix_available: bool + fixed_version: Optional[str] + + +class DependencyScanner: + """ + Dependency vulnerability scanner. + Scans Python dependencies for known vulnerabilities. + """ + + def __init__(self, requirements_file: Optional[Path] = None): + """ + Initialize dependency scanner + + Args: + requirements_file: Path to requirements.txt or pyproject.toml + """ + self.requirements_file = requirements_file or Path("requirements.txt") + self._vulnerabilities: List[VulnerabilityReport] = [] + + def scan_with_pip_audit(self) -> List[VulnerabilityReport]: + """ + Scan dependencies using pip-audit + + Returns: + List of vulnerability reports + """ + logger.info("Running pip-audit vulnerability scan") + + try: + result = subprocess.run( + ["pip-audit", "--format", "json"], + capture_output=True, + text=True, + timeout=300 + ) + + if result.returncode == 0: + logger.info("No vulnerabilities found") + return [] + + # Parse JSON output + try: + audit_data = json.loads(result.stdout) + return self._parse_pip_audit_output(audit_data) + except json.JSONDecodeError: + logger.warning("Failed to parse pip-audit JSON output") + return [] + + except FileNotFoundError: + logger.warning("pip-audit not found, skipping scan") + return [] + except subprocess.TimeoutExpired: + logger.error("pip-audit scan timed out") + return [] + except Exception as e: + logger.error(f"pip-audit scan failed: {e}") + return [] + + def scan_with_bandit(self, target_dir: Optional[Path] = None) -> List[Dict[str, Any]]: + """ + Scan code for security issues using Bandit + + Args: + target_dir: Directory to scan (default: current directory) + + Returns: + List of security issues + """ + target_dir = target_dir or Path(".") + logger.info(f"Running Bandit security scan on {target_dir}") + + try: + result = subprocess.run( + ["bandit", "-r", str(target_dir), "-f", "json"], + capture_output=True, + text=True, + timeout=300 + ) + + try: + bandit_data = json.loads(result.stdout) + return bandit_data.get("results", []) + except json.JSONDecodeError: + logger.warning("Failed to parse Bandit JSON output") + return [] + + except FileNotFoundError: + logger.warning("Bandit not found, skipping scan") + return [] + except subprocess.TimeoutExpired: + logger.error("Bandit scan timed out") + return [] + except Exception as e: + logger.error(f"Bandit scan failed: {e}") + return [] + + def _parse_pip_audit_output(self, audit_data: Dict[str, Any]) -> List[VulnerabilityReport]: + """ + Parse pip-audit JSON output + + Args: + audit_data: Raw audit data from pip-audit + + Returns: + List of vulnerability reports + """ + vulnerabilities = [] + + for dep in audit_data.get("dependencies", []): + for vuln in dep.get("vulnerabilities", []): + report = VulnerabilityReport( + package=dep.get("name", ""), + version=dep.get("version", ""), + vulnerability_id=vuln.get("id", ""), + severity=vuln.get("severity", "UNKNOWN"), + description=vuln.get("description", ""), + fix_available=vuln.get("fix_versions", []) != [], + fixed_version=vuln.get("fix_versions", [None])[0] if vuln.get("fix_versions") else None + ) + vulnerabilities.append(report) + + return vulnerabilities + + def generate_report(self) -> Dict[str, Any]: + """ + Generate comprehensive vulnerability report + + Returns: + Dictionary with scan results + """ + pip_audit_results = self.scan_with_pip_audit() + bandit_results = self.scan_with_bandit() + + # Count vulnerabilities by severity + severity_counts = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0, "UNKNOWN": 0} + for vuln in pip_audit_results: + severity = vuln.severity.upper() + severity_counts[severity] = severity_counts.get(severity, 0) + 1 + + return { + "timestamp": datetime.now().isoformat(), + "dependency_vulnerabilities": len(pip_audit_results), + "security_issues": len(bandit_results), + "severity_breakdown": severity_counts, + "vulnerabilities": [ + { + "package": v.package, + "version": v.version, + "id": v.vulnerability_id, + "severity": v.severity, + "description": v.description, + "fix_available": v.fix_available, + "fixed_version": v.fixed_version + } + for v in pip_audit_results + ], + "bandit_issues": bandit_results + } + + def save_report(self, output_file: Path) -> None: + """ + Save vulnerability report to file + + Args: + output_file: Path to output file + """ + report = self.generate_report() + + output_file.parent.mkdir(parents=True, exist_ok=True) + with open(output_file, 'w') as f: + json.dump(report, f, indent=2, default=str) + + logger.info(f"Vulnerability report saved to {output_file}") + + +def run_dependency_scan( + requirements_file: Optional[Path] = None, + output_file: Optional[Path] = None +) -> Dict[str, Any]: + """ + Run comprehensive dependency vulnerability scan + + Args: + requirements_file: Path to requirements file + output_file: Path to save report + + Returns: + Vulnerability scan report + """ + scanner = DependencyScanner(requirements_file) + report = scanner.generate_report() + + if output_file: + scanner.save_report(output_file) + + return report + + +def check_vulnerability_thresholds( + report: Dict[str, Any], + max_critical: int = 0, + max_high: int = 0, + max_medium: int = 10, + max_low: int = 50 +) -> bool: + """ + Check if vulnerability counts are within acceptable thresholds + + Args: + report: Vulnerability scan report + max_critical: Maximum allowed critical vulnerabilities + max_high: Maximum allowed high vulnerabilities + max_medium: Maximum allowed medium vulnerabilities + max_low: Maximum allowed low vulnerabilities + + Returns: + True if within thresholds, False otherwise + """ + severity = report.get("severity_breakdown", {}) + + if severity.get("CRITICAL", 0) > max_critical: + logger.error(f"Critical vulnerabilities exceed threshold: {severity.get('CRITICAL')} > {max_critical}") + return False + + if severity.get("HIGH", 0) > max_high: + logger.error(f"High vulnerabilities exceed threshold: {severity.get('HIGH')} > {max_high}") + return False + + if severity.get("MEDIUM", 0) > max_medium: + logger.warning(f"Medium vulnerabilities exceed threshold: {severity.get('MEDIUM')} > {max_medium}") + + if severity.get("LOW", 0) > max_low: + logger.warning(f"Low vulnerabilities exceed threshold: {severity.get('LOW')} > {max_low}") + + return True diff --git a/aitbc/feature_flags.py b/aitbc/feature_flags.py new file mode 100644 index 00000000..e3d085f5 --- /dev/null +++ b/aitbc/feature_flags.py @@ -0,0 +1,278 @@ +""" +Feature flags utilities for AITBC +Provides feature flag management for gradual rollouts +""" + +from typing import Dict, Any, Optional, Set +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +import json + +from .aitbc_logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class FeatureFlag: + """Feature flag configuration""" + name: str + enabled: bool + description: str + rollout_percentage: float = 100.0 + whitelisted_users: Optional[Set[str]] = None + blacklisted_users: Optional[Set[str]] = None + enabled_since: Optional[datetime] = None + + +class FeatureFlagManager: + """ + Feature flag manager for gradual rollouts. + Provides feature flag management with user whitelisting and percentage-based rollouts. + """ + + def __init__(self, config_file: Optional[Path] = None): + """ + Initialize feature flag manager + + Args: + config_file: Path to feature flags configuration file + """ + self.config_file = config_file or Path("feature_flags.json") + self._flags: Dict[str, FeatureFlag] = {} + self._load_flags() + + def _load_flags(self) -> None: + """Load feature flags from configuration file""" + if not self.config_file.exists(): + logger.info(f"No feature flags file found at {self.config_file}, using defaults") + return + + try: + with open(self.config_file, 'r') as f: + data = json.load(f) + + for name, config in data.items(): + self._flags[name] = FeatureFlag( + name=name, + enabled=config.get("enabled", False), + description=config.get("description", ""), + rollout_percentage=config.get("rollout_percentage", 100.0), + whitelisted_users=set(config.get("whitelisted_users", [])), + blacklisted_users=set(config.get("blacklisted_users", [])), + enabled_since=datetime.fromisoformat(config["enabled_since"]) if config.get("enabled_since") else None + ) + + logger.info(f"Loaded {len(self._flags)} feature flags from {self.config_file}") + + except Exception as e: + logger.error(f"Failed to load feature flags: {e}") + + def save_flags(self) -> None: + """Save feature flags to configuration file""" + data = {} + for name, flag in self._flags.items(): + data[name] = { + "enabled": flag.enabled, + "description": flag.description, + "rollout_percentage": flag.rollout_percentage, + "whitelisted_users": list(flag.whitelisted_users) if flag.whitelisted_users else [], + "blacklisted_users": list(flag.blacklisted_users) if flag.blacklisted_users else [], + "enabled_since": flag.enabled_since.isoformat() if flag.enabled_since else None + } + + self.config_file.parent.mkdir(parents=True, exist_ok=True) + with open(self.config_file, 'w') as f: + json.dump(data, f, indent=2) + + logger.info(f"Saved {len(self._flags)} feature flags to {self.config_file}") + + def is_enabled( + self, + feature_name: str, + user_id: Optional[str] = None, + user_hash: Optional[int] = None + ) -> bool: + """ + Check if a feature is enabled for a user + + Args: + feature_name: Name of the feature flag + user_id: User identifier + user_hash: Hash of user identifier for percentage-based rollout + + Returns: + True if feature is enabled, False otherwise + """ + flag = self._flags.get(feature_name) + + if not flag: + logger.warning(f"Feature flag {feature_name} not found, defaulting to disabled") + return False + + # Check if globally disabled + if not flag.enabled: + return False + + # Check if user is blacklisted + if flag.blacklisted_users and user_id in flag.blacklisted_users: + return False + + # Check if user is whitelisted + if flag.whitelisted_users and user_id in flag.whitelisted_users: + return True + + # Check percentage-based rollout + if flag.rollout_percentage < 100.0 and user_hash is not None: + # Use modulo to determine if user is in rollout percentage + if (user_hash % 100) < flag.rollout_percentage: + return True + return False + + # Feature is globally enabled + return True + + def enable_feature(self, feature_name: str, rollout_percentage: float = 100.0) -> None: + """ + Enable a feature flag + + Args: + feature_name: Name of the feature flag + rollout_percentage: Rollout percentage (0-100) + """ + if feature_name not in self._flags: + self._flags[feature_name] = FeatureFlag( + name=feature_name, + enabled=True, + description="", + rollout_percentage=rollout_percentage, + enabled_since=datetime.now() + ) + else: + self._flags[feature_name].enabled = True + self._flags[feature_name].rollout_percentage = rollout_percentage + if not self._flags[feature_name].enabled_since: + self._flags[feature_name].enabled_since = datetime.now() + + logger.info(f"Enabled feature flag {feature_name} with {rollout_percentage}% rollout") + self.save_flags() + + def disable_feature(self, feature_name: str) -> None: + """ + Disable a feature flag + + Args: + feature_name: Name of the feature flag + """ + if feature_name in self._flags: + self._flags[feature_name].enabled = False + logger.info(f"Disabled feature flag {feature_name}") + self.save_flags() + + def add_whitelisted_user(self, feature_name: str, user_id: str) -> None: + """ + Add user to feature whitelist + + Args: + feature_name: Name of the feature flag + user_id: User identifier + """ + if feature_name not in self._flags: + self._flags[feature_name] = FeatureFlag( + name=feature_name, + enabled=False, + description="", + whitelisted_users=set() + ) + + if not self._flags[feature_name].whitelisted_users: + self._flags[feature_name].whitelisted_users = set() + + self._flags[feature_name].whitelisted_users.add(user_id) + logger.info(f"Added {user_id} to whitelist for {feature_name}") + self.save_flags() + + def add_blacklisted_user(self, feature_name: str, user_id: str) -> None: + """ + Add user to feature blacklist + + Args: + feature_name: Name of the feature flag + user_id: User identifier + """ + if feature_name not in self._flags: + self._flags[feature_name] = FeatureFlag( + name=feature_name, + enabled=False, + description="", + blacklisted_users=set() + ) + + if not self._flags[feature_name].blacklisted_users: + self._flags[feature_name].blacklisted_users = set() + + self._flags[feature_name].blacklisted_users.add(user_id) + logger.info(f"Added {user_id} to blacklist for {feature_name}") + self.save_flags() + + def get_all_flags(self) -> Dict[str, FeatureFlag]: + """ + Get all feature flags + + Returns: + Dictionary of all feature flags + """ + return self._flags.copy() + + def get_flag_status(self, feature_name: str) -> Optional[FeatureFlag]: + """ + Get status of a specific feature flag + + Args: + feature_name: Name of the feature flag + + Returns: + Feature flag or None if not found + """ + return self._flags.get(feature_name) + + +# Global feature flag manager instance +_global_feature_flag_manager: Optional[FeatureFlagManager] = None + + +def get_feature_flag_manager(config_file: Optional[Path] = None) -> FeatureFlagManager: + """ + Get the global feature flag manager instance + + Args: + config_file: Path to feature flags configuration file + + Returns: + FeatureFlagManager instance + """ + global _global_feature_flag_manager + if _global_feature_flag_manager is None: + _global_feature_flag_manager = FeatureFlagManager(config_file) + return _global_feature_flag_manager + + +def is_feature_enabled( + feature_name: str, + user_id: Optional[str] = None, + user_hash: Optional[int] = None +) -> bool: + """ + Check if a feature is enabled using global manager + + Args: + feature_name: Name of the feature flag + user_id: User identifier + user_hash: Hash of user identifier + + Returns: + True if feature is enabled, False otherwise + """ + manager = get_feature_flag_manager() + return manager.is_enabled(feature_name, user_id, user_hash)