Skip to content

Add TurboQuant rotation-based vector quantization codec to sandbox#15903

Draft
xande wants to merge 19 commits intoapache:mainfrom
xande:turboquant
Draft

Add TurboQuant rotation-based vector quantization codec to sandbox#15903
xande wants to merge 19 commits intoapache:mainfrom
xande:turboquant

Conversation

@xande
Copy link
Copy Markdown

@xande xande commented Mar 31, 2026

Summary

This PR adds a new FlatVectorsFormat implementation based on the TurboQuant algorithm (Zandieh et al., arXiv:2504.19874, ICLR 2026) to the lucene/sandbox module.

This implementation was co-authored with an AI coding agent (Kiro) as an experiment in AI-assisted open source contribution. The agent handled the bulk of the code generation, test writing, and iterative debugging while I provided direction, reviewed outputs, ran benchmarks, and validated against real datasets. I want to be transparent that while I've tested and benchmarked this across various configurations, I don't have deep expertise in Lucene's codec internals - I'd greatly appreciate thorough review from the community.

Motivation

Current Lucene vector quantization formats (scalar quantization, BBQ) are limited to 1024 dimensions and require per-segment calibration. With embedding models increasingly producing higher-dimensional vectors (OpenAI text-embedding-3-large at 3072d, various 4096d models emerging), we need a quantization approach that scales beyond this limit.

TurboQuant is a data-oblivious rotation-based quantizer that:

  • Requires no calibration - precomputed codebooks, each vector quantized independently
  • Supports dimensions up to 16,384 (vs 1024 limit of existing formats)
  • Enables byte-copy merge - all segments share the same rotation seed, so merge never re-quantizes
  • Is streaming-friendly - no warmup period, no need to sample data upfront

Design

Follows the Lucene104ScalarQuantizedVectorsFormat pattern:

  • TurboQuantFlatVectorsFormat extends FlatVectorsFormat - stores quantized vectors in .vetq, metadata in .vemtq, delegates raw vectors to Lucene99FlatVectorsFormat
  • TurboQuantHnswVectorsFormat extends KnnVectorsFormat - convenience composition with Lucene99HnswVectorsWriter/Reader
  • TurboQuantVectorsScorer implements FlatVectorsScorer - LUT-based scoring directly from packed bytes
  • TurboQuantEncoding enum: BITS_2 (16x), BITS_3 (~10.7x), BITS_4 (8x), BITS_8 (4x) compression
  • Block-diagonal Hadamard rotation for non-power-of-2 dimensions (e.g., d=768 → blocks 512+256)
  • Placed in lucene/sandbox - no changes to lucene/core

Implementation

12 source files (2,090 lines), 11 test files (1,591 lines), 1 JMH benchmark:

sandbox/codecs/turboquant/
├── TurboQuantEncoding.java              Enum: BITS_2/3/4/8
├── BetaCodebook.java                    Precomputed Lloyd-Max centroids
├── HadamardRotation.java                Block-diagonal FWHT + permutation + sign flip
├── TurboQuantBitPacker.java             Bit-packing for b=2,3,4,8
├── TurboQuantScoringUtil.java           LUT-based dot product & distance
├── TurboQuantFlatVectorsFormat.java     FlatVectorsFormat SPI entry point
├── TurboQuantFlatVectorsWriter.java     Rotate + quantize + write at flush
├── TurboQuantFlatVectorsReader.java     Off-heap read + scoring delegation
├── OffHeapTurboQuantVectorValues.java   mmap'd random access to quantized data
├── TurboQuantVectorsScorer.java         FlatVectorsScorer implementation
├── TurboQuantHnswVectorsFormat.java     HNSW + TurboQuant composition
└── package-info.java                    Javadoc with format spec

Benchmark: Cohere v3 Wikipedia English

400K vectors, 1024 dimensions, dot product similarity, HNSW (maxConn=64, beamWidth=250), fanout=100, topK=100, 10K queries, force-merged to 1 segment.

Encoding Compression Recall@100 QPS Latency Index time Force merge Index size
TurboQuant b=2 16x 0.811 113 8.8ms 75s 924s 1,719 MB
TurboQuant b=3 ~10.7x 0.887 116 8.6ms 76s 160s 1,767 MB
TurboQuant b=4 8x 0.935 117 8.5ms 76s 870s 1,816 MB
TurboQuant b=8 4x 0.983 115 8.6ms 81s 126s 2,011 MB

Note: force merge times for b=2 and b=4 are anomalously high compared to b=3 and b=8 - this may be a caching artifact and needs further investigation.

Test results

107 dedicated TurboQuant tests pass - 3 skipped (byte-vector-only), TurboQuant is float32 only:

Test suite Tests Notes
TestTurboQuantEncoding 7 Wire number serialization, packed lengths
TestBetaCodebook 7 Centroid symmetry, MSE distortion validation
TestHadamardRotation 9 Norm/IP preservation, round-trip, block-diagonal quality
TestTurboQuantBitPacker 6 Round-trip all encodings × dimensions
TestTurboQuantScoringUtil 2 LUT vs naive agreement < 1e-5
TestTurboQuantHnswVectorsFormat 53 Inherited from BaseKnnVectorsFormatTestCase
TestTurboQuantHnswVectorsFormatParams 6 Parameter validation, toString
TestTurboQuantHighDim 2 Index + search at d=768 and d=4096
TestTurboQuantQuality 11 Recall, edge cases, merge stress, similarity matrix
TestTurboQuantRecall 5 Recall at d=768/4096 across all encodings
TestTurboQuantBruteForceRecall 3 Brute-force recall isolating quantization quality

JMH microbenchmarks (d=4096, b=4, single thread)

Benchmark                              Mode  Cnt       Score   Units
TurboQuantBenchmark.dotProductScoring  thrpt    2  313,617   ops/s   (~3.2 µs/candidate)
TurboQuantBenchmark.hadamardRotation   thrpt    2   32,125   ops/s   (~31 µs/query)
TurboQuantBenchmark.quantize           thrpt    2    8,169   ops/s   (~122 µs/vector)

What's not implemented (deferred)

  1. Byte-copy merge optimization - currently re-quantizes from raw vectors during merge. The architecture supports byte-copy since all segments share the same rotation seed.
  2. Panama Vector API SIMD - the LUT scorer relies on JVM auto-vectorization. Explicit intrinsics could improve scoring throughput.
  3. Quantized-only mode - raw float32 vectors are always stored alongside quantized data.

When to use TurboQuant

TurboQuant is best suited for:

  • Workloads with shifting data distributions (no recalibration needed)
  • Merge-heavy indices (byte-copy merge advantage)
  • Streaming/online indexing

Open questions

  1. File extensions.vetq / .vemtq following the convention that different format types use different extensions. Any concerns?
  2. Max dimensionsgetMaxDimensions() returns 16384. Reasonable?
  3. Sandbox vs codecs — placed in sandbox as this is a new, community-unvetted contribution. Should it move to codecs once stabilized?

Alex Baranov added 19 commits March 30, 2026 21:51
…fold (codec integration)

Phase 1 - Core Algorithm (COMPLETE):
- TurboQuantEncoding: enum with BITS_2/3/4/8, wire numbers, packing math
- BetaCodebook: precomputed Lloyd-Max optimal centroids for N(0,1)
- HadamardRotation: block-diagonal FWHT with random permutation + sign flip
- TurboQuantBitPacker: optimized bit-packing for b=2,3,4,8
- All 32 Phase 1 unit tests pass
- MSE distortion at d=4096 b=4 matches paper (0.009)

Phase 2 - Codec Integration (IN PROGRESS):
- TurboQuantFlatVectorsFormat: FlatVectorsFormat SPI entry point
- TurboQuantFlatVectorsWriter: rotate + quantize + write at flush time
- TurboQuantFlatVectorsReader: off-heap read + scoring delegation
- OffHeapTurboQuantVectorValues: mmap'd random access to quantized vectors
- TurboQuantVectorsScorer: naive scorer (correctness-first, SIMD in Phase 3)
- TurboQuantHnswVectorsFormat: HNSW + TurboQuant composition
- SPI registration in META-INF/services
- 31/53 inherited BaseKnnVectorsFormatTestCase tests pass
- Remaining failures: byte vector tests (expected), merge path, off-heap map
…s pass

