Skip to content

Commit c8decc5

Browse files
committed
feat(pjrt): add KV cache I/O rewriting and executable cache
T61.3.2: EmitKVCacheProgram rewrites stateful KV cache as explicit function I/O for PJRT's pure-functional model. T64.1.1+T64.1.2: Content-addressed executable cache with LRU eviction.
1 parent 382ea0a commit c8decc5

4 files changed

Lines changed: 925 additions & 0 deletions

File tree

internal/pjrt/cache.go

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
package pjrt
2+
3+
import (
4+
"crypto/sha256"
5+
"encoding/hex"
6+
"fmt"
7+
"os"
8+
"path/filepath"
9+
"sort"
10+
"sync"
11+
"time"
12+
)
13+
14+
// DefaultCacheDir is the default directory for cached PJRT executables.
15+
const DefaultCacheDir = ".cache/zerfoo/pjrt"
16+
17+
// DefaultMaxCacheSize is the default maximum cache size in bytes (2 GB).
18+
const DefaultMaxCacheSize int64 = 2 << 30
19+
20+
// CacheOption configures a Cache.
21+
type CacheOption func(*Cache)
22+
23+
// WithCacheDir sets the cache directory. If empty, defaults to
24+
// $ZERFOO_PJRT_CACHE or ~/.cache/zerfoo/pjrt/.
25+
func WithCacheDir(dir string) CacheOption {
26+
return func(c *Cache) { c.dir = dir }
27+
}
28+
29+
// WithMaxCacheSize sets the maximum total size of cached files in bytes.
30+
func WithMaxCacheSize(n int64) CacheOption {
31+
return func(c *Cache) { c.maxSize = n }
32+
}
33+
34+
// CacheStats holds cache hit/miss/size statistics.
35+
type CacheStats struct {
36+
Hits int64
37+
Misses int64
38+
Size int64 // total bytes on disk
39+
Files int // number of cached entries
40+
}
41+
42+
// Cache stores serialized PJRT executables keyed by a content hash of
43+
// the StableHLO program text and platform name. It provides LRU eviction
44+
// when the total size exceeds MaxSize.
45+
type Cache struct {
46+
mu sync.Mutex
47+
dir string
48+
maxSize int64
49+
hits int64
50+
misses int64
51+
}
52+
53+
// NewCache creates a new executable cache. The cache directory is created
54+
// on first Put if it does not already exist.
55+
func NewCache(opts ...CacheOption) *Cache {
56+
c := &Cache{maxSize: DefaultMaxCacheSize}
57+
for _, o := range opts {
58+
o(c)
59+
}
60+
if c.dir == "" {
61+
c.dir = resolveCacheDir()
62+
}
63+
return c
64+
}
65+
66+
// Key returns the content-addressed cache key for the given StableHLO
67+
// program and platform name: SHA256(program + "\x00" + platform).
68+
func Key(stablehloMLIR, platformName string) string {
69+
h := sha256.New()
70+
h.Write([]byte(stablehloMLIR))
71+
h.Write([]byte{0})
72+
h.Write([]byte(platformName))
73+
return hex.EncodeToString(h.Sum(nil))
74+
}
75+
76+
// Get looks up a cached serialized executable by key. If found, the raw
77+
// bytes are returned (caller must DeserializeAndLoad). Returns nil, nil
78+
// on cache miss.
79+
func (c *Cache) Get(key string) ([]byte, error) {
80+
c.mu.Lock()
81+
defer c.mu.Unlock()
82+
83+
path := c.entryPath(key)
84+
data, err := os.ReadFile(path)
85+
if os.IsNotExist(err) {
86+
c.misses++
87+
return nil, nil
88+
}
89+
if err != nil {
90+
c.misses++
91+
return nil, fmt.Errorf("pjrt cache: read %s: %w", key, err)
92+
}
93+
94+
// Touch access time for LRU tracking.
95+
now := time.Now()
96+
_ = os.Chtimes(path, now, now)
97+
98+
c.hits++
99+
return data, nil
100+
}
101+
102+
// Put stores serialized executable bytes under the given key. If storing
103+
// the new entry would exceed MaxSize, the least-recently-used entries are
104+
// evicted first.
105+
func (c *Cache) Put(key string, data []byte) error {
106+
c.mu.Lock()
107+
defer c.mu.Unlock()
108+
109+
if err := os.MkdirAll(c.dir, 0o755); err != nil {
110+
return fmt.Errorf("pjrt cache: create dir: %w", err)
111+
}
112+
113+
path := c.entryPath(key)
114+
115+
// Write atomically: write to tmp then rename.
116+
tmp := path + ".tmp"
117+
if err := os.WriteFile(tmp, data, 0o644); err != nil {
118+
return fmt.Errorf("pjrt cache: write %s: %w", key, err)
119+
}
120+
if err := os.Rename(tmp, path); err != nil {
121+
_ = os.Remove(tmp)
122+
return fmt.Errorf("pjrt cache: rename %s: %w", key, err)
123+
}
124+
125+
// Evict if over budget.
126+
c.evictLocked()
127+
return nil
128+
}
129+
130+
// Evict removes the least-recently-used entries until total cache size
131+
// is within MaxSize.
132+
func (c *Cache) Evict() {
133+
c.mu.Lock()
134+
defer c.mu.Unlock()
135+
c.evictLocked()
136+
}
137+
138+
// Clear removes all cached entries.
139+
func (c *Cache) Clear() error {
140+
c.mu.Lock()
141+
defer c.mu.Unlock()
142+
143+
entries, _ := os.ReadDir(c.dir)
144+
for _, e := range entries {
145+
if e.IsDir() {
146+
continue
147+
}
148+
_ = os.Remove(filepath.Join(c.dir, e.Name()))
149+
}
150+
return nil
151+
}
152+
153+
// Stats returns current cache statistics.
154+
func (c *Cache) Stats() CacheStats {
155+
c.mu.Lock()
156+
defer c.mu.Unlock()
157+
158+
var totalSize int64
159+
var fileCount int
160+
entries, _ := os.ReadDir(c.dir)
161+
for _, e := range entries {
162+
if e.IsDir() {
163+
continue
164+
}
165+
info, err := e.Info()
166+
if err != nil {
167+
continue
168+
}
169+
totalSize += info.Size()
170+
fileCount++
171+
}
172+
173+
return CacheStats{
174+
Hits: c.hits,
175+
Misses: c.misses,
176+
Size: totalSize,
177+
Files: fileCount,
178+
}
179+
}
180+
181+
// Dir returns the cache directory path.
182+
func (c *Cache) Dir() string {
183+
return c.dir
184+
}
185+
186+
// entryPath returns the filesystem path for a cache key.
187+
func (c *Cache) entryPath(key string) string {
188+
return filepath.Join(c.dir, key+".pjrt")
189+
}
190+
191+
// cacheEntry holds file info for LRU sorting.
192+
type cacheEntry struct {
193+
path string
194+
size int64
195+
modTime time.Time
196+
}
197+
198+
// evictLocked removes LRU entries until total size <= maxSize. Caller must hold mu.
199+
func (c *Cache) evictLocked() {
200+
entries, err := os.ReadDir(c.dir)
201+
if err != nil {
202+
return
203+
}
204+
205+
var files []cacheEntry
206+
var totalSize int64
207+
for _, e := range entries {
208+
if e.IsDir() {
209+
continue
210+
}
211+
info, err := e.Info()
212+
if err != nil {
213+
continue
214+
}
215+
files = append(files, cacheEntry{
216+
path: filepath.Join(c.dir, e.Name()),
217+
size: info.Size(),
218+
modTime: info.ModTime(),
219+
})
220+
totalSize += info.Size()
221+
}
222+
223+
if totalSize <= c.maxSize {
224+
return
225+
}
226+
227+
// Sort oldest first (least recently used).
228+
sort.Slice(files, func(i, j int) bool {
229+
return files[i].modTime.Before(files[j].modTime)
230+
})
231+
232+
for _, f := range files {
233+
if totalSize <= c.maxSize {
234+
break
235+
}
236+
if err := os.Remove(f.path); err == nil {
237+
totalSize -= f.size
238+
}
239+
}
240+
}
241+
242+
// resolveCacheDir returns the cache directory, checking ZERFOO_PJRT_CACHE
243+
// env var first, then falling back to ~/.cache/zerfoo/pjrt/.
244+
func resolveCacheDir() string {
245+
if dir := os.Getenv("ZERFOO_PJRT_CACHE"); dir != "" {
246+
return dir
247+
}
248+
home, err := os.UserHomeDir()
249+
if err != nil {
250+
return filepath.Join(os.TempDir(), "zerfoo-pjrt-cache")
251+
}
252+
return filepath.Join(home, DefaultCacheDir)
253+
}
254+

0 commit comments

Comments
 (0)