Skip to content

Commit 6afb877

Browse files
committed
Added bounds checking
1 parent 2b0842f commit 6afb877

File tree

5 files changed

+49
-32
lines changed

5 files changed

+49
-32
lines changed

include/pgvector/halfvec.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,18 @@ class HalfVector {
5757

5858
friend std::ostream& operator<<(std::ostream& os, const HalfVector& value) {
5959
os << "[";
60-
for (size_t i = 0; i < value.value_.size(); i++) {
60+
// TODO use std::views::enumerate for C++23
61+
size_t i = 0;
62+
for (auto v : value.value_) {
6163
if (i > 0) {
6264
os << ",";
6365
}
6466
#if __STDCPP_FLOAT16_T__
65-
os << value.value_[i];
67+
os << v;
6668
#else
67-
os << static_cast<float>(value.value_[i]);
69+
os << static_cast<float>(v);
6870
#endif
71+
i++;
6972
}
7073
os << "]";
7174
return os;

include/pgvector/pqxx.hpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,14 @@ template<> struct string_traits<pgvector::Vector> {
6565
size_t here = 0;
6666
here += pqxx::into_buf(buf.subspan(here), "[", c);
6767

68-
for (size_t i = 0; i < values.size(); i++) {
68+
// TODO use std::views::enumerate for C++23
69+
size_t i = 0;
70+
for (auto v : values) {
6971
if (i != 0) {
7072
here += pqxx::into_buf(buf.subspan(here), ",", c);
7173
}
72-
here += pqxx::into_buf(buf.subspan(here), values[i], c);
74+
here += pqxx::into_buf(buf.subspan(here), v, c);
75+
i++;
7376
}
7477

7578
here += pqxx::into_buf(buf.subspan(here), "]", c);
@@ -134,11 +137,14 @@ template<> struct string_traits<pgvector::HalfVector> {
134137
size_t here = 0;
135138
here += pqxx::into_buf(buf.subspan(here), "[", c);
136139

137-
for (size_t i = 0; i < values.size(); i++) {
140+
// TODO use std::views::enumerate for C++23
141+
size_t i = 0;
142+
for (auto v : values) {
138143
if (i != 0) {
139144
here += pqxx::into_buf(buf.subspan(here), ",", c);
140145
}
141-
here += pqxx::into_buf(buf.subspan(here), static_cast<float>(values[i]), c);
146+
here += pqxx::into_buf(buf.subspan(here), static_cast<float>(v), c);
147+
i++;
142148
}
143149

144150
here += pqxx::into_buf(buf.subspan(here), "]", c);
@@ -232,14 +238,15 @@ template<> struct string_traits<pgvector::SparseVector> {
232238
size_t here = 0;
233239
here += pqxx::into_buf(buf.subspan(here), "{", c);
234240

241+
// TODO use std::views::zip for C++23
235242
for (size_t i = 0; i < nnz; i++) {
236243
if (i != 0) {
237244
here += pqxx::into_buf(buf.subspan(here), ",", c);
238245
}
239246
// cast to avoid undefined behavior and require less buffer space
240-
here += pqxx::into_buf(buf.subspan(here), static_cast<unsigned int>(indices[i]) + 1, c);
247+
here += pqxx::into_buf(buf.subspan(here), static_cast<unsigned int>(indices.at(i)) + 1, c);
241248
here += pqxx::into_buf(buf.subspan(here), ":", c);
242-
here += pqxx::into_buf(buf.subspan(here), values[i], c);
249+
here += pqxx::into_buf(buf.subspan(here), values.at(i), c);
243250
}
244251

245252
here += pqxx::into_buf(buf.subspan(here), "}/", c);
@@ -259,11 +266,12 @@ template<> struct string_traits<pgvector::SparseVector> {
259266

260267
size_t size = 0;
261268
size += pqxx::size_buffer("{");
269+
// TODO use std::views::zip for C++23
262270
for (size_t i = 0; i < nnz; i++) {
263271
size += pqxx::size_buffer(",");
264-
size += pqxx::size_buffer(static_cast<unsigned int>(indices[i]) + 1);
272+
size += pqxx::size_buffer(static_cast<unsigned int>(indices.at(i)) + 1);
265273
size += pqxx::size_buffer(":");
266-
size += pqxx::size_buffer(values[i]);
274+
size += pqxx::size_buffer(values.at(i));
267275
}
268276
size += pqxx::size_buffer("}/");
269277
size += pqxx::size_buffer(dimensions);

include/pgvector/sparsevec.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ class SparseVector {
3030
dimensions_ = static_cast<int>(value.size());
3131

3232
// do not reserve capacity for indices/values since likely many zeros
33-
for (size_t i = 0; i < value.size(); i++) {
34-
float v = value[i];
33+
// TODO use std::views::enumerate for C++23
34+
size_t i = 0;
35+
for (auto v : value) {
3536
if (v != 0) {
3637
indices_.push_back(static_cast<int>(i));
3738
values_.push_back(v);
3839
}
40+
i++;
3941
}
4042
}
4143

@@ -85,13 +87,14 @@ class SparseVector {
8587

8688
friend std::ostream& operator<<(std::ostream& os, const SparseVector& value) {
8789
os << "{";
90+
// TODO use std::views::zip for C++23
8891
for (size_t i = 0; i < value.indices_.size(); i++) {
8992
if (i > 0) {
9093
os << ",";
9194
}
92-
os << value.indices_[i] + 1;
95+
os << value.indices_.at(i) + 1;
9396
os << ":";
94-
os << value.values_[i];
97+
os << value.values_.at(i);
9598
}
9699
os << "}/";
97100
os << value.dimensions_;

include/pgvector/vector.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,14 @@ class Vector {
4141

4242
friend std::ostream& operator<<(std::ostream& os, const Vector& value) {
4343
os << "[";
44-
for (size_t i = 0; i < value.value_.size(); i++) {
44+
// TODO use std::views::enumerate for C++23
45+
size_t i = 0;
46+
for (auto v : value.value_) {
4547
if (i > 0) {
4648
os << ",";
4749
}
48-
os << value.value_[i];
50+
os << v;
51+
i++;
4952
}
5053
os << "]";
5154
return os;

test/pqxx_test.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ void test_vector(pqxx::connection &conn) {
4242

4343
pqxx::result res = tx.exec("SELECT embedding FROM items ORDER BY embedding <-> $1", {embedding2});
4444
assert_equal(res.size(), 3);
45-
assert_equal(res[0][0].as<pgvector::Vector>(), embedding2);
46-
assert_equal(res[1][0].as<pgvector::Vector>(), embedding);
47-
assert_equal(res[2][0].as<std::optional<pgvector::Vector>>().has_value(), false);
45+
assert_equal(res.at(0).at(0).as<pgvector::Vector>(), embedding2);
46+
assert_equal(res.at(1).at(0).as<pgvector::Vector>(), embedding);
47+
assert_equal(res.at(2).at(0).as<std::optional<pgvector::Vector>>().has_value(), false);
4848
}
4949

5050
void test_halfvec(pqxx::connection &conn) {
@@ -57,9 +57,9 @@ void test_halfvec(pqxx::connection &conn) {
5757

5858
pqxx::result res = tx.exec("SELECT half_embedding FROM items ORDER BY half_embedding <-> $1", {embedding2});
5959
assert_equal(res.size(), 3);
60-
assert_equal(res[0][0].as<pgvector::HalfVector>(), embedding2);
61-
assert_equal(res[1][0].as<pgvector::HalfVector>(), embedding);
62-
assert_equal(res[2][0].as<std::optional<pgvector::HalfVector>>().has_value(), false);
60+
assert_equal(res.at(0).at(0).as<pgvector::HalfVector>(), embedding2);
61+
assert_equal(res.at(1).at(0).as<pgvector::HalfVector>(), embedding);
62+
assert_equal(res.at(2).at(0).as<std::optional<pgvector::HalfVector>>().has_value(), false);
6363
}
6464

6565
void test_bit(pqxx::connection &conn) {
@@ -72,9 +72,9 @@ void test_bit(pqxx::connection &conn) {
7272

7373
pqxx::result res = tx.exec("SELECT binary_embedding FROM items ORDER BY binary_embedding <~> $1", pqxx::params{embedding2});
7474
assert_equal(res.size(), 3);
75-
assert_equal(res[0][0].as<std::string>(), embedding2);
76-
assert_equal(res[1][0].as<std::string>(), embedding);
77-
assert_equal(res[2][0].as<std::optional<std::string>>().has_value(), false);
75+
assert_equal(res.at(0).at(0).as<std::string>(), embedding2);
76+
assert_equal(res.at(1).at(0).as<std::string>(), embedding);
77+
assert_equal(res.at(2).at(0).as<std::optional<std::string>>().has_value(), false);
7878
}
7979

8080
void test_sparsevec(pqxx::connection &conn) {
@@ -87,9 +87,9 @@ void test_sparsevec(pqxx::connection &conn) {
8787

8888
pqxx::result res = tx.exec("SELECT sparse_embedding FROM items ORDER BY sparse_embedding <-> $1", {embedding2});
8989
assert_equal(res.size(), 3);
90-
assert_equal(res[0][0].as<pgvector::SparseVector>(), embedding2);
91-
assert_equal(res[1][0].as<pgvector::SparseVector>(), embedding);
92-
assert_equal(res[2][0].as<std::optional<pgvector::SparseVector>>().has_value(), false);
90+
assert_equal(res.at(0).at(0).as<pgvector::SparseVector>(), embedding2);
91+
assert_equal(res.at(1).at(0).as<pgvector::SparseVector>(), embedding);
92+
assert_equal(res.at(2).at(0).as<std::optional<pgvector::SparseVector>>().has_value(), false);
9393
}
9494

9595
void test_sparsevec_nnz(pqxx::connection &conn) {
@@ -126,8 +126,8 @@ void test_stream_to(pqxx::connection &conn) {
126126
stream.write_values(pgvector::Vector{{4, 5, 6}});
127127
stream.complete();
128128
pqxx::result res = tx.exec("SELECT embedding FROM items ORDER BY id");
129-
assert_equal(res[0][0].as<std::string>(), "[1,2,3]");
130-
assert_equal(res[1][0].as<std::string>(), "[4,5,6]");
129+
assert_equal(res.at(0).at(0).as<std::string>(), "[1,2,3]");
130+
assert_equal(res.at(1).at(0).as<std::string>(), "[4,5,6]");
131131
}
132132

133133
void test_precision(pqxx::connection &conn) {
@@ -138,7 +138,7 @@ void test_precision(pqxx::connection &conn) {
138138
tx.exec("INSERT INTO items (embedding) VALUES ($1)", {embedding});
139139
tx.exec("SET extra_float_digits = 3");
140140
pqxx::result res = tx.exec("SELECT embedding FROM items ORDER BY id DESC LIMIT 1");
141-
assert_equal(res[0][0].as<pgvector::Vector>(), embedding);
141+
assert_equal(res.at(0).at(0).as<pgvector::Vector>(), embedding);
142142
}
143143

144144
void test_vector_to_string() {

0 commit comments

Comments
 (0)