"""Tests for SHA-256 hash-chain audit signing -- Sprint 8.""" from __future__ import annotations import json import tempfile from pathlib import Path import pytest from arautopilot.core.audit import ( GENESIS_HASH, AuditEvent, AuditLog, AuditOutcome, ) def _make_event(**kwargs) -> AuditEvent: defaults = dict(action="test_action", outcome=AuditOutcome.SUCCESS) defaults.update(kwargs) return AuditEvent(**defaults) class TestGenesisHash: def test_genesis_hash_is_64_hex_chars(self): assert len(GENESIS_HASH) == 64 assert all(c == "0" for c in GENESIS_HASH) def test_compute_hash_returns_64_char_hex(self): h = AuditEvent._compute_hash(GENESIS_HASH, "payload") assert len(h) == 64 assert all(c in "0123456789abcdef" for c in h) class TestHashChainAppend: def test_first_event_prev_hash_is_genesis(self, tmp_path): log = AuditLog(tmp_path / "audit.jsonl") log.append(_make_event(action="first")) events = log.read_all() assert events[0].prev_hash == GENESIS_HASH def test_second_event_prev_hash_links_to_first(self, tmp_path): log = AuditLog(tmp_path / "audit.jsonl") log.append(_make_event(action="first")) log.append(_make_event(action="second")) events = log.read_all() assert events[1].prev_hash == events[0].line_hash def test_line_hash_is_deterministic(self, tmp_path): log = AuditLog(tmp_path / "audit.jsonl") event = _make_event(action="deterministic") log.append(event) events = log.read_all() e = events[0] # Recompute manually payload_dict = e.model_dump(mode="json") payload_dict.pop("prev_hash") payload_dict.pop("line_hash") canonical = json.dumps(payload_dict, ensure_ascii=False, sort_keys=True) expected = AuditEvent._compute_hash(GENESIS_HASH, canonical) assert e.line_hash == expected def test_chain_links_across_multiple_events(self, tmp_path): log = AuditLog(tmp_path / "audit.jsonl") for i in range(5): log.append(_make_event(action=f"event_{i}")) events = log.read_all() assert events[0].prev_hash == GENESIS_HASH for i in range(1, 5): assert events[i].prev_hash == events[i - 1].line_hash def test_chain_continues_after_reload(self, tmp_path): p = tmp_path / "audit.jsonl" log1 = AuditLog(p) log1.append(_make_event(action="first")) first_hash = log1._last_line_hash log2 = AuditLog(p) # reload log2.append(_make_event(action="second")) events = log2.read_all() assert events[1].prev_hash == first_hash class TestVerifyChain: def test_empty_log_verifies_ok(self, tmp_path): log = AuditLog(tmp_path / "audit.jsonl") ok, reason = log.verify_chain() assert ok assert reason == "ok" def test_valid_chain_verifies_ok(self, tmp_path): log = AuditLog(tmp_path / "audit.jsonl") for i in range(10): log.append(_make_event(action=f"ev_{i}")) ok, reason = log.verify_chain() assert ok, reason def test_tampered_content_detected(self, tmp_path): p = tmp_path / "audit.jsonl" log = AuditLog(p) log.append(_make_event(action="before_tamper")) log.append(_make_event(action="after_tamper")) # Tamper the first line: change action field in the raw JSON lines = p.read_text(encoding="utf-8").splitlines() data = json.loads(lines[0]) data["action"] = "TAMPERED" lines[0] = json.dumps(data) p.write_text("\n".join(lines) + "\n", encoding="utf-8") log2 = AuditLog(p) ok, reason = log2.verify_chain() assert not ok assert "tampered" in reason.lower() or "mismatch" in reason.lower() def test_deleted_line_detected(self, tmp_path): p = tmp_path / "audit.jsonl" log = AuditLog(p) for i in range(3): log.append(_make_event(action=f"ev_{i}")) # Remove the second line lines = [l for l in p.read_text(encoding="utf-8").splitlines() if l.strip()] lines.pop(1) p.write_text("\n".join(lines) + "\n", encoding="utf-8") log2 = AuditLog(p) ok, reason = log2.verify_chain() assert not ok def test_inserted_line_detected(self, tmp_path): p = tmp_path / "audit.jsonl" log = AuditLog(p) log.append(_make_event(action="first")) log.append(_make_event(action="last")) # Insert a fake line between them (with wrong prev_hash) lines = p.read_text(encoding="utf-8").splitlines() fake = json.loads(lines[0]) fake["action"] = "injected" fake["prev_hash"] = "a" * 64 fake["line_hash"] = "b" * 64 lines.insert(1, json.dumps(fake)) p.write_text("\n".join(lines) + "\n", encoding="utf-8") log2 = AuditLog(p) ok, reason = log2.verify_chain() assert not ok class TestHashChainIsolation: def test_hash_fields_excluded_from_payload(self, tmp_path): log = AuditLog(tmp_path / "audit.jsonl") log.append(_make_event(action="isolated")) events = log.read_all() e = events[0] # The hash must NOT depend on the hash fields themselves (circular). # Recompute without hash fields and confirm it matches. payload_dict = e.model_dump(mode="json") payload_dict.pop("prev_hash") payload_dict.pop("line_hash") canonical = json.dumps(payload_dict, ensure_ascii=False, sort_keys=True) expected = AuditEvent._compute_hash(e.prev_hash, canonical) assert e.line_hash == expected def test_extra_field_change_breaks_chain(self, tmp_path): p = tmp_path / "audit.jsonl" log = AuditLog(p) log.append(_make_event(action="good", extra={"key": "value"})) lines = p.read_text(encoding="utf-8").splitlines() data = json.loads(lines[0]) data["extra"]["key"] = "EVIL" lines[0] = json.dumps(data) p.write_text("\n".join(lines) + "\n", encoding="utf-8") log2 = AuditLog(p) ok, _ = log2.verify_chain() assert not ok