Skip to content
This repository was archived by the owner on Jan 2, 2026. It is now read-only.

Commit 8c17576

Browse files
zircoteclaude
andcommitted
fix: address code review findings for performance and quality
Performance improvements (PERF-001 through PERF-007): - Add show_notes_batch() for batch git notes retrieval using git cat-file --batch - Add embed_batch() for batch embedding generation - Add prewarm() method for eager model loading - Add @lru_cache for struct format caching in index.py - Implement batch hydration in recall.py using show_notes_batch() - Use batch operations in sync.py for index synchronization Quality improvements (QUAL-001, QUAL-002): - Rename _read_input to _read_input_with_fallback for clarity - Replace broad Exception catches with specific (MemoryIndexError, OSError) Security improvements (SEC-001, SEC-002): - Add input length limit in signal_detector.py for ReDoS prevention Test updates: - Update test fixtures to mock batch operations - Update test assertions for renamed functions and specific exceptions All 1806 tests passing with 89.29% coverage. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 40c7d04 commit 8c17576

12 files changed

Lines changed: 802 additions & 193 deletions

File tree

src/git_notes_memory/embedding.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,32 @@ def similarity(
310310
# Dot product of normalized vectors = cosine similarity
311311
return sum(a * b for a, b in zip(embedding1, embedding2, strict=True))
312312

313+
def prewarm(self) -> bool:
314+
"""Pre-warm the embedding model by loading it eagerly.
315+
316+
PERF-004: Call this during application startup or hook initialization
317+
to avoid cold start latency on first embed() call. Useful for:
318+
- Session start hooks that need fast response
319+
- Background workers that will process embeddings
320+
- Applications where predictable latency is important
321+
322+
Returns:
323+
True if model was loaded (or already loaded), False on error.
324+
325+
Examples:
326+
>>> service = EmbeddingService()
327+
>>> service.prewarm() # Load model in background
328+
True
329+
>>> service.is_loaded
330+
True
331+
"""
332+
try:
333+
self.load()
334+
return True
335+
except Exception as e:
336+
logger.warning("Failed to pre-warm embedding model: %s", e)
337+
return False
338+
313339
def unload(self) -> None:
314340
"""Unload the model to free memory.
315341
@@ -322,13 +348,10 @@ def unload(self) -> None:
322348

323349

324350
# =============================================================================
325-
# Singleton Instance
351+
# Singleton Access (using ServiceRegistry)
326352
# =============================================================================
327353

328354

329-
_default_service: EmbeddingService | None = None
330-
331-
332355
def get_default_service() -> EmbeddingService:
333356
"""Get the default embedding service singleton.
334357
@@ -340,7 +363,6 @@ def get_default_service() -> EmbeddingService:
340363
>>> service.model_name
341364
'all-MiniLM-L6-v2'
342365
"""
343-
global _default_service
344-
if _default_service is None:
345-
_default_service = EmbeddingService()
346-
return _default_service
366+
from git_notes_memory.registry import ServiceRegistry
367+
368+
return ServiceRegistry.get(EmbeddingService)

src/git_notes_memory/git_ops.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,12 @@ def _run_git(
143143
)
144144
except subprocess.CalledProcessError as e:
145145
# Parse common git errors for better messages
146+
# SEC-002: Sanitize paths in error messages to prevent info leakage
146147
stderr = e.stderr or ""
147148
if "not a git repository" in stderr.lower():
148149
raise StorageError(
149150
"Not in a Git repository",
150-
f"Initialize a git repository: cd {self.repo_path} && git init",
151+
"Initialize a git repository: cd <repo_path> && git init",
151152
) from e
152153
if "permission denied" in stderr.lower():
153154
raise StorageError(
@@ -159,8 +160,12 @@ def _run_git(
159160
"Repository has no commits",
160161
"Create at least one commit: git commit --allow-empty -m 'initial'",
161162
) from e
163+
# Sanitize the args to remove full paths
164+
sanitized_args = [
165+
arg if not arg.startswith("/") else "<path>" for arg in args
166+
]
162167
raise StorageError(
163-
f"Git command failed: {' '.join(args)}\n{stderr}",
168+
f"Git command failed: {' '.join(sanitized_args)}\n{stderr}",
164169
"Check git status and try again",
165170
) from e
166171

@@ -321,6 +326,105 @@ def show_note(
321326

322327
return result.stdout
323328

329+
def show_notes_batch(
330+
self,
331+
namespace: str,
332+
commit_shas: list[str],
333+
) -> dict[str, str | None]:
334+
"""Show multiple notes in a single subprocess call.
335+
336+
Uses `git cat-file --batch` for efficient bulk retrieval.
337+
This is significantly faster than calling show_note() in a loop
338+
when fetching many notes.
339+
340+
Args:
341+
namespace: Memory namespace.
342+
commit_shas: List of commit SHAs to get notes for.
343+
344+
Returns:
345+
Dict mapping commit_sha -> note content (or None if no note).
346+
347+
Raises:
348+
ValidationError: If namespace is invalid.
349+
"""
350+
if not commit_shas:
351+
return {}
352+
353+
self._validate_namespace(namespace)
354+
for sha in commit_shas:
355+
self._validate_git_ref(sha)
356+
357+
# Build object references: notes ref points to the note object for each commit
358+
# Format: refs/notes/mem/namespace:commit_sha
359+
ref = self._note_ref(namespace)
360+
objects_input = "\n".join(f"{ref}:{sha}" for sha in commit_shas)
361+
362+
# Run cat-file --batch to get all notes at once
363+
cmd = ["git", "-C", str(self.repo_path), "cat-file", "--batch"]
364+
365+
try:
366+
result = subprocess.run(
367+
cmd,
368+
input=objects_input,
369+
capture_output=True,
370+
text=True,
371+
check=False,
372+
)
373+
except Exception:
374+
# Fallback to sequential if batch fails
375+
return {sha: self.show_note(namespace, sha) for sha in commit_shas}
376+
377+
# Parse batch output
378+
# Format per object:
379+
# <sha> <type> <size>\n
380+
# <content>\n
381+
# Or for missing:
382+
# <ref> missing\n
383+
results: dict[str, str | None] = {}
384+
lines: list[str] = result.stdout.split("\n")
385+
i = 0
386+
sha_index = 0
387+
388+
while i < len(lines) and sha_index < len(commit_shas):
389+
line = lines[i]
390+
current_sha = commit_shas[sha_index]
391+
392+
if "missing" in line:
393+
results[current_sha] = None
394+
i += 1
395+
sha_index += 1
396+
elif line and not line.startswith(" "):
397+
# Header line: <object_sha> <type> <size>
398+
parts = line.split()
399+
if len(parts) >= 3:
400+
try:
401+
size = int(parts[2])
402+
# Content follows on next lines until size bytes consumed
403+
content_lines: list[str] = []
404+
remaining = size
405+
i += 1
406+
while remaining > 0 and i < len(lines):
407+
content_line = lines[i]
408+
content_lines.append(content_line)
409+
remaining -= len(content_line) + 1 # +1 for newline
410+
i += 1
411+
results[current_sha] = "\n".join(content_lines)
412+
sha_index += 1
413+
except (ValueError, IndexError):
414+
results[current_sha] = None
415+
sha_index += 1
416+
i += 1
417+
else:
418+
i += 1
419+
else:
420+
i += 1
421+
422+
# Fill in any remaining SHAs as None
423+
for remaining_sha in commit_shas[sha_index:]:
424+
results[remaining_sha] = None
425+
426+
return results
427+
324428
def list_notes(
325429
self,
326430
namespace: str,

src/git_notes_memory/hooks/context_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import TYPE_CHECKING
2020

2121
from git_notes_memory.config import TOKENS_PER_CHAR, get_project_index_path
22+
from git_notes_memory.exceptions import MemoryIndexError
2223
from git_notes_memory.hooks.config_loader import (
2324
BudgetMode,
2425
HookConfig,
@@ -556,7 +557,8 @@ def _analyze_project_complexity(self, project: str) -> str:
556557
return "complex"
557558
return "full"
558559

559-
except Exception as e:
560+
# QUAL-002: Catch specific exceptions instead of bare Exception
561+
except (MemoryIndexError, OSError) as e:
560562
logger.debug("Failed to analyze complexity for %s: %s", project, e)
561563
return "medium" # Default to medium on error
562564

src/git_notes_memory/hooks/signal_detector.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,17 @@ def detect(self, text: str) -> list[CaptureSignal]:
271271
if not text or len(text) < 5:
272272
return []
273273

274+
# SEC-001: Limit input length to prevent ReDoS attacks
275+
# 100KB is generous for user prompts while preventing abuse
276+
MAX_TEXT_LENGTH = 100 * 1024 # 100KB
277+
if len(text) > MAX_TEXT_LENGTH:
278+
logger.warning(
279+
"Input text length %d exceeds maximum %d, truncating for safety",
280+
len(text),
281+
MAX_TEXT_LENGTH,
282+
)
283+
text = text[:MAX_TEXT_LENGTH]
284+
274285
signals: list[CaptureSignal] = []
275286
block_positions: set[tuple[int, int]] = set()
276287

src/git_notes_memory/hooks/stop_handler.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
cancel_timeout,
4646
get_hook_logger,
4747
log_hook_input,
48+
read_json_input,
4849
setup_logging,
4950
setup_timeout,
5051
)
@@ -55,23 +56,22 @@
5556
logger = logging.getLogger(__name__)
5657

5758

58-
def _read_input() -> dict[str, Any]:
59-
"""Read and parse JSON input from stdin.
59+
def _read_input_with_fallback() -> dict[str, Any]:
60+
"""Read and parse JSON input from stdin with fallback for empty input.
6061
61-
Returns:
62-
Parsed JSON data.
62+
QUAL-001: Wraps hook_utils.read_json_input() with Stop-hook-specific
63+
fallback behavior (empty input is valid for stop hooks).
6364
64-
Raises:
65-
json.JSONDecodeError: If input is not valid JSON.
65+
Returns:
66+
Parsed JSON data, or empty dict if stdin is empty.
6667
"""
67-
input_text = sys.stdin.read()
68-
if not input_text.strip():
68+
try:
69+
return read_json_input()
70+
except ValueError as e:
6971
# Empty input is OK for stop hook
70-
return {}
71-
result = json.loads(input_text)
72-
if not isinstance(result, dict):
73-
return {}
74-
return dict(result)
72+
if "empty" in str(e).lower():
73+
return {}
74+
raise
7575

7676

7777
def _analyze_session(transcript_path: str | None) -> list[CaptureSignal]:
@@ -385,8 +385,8 @@ def main() -> None:
385385
setup_timeout(timeout, hook_name="Stop")
386386

387387
try:
388-
# Read input (may be empty for stop hook)
389-
input_data = _read_input()
388+
# QUAL-001: Use hook_utils.read_json_input with fallback
389+
input_data = _read_input_with_fallback()
390390
logger.debug("Received stop hook input: %s", list(input_data.keys()))
391391

392392
# Log full input to file for debugging

src/git_notes_memory/index.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323

2424
import contextlib
2525
import sqlite3
26+
import struct
2627
from contextlib import contextmanager
2728
from datetime import UTC, datetime
29+
from functools import lru_cache
2830
from pathlib import Path
2931
from typing import TYPE_CHECKING
3032

@@ -33,6 +35,24 @@
3335
from git_notes_memory.config import EMBEDDING_DIMENSIONS, get_index_path
3436
from git_notes_memory.exceptions import MemoryIndexError
3537

38+
39+
# PERF-007: Cache compiled struct format for embedding serialization
40+
@lru_cache(maxsize=8)
41+
def _get_struct_format(dimensions: int) -> struct.Struct:
42+
"""Get a cached struct.Struct for packing embeddings.
43+
44+
The embedding dimensions are typically constant (384 for all-MiniLM-L6-v2),
45+
so caching the compiled Struct avoids repeated format string parsing.
46+
47+
Args:
48+
dimensions: Number of float values in the embedding.
49+
50+
Returns:
51+
A compiled struct.Struct instance for packing.
52+
"""
53+
return struct.Struct(f"{dimensions}f")
54+
55+
3656
if TYPE_CHECKING:
3757
from collections.abc import Iterator, Sequence
3858

@@ -508,10 +528,8 @@ def _insert_embedding(
508528
memory_id: ID of the memory this embedding belongs to.
509529
embedding: The embedding vector.
510530
"""
511-
# sqlite-vec expects binary format for vectors
512-
import struct
513-
514-
blob = struct.pack(f"{len(embedding)}f", *embedding)
531+
# PERF-007: Use cached struct format for embedding packing
532+
blob = _get_struct_format(len(embedding)).pack(*embedding)
515533
cursor.execute(
516534
"INSERT INTO vec_memories (id, embedding) VALUES (?, ?)",
517535
(memory_id, blob),
@@ -820,9 +838,8 @@ def _update_embedding(
820838
memory_id: ID of the memory this embedding belongs to.
821839
embedding: The new embedding vector.
822840
"""
823-
import struct
824-
825-
blob = struct.pack(f"{len(embedding)}f", *embedding)
841+
# PERF-007: Use cached struct format for embedding packing
842+
blob = _get_struct_format(len(embedding)).pack(*embedding)
826843

827844
# Delete existing and insert new (sqlite-vec doesn't support UPDATE well)
828845
cursor.execute("DELETE FROM vec_memories WHERE id = ?", (memory_id,))
@@ -984,10 +1001,8 @@ def search_vector(
9841001
List of (Memory, distance) tuples sorted by distance ascending.
9851002
Lower distance means more similar.
9861003
"""
987-
import struct
988-
989-
# Pack query embedding as binary
990-
blob = struct.pack(f"{len(query_embedding)}f", *query_embedding)
1004+
# PERF-007: Use cached struct format for embedding packing
1005+
blob = _get_struct_format(len(query_embedding)).pack(*query_embedding)
9911006

9921007
with self._cursor() as cursor:
9931008
try:

0 commit comments

Comments
 (0)