Three root causes fixed:
1. Merge path file handle: use temp file for scorer instead of opening
   .vetq while still writing (AccessDeniedException)
2. Byte vector support: delegate to raw reader instead of throwing
   UnsupportedOperationException
3. Off-heap size assertion: override assertOffHeapByteSize in test to
   handle TurboQuant's unique 'vetq' extension key

Results: 85 total tests pass (32 Phase 1 + 53 Phase 2 inherited), 3 skipped
…nd d=768 verified

Phase 2 Gate: COMPLETE
- 53/53 inherited BaseKnnVectorsFormatTestCase tests pass
- Index + search verified at d=768 and d=4096
- High-dim test added (TestTurboQuantHighDim)
- All merge tests pass
- CheckIndex integrity passes
- No resource leaks (testRandomExceptions passes)

Total: 87 tests pass, 0 failures, 3 skipped (byte-only tests)
… scorer

Phase 3 - SIMD Scoring:
- TurboQuantScoringUtil: LUT-based dot product and square distance
  for b=2,3,4,8 — operates directly on packed bytes without unpacking
- Scorer updated to use TurboQuantScoringUtil
- All 89 tests pass (no regression from Phase 2)
- SIMD vs naive agreement verified within 1e-5 for all encodings
- Performance benchmark deferred (JMH in Phase 4)

Phase 3 Gate: 3/4 items complete (perf benchmark deferred to Phase 4)
… cases, merge stress

Phase 4 - Comprehensive Testing:
- Recall validation: b=4 recall@10 >= 0.8 at d=128, b=8 >= 0.9, b=2 >= 0.5
- Edge cases: empty segment, single vector, all pass
- Merge stress: force merge 3 segments to 1, merge with 50% deleted docs
- All 4 similarity functions produce valid scores (non-NaN, non-negative)
- Total: 97 tests pass, 0 failures, 3 skipped

Phase 4 Gate: 5/7 items complete (full ant test + perf benchmarks deferred)
…rs verified

Phase 5 - Documentation:
- package-info.java with algorithm summary, file format spec, usage guidance
- All 20 Java files have ASF license headers
- No external dependencies (pure Java + precomputed constants)
- SPI registration in META-INF/services

All 5 phases complete. 97 tests pass, 0 failures.
Phase 1 gap fixed:
- Block-diagonal MSE quality test at d=768 vs d=1024 (within 5%)

Phase 2 gaps fixed:
- TestTurboQuantHnswVectorsFormatParams: testLimits, testToString,
  testMaxDimensions per section 2.6a

Phase 4 gaps fixed:
- Recall test at d=768 b=4 per section 4.1
- Randomized dimension recall test per section 4.1
- All similarity × all encoding combinations per section 4.2
- 10-segment force merge stress test per section 4.4

Phase 4.6:
- JMH benchmark: TurboQuantBenchmark (hadamard, scoring, quantize)
- benchmark-jmh module dependency and module export added

Phase 5.2:
- CHANGES.txt entry under New Features

Total: 107 tests pass, 0 failures, 3 skipped
3 items remain unchecked — all are runtime measurements, not code:
1. SIMD perf benchmark (JMH code written, needs execution)
2. Full test suite with randomized codec (needs CI run)
3. Perf comparison with scalar quant (needs JMH execution)

All code deliverables are complete. 107 tests pass.
…test suite

Scorer fixes:
- DOT_PRODUCT: remove docNorm multiplication (vectors are unit by contract)
- MAXIMUM_INNER_PRODUCT: use VectorUtil.scaleMaxInnerProductScore()
- Separate DOT_PRODUCT and MAXIMUM_INNER_PRODUCT cases

RandomCodec integration:
- Added TurboQuantHnswVectorsFormat to RandomCodec's knn format pool
- Random encoding selection per test run
- Exported turboquant package from codecs module-info
- 504 core vector tests pass with TurboQuant in random rotation
- 107 TurboQuant-specific tests pass
Final gates cleared:
- Phase 3: LUT scorer 313K ops/s dot product at d=4096 b=4 (JMH)
- Phase 4: Randomized codec test pass (504 core vector tests)
- Phase 4: Performance benchmarks documented

JMH Results (d=4096, b=4):
  dotProductScoring: 313,617 ops/s (~3.2 µs/score)
  hadamardRotation:   32,125 ops/s (~31 µs/rotation)
  quantize:            8,169 ops/s (~122 µs/quantize)

All gate checkboxes in TURBOQUANT_IMPLEMENTATION_PLAN.md are [x].
TURBOQUANT_IMPLEMENTATION_REPORT.md covers:
- Architecture & design decisions with rationale
- Implementation details (file format, index/search/merge flows)
- Full test results (107 dedicated + 504 core tests)
- JMH benchmark results (313K scoring ops/s at d=4096)
- 4 bugs found and fixed during implementation
- Deferred items and reproduction instructions
Recall results (HNSW search with over-retrieval):
  d=4096 b=4: 0.905 recall@10 (searchK=50, 500 vectors)
  d=768  b=4: 0.850 recall@10 (searchK=50, 1000 vectors)
  d=768  b=8: 0.980 recall@10 (searchK=10, 500 vectors)
  d=768  b=3: 0.810 recall@10 (searchK=30, 500 vectors)
  d=768  b=2: 0.680 recall@10 (searchK=50, 500 vectors)

Brute-force quantization quality (no HNSW):
  d=768  b=4: 0.856 recall@10 (pure quantization ranking)
  d=128  b=4: 0.876 recall@10
  d=768  b=8: 0.980 recall@10

Key finding: quantization quality is good (brute-force 0.856 at d=768 b=4)
but HNSW greedy traversal needs over-retrieval (searchK > k) to compensate
for quantized distance approximation during graph traversal.
Replaced placeholder recall numbers with actual measured values:
- Brute-force quantization quality: 0.856 at d=768 b=4
- HNSW recall with over-retrieval: 0.905 at d=4096 b=4
- Key finding documented: over-retrieval needed for HNSW + quantization
Covers the full implementation session: Phase 1-5 execution, 4 bugs found
and fixed, recall validation findings, JMH benchmarks, and final artifact
summary (12 source files, 10 test files, 11 commits).
These are local planning/review documents that should not be part of
the Lucene contribution.
An unvetted codec should not be randomly injected into the entire
Lucene test suite. TurboQuant compatibility is validated by its own
BaseKnnVectorsFormatTestCase extension.
The sandbox module is the appropriate home for new experimental codecs
that have not yet been community-vetted. This follows the precedent
set by FaissKnnVectorsFormat.

- Move source and tests to org.apache.lucene.sandbox.codecs.turboquant
- Update module-info.java and SPI registrations for both modules
- Update benchmark-jmh imports
- Remove @nightly from TestTurboQuantRecall (3s total, not slow)
- Update CHANGES.txt to reference sandbox module
@benwtrent
Copy link
Copy Markdown
Member

How does turboquant's performance compare with Lucene's existing quantization techniques? They are honestly very similar, though I would think that Lucene's lends itself more to faster inner-product than TQ.

@xande
Copy link
Copy Markdown
Author

xande commented Mar 31, 2026

@benwtrent - I still have to do proper benchmarking in terms of performance, and there are a couple more optimizations to do. Though my primary motivation is reducing the memory footprint for high-dimensional vectors. Recall of 0.935 at 4-bit on Cohere V3 is quite impressive.

@benwtrent
Copy link
Copy Markdown
Member

benwtrent commented Mar 31, 2026

recall  latency(ms)  netCPU  avgCpuCount     nDoc  topK  fanout  maxConn  beamWidth  quantized  index(s)  index_docs/s  force_merge(s)  num_segments  index_size(MB)  vec_disk(MB)  vec_RAM(MB)  indexType
 0.875        0.858   0.855        0.996  1000000    10     100       32        250    -4 bits    203.69       4909.49          168.11             1         3349.44      3311.157      381.470       HNSW

Thats the last run @mccullocht did for Lucene's OSQ technique (1M vectors, would need to do the exact same data set for Apples to apples).

I realize performance apples to apples will take way more work (panama vector APIs, etc.). I am more concerned about recall, and I am not sure TQ will provide any significant recall improvement itself.

The main thing I think that OSQ might be missing is some random rotation for non-guassian component vectors (which are an anomaly for the modern models). But, adding that to the existing OSQ for Lucene would be a snap (though careful thought would be needed as that could be a significant performance burden with very little benefit for many users).

