Skip to content

Commit a220909

Browse files
linyihaievenyag
authored andcommitted
feat: Add vec_mul function. (GreptimeTeam#5205)
1 parent 65e3f80 commit a220909

File tree

4 files changed

+239
-0
lines changed

4 files changed

+239
-0
lines changed

src/common/function/src/scalars/vector.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ mod distance;
1717
pub(crate) mod impl_conv;
1818
mod scalar_add;
1919
mod scalar_mul;
20+
mod vector_mul;
2021

2122
use std::sync::Arc;
2223

@@ -38,5 +39,8 @@ impl VectorFunction {
3839
// scalar calculation
3940
registry.register(Arc::new(scalar_add::ScalarAddFunction));
4041
registry.register(Arc::new(scalar_mul::ScalarMulFunction));
42+
43+
// vector calculation
44+
registry.register(Arc::new(vector_mul::VectorMulFunction));
4145
}
4246
}
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
// Copyright 2023 Greptime Team
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::borrow::Cow;
16+
use std::fmt::Display;
17+
18+
use common_query::error::{InvalidFuncArgsSnafu, Result};
19+
use common_query::prelude::Signature;
20+
use datatypes::prelude::ConcreteDataType;
21+
use datatypes::scalars::ScalarVectorBuilder;
22+
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23+
use nalgebra::DVectorView;
24+
use snafu::ensure;
25+
26+
use crate::function::{Function, FunctionContext};
27+
use crate::helper;
28+
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
29+
30+
const NAME: &str = "vec_mul";
31+
32+
/// Multiplies corresponding elements of two vectors.
33+
///
34+
/// # Example
35+
///
36+
/// ```sql
37+
/// SELECT vec_to_string(vec_mul("[1, 2, 3]", "[1, 2, 3]")) as result;
38+
///
39+
/// +---------+
40+
/// | result |
41+
/// +---------+
42+
/// | [1,4,9] |
43+
/// +---------+
44+
///
45+
/// ```
46+
#[derive(Debug, Clone, Default)]
47+
pub struct VectorMulFunction;
48+
49+
impl Function for VectorMulFunction {
50+
fn name(&self) -> &str {
51+
NAME
52+
}
53+
54+
fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
55+
Ok(ConcreteDataType::binary_datatype())
56+
}
57+
58+
fn signature(&self) -> Signature {
59+
helper::one_of_sigs2(
60+
vec![
61+
ConcreteDataType::string_datatype(),
62+
ConcreteDataType::binary_datatype(),
63+
],
64+
vec![
65+
ConcreteDataType::string_datatype(),
66+
ConcreteDataType::binary_datatype(),
67+
],
68+
)
69+
}
70+
71+
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
72+
ensure!(
73+
columns.len() == 2,
74+
InvalidFuncArgsSnafu {
75+
err_msg: format!(
76+
"The length of the args is not correct, expect exactly two, have: {}",
77+
columns.len()
78+
),
79+
}
80+
);
81+
82+
let arg0 = &columns[0];
83+
let arg1 = &columns[1];
84+
85+
let len = arg0.len();
86+
let mut result = BinaryVectorBuilder::with_capacity(len);
87+
if len == 0 {
88+
return Ok(result.to_vector());
89+
}
90+
91+
let arg0_const = as_veclit_if_const(arg0)?;
92+
let arg1_const = as_veclit_if_const(arg1)?;
93+
94+
for i in 0..len {
95+
let arg0 = match arg0_const.as_ref() {
96+
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
97+
None => as_veclit(arg0.get_ref(i))?,
98+
};
99+
100+
let arg1 = match arg1_const.as_ref() {
101+
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
102+
None => as_veclit(arg1.get_ref(i))?,
103+
};
104+
105+
if let (Some(arg0), Some(arg1)) = (arg0, arg1) {
106+
ensure!(
107+
arg0.len() == arg1.len(),
108+
InvalidFuncArgsSnafu {
109+
err_msg: format!(
110+
"The length of the vectors must match for multiplying, have: {} vs {}",
111+
arg0.len(),
112+
arg1.len()
113+
),
114+
}
115+
);
116+
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
117+
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
118+
let vec_res = vec1.component_mul(&vec0);
119+
120+
let veclit = vec_res.as_slice();
121+
let binlit = veclit_to_binlit(veclit);
122+
result.push(Some(&binlit));
123+
} else {
124+
result.push_null();
125+
}
126+
}
127+
128+
Ok(result.to_vector())
129+
}
130+
}
131+
132+
impl Display for VectorMulFunction {
133+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134+
write!(f, "{}", NAME.to_ascii_uppercase())
135+
}
136+
}
137+
138+
#[cfg(test)]
139+
mod tests {
140+
use std::sync::Arc;
141+
142+
use common_query::error;
143+
use datatypes::vectors::StringVector;
144+
145+
use super::*;
146+
147+
#[test]
148+
fn test_vector_mul() {
149+
let func = VectorMulFunction;
150+
151+
let vec0 = vec![1.0, 2.0, 3.0];
152+
let vec1 = vec![1.0, 1.0];
153+
let (len0, len1) = (vec0.len(), vec1.len());
154+
let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))]));
155+
let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
156+
157+
let err = func
158+
.eval(FunctionContext::default(), &[input0, input1])
159+
.unwrap_err();
160+
161+
match err {
162+
error::Error::InvalidFuncArgs { err_msg, .. } => {
163+
assert_eq!(
164+
err_msg,
165+
format!(
166+
"The length of the vectors must match for multiplying, have: {} vs {}",
167+
len0, len1
168+
)
169+
)
170+
}
171+
_ => unreachable!(),
172+
}
173+
174+
let input0 = Arc::new(StringVector::from(vec![
175+
Some("[1.0,2.0,3.0]".to_string()),
176+
Some("[8.0,10.0,12.0]".to_string()),
177+
Some("[7.0,8.0,9.0]".to_string()),
178+
None,
179+
]));
180+
181+
let input1 = Arc::new(StringVector::from(vec![
182+
Some("[1.0,1.0,1.0]".to_string()),
183+
Some("[2.0,2.0,2.0]".to_string()),
184+
None,
185+
Some("[3.0,3.0,3.0]".to_string()),
186+
]));
187+
188+
let result = func
189+
.eval(FunctionContext::default(), &[input0, input1])
190+
.unwrap();
191+
192+
let result = result.as_ref();
193+
assert_eq!(result.len(), 4);
194+
assert_eq!(
195+
result.get_ref(0).as_binary().unwrap(),
196+
Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice())
197+
);
198+
assert_eq!(
199+
result.get_ref(1).as_binary().unwrap(),
200+
Some(veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice())
201+
);
202+
assert!(result.get_ref(2).is_null());
203+
assert!(result.get_ref(3).is_null());
204+
}
205+
}

