@@ -1229,44 +1229,156 @@ int tq_encode(const tq_tokenizer_t* tok, const char* text,
12291229 }
12301230 }
12311231
1232- /* BPE merge pass: repeatedly merge the highest-priority pair.
1233- * A merge has higher priority if its score is larger.
1234- * We check all consecutive token pairs against the merge table. */
1235- while (n_tokens >= 2 ) {
1236- float best_score = -1e30f ;
1237- int best_idx = -1 ;
1238- int best_id = -1 ;
1239-
1232+ /* BPE merge pass using a max-heap for O(n log n) instead of O(n²).
1233+ *
1234+ * The naive algorithm scans all pairs on each merge step → O(n²).
1235+ * For 17K initial tokens (GPT2 byte-level), that's ~289M ops = minutes.
1236+ *
1237+ * Heap approach:
1238+ * 1. Build a heap of all mergeable consecutive pairs (score, position)
1239+ * 2. Pop max-score pair, apply merge, invalidate stale entries
1240+ * 3. Insert new pairs formed at the merge point
1241+ * 4. O(n log n) total: n initial inserts + n pops + O(1) updates each
1242+ *
1243+ * We use a simple binary max-heap with lazy deletion (stale entries
1244+ * are skipped when popped, identified by a generation counter). */
1245+ {
1246+ /* Linked list for O(1) neighbor access after merges */
1247+ int * prev = (int * )malloc ((size_t )n_tokens * sizeof (int ));
1248+ int * next = (int * )malloc ((size_t )n_tokens * sizeof (int ));
1249+ if (!prev || !next ) { free (prev ); free (next ); return n_tokens ; }
1250+ for (int i = 0 ; i < n_tokens ; i ++ ) { prev [i ] = i - 1 ; next [i ] = i + 1 ; }
1251+
1252+ /* Heap entry: (score, left_pos, merge_id, generation) */
1253+ typedef struct { float score ; int pos ; int merge_id ; int gen ; } heap_entry_t ;
1254+ int heap_cap = n_tokens + 16 ;
1255+ heap_entry_t * heap = (heap_entry_t * )malloc ((size_t )heap_cap * sizeof (heap_entry_t ));
1256+ int * gen = (int * )calloc ((size_t )n_tokens , sizeof (int )); /* per-position generation */
1257+ if (!heap || !gen ) { free (prev ); free (next ); free (heap ); free (gen ); return n_tokens ; }
1258+ int heap_size = 0 ;
1259+
1260+ /* Heap helpers (max-heap by score) */
1261+ #define HEAP_PARENT (i ) (((i)-1)/2)
1262+ #define HEAP_LEFT (i ) (2*(i)+1)
1263+ #define HEAP_RIGHT (i ) (2*(i)+2)
1264+ #define HEAP_SWAP (a ,b ) { heap_entry_t _t = heap[a]; heap[a] = heap[b]; heap[b] = _t; }
1265+
1266+ void * _dummy_ptr = NULL ; (void )_dummy_ptr ; /* suppress unused warning */
1267+
1268+ /* Sift up */
1269+ int sift_up_idx = 0 ;
1270+ #define SIFT_UP (idx ) do { \
1271+ sift_up_idx = (idx); \
1272+ while (sift_up_idx > 0 && heap[sift_up_idx].score > heap[HEAP_PARENT(sift_up_idx)].score) { \
1273+ HEAP_SWAP(sift_up_idx, HEAP_PARENT(sift_up_idx)); \
1274+ sift_up_idx = HEAP_PARENT(sift_up_idx); \
1275+ } \
1276+ } while(0)
1277+
1278+ /* Sift down */
1279+ #define SIFT_DOWN (idx ) do { \
1280+ int _si = (idx); \
1281+ for (;;) { \
1282+ int _best = _si; \
1283+ int _l = HEAP_LEFT(_si), _r = HEAP_RIGHT(_si); \
1284+ if (_l < heap_size && heap[_l].score > heap[_best].score) _best = _l; \
1285+ if (_r < heap_size && heap[_r].score > heap[_best].score) _best = _r; \
1286+ if (_best == _si) break; \
1287+ HEAP_SWAP(_si, _best); _si = _best; \
1288+ } \
1289+ } while(0)
1290+
1291+ /* Try to create a merge entry for position i and its next neighbor */
1292+ #define TRY_INSERT_PAIR (i ) do { \
1293+ int _ni = next[i]; \
1294+ if (_ni < n_tokens && tokens[_ni] >= 0) { \
1295+ const char* _s1 = tok->vocab[tokens[i]]; \
1296+ const char* _s2 = tok->vocab[tokens[_ni]]; \
1297+ int _l1 = (int)strlen(_s1), _l2 = (int)strlen(_s2); \
1298+ if (_l1 + _l2 < 512) { \
1299+ char _m[512]; memcpy(_m, _s1, _l1); memcpy(_m+_l1, _s2, _l2); _m[_l1+_l2]=0; \
1300+ int _mid = str_lookup(tok, _m); \
1301+ if (_mid >= 0) { \
1302+ if (heap_size >= heap_cap) { heap_cap *= 2; heap = realloc(heap, (size_t)heap_cap * sizeof(heap_entry_t)); } \
1303+ heap[heap_size] = (heap_entry_t){tok->scores[_mid], (i), _mid, gen[i]}; \
1304+ SIFT_UP(heap_size); heap_size++; \
1305+ } \
1306+ } \
1307+ } \
1308+ } while(0)
1309+
1310+ /* Build initial heap */
12401311 for (int i = 0 ; i < n_tokens - 1 ; i ++ ) {
1241- /* Construct merged string */
1242- const char * s1 = tok -> vocab [ tokens [ i ]];
1243- const char * s2 = tok -> vocab [tokens [i + 1 ]];
1244- int len1 = ( int ) strlen ( s1 ) ;
1245- int len2 = (int )strlen (s2 );
1246-
1247- if ( len1 + len2 >= 512 ) continue ;
1248-
1249- char merged [ 512 ] ;
1250- memcpy ( merged , s1 , ( size_t ) len1 ) ;
1251- memcpy ( merged + len1 , s2 , ( size_t ) len2 );
1252- merged [ len1 + len2 ] = '\0' ;
1253-
1254- int id = str_lookup ( tok , merged ) ;
1255- if ( id >= 0 && tok -> scores [ id ] > best_score ) {
1256- best_score = tok -> scores [ id ] ;
1257- best_idx = i ;
1258- best_id = id ;
1312+ int ni = next [ i ];
1313+ if ( ni < n_tokens ) {
1314+ const char * s1 = tok -> vocab [tokens [i ]];
1315+ const char * s2 = tok -> vocab [ tokens [ ni ]] ;
1316+ int l1 = ( int ) strlen ( s1 ), l2 = (int )strlen (s2 );
1317+ if ( l1 + l2 < 512 ) {
1318+ char merged [ 512 ] ;
1319+ memcpy ( merged , s1 , ( size_t ) l1 );
1320+ memcpy ( merged + l1 , s2 , ( size_t ) l2 ) ;
1321+ merged [ l1 + l2 ] = '\0' ;
1322+ int mid = str_lookup ( tok , merged );
1323+ if ( mid >= 0 ) {
1324+ if ( heap_size >= heap_cap ) { heap_cap *= 2 ; heap = realloc ( heap , ( size_t ) heap_cap * sizeof ( heap_entry_t )); }
1325+ heap [ heap_size ] = ( heap_entry_t ){ tok -> scores [ mid ], i , mid , 0 } ;
1326+ SIFT_UP ( heap_size );
1327+ heap_size ++ ;
1328+ }
1329+ }
12591330 }
12601331 }
12611332
1262- if (best_idx < 0 ) break ;
1333+ /* Merge loop */
1334+ int active_count = n_tokens ;
1335+ while (heap_size > 0 && active_count >= 2 ) {
1336+ /* Pop max */
1337+ heap_entry_t top = heap [0 ];
1338+ heap [0 ] = heap [-- heap_size ];
1339+ if (heap_size > 0 ) { SIFT_DOWN (0 ); }
1340+
1341+ /* Check if stale (position was already merged) */
1342+ if (top .gen != gen [top .pos ]) continue ;
1343+ int ri = next [top .pos ];
1344+ if (ri >= n_tokens || tokens [ri ] < 0 ) continue ;
1345+
1346+ /* Apply merge: left absorbs right */
1347+ tokens [top .pos ] = top .merge_id ;
1348+ tokens [ri ] = -1 ; /* mark dead */
1349+ gen [top .pos ]++ ; /* invalidate old entries for this position */
1350+
1351+ /* Update linked list: skip the dead right node */
1352+ int rr = next [ri ];
1353+ next [top .pos ] = rr ;
1354+ if (rr < n_tokens ) prev [rr ] = top .pos ;
1355+ active_count -- ;
1356+
1357+ /* Insert new pairs: (prev_of_left, left) and (left, next_of_right) */
1358+ if (prev [top .pos ] >= 0 && tokens [prev [top .pos ]] >= 0 ) {
1359+ gen [prev [top .pos ]]++ ;
1360+ TRY_INSERT_PAIR (prev [top .pos ]);
1361+ }
1362+ if (next [top .pos ] < n_tokens && tokens [next [top .pos ]] >= 0 ) {
1363+ TRY_INSERT_PAIR (top .pos );
1364+ }
1365+ }
12631366
1264- /* Apply the merge */
1265- tokens [ best_idx ] = best_id ;
1266- for (int i = best_idx + 1 ; i < n_tokens - 1 ; i ++ ) {
1267- tokens [i ] = tokens [i + 1 ];
1367+ /* Compact: remove dead tokens */
1368+ int out = 0 ;
1369+ for (int i = 0 ; i < n_tokens ; i ++ ) {
1370+ if ( tokens [i ] >= 0 ) tokens [out ++ ] = tokens [ i ];
12681371 }
1269- n_tokens -- ;
1372+ n_tokens = out ;
1373+
1374+ free (prev ); free (next ); free (heap ); free (gen );
1375+ #undef HEAP_PARENT
1376+ #undef HEAP_LEFT
1377+ #undef HEAP_RIGHT
1378+ #undef HEAP_SWAP
1379+ #undef SIFT_UP
1380+ #undef SIFT_DOWN
1381+ #undef TRY_INSERT_PAIR
12701382 }
12711383
12721384 return n_tokens ;
0 commit comments