It would be good to just do a "flat" index to remove any HNSW noise.

@xande
Copy link
Copy Markdown
Author

xande commented Mar 31, 2026

Interestingly enough, at Amazon we were putting a lot of bets on OSQ, though on internal datasets we did not see meaningful recall improvement vs non-OSQ - low enough for us not to use it. I am planning to run more benchmarks to see how TQ compares.

@benwtrent
Copy link
Copy Markdown
Member

And sorry for immediately asking for more ;). Thank you for the initiative and initial contribution. I do think there are things to learn from TQ

@mccullocht
Copy link
Copy Markdown
Contributor

I've played with TQ a bit over the last week and wrote a less sophisticated implementation covering 1 and 2 bit encodings. I came to the conclusion that there was a small recall improvement on modern embeddings (voyage-3.5 in my case). I think testing the flat case is a good idea in terms of an upper bound improvement.

One worry I have with TQ in Lucene is related to per-segment overhead at query time. The transforms can be addressed by pushing it up to the query layer, but an efficient scoring implementation would likely use lookup tables that are expensive to compute and may not have a good implementation on panama depending on how well Vector.shuffle() is implemented.

@mccullocht
Copy link
Copy Markdown
Contributor

IIUC this is an implementation covers TurboQuantMSE which minimizes MSE and not TurboQuantProd that minimizes inner product distance and would require a second random projection on the MSE encoded residuals.

@mikemccand
Copy link
Copy Markdown
Member

Wow, what an impressive genai example! I also know nearly nothing about TQ, and only scratch surfaces in understanding OSQ. I am curious how the two compare. E.g. does OSQ also not alter the quantization per-segment (merge of flat vectors could optimized copyBytes (the hardest function in the world to implement correctly/performantly!))? Do we get a 3 bit option with OSQ?

Thank you @xande for preserving the iterations (separate commits) as you stepped through the plan with Kiro. Is the original plan/prompting visible somewhere here? I wish we all would preserve all prompts/plans/soul context docs -- they should be treated like source code. Imaging finding an exotic bug in this Codec some time in the future and being able to look back at how the prompts were written, how Kiro iterated, etc., to gain insight. Also, it would help us all learn how to use genai if we were better about sharing prompts / steering docs. Today, genai is a lonely endeavor -- what little human contact we had in a team / our craft is being replaced with solo time with your genai. Kinda like putting on your Apple Vision Pro. Genai is missing good tooling/culture to enable human to human collaboration/learning.

I'd love to see ROC-type curves using luceneutil's knnPerfTest.py, showing tradeoff of latency vs recall as you turn the "try harder" query time knob (oversample? fanout?) for all of Lucene's vector codecs (core OSQ/BBQ/etc., but also including Faiss and JVector!).