tests/cases/standalone/common/function/vector/vector.result

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,27 @@ SELECT vec_to_string(parse_vec('[]'));
2222
| [] |
2323
+--------------------------------------+
2424

25+
SELECT vec_to_string(vec_mul('[1.0, 2.0]', '[3.0, 4.0]'));
26+
27+
+---------------------------------------------------------------+
28+
| vec_to_string(vec_mul(Utf8("[1.0, 2.0]"),Utf8("[3.0, 4.0]"))) |
29+
+---------------------------------------------------------------+
30+
| [3,8] |
31+
+---------------------------------------------------------------+
32+
33+
SELECT vec_to_string(vec_mul(parse_vec('[1.0, 2.0]'), '[3.0, 4.0]'));
34+
35+
+--------------------------------------------------------------------------+
36+
| vec_to_string(vec_mul(parse_vec(Utf8("[1.0, 2.0]")),Utf8("[3.0, 4.0]"))) |
37+
+--------------------------------------------------------------------------+
38+
| [3,8] |
39+
+--------------------------------------------------------------------------+
40+
41+
SELECT vec_to_string(vec_mul('[1.0, 2.0]', parse_vec('[3.0, 4.0]')));
42+
43+
+--------------------------------------------------------------------------+
44+
| vec_to_string(vec_mul(Utf8("[1.0, 2.0]"),parse_vec(Utf8("[3.0, 4.0]")))) |
45+
+--------------------------------------------------------------------------+
46+
| [3,8] |
47+
+--------------------------------------------------------------------------+
48+

tests/cases/standalone/common/function/vector/vector.sql

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,9 @@ SELECT vec_to_string(parse_vec('[1.0, 2.0]'));
33
SELECT vec_to_string(parse_vec('[1.0, 2.0, 3.0]'));
44

55
SELECT vec_to_string(parse_vec('[]'));
6+
7+
SELECT vec_to_string(vec_mul('[1.0, 2.0]', '[3.0, 4.0]'));
8+
9+
SELECT vec_to_string(vec_mul(parse_vec('[1.0, 2.0]'), '[3.0, 4.0]'));
10+
11+
SELECT vec_to_string(vec_mul('[1.0, 2.0]', parse_vec('[3.0, 4.0]')));

0 commit comments

Comments
 (0)