Skip to content

Commit 02bca34

Browse files
committed
Add nbits for SQ
1 parent 30857aa commit 02bca34

File tree

4 files changed

+43
-6
lines changed

4 files changed

+43
-6
lines changed

src/index/ivf/ivf.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ IvfIndexNode<DataType, IndexType>::TrainInternal(const DataSetPtr dataset, std::
614614
if constexpr (std::is_same<IndexIVFSQWrapper, IndexType>::value) {
615615
const IvfSqConfig& ivf_sq_cfg = static_cast<const IvfSqConfig&>(*cfg);
616616
auto nlist = MatchNlist(rows, ivf_sq_cfg.nlist.value());
617+
auto nbits = MatchNbits(rows, ivf_sq_cfg.nbits.value());
617618

618619
const bool use_elkan = ivf_sq_cfg.use_elkan.value_or(true);
619620

@@ -623,7 +624,7 @@ IvfIndexNode<DataType, IndexType>::TrainInternal(const DataSetPtr dataset, std::
623624

624625
DataFormatEnum data_format = DataType2EnumHelper<DataType>::value;
625626

626-
auto result = IndexIvfFactory::create_for_sq(qzr.get(), dim, nlist, ivf_sq_cfg, data_format, metric.value());
627+
auto result = IndexIvfFactory::create_for_sq(qzr.get(), dim, nlist, nbits, ivf_sq_cfg, data_format, metric.value());
627628
if (!result.has_value()) {
628629
return result.error();
629630
}

src/index/ivf/ivf_config.h

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,35 @@ class ScannConfig : public IvfFlatConfig {
230230
}
231231
};
232232

233-
class IvfSqConfig : public IvfConfig {};
233+
class IvfSqConfig : public IvfConfig {
234+
public:
235+
CFG_INT nbits;
236+
KNOHWERE_DECLARE_CONFIG(IvfSqConfig) {
237+
KNOWHERE_CONFIG_DECLARE_FIELD(nbits).description("nbits").set_default(8).for_train().set_range(4, 8);
238+
}
239+
240+
Status
241+
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
242+
// check the base class
243+
const auto base_status = IvfConfig::CheckAndAdjust(param_type, err_msg);
244+
if (base_status != Status::success) {
245+
return base_status;
246+
}
247+
248+
// check our parameters
249+
if (param_type == PARAM_TYPE::TRAIN) {
250+
// check nbits
251+
if (nbits.has_value()) {
252+
if (nbits.value() != 4 && nbits.value() != 6 && nbits.value() != 8) {
253+
std::string msg = "invalid nbits : " + std::to_string(nbits.value()) +
254+
", optional values are [4, 6, 8]";
255+
return HandleError(err_msg, msg, Status::invalid_args);
256+
}
257+
}
258+
}
259+
return Status::success;
260+
}
261+
};
234262

235263
class IvfBinConfig : public IvfConfig {
236264
Status

src/index/ivf/ivf_wrapper.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,15 +239,23 @@ IndexIvfFactory::create_for_pq(faiss::IndexFlat* qzr_raw_ptr,
239239

240240
expected<std::unique_ptr<IndexIVFSQWrapper>>
241241
IndexIvfFactory::create_for_sq(faiss::IndexFlat* qzr_raw_ptr,
242-
const faiss::idx_t d, const size_t nlist, const IvfSqConfig& ivf_sq_cfg,
242+
const faiss::idx_t d, const size_t nlist, const size_t nbits, const IvfSqConfig& ivf_sq_cfg,
243243
const DataFormatEnum raw_data_format, const faiss::MetricType metric) {
244244
// the index factory string is either `IVFx,SQ,Refine(y)` or `IVFx,SQ`,
245245
// depends on the refine parameters
246246

247247
// create IndexIVFSQ
248248
// Index does not own qzr
249+
faiss::ScalarQuantizer::QuantizerType quantizer_type;
250+
if (nbits == 4) {
251+
quantizer_type = faiss::ScalarQuantizer::QuantizerType::QT_4bit;
252+
} else if (nbits == 6) {
253+
quantizer_type = faiss::ScalarQuantizer::QuantizerType::QT_6bit;
254+
} else {
255+
quantizer_type = faiss::ScalarQuantizer::QuantizerType::QT_8bit;
256+
}
249257
auto index = std::make_unique<faiss::IndexIVFScalarQuantizer>(
250-
qzr_raw_ptr, d, nlist, faiss::ScalarQuantizer::QuantizerType::QT_8bit, metric);
258+
qzr_raw_ptr, d, nlist, quantizer_type, metric);
251259

252260
// create a refiner index, if needed
253261
std::unique_ptr<faiss::Index> idx_final;

src/index/ivf/ivf_wrapper.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ using IndexIVFPQWrapper = IndexIVFWrapper<faiss::IndexIVFPQ>;
9696
using IndexIVFSQWrapper = IndexIVFWrapper<faiss::IndexIVFScalarQuantizer>;
9797

9898
class IndexIvfFactory {
99-
public:
99+
public:
100100
static expected<std::unique_ptr<IndexIVFPQWrapper>>
101101
create_for_pq(faiss::IndexFlat* qzr_raw_ptr,
102102
const faiss::idx_t d, const size_t nlist, const size_t nbits, const IvfPqConfig& ivf_pq_cfg,
@@ -105,7 +105,7 @@ class IndexIvfFactory {
105105

106106
static expected<std::unique_ptr<IndexIVFSQWrapper>>
107107
create_for_sq(faiss::IndexFlat* qzr_raw_ptr,
108-
const faiss::idx_t d, const size_t nlist, const IvfSqConfig& ivf_sq_cfg,
108+
const faiss::idx_t d, const size_t nlist, const size_t nbits, const IvfSqConfig& ivf_sq_cfg,
109109
// this is the data format of the raw data (if the refine is used)
110110
const DataFormatEnum raw_data_format, const faiss::MetricType metric = faiss::METRIC_L2);
111111
};

0 commit comments

Comments
 (0)