import ast import hashlib import re from collections import OrderedDict from src.errors import SecurityViolationError from src.import_validation import validate_module_import from src.config.security_config import SecurityConfig from src.constants import ( MAX_VALIDATION_CACHE_SIZE, ERROR_RELATIVE_IMPORT, ERROR_DANGEROUS_NAME, ERROR_DANGEROUS_ATTRIBUTE, ERROR_NAME_MANGLED_ATTRIBUTE, ERROR_DYNAMIC_IMPORT, ERROR_DANGEROUS_STRING_PATTERN, ERROR_MATCH_PATTERN_ATTRIBUTE, ERROR_MATCH_POSITIONAL_PATTERN, ERROR_GLOBAL_BLOCKED_NAME, ERROR_FUNCDEF_BLOCKED_NAME, ERROR_CLASSDEF_BLOCKED_NAME, BLOCKED_ATTRIBUTES, BLOCKED_NAMES, ) CacheKey = tuple[str, tuple] # (code_hash, allowlists_tuple) CachedViolations = list[str] ValidationCache = OrderedDict[CacheKey, CachedViolations] FORMAT_FIELD_PATTERN = re.compile(r"\{([^}]*)\}") class SecurityValidator(ast.NodeVisitor): """AST visitor that enforces import allowlists and blocks dangerous attribute access.""" def __init__(self, security_config: SecurityConfig): self.checked_modules: set[str] = set() self.violations: list[str] = [] self.security_config = security_config # ========== Detection ========== def visit_Import(self, node: ast.Import) -> None: """Detect bare import statements (e.g., import os), including aliased (e.g., import numpy as np).""" for alias in node.names: module_name = alias.name self._validate_import(module_name, node.lineno) self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: """Detect from import statements (e.g., from os import path).""" if node.level > 0: self._add_violation(node.lineno, ERROR_RELATIVE_IMPORT) elif node.module: self._validate_import(node.module, node.lineno) self.generic_visit(node) def visit_Name(self, node: ast.Name) -> None: if node.id in BLOCKED_NAMES: self._add_violation(node.lineno, ERROR_DANGEROUS_NAME.format(name=node.id)) self.generic_visit(node) def visit_Attribute(self, node: ast.Attribute) -> None: """Detect access to unsafe attributes that could bypass security restrictions.""" if node.attr in BLOCKED_ATTRIBUTES: self._add_violation( node.lineno, ERROR_DANGEROUS_ATTRIBUTE.format(attr=node.attr) ) if node.attr.startswith("_") and "__" in node.attr: parts = node.attr.split("__", 1) if len(parts) == 2 and parts[0].startswith("_"): self._add_violation(node.lineno, ERROR_NAME_MANGLED_ATTRIBUTE) self.generic_visit(node) def visit_Call(self, node: ast.Call) -> None: """Detect calls to __import__() that could bypass security restrictions.""" is_import_call = ( # __import__() (isinstance(node.func, ast.Name) and node.func.id == "__import__") or # builtins.__import__() or __builtins__.__import__() ( isinstance(node.func, ast.Attribute) and node.func.attr == "__import__" and isinstance(node.func.value, ast.Name) and node.func.value.id in {"builtins", "__builtins__"} ) ) if is_import_call: if ( node.args and isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str) ): module_name = node.args[0].value self._validate_import(module_name, node.lineno) else: self._add_violation(node.lineno, ERROR_DYNAMIC_IMPORT) self.generic_visit(node) def visit_Subscript(self, node: ast.Subscript) -> None: """Detect dict access to blocked attributes, e.g. __builtins__['__spec__']""" is_builtins_access = ( # __builtins__['__spec__'] ( isinstance(node.value, ast.Name) and node.value.id in {"__builtins__", "builtins"} ) # obj.__builtins__['__spec__'] or ( isinstance(node.value, ast.Attribute) and node.value.attr in {"__builtins__", "builtins"} ) ) if ( is_builtins_access and isinstance(node.slice, ast.Constant) and isinstance(node.slice.value, str) ): key = node.slice.value if key in BLOCKED_ATTRIBUTES: self._add_violation( node.lineno, ERROR_DANGEROUS_ATTRIBUTE.format(attr=key) ) self.generic_visit(node) def visit_Constant(self, node: ast.Constant) -> None: """Detect string constants containing dangerous format patterns.""" if isinstance(node.value, str): self._check_format_string(node.value, node.lineno) self.generic_visit(node) def visit_Global(self, node: ast.Global) -> None: """Detect global declarations of blocked names, e.g. `global __builtins__`""" for name in node.names: if name in BLOCKED_NAMES: self._add_violation( node.lineno, ERROR_GLOBAL_BLOCKED_NAME.format(name=name) ) self.generic_visit(node) def visit_FunctionDef(self, node: ast.FunctionDef) -> None: """Detect function definitions with blocked names, e.g. `def __builtins__(): pass`""" if node.name in BLOCKED_NAMES: self._add_violation( node.lineno, ERROR_FUNCDEF_BLOCKED_NAME.format(name=node.name) ) self.generic_visit(node) def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: """Detect async function definitions with blocked names.""" if node.name in BLOCKED_NAMES: self._add_violation( node.lineno, ERROR_FUNCDEF_BLOCKED_NAME.format(name=node.name) ) self.generic_visit(node) def visit_ClassDef(self, node: ast.ClassDef) -> None: """Detect class definitions with blocked names, e.g. `class __builtins__: pass`""" if node.name in BLOCKED_NAMES: self._add_violation( node.lineno, ERROR_CLASSDEF_BLOCKED_NAME.format(name=node.name) ) self.generic_visit(node) def visit_MatchClass(self, node: ast.MatchClass) -> None: """Detect match patterns that extract blocked attributes, e.g. `case AttributeError(obj=x)`""" if node.patterns: self._add_violation(node.lineno, ERROR_MATCH_POSITIONAL_PATTERN) for attr in node.kwd_attrs: if attr in BLOCKED_ATTRIBUTES: self._add_violation( node.lineno, ERROR_MATCH_PATTERN_ATTRIBUTE.format(attr=attr) ) self.generic_visit(node) def _check_format_string(self, s: str, lineno: int) -> None: """Check if a string contains format patterns that access blocked attributes.""" # escaped braces produce literal braces, not format fields s = s.replace("{{", "").replace("}}", "") for match in FORMAT_FIELD_PATTERN.finditer(s): field = match.group(1) # attribute access for attr_match in re.finditer(r"\.(\w+)", field): attr = attr_match.group(1) if attr in BLOCKED_ATTRIBUTES or attr in BLOCKED_NAMES: self._add_violation( lineno, ERROR_DANGEROUS_STRING_PATTERN.format(attr=attr) ) # subscript access for subscript_match in re.finditer(r"\[(['\"]?)(\w+)\1\]", field): key = subscript_match.group(2) if key in BLOCKED_ATTRIBUTES or key in BLOCKED_NAMES: self._add_violation( lineno, ERROR_DANGEROUS_STRING_PATTERN.format(attr=key) ) # ========== Validation ========== def _validate_import(self, module_path: str, lineno: int) -> None: """Validate that a module import is allowed based on allowlists. Also disallow relative imports.""" if module_path.startswith("."): self._add_violation(lineno, ERROR_RELATIVE_IMPORT) return module_name = module_path.split(".")[0] # e.g., os.path -> os if module_name in self.checked_modules: return self.checked_modules.add(module_name) is_allowed, error_msg = validate_module_import( module_path, self.security_config ) if not is_allowed: assert error_msg is not None self._add_violation(lineno, error_msg) def _add_violation(self, lineno: int, message: str) -> None: self.violations.append(f"Line {lineno}: {message}") class TaskAnalyzer: _cache: ValidationCache = OrderedDict() def __init__(self, security_config: SecurityConfig): self._security_config = security_config self._allowlists = ( tuple(sorted(security_config.stdlib_allow)), tuple(sorted(security_config.external_allow)), ) self._allow_all = ( "*" in security_config.stdlib_allow and "*" in security_config.external_allow ) def validate(self, code: str) -> None: if self._allow_all: return cache_key = self._to_cache_key(code) cached_violations = self._cache.get(cache_key) if cached_violations is not None: self._cache.move_to_end(cache_key) if len(cached_violations) == 0: return self._raise_security_error(cached_violations) tree = ast.parse(code) security_validator = SecurityValidator(self._security_config) security_validator.visit(tree) self._set_in_cache(cache_key, security_validator.violations) if security_validator.violations: self._raise_security_error(security_validator.violations) def _raise_security_error(self, violations: CachedViolations) -> None: raise SecurityViolationError( message="Security violations detected", description="\n".join(violations) ) def _to_cache_key(self, code: str) -> CacheKey: code_hash = hashlib.sha256(code.encode()).hexdigest() return (code_hash, self._allowlists) def _set_in_cache(self, cache_key: CacheKey, violations: CachedViolations) -> None: if len(self._cache) >= MAX_VALIDATION_CACHE_SIZE: self._cache.popitem(last=False) # FIFO self._cache[cache_key] = violations.copy() self._cache.move_to_end(cache_key)