[A side rant: it's weird that nobody talks about precision of our vector queries, I guess because that's a lot more work to measure (you need an annotated corpus that marks pairs of query/index vectors with at least relevant/irrelevant binary classification), and, it's really measuring the model that generated the embeddings. So, we drastically simplify, assume the model is perfection, all vectors are precisely relevant if they are close, and only measure recall.]

Thats the last run @mccullocht did for Lucene's OSQ technique (1M vectors, would need to do the exact same data set for Apples to apples).

+1 to do as apples/apples comparison as we can. But what corpus was this @benwtrent? (@mccullocht later mentioned voyage-3.5 but I want to confirm the results you listed). It's interesting how different each corpus is -- I wish there were some way to visualize these massive-dimension vectors. High dimension math can be crazy counterintuitive!

I realize performance apples to apples will take way more work (panama vector APIs, etc.). I am more concerned about recall, and I am not sure TQ will provide any significant recall improvement itself.

I think recall, total CPU, wall-clock-time-with-many-cores, and effective hot RAM required (e.g. 2nd reranking phase is a big penalty there) at query time, and then also indexing performance, are all important when comparing the many vector Codecs we have now. Maybe we can just submit a bunch of competitors to ann-benchmarks? Hmm maybe one can run their own ann-benchmarks instance (using their GitHub repo)?

@mikemccand
Copy link
Copy Markdown
Member

One worry I have with TQ in Lucene is related to per-segment overhead at query time. The transforms can be addressed by pushing it up to the query layer, but an efficient scoring implementation would likely use lookup tables that are expensive to compute and may not have a good implementation on panama depending on how well Vector.shuffle() is implemented.

Isn't it one global transform (not per segment) in this PR? Or would we want to change that to per-segment, to increase randomness/protection against unlucky rotation choice?

These highly concurrent SIMD shuffle instruction (like VPSHUFB) implementations in silicon are wild: https://claude.ai/share/226e0cf6-0d44-4602-804d-b2449777621e

@tveasey
Copy link
Copy Markdown

tveasey commented Apr 1, 2026

Wow, what an impressive genai example! I also know nearly nothing about TQ, and only scratch surfaces in understanding OSQ. I am curious how the two compare. E.g. does OSQ also not alter the quantization per-segment (merge of flat vectors could optimized copyBytes (the hardest function in the world to implement correctly/performantly!))? Do we get a 3 bit option with OSQ?

So they are actually very similar in conception. If you notice Equation (4) in their paper is almost exactly equal to our initialisation procedure. The only difference is they allow for a non-uniform grid. For example, whereas for 2 bits we put centroids at [−1.493, -0.498, 0.498, 1.493] they put them at [-1.51, -0.453, 0.453, 1.51]. If you abandon uniform grid spacing you can no longer implement the dot product via integer arithmetic. This is actually a huge performance hit, IIRC we get 4-8x performance vs float arithmetic for well crafted SIMD variants of low bit integer dot products. The final implementation for TurboQuant is table lookup (centroid positions) followed by floating point arithmetic.

They also do a bias correction based on QJL. Since we optimise the dot product in the direction of document vector in our full implementation I don't think this will actually help OSQ, but I will try this out.

@mccullocht
Copy link
Copy Markdown
Contributor

E.g. does OSQ also not alter the quantization per-segment (merge of flat vectors could optimized copyBytes (the hardest function in the world to implement correctly/performantly!))? Do we get a 3 bit option with OSQ?

OSQ centers the vectors -- during a segment build it computes the mean vector then quantizes v - c. This adds a data dependency that requires re-encoding vectors as you merge. You could operate OSQ without centering but the resulting quantized vectors would be a less accurate representation on average. Requantizing OSQ is likely cheaper than the transform in TQ, but TQ could more easily discard the original full fidelity vector field.

We could support any value in [1,8] for OSQ, but efficiently unpacking for comparisons can be a real challenge. This PR is packing 3 bits as 8 values in 3 consecutive bytes. I can think of an efficient 128 bit implementation of this that would work on x86 and ARM but AVX/AVX512 are not amenable to the approach that I am thinking of.

Isn't it one global transform (not per segment) in this PR? Or would we want to change that to per-segment, to increase randomness/protection against unlucky rotation choice?

This PR is doing one global transform. If we use a transform per segment, then we will have to re-quantize vectors during merge so it would be more complicated/expensive than copyBytes. I personally have not examined the effect of the random seed in a rigorous way but it is plausible that some transforms would be "better" than others in some measurable way like minimizing MSE.

If you abandon uniform grid spacing you can no longer implement the dot product via integer arithmetic. This is actually a huge performance hit, IIRC we get 4-8x performance vs float arithmetic for well crafted SIMD variants of low bit integer dot products. The final implementation for TurboQuant is table lookup (centroid positions) followed by floating point arithmetic.

For this you'd take a totally different approach -- probably something that looks more like distance computation for product quantization since it uses a codebook in a similar way. This involves generating lookup tables that can be quite large (8KB+) and you would not want to repreat this process on every segment. It can still be very fast but it almost certainly won't be as fast as OSQ's arithmetic comparisons.

@tveasey
Copy link
Copy Markdown

tveasey commented Apr 2, 2026

For this you'd take a totally different approach -- probably something that looks more like distance computation for product quantization since it uses a codebook in a similar way. This involves generating lookup tables that can be quite large (8KB+) and you would not want to repreat this process on every segment. It can still be very fast but it almost certainly won't be as fast as OSQ's arithmetic comparisons.

The paper simply proposes DeQuant and doing things in the original vector space as I read it. This is driven I suspect by targeting mainly GPU where you really want to use matmul operations. I agree that for CPU ANN you'd probably want more per query PQ codebook approach, but AFAIK getting these fast requires imposing fairly significant limitations on table sizes which I'm not sure this satisfies. In fairness, I haven't looked at this topic in detail so maybe there are other tricks available. My expectation would be the better route to squeeze more accuracy is use residual quantisation better, since one can centre both the query and document vectors w.r.t. different arbitrary centroids.

@benwtrent
Copy link
Copy Markdown
Member

I ran the 400k cohere v3 set on my machine, TQ does seem to provide nicer quantization mechanics on low bits. However, its really tricky to get performance right with LUTs for all CPUs in Java land :/

recall  latency(ms)  netCPU  avgCpuCount  quantized  visited  index(s)  index_docs/s  force_merge(s)  index_size(MB)  vec_disk(MB)  vec_RAM(MB)
 0.651        1.630   1.624        0.997     1 bits     9683    121.66       3287.88           62.24         1679.17      1616.669       54.169
 0.772        2.051   2.037        0.993     2 bits     8591    126.26       3168.17           80.05         1722.58      1665.497      102.997
 0.905        2.442   2.424        0.993     4 bits     8061    149.44       2676.59           93.41         1820.22      1763.916      201.416
 0.979        3.448   3.431        0.995     8 bits     7960    187.99       2127.78          220.49         2015.69      1959.229      396.729
 0.989        2.303   2.269        0.985         no     7952    197.75       2022.74          ```

@xande
Copy link
Copy Markdown
Author

xande commented Apr 3, 2026

Thanks everyone for the incredibly thorough feedback. This is exactly the kind of review I was hoping for. I spent some time with Kiro going back through the paper, the Elastic OSQ blog posts, and the actual codec source. Hopefully this will help the community to decide on the path forward (which could be to be improving OSQ with learning from TQ). Also the Lucene communmity might find discussions and findings in this reference implementation helpful - https://github.com/tonbistudio/turboquant-pytorch

On TurboQuantMSE vs TurboQuantProd (@mccullocht)

You're correct — this implements TurboQuantMSE only. The paper's TurboQuantProd variant (Algorithm 2, Section 3.2) applies a (b-1)-bit MSE quantizer followed by a 1-bit QJL transform on the residual vector r = x - dequant(quant(x)), yielding an unbiased inner product estimator at total bit-width b.

Based on the community comments on https://github.com/tonbistudio/turboquant-pytorch?tab=readme-ov-file#v3-improvements-community-informed, MSE alone is enough.

On the non-uniform grid and performance (@tveasey, @mccullocht)

@tveasey's analysis is spot-on and this is the most important tradeoff. I had Kiro pull up the actual centroid values from BetaCodebook.java and the OSQ initialization intervals from the Elastic blog to compare side by side. The TQ centroids are Lloyd-Max optimal for the Gaussian distribution, which means non-uniform spacing. For b=2, TQ places centroids at [-1.510, -0.453, 0.453, 1.510] vs OSQ's initial uniform grid at [-1.493, -0.498, 0.498, 1.493] (before OSQ's per-vector interval refinement). The spacing difference is small but the consequence is large: you can't decompose the dot product into integer arithmetic plus scalar corrections.

The current implementation uses a float gather-multiply-accumulate loop: for each packed byte, extract indices, look up centroid values from a 2^b-entry table, multiply by the rotated query coordinate, accumulate. At b=4 this is a 16-entry LUT. The JMH numbers (313K ops/s at d=4096, ~3.2µs per candidate) reflect JVM auto-vectorization of this loop, not explicit SIMD [benchmak data has to be corss-checked].

For comparison, @tveasey notes OSQ gets 4-8x performance vs float arithmetic with well-crafted SIMD integer dot products. That's a real and significant gap. The question is whether the recall advantage at low bits justifies the scoring cost, or whether the other properties (no calibration, byte-copy merge, streaming) matter enough for specific workloads.

The framing: TQ is not going to beat OSQ on scoring throughput (even though it looks like there are ways to make it faster). Its value proposition is the combination of (a) no calibration overhead, (b) merge-friendly architecture, and (c) high-dimension support beyond 1024d. If scoring throughput is the bottleneck, OSQ wins.

On the Panama path forward: @mccullocht is right that Vector.shuffle() is the key operation for LUT-based scoring. For b=4, the 16-entry centroid table fits in a single 512-bit register (16 × float32), and vpermps (AVX-512) or tbl (NEON) can do the gather. Whether the JVM's Panama implementation of VectorShuffle actually emits these instructions efficiently is an open question.

On the PQ-style codebook approach (@mccullocht, @tveasey)

@mccullocht suggested a PQ-like approach with precomputed per-query lookup tables. For TQ at b=4, the current implementation already does something similar: the centroid table is the LUT, and the inner loop is sum += query[i] * centroids[packed_index[i]]. The 16-entry table is small enough to stay in L1 cache.

The larger question @tveasey raises about whether the better route is residual quantization with centering is interesting. I had Kiro review the OSQ blog's "Refining the quantization interval" section in detail. It thinks "OSQ's per-vector interval optimization — the coordinate descent that minimizes dot product error weighted toward the document vector direction (with λ=0.1) — is clever (Kiro thinks so, I agree:)) and exploits structure that TQ's data-oblivious approach deliberately ignores. For data that has exploitable per-dimension structure, OSQ should win on recall-per-bit."

On the global transform and per-segment overhead (@mikemccand, @mccullocht)

Yes, this PR uses one global transform per field (seed derived from field name hash). This means:

  • No per-segment rotation storage
  • Merge can byte-copy quantized data (not yet implemented, but architecturally supported since all segments share the same rotation)
  • Query rotation is done once per query, not per-segment

@mccullocht asked whether per-segment transforms would increase randomness. Theoretically yes — different random rotations would give independent quantization errors across segments, which could improve recall after merging results from multiple segments. But the cost is losing byte-copy merge and needing a separate query rotation per segment.

On the AI collaboration process (@mikemccand)

The original prompts and iteration history are preserved in the commit history and in the TURBOQUANT_.md files in the repo root. The session log (TURBOQUANT_SESSION_LOG.md) documents the full interaction timeline including the expert review simulation rounds that shaped the architecture. The implementation plan (TURBOQUANT_IMPLEMENTATION_PLAN.md) shows the phased approach with gate conditions. I agree these could be treated like source artifacts, though I could not find a good place for them and deleted in one of the latest commits. In fact I'm using Kiro right now to research these responses, cross-referencing those TURBOQUANT.md plans, the paper, the OSQ blog posts, and the actual source code to fact-check my claims before posting.

On adding random rotation to OSQ (@benwtrent, @tveasey)

@benwtrent mentioned that adding random rotation to OSQ for non-Gaussian components would be straightforward. This could be the most interesting direction (and MAY explain our low recall values on internal data sets): OSQ with a Hadamard pre-rotation would get the best of both worlds — the rotation homogenizes coordinate distributions (helping with non-Gaussian embeddings), and OSQ's per-vector interval optimization + integer arithmetic scoring handles the rest. The rotation would add some latency per query at d=4096 but OSQ's scoring would remain fast. Worth exploring?

-@xande, co-authored with Kiro/Opus 4.6

@tveasey
Copy link
Copy Markdown

tveasey commented Apr 4, 2026

This could be the most interesting direction (and MAY explain our low recall values on internal data sets)

This is probably the first thing to try out for your problem case. For some data sets the uplift can be dramatic. This is usually the only reason we see bad accuracy with OSQ, although low dimension vectors are also less compressible. Note though that many general embedding models produce fairly normal components for which you get small benefits from this technique.

If the internal cases are model based and you own the training process, there are very standard methods (such as spreading losses or simulating quantisation with straight-through estimators) which give dramatic improvements in compressibility of vectors and should probably be introduced to the training pipeline.

Another challenging case is CLIP style models, which suffer from a modality gap (between text queries and image documents for example) unless it is trained out. These pose additional challenges for quantisation, which would ideally be query distribution aware. These days CLIP models perform less well than VLMs in relevance for multimodal retrieval, so if your internal use cases use CLIP architecture these might be something to explore.

The spacing difference is small but the consequence is large

Definitely non-uniform grids will retain more information about the original vectors. We can't really compete in this respect.

However, the crux comes down to what are you trying to optimise for here. If it is purely recall for maximum compression then great. In this case ideally we'd consider the query distribution too, but this seems a reasonable first step. However, typically you care about recall vs latency. In this case I think we'd have to prove out that we can implement the distance operation competitively with integer arithmetic.

There are two other factors in the mix:

  1. Centring the data with respect to multiple centroids helps (a reasonable amount). We can account for the centroid contribution exactly in OSQ so we gain by having to quantise smaller magnitude vectors (PQ implementations often use a similar trick). We also (although not in the original blog) worked out an asymmetric centring approach. This means we can quantise the query fewer times in the course of a search. This is good for the recall latency trade off. We get the centres for free with IVF style indices but not HNSW. One can potentially get cheap approximate clustering from the HNSW graph. We've also thought HNSW may benefit from some clustering in the mix at construction time.
  2. You get massive wins from reranking. For example, for recall@10 reranking 30 vectors often gives you 15-20 pt improvement in recall. This will often be the most important lever at your disposal. So, if you want to hit high recalls, working out how to maximise rerank performance is probably the most valuable thing. For example, jumping straight to loading float32 vectors for reranking is significantly suboptimal. Even just bfloat16 reranking is probably essentially as good, but has 2x memory throughput.

@shbhar
Copy link
Copy Markdown

shbhar commented Apr 6, 2026

Hi all,

I also work at Amazon but in Advertising (a different org from @xande's product search) where we also use Lucene heavily and I've been independently iterating on a TurboQuant implementation for the past week (also with Kiro CLI) with various tests and benchmarks but with a different approach from this PR where I focus more on 1 bit TQ & never store fp32 vectors to get the full compression benefits. I have early comparison benchmark data below and more benchmarks are still running (after some bug fixes) and I'll update as more data comes in. Branch: https://github.com/shbhar/lucene/tree/turboquant-v1

Below is a summary of the approach and current results co-authored with claude 4.6

Design philosophy: quantized-only storage

The implementation stores only quantized data on disk — no float32 vectors alongside. The key insight is that TQ's quantization quality after FWHT rotation is so good especially at higher dimensions that float32 rescoring is unnecessary for most use cases. Instead, I rescore directly from quantized data using centroid lookup tables.

For users who do need higher-fidelity rescore, this doesn't require baking float32 into the codec — they can store vectors in a separate field and use Lucene's existing FloatVectorSimilarityValuesSource with a rescore query. A TurboQuant-specific ValuesSource that rescores from a higher-bitwidth TQ field (e.g., search at 1-bit, rescore at 8-bit) is also straightforward to add. I started with the approach of baking in higher bitwidth TQ vectors for rescoring in the codec, but reverted it after realizing that search+rescore with quantized only vectors is actually viable and the choice is better left to users.

The storage impact at 1M × 4096d (Qwen3-8B embeddings):

Method Index Size Compression
Float32 15,674 MB
BBQ-1bit (quantized + float32) 16,178 MB 0.97× (larger!)
TQ-1bit (quantized only) 539 MB 29×
TQ-4bit (quantized only) 2,000 MB 7.8×
TQ-8bit (quantized only) 3,951 MB

This is possible because the centroid LUT rescore operates on the same packed bytes as search — score = queryNorm × docNorm × Σ centroid[bin[i]] × rotatedQuery[i]. No float32 vectors needed.

Dimension scaling: "blessing of dimensionality"

100K MS MARCO passages, Qwen3-8B, MRL-truncated to test lower dimensions. HNSW (M=32, beamWidth=100, topK=10, fanout=50).

First, raw quantized recall without rescore — this isolates quantization quality during HNSW graph traversal:

Dim Float32 SQ-4bit SQ-8bit BBQ-1bit TQ-1bit TQ-4bit TQ-8bit
128 0.973 0.846 0.962 0.477 0.303 0.799 0.949
512 0.971 0.822 0.903 0.632 0.587 0.878 0.956
1024 0.973 0.853 0.901 0.721 0.720 0.907 0.960
4096 0.977 0.804 0.838 0.722 0.807 0.929 0.962

SQ degrades with dimension (0.846→0.804 at 4-bit, 0.962→0.838 at 8-bit) while TQ improves (0.303→0.807 at 1-bit, 0.799→0.929 at 4-bit).

At ≥1024d, TQ-1bit already matches BBQ-1bit (0.720 vs 0.721), and TQ-4bit/8bit surpass their SQ counterparts at ≥512d. At 4096d, TQ-1bit (0.807) surpasses SQ-4bit (0.804) without any rescore.

With 5× rescore, the pattern holds and latency tells the full story:

Dim SQ-4bit+rsc SQ-8bit+rsc BBQ-1bit+rsc TQ-1bit+rsc TQ-4bit+rsc
R@10 lat R@10 lat R@10 lat R@10 lat R@10 lat
128 0.994 0.70ms 0.997 0.78ms 0.845 0.71ms 0.621 0.74ms 0.995 0.76ms
512 0.994 0.93ms 0.997 1.39ms 0.934 0.79ms 0.902 0.96ms 0.998 1.28ms
1024 0.996 1.29ms 0.997 1.89ms 0.964 0.85ms 0.962 1.00ms 0.997 1.88ms
2048 0.978 1.62ms 0.989 2.87ms 0.931 1.08ms 0.990 1.18ms 0.988 2.67ms
4096 0.986 2.77ms 0.991 5.06ms 0.951 1.48ms 0.997 1.53ms 0.997 5.19ms

At 4096d, TQ-1bit+rsc achieves the highest recall (0.997) at the lowest latency (1.53ms) — beating SQ-8bit+rsc (0.991, 5.06ms) on both recall and latency, at 30× less storage.

At ≥1024d, TQ-1bit+rsc matches or exceeds BBQ-1bit+rsc, and at ≥2048d it's so strong on every axis (recall, latency, storage) that higher bit widths may not even be necessary depending on the dataset. BBQ-1bit+rsc plateaus at 0.951 because its float32 rescore can't recover from the binary quantization error at high dimensions, while TQ-1bit's centroid LUT rescore continues improving.

Early Benchmark data: 1M ASIN Vectors, Qwen3-8B, 4096d

1M Amazon product ASINs encoded with Qwen3-Embedding-8B at native 4096 dimensions. 5K real product search queries. HNSW (M=32, beamWidth=200, topK=10, fanout=50, forceMerge to 1 segment).

Note: TQ-4bit and TQ-8bit had a int overflow bug during merge path in this run that caused multi-segment indices — latency and forceMerge times for those methods were wrong and are omitted. Recall and index size should be mostly unaffected. Re-running all methods with the fix and more SQ options - will update when done (~12 hours)

Method R@10 Lat (ms) Docs/s FMerge (s) Index MB
Float32 0.931 0.84 5,516 386 15,674
BBQ-1bit 0.771 0.57 12,986 418 16,178
BBQ-1bit+5×rsc 0.977 1.49 12,847 419 16,178
BBQ-1bit+10×rsc 0.987 2.30 13,463 432 16,178
TQ-1bit 0.741 0.48 19,924 103 539
TQ-1bit+5×rsc 0.968 1.56 19,966 283 539
TQ-1bit+10×rsc 0.985 2.43 20,054 205 539
TQ-4bit† 0.856 8,386 2,000
TQ-4bit+5×rsc† 0.975 8,252 2,000
TQ-8bit† 0.946 6,825 3,951
TQ-8bit+5×rsc† 0.984 6,680 3,951
SQ-4bit+5×rsc 0.977 2.39 9,276 509 17,642

† Latency/merge omitted — int overflow caused multi-segment index. Recall is largely valid.

TQ-1bit+10×rsc (0.985) matches BBQ-1bit+10×rsc (0.987) at 30× less storage (539 MB vs 16,178 MB), with 1.5× faster indexing (20K docs/s vs BBQ's 13K) and 4× faster forceMerge (103s vs BBQ's 418s).

Early Benchmark data: 5M Cohere Wikipedia, 1024d

From a previous run where I had reliable TQ-1bit numbers (but TQ-4bit/8bit were affected by the same int overflow bug at this scale — fixed and re-running this one too, will update):

Method R@10 Latency Docs/s FMerge (s) Index MB
Float32 0.929 1.51ms 13,543 2,194 20,021
SQ-4bit 0.854 0.82ms 19,553 1,310 22,540
SQ-4bit+5×rsc 0.984 2.90ms 19,323 1,303 22,540
SQ-8bit 0.917 1.17ms 15,769 1,737 24,979
SQ-8bit+5×rsc 0.984 3.92ms 15,689 1,759 24,981
BBQ-1bit 0.636 0.69ms 23,384 1,170 20,743
BBQ-1bit+5×rsc 0.945 2.37ms 22,919 1,150 20,744
TQ-1bit 0.605 0.68ms 30,381 1,748 1,063
TQ-1bit+5×rsc 0.928 2.49ms 30,460 1,247 1,064
TQ-4bit† 0.887 18,075 2,827
TQ-4bit+5×rsc† 0.986 18,341 2,802
TQ-8bit† 0.941 1,237 5,176

† Latency/merge omitted — int overflow caused force merge failure and multi-segment index. Recall is largely valid.

TQ-1bit+5×rescore (0.928) matches Float32 (0.929) at 19× less storage, with 2.2× faster indexing (30K docs/s vs Float32's 13.5K). ForceMerge is slower in this run (no byte-copy merge yet) — the re-run with byte-copy should improve this significantly.

Addressing the discussion points

On LUT scoring performance (@tveasey, @mccullocht): I handle each bit width differently. 1-bit uses weighted popcount via Integer.bitCount() (which JITs to hardware popcnt/NEON cnt) — no per-dimension LUT gather, just two centroid constants and bit-plane accumulation. 8-bit quantizes both query and centroid LUT to int8 for SIMD dot product via the Vector API (ByteVectorIntVector widening multiply-accumulate). 2-bit and 4-bit currently use scalar nibble extraction — dedicated SIMD paths are planned. My focus has been mostly on 1-bit where TQ's compression advantage is highest.

On TurboQuantMSE vs TurboQuantProd (@mccullocht): My Initial tests showed QJL correction (TurboQuantProd) was expensive and at least on smaller datasets and higher bit widths — even at qjlBits=1024 — did not improve recall much. I reverted it for now, but this needs to be explored further, especially at 1-bit on larger datasets where the correction might matter more.

On byte-copy merge (@mikemccand, @mccullocht): Implemented. All segments share the same rotation seed, so same-codec merge copies packed bytes directly via MemorySegment.copy — no re-quantization. I also found and fixed an int overflow bug in the merge path (new byte[numVecs * bytesPerVec] overflows at 5.1GB) — currently validating the fix in the benchmark re-runs.

On ROC curves (@mikemccand): The dimension scaling table above includes latency alongside recall for each method. I haven't done a full overSample sweep yet but plan to.

Benchmarks in progress

After fixing the merge overflow bug and implementing byte-copy merge, I'm re-running everything to get clean numbers:

Run Dataset What it proves
Cohere 5M v5 5M × 1024d Wikipedia Clean TQ-4bit/8bit numbers — v3 had a merge overflow bug that left these as multi-segment indices with unreliable results. Also includes byte-copy merge for faster forceMerge times.
ASIN 1M v2 1M × 4096d product embeddings High-dimension real-world data (Qwen3-8B). The v1 run had TQ-8bit merge overflow — this will give correct TQ-8bit numbers at 4096d for the first time.
Cohere 5M v4 5M × 1024d Wikipedia Same as v5 but without byte-copy merge — provides a comparison point for merge optimization impact.

Will update this thread with full results as they complete. Please feel free to ask for any other benchmark/comparison and I can include them in future runs.

What I'd like to contribute

  1. 1-bit encoding — not in the current PR (which starts at 2-bit). At high dimensions, 1-bit is so effective that higher bit widths may not even be necessary — 20-30× compression, popcount-based SIMD scorer, and as shown above viable recall with rescoring on multiple datasets that seems to improve with more dimensions
  2. Centroid LUT rescore — the key to eliminating float32 storage and getting true compression
  3. Byte-copy merge implementation with overflow fixes (under testing)
  4. SIMD scorer paths for 1-bit (popcount) and 8-bit (int8 dot product)
  5. Large-scale benchmarks (5M Cohere, 1M ASIN 4096d, multi-dim scaling/analysis with MRL to know where TQ overtakes existing quantization in Lucene)

Happy to collaborate on merging these into the existing PR or opening a companion PR. The implementations are complementary — @xande's has some advantages mine doesn't:

  • 3-bit encoding (I only do 1/2/4/8)
  • Block-diagonal Hadamard for non-power-of-2 dims like 768 (I require power-of-2 right now)
  • All 4 similarity functions (DOT_PRODUCT, COSINE, EUCLIDEAN, MIP — I only support DOT_PRODUCT/MIP)

Code: https://github.com/shbhar/lucene/tree/turboquant-v1


@mccullocht
Copy link
Copy Markdown
Contributor

I think that OSQ does a good job of quantizing the same distribution as TurboQuant and the rotation is the real secret sauce here.

I tried naively rotating the entire cohere dataset used by luceneutil and running it through exhaustive recall tests. I also hacked it so that we could operate OSQ without centering so we can discuss data blind performance.

format centered + unrotated (default) centered + rotated uncentered + rotated uncentered + unrotated
OSQ1 0.660770 0.692940 0.647200 0.624450
OSQ2 0.784210 0.811090 0.762620 0.741080
OSQ4 0.884820 0.938840 0.923630 0.880600
OSQ8 0.975710 0.993360 0.992320 0.977370
Naive BQ - - 0.671290 0.647720

The results suggest that rotations and centering are two great tastes that taste great together. I can see how the data blind property is really desirable though and it's possible to make changes to OSQ to allow this mode of operation, it seems that rotation-only performs pretty well at high bit rates.

I tried a similar exercise with voyage vectors and rotating showed no improvement but centering still helped. I'm going to follow up with someone about distribution/rotation.

@xande @shbhar I suggest you try rotating your vectors first and test recall with OSQ. It should be easy enough to perform the rotation outside of Lucene, and if there's significant value we can figure out if or how we'd like to internalize this.

@tveasey
Copy link
Copy Markdown

tveasey commented Apr 7, 2026

rotation is the real secret sauce here

One thing to note on rotations is that block diagonal with random permutation performs basically as well as dense with block sizes of 64 x 64. This might be competitive with Hadamard given we can perform the 64d matmul extremely fast with SIMD.

Regarding this comment

I'm not sure we should mix up the choice of reranking representation and retrieval representation. There is also something a bit odd about these results: TQ-1 bit shows fairly consistent worse recall than BBQ (even without reranking). This makes me wonder if the accelerated distance calculation is just different. If we were to make the argument to use quantised representations for reranking on accuracy grounds (care is needed here) this suggests we should just use higher bit OSQ.

@shbhar
Copy link
Copy Markdown

shbhar commented Apr 7, 2026

Here are the updated results after some fixes - TQ8Bit is also comparable to SQ8bit now, but TQ4bit is still slower than SQ4bit (but at comparable recall and much smaller index size). Changes from last run:

  1. Merge int overflow fix
  2. Byte copy merge
  3. int8 SIMD scorer for TQ8Bit

Benchmark data: 5M Cohere Wikipedia, 1024d

5M Cohere Wikipedia vectors at 1024 dimensions. HNSW (M=32, beamWidth=100, topK=10, fanout=50, forceMerge to 1 segment).

Method R@10 Latency Docs/s FMerge (s) Index MB
Float32 0.928 1.60ms 11,881 2,267 20,020
SQ-4bit 0.855 0.86ms 19,025 1,347 22,538
SQ-4bit+5×rsc 0.986 3.17ms 19,133 1,361 22,539
SQ-8bit 0.918 1.23ms 15,784 1,791 24,980
SQ-8bit+5×rsc 0.987 4.13ms 15,318 1,776 24,979
BBQ-1bit 0.631 0.82ms 22,840 1,208 20,743
BBQ-1bit+5×rsc 0.944 2.59ms 23,257 1,215 20,744
TQ-1bit 0.608 0.77ms 30,381 1,897 1,064
TQ-1bit+5×rsc 0.928 2.83ms 30,405 1,532 1,066
TQ-4bit 0.852 1.19ms 17,915 2,459 2,851
TQ-4bit+5×rsc 0.983 3.91ms 17,195 3,858 2,849
TQ-8bit 0.902 0.94ms 19,369 2,659 5,293
TQ-8bit+5×rsc 0.983 3.13ms 18,471 3,110 5,293

TQ-1bit vs BBQ-1bit: TQ-1bit (0.608) nearly matches BBQ-1bit (0.631) raw recall, but at 19× less storage (1,064 MB vs 20,743 MB) and 1.3× faster indexing (30K vs 23K docs/s). With 5× rescore, TQ-1bit+rsc (0.928) nearly matches BBQ-1bit+rsc (0.944) — the gap narrows further at higher dimensions (see ASIN 4096d below).

TQ-8bit vs SQ-8bit: TQ-8bit (0.902) nearly matches SQ-8bit (0.918) raw recall at 0.94ms vs 1.23ms latency (1.3× faster), with 4.7× less storage (5,293 MB vs 24,980 MB). With 5× rescore, TQ-8bit+rsc (0.983) nearly matches SQ-8bit+rsc (0.987) at 24% less latency (3.13ms vs 4.13ms).

Benchmark data: 1M ASIN Vectors, Qwen3-8B, 4096d

1M Amazon product ASINs encoded with Qwen3-Embedding-8B at native 4096 dimensions. 5K real product search queries. HNSW (M=32, beamWidth=200, topK=10, fanout=50, forceMerge to 1 segment).

Method R@10 Lat (ms) Docs/s FMerge (s) Index MB
Float32 0.925 0.85 5,430 389 15,674
SQ-4bit 0.883 0.84 9,287 504 17,642
SQ-4bit+5×rsc 0.978 2.47 9,049 512 17,642
SQ-8bit 0.902 1.31 6,752 680 19,595
SQ-8bit+5×rsc 0.980 3.91 6,947 672 19,595
BBQ-1bit 0.774 0.56 13,200 417 16,178
BBQ-1bit+5×rsc 0.976 1.53 13,235 419 16,178
BBQ-1bit+10×rsc 0.987 2.26 13,120 422 16,178
TQ-1bit 0.741 0.49 20,020 210 539
TQ-1bit+5×rsc 0.970 1.58 19,376 353 539
TQ-1bit+10×rsc 0.984 2.47 19,460 352 538
TQ-4bit 0.866 1.33 8,226 1,397 2,000
TQ-4bit+5×rsc 0.974 4.18 8,181 1,409 2,000
TQ-8bit 0.908 0.94 10,537 667 3,954
TQ-8bit+5×rsc 0.974 2.89 10,564 724 3,954

TQ-8bit beats SQ-8bit on every axis at 4096d: higher recall (0.908 vs 0.902), lower latency (0.94ms vs 1.31ms), faster indexing (10.5K vs 6.8K docs/s), comparable merge time (667s vs 680s), and 5× smaller index (3,954 MB vs 19,595 MB).

TQ-1bit+10×rsc (0.984) matches BBQ-1bit+10×rsc (0.987) at 30× less storage (538 MB vs 16,178 MB), with 1.5× faster indexing and 2× faster merge.

If anyone wants to replicate these results:
lucene: https://github.com/shbhar/lucene/tree/turboquant-v1 (commit for these tests: 62cce04)
luceneutil: https://github.com/shbhar/luceneutil/tree/turboquant-v1 (hacky - make sure to run fp32 as the first one as tq ground truth depends on that index) - commit for this test: 911b947dab95a6164ba38c875eca5a1d72298b3c

@shbhar
Copy link
Copy Markdown

shbhar commented Apr 7, 2026

@mccullocht let me see if I can try this

I suggest you try rotating your vectors first and test recall with OSQ. It should be easy enough to perform the rotation outside of Lucene, and if there's significant value we can figure out if or how we'd like to internalize this.

@tveasey I think the TQ-1bit vs BBQ-1bit comparison is misleading because the storage is very different. BBQ "1-bit" keeps the full float32 vectors alongside the binary quantization (16,178 MB at 4096d) and uses per-vector scalar correction terms during search. TQ-1bit, as implemented in my poc branch, only stores quantized data which is 539 MB total (30x less). Today there's no way to opt out of float32 storage in Lucene's quantized formats (What are the reasons for that? I assume because it needs to keep float32 vectors around for requantization during segment merges?). This TQ approach gives users the choice: if they want float32 reranking they can store vectors in a separate field and use a rescore query completely ignoring the dequant rescoring path - and the choice is meaningful because the built in rescore path does appear to have usable recall depending on dataset/dimensions.

The recall gap also depends heavily on the dataset. Cohere and ASIN datasets have very different distributions (mean pairwise cosine similarity 0.23 vs 0.50), so comparing TQ-1bit recall across them isn't very informative. When we compare on the same dataset (100K MS MARCO passages, Qwen3-8B), TQ-1bit matches BBQ-1bit at 1024d (0.720 vs 0.721) and beats it at 4096d (0.807 vs 0.722) and even float32 rescore on bbq doesnt help it beat tq1bit+rescore in recall at 4096d (0.951 vs 0.997). So it appears that at higher dimensions TQ has a much bigger advantage. See my first post for the blessing of dimensionality section which has a test on multi-dimension on same dataset utilizing MRL - though I'm not sure if MRL property itself biases this comparison against BBQ somehow (MRL is increasingly common though).

Some of these results look too good to be true honestly (30x smaller index size & still comparable or better latency+recall than bbq at 4096d for MS MARCO 100k? Really?), but I havent been able to find a bug so far, would be great if it can be reviewed & reproduced independently

@mccullocht
Copy link
Copy Markdown
Contributor

TQ-1bit, as implemented in my poc branch, only stores quantized data which is 539 MB total (30x less). Today there's no way to opt out of float32 storage in Lucene's quantized formats (What are the reasons for that? I assume because it needs to keep float32 vectors around for requantization during segment merges?).

OSQ centers the vectors -- it computes a mean vector within the segment, then quantizes the residual vector v - c. When building a new segment this operation is repeated. Lucene104SQFlatVectorsFormat could be pretty easily extended to use a zero vector instead of computing a mean and make merging a byte copy.

@shbhar
Copy link
Copy Markdown

shbhar commented Apr 8, 2026

Following @mccullocht's advice, had Kiro run centered benchmarks (subtract global mean, re-normalize) on both datasets. Also throwing in x86 (r7i.8xlarge) together to make sure there are no arch specific discrepancies

ASIN 1M × 4096d (Qwen3-8B, centered, M=32, beamWidth=200, topK=10, fanout=50)

Method R@10 (arm) Lat arm R@10 (x86) Lat x86 Index MB
Float32 0.937 0.80ms 0.938 1.27ms 15,678
OSQ-1bit 0.792 0.53ms 0.791 0.63ms 16,183
OSQ-1bit+5×rsc 0.981 1.58ms 0.981 1.61ms 16,183
TQ-1bit 0.790 0.45ms 0.793 0.58ms 541
TQ-1bit+5×rsc 0.976 1.51ms 0.976 1.81ms 541

Cohere 5M × 1024d (centered, M=32, beamWidth=100, topK=10, fanout=50)

Method R@10 (arm) Lat arm R@10 (x86) Lat x86 Index MB
Float32 0.925 1.45ms 0.929 1.42ms 20,001
OSQ-1bit 0.622 0.68ms 0.629 0.71ms 20,726
OSQ-1bit+5×rsc 0.940 2.42ms 0.946 2.32ms 20,726
TQ-1bit 0.651 0.66ms 0.644 0.55ms 1,048
TQ-1bit+5×rsc 0.953 2.24ms 0.951 2.05ms 1,048

Centering impact (same dataset cross-run deltas)

Method Dataset Uncentered Centered Δ
TQ-1bit ASIN 4096d 0.741 0.790 +0.049
TQ-1bit Cohere 1024d 0.608 0.651 +0.043
OSQ-1bit ASIN 4096d 0.774 0.792 +0.018
OSQ-1bit Cohere 1024d 0.631 0.622 -0.009

Like you suspected, it does appear that centering has a big impact on TQ and with it TQ-1bit essentially ties OSQ-1bit on ASIN (0.790 vs 0.792 on graviton but flipped for intel - probably just run to run indeterminism) and beats it on Cohere on both runs (0.651 vs 0.622 on graviton and 0.644 vs 0.629 on intel).

@mccullocht
Copy link
Copy Markdown
Contributor

@shbhar did you try rotating the vectors first and then testing recall with OSQ 1 bit?

@shbhar
Copy link
Copy Markdown

shbhar commented Apr 13, 2026

@mccullocht I had to re-run some benchmarks but now tested all four combinations of centering × rotation on both datasets. To make the comparison fair, I disabled OSQ's internal per-segment centering (forcing centroid to zero) so both methods are fully data-blind at segment level. All benchmarks: aarch64 r7g.8xlarge, single segment (forceMerge), M=32, topK=10, fanout=50, 1-bit search R@10.

ASIN 1M × 4096d

Unrotated Rotated
Uncentered OSQ=0.740, TQ=0.743 OSQ=0.754, TQ=0.737*
Centered OSQ=0.792, TQ=0.791 OSQ=0.805, TQ=0.786*

Cohere 5M × 1024d

Unrotated Rotated
Uncentered OSQ=0.596, TQ=0.607 OSQ=0.622, TQ=0.604*
Centered OSQ=0.630, TQ=0.648 OSQ=0.659, TQ=0.643*

*Double rotated - maybe not a noop and seems to hurt TQ (more floating point error?)

OSQ(no centering): centroid forced to zero to disable per-segment centering but per-vector optimizeIntervals() + 14-byte corrections still run.

Key observations:

  • On centered+unrotated data, TQ ≈ OSQon ASIN (0.791 vs 0.792) and TQ wins on Cohere (0.648 vs 0.630). This is the most practical comparison since centering is a common preprocessing step, and neither method applies external rotation.
  • Rotation does help OSQ OSQ gains +0.013/+0.029 from rotation
  • On raw data, TQ has a slight edge (0.743 vs 0.740 ASIN, 0.607 vs 0.596 Cohere) — TQ's rotation partially compensates for the lack of centering. But this is close to cross-run variance of 0.005 so the edge appears to be slight on these datasets

@mccullocht
Copy link
Copy Markdown
Contributor

It seems to me like we may want to open a couple of issues:

  • Add a datablind scalar quantized flat vectors format that is OSQ based. I took a stab at this in the existing SQ format and it's quite tricky so we may want it to be a completely separate format.
  • Add support for rotation. This could be a stand-alone class folks can use to rotate their vectors before ingestion and querying, or it could be internalized in the quantized codec (I think there are arguments for both methods)

@shbhar
Copy link
Copy Markdown

shbhar commented Apr 13, 2026

And a couple of other updates:

I was trying QJL again to realize the "2 stage process" for NN the paper mentions but QJL correction at least at 1bit adds so much variance that it makes recall much worse. So I'm not sure how to incorporate QJL and make the Turboquant prod version actually work like the paper describes (maybe it can work at higher bit widths). This is an observation by others as well, like this blog in KV compression context:

https://dejan.ai/blog/turboquant/

The QJL stage produces a correction term that makes the inner product estimator unbiased. But when you add this correction back to the reconstructed vector and store it in the KV cache, you’re injecting noise into the vector itself. The result: cosine similarity dropped to 0.69 (terrible) and the model produced garbage.

I also got hold of a much larger ASIN production dataset to test on (also 4096d), and it seems much more well behaved (pairwise mean cosine similarity of ~0.05 vs ~0.5 of the previous ASIN dataset I was using). Below test is with 1M random sample with 10K random sample queries. Graph: M=32, efConstruction=200. Search: fanout=50, topK=10. Force-merged to 1 segment. r7g.8xlarge (32 vCPU Graviton3, 256 GiB).

Method Recall@10 Latency (ms) Search QPS Index docs/s Index time (s) Merge time (s) Size (MB)
Float32 0.972 0.810 3,670 1,016 984 272.5 15,682
OSQ-1bit 0.836 0.569 10,921 973 1,028 91.6 16,188
OSQ-2bit 0.850 0.694 9,044 1,001 999 110.6 16,674
OSQ-4bit 0.915 0.795 6,964 1,018 982 143.6 17,650
OSQ-8bit 0.950 1.242 5,036 1,021 979 198.6 19,604
OSQ-1bit-nocenter 0.840 0.578 10,512 976 1,025 95.1 16,188
OSQ-2bit-nocenter 0.843 0.609 9,244 991 1,009 108.2 16,674
OSQ-4bit-nocenter 0.916 0.800 7,350 1,022 978 136.1 17,650
OSQ-8bit-nocenter 0.945 1.239 5,311 1,021 979 188.3 19,604
TQ-1bit 0.840 0.483 15,278 1,129 886 65.5 544
TQ-2bit 0.879 3.537 12,647 1,040 962 79.1 1,036
TQ-4bit 0.930 3.212 6,280 1,188 842 159.2 2,007
TQ-8bit 0.960 0.876 8,194 1,190 840 122.0 3,960

Note: I haven't made any attempt to optimize 2bit/4bit latency for TQ yet, so they can be ignored. But 1bit is already ~15% faster and 8bit ~30% faster (I have a couple of other optimization ideas, will have Kiro try them later)

@mccullocht
Copy link
Copy Markdown
Contributor

@shbhar for performance some things to consider:

  • The core cost at 8 bits is going to be int8 dot product in both cases. OSQ has a more complicated/expensive correction to the dot product but I doubt it's 30% slower than the TQ correction. Vector incubator code paths are very touchy and it's possible that if you swapped the dot product implementations either way (your dot product on OSQ or lucene dot product on TQ) that performance would even out. It's hard to say unless they use the same interface and support the same set of vector sizes.
  • Consider microbenchmarking quantization. Quantization for the query will be repeated across all segments and in real workloads if your quantizer is significantly slower it may show up. Keep in mind that OSQ quantization is not SIMD accelerated today but could easily be.

@shbhar
Copy link
Copy Markdown

shbhar commented Apr 13, 2026

Add a datablind scalar quantized flat vectors format that is OSQ based. I took a stab at this in the existing SQ format and it's quite tricky so we may want it to be a completely separate format.

Makes sense, I guess with this we can also give the option to user to not store fp32 vectors at all

Add support for rotation. This could be a stand-alone class folks can use to rotate their vectors before ingestion and querying, or it could be internalized in the quantized codec (I think there are arguments for both methods)

Do you mean add rotation inside existing OSQ and still keep optimizeIntervals + 14 byte correction? I have not yet done an experiment where I disable optimizeIntervals+14byte correction and see if it still helps over rotation alone. My understanding of why QJL correction also doesnt work is that while it reduces per vector reconstruction error/MSE, we dont directly care about reconstruction error and only care about ranking via dot products in NN - so if any correction adds more noise in ranking it might actually make recall worse.

I will try disabling optimizeIntervals/14byte correction next and see how OSQ with correction vs without correction performs on precentered and prerotated vectors to see if it helps, hurts or is neutral for recall.

@mccullocht
Copy link
Copy Markdown
Contributor

Do you mean add rotation inside existing OSQ and still keep optimizeIntervals + 14 byte correction?

Yes. Theoretically you could implement this as a generic wrapper codec that rotates at write and read time.

You can't/shouldn't remove the code in OSQ that corrects the integer dot product using the 16 byte footer, it doesn't make any more sense than returning the int8 dot product directly for TQ.

@shbhar
Copy link
Copy Markdown

shbhar commented Apr 13, 2026

You can't/shouldn't remove the code in OSQ that corrects the integer dot product using the 16 byte footer, it doesn't make any more sense than returning the int8 dot product directly for TQ.

You are right - I guess I can only disable optimizeInterval() and see if the per vector footer still provides benefit or not on already rotated vectors. So on centered+rotated data, if OSQ recall without optimizeInterval() is also same as TQ on centered+unrotated vectors (avoiding double rotation), then maybe that would be an argument for the remaining TQ approach over just adding rotation as an option in OSQ? Does that make sense?

But I guess the footer is negligible storage cost and optimizeIntervals is cheap anyway (right?) so might not be worth optimizing for and you are making the argument that it is better to just add rotation & datablind options to OSQ (to be able to drop fp32). Let me look into that.

One thing I've ignored completely so far is the power of 2 limitation of the current FWHT implementation, so with padding/block-diagonal etc approaches I am not sure what happens to recall/performance on something like 1536d vectors

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants