Skip to content

Commit 21b8227

Browse files
yutannihilationpetern48paleolimbot
authored
feat(rust/sedona-functions): Implement ST_Azimuth() (#183)
Co-authored-by: Peter Nguyen <[email protected]> Co-authored-by: Dewey Dunnington <[email protected]>
1 parent 69fd05b commit 21b8227

File tree

5 files changed

+259
-0
lines changed

5 files changed

+259
-0
lines changed

python/sedonadb/tests/functions/test_functions.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,29 @@ def test_st_astext(eng, geom):
119119
eng.assert_query_result(f"SELECT ST_AsText({geom_or_null(geom)})", expected)
120120

121121

122+
@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
123+
@pytest.mark.parametrize(
124+
("geom1", "geom2", "expected"),
125+
[
126+
# TODO: PostGIS fails without explicit ::GEOMETRY type cast, but casting
127+
# doesn't work on SedonaDB yet.
128+
# (None, None, None),
129+
("POINT (0 0)", None, None),
130+
(None, "POINT (0 0)", None),
131+
("POINT (0 0)", "POINT (0 0)", None),
132+
("POINT (0 0)", "POINT (1 1)", 0.7853981633974483), # 45 / 180 * PI
133+
("POINT (0 0)", "POINT (-1 -1)", 3.9269908169872414), # 225 / 180 * PI
134+
],
135+
)
136+
def test_st_azimuth(eng, geom1, geom2, expected):
137+
eng = eng.create_or_skip()
138+
eng.assert_query_result(
139+
f"SELECT ST_Azimuth({geom_or_null(geom1)}, {geom_or_null(geom2)})",
140+
expected,
141+
numeric_epsilon=1e-8,
142+
)
143+
144+
122145
@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
123146
@pytest.mark.parametrize(
124147
("geom", "dist", "expected_area"),

rust/sedona-functions/benches/native-functions.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,14 @@ fn criterion_benchmark(c: &mut Criterion) {
138138
benchmark::scalar(c, &f, "native", "st_mmin", LineString(10));
139139
benchmark::scalar(c, &f, "native", "st_mmax", LineString(10));
140140

141+
benchmark::scalar(
142+
c,
143+
&f,
144+
"native",
145+
"st_azimuth",
146+
BenchmarkArgs::ArrayArray(Point, Point),
147+
);
148+
141149
benchmark::aggregate(c, &f, "native", "st_envelope_aggr", Point);
142150
benchmark::aggregate(c, &f, "native", "st_envelope_aggr", LineString(10));
143151

rust/sedona-functions/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub mod st_analyze_aggr;
2626
mod st_area;
2727
mod st_asbinary;
2828
mod st_astext;
29+
mod st_azimuth;
2930
mod st_buffer;
3031
mod st_centroid;
3132
mod st_collect;

rust/sedona-functions/src/register.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ pub fn default_function_set() -> FunctionSet {
6464
crate::st_area::st_area_udf,
6565
crate::st_asbinary::st_asbinary_udf,
6666
crate::st_astext::st_astext_udf,
67+
crate::st_azimuth::st_azimuth_udf,
6768
crate::st_buffer::st_buffer_udf,
6869
crate::st_centroid::st_centroid_udf,
6970
crate::st_dimension::st_dimension_udf,
@@ -127,6 +128,7 @@ pub mod stubs {
127128
pub use crate::predicates::*;
128129
pub use crate::referencing::*;
129130
pub use crate::st_area::st_area_udf;
131+
pub use crate::st_azimuth::st_azimuth_udf;
130132
pub use crate::st_centroid::st_centroid_udf;
131133
pub use crate::st_length::st_length_udf;
132134
pub use crate::st_perimeter::st_perimeter_udf;
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
use arrow_array::builder::Float64Builder;
18+
use arrow_schema::DataType;
19+
use datafusion_common::{error::Result, exec_err};
20+
use datafusion_expr::{
21+
scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility,
22+
};
23+
use geo_traits::{CoordTrait, GeometryTrait, GeometryType, PointTrait};
24+
use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
25+
use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
26+
use std::sync::Arc;
27+
use wkb::reader::Wkb;
28+
29+
use crate::executor::WkbExecutor;
30+
31+
/// ST_Azimuth() scalar UDF
32+
///
33+
/// Stub function for azimuth calculation between two points.
34+
pub fn st_azimuth_udf() -> SedonaScalarUDF {
35+
SedonaScalarUDF::new(
36+
"st_azimuth",
37+
vec![Arc::new(STAzimuth {})],
38+
Volatility::Immutable,
39+
Some(st_azimuth_doc()),
40+
)
41+
}
42+
43+
fn st_azimuth_doc() -> Documentation {
44+
Documentation::builder(
45+
DOC_SECTION_OTHER,
46+
"Returns the azimuth (a clockwise angle measured from north) in radians from geomA to geomB",
47+
"ST_Azimuth (A: Geometry, B: Geometry)",
48+
)
49+
.with_argument("geomA", "geometry: Start point geometry")
50+
.with_argument("geomB", "geometry: End point geometry")
51+
.with_sql_example(
52+
"SELECT degrees(ST_Azimuth(ST_Point(0, 0), ST_Point(1, 1)))",
53+
)
54+
.build()
55+
}
56+
57+
#[derive(Debug)]
58+
struct STAzimuth {}
59+
60+
impl SedonaScalarKernel for STAzimuth {
61+
fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
62+
let matcher = ArgMatcher::new(
63+
vec![ArgMatcher::is_geometry(), ArgMatcher::is_geometry()],
64+
SedonaType::Arrow(DataType::Float64),
65+
);
66+
67+
matcher.match_args(args)
68+
}
69+
70+
fn invoke_batch(
71+
&self,
72+
arg_types: &[SedonaType],
73+
args: &[ColumnarValue],
74+
) -> Result<ColumnarValue> {
75+
let executor = WkbExecutor::new(arg_types, args);
76+
let mut builder = Float64Builder::with_capacity(executor.num_iterations());
77+
executor.execute_wkb_wkb_void(|maybe_start, maybe_end| {
78+
match (maybe_start, maybe_end) {
79+
(Some(start), Some(end)) => match invoke_scalar(start, end)? {
80+
Some(angle) => builder.append_value(angle),
81+
None => builder.append_null(),
82+
},
83+
_ => builder.append_null(),
84+
}
85+
86+
Ok(())
87+
})?;
88+
89+
executor.finish(Arc::new(builder.finish()))
90+
}
91+
}
92+
93+
fn invoke_scalar(start: &Wkb, end: &Wkb) -> Result<Option<f64>> {
94+
match (start.as_type(), end.as_type()) {
95+
(GeometryType::Point(start_point), GeometryType::Point(end_point)) => {
96+
match (start_point.coord(), end_point.coord()) {
97+
// If both geometries are non-empty points, calculate the angle
98+
(Some(start_coord), Some(end_coord)) => Ok(calc_azimuth(
99+
start_coord.x(),
100+
start_coord.y(),
101+
end_coord.x(),
102+
end_coord.y(),
103+
)),
104+
// If either of the points is empty, raise an error.
105+
_ => {
106+
exec_err!("ST_Azimuth expects both arguments to be non-empty POINT geometries")
107+
}
108+
}
109+
}
110+
_ => exec_err!("ST_Azimuth expects both arguments to be non-empty POINT geometries"),
111+
}
112+
}
113+
114+
fn calc_azimuth(start_x: f64, start_y: f64, end_x: f64, end_y: f64) -> Option<f64> {
115+
let dx = end_x - start_x;
116+
let dy = end_y - start_y;
117+
118+
if dx == 0.0 && dy == 0.0 {
119+
return None;
120+
}
121+
122+
let mut angle = dx.atan2(dy);
123+
if angle < 0.0 {
124+
angle += 2.0 * std::f64::consts::PI;
125+
}
126+
127+
Some(angle)
128+
}
129+
130+
#[cfg(test)]
131+
mod tests {
132+
use datafusion_common::scalar::ScalarValue;
133+
use datafusion_expr::ScalarUDF;
134+
use rstest::rstest;
135+
use sedona_schema::datatypes::{WKB_GEOMETRY, WKB_VIEW_GEOMETRY};
136+
use sedona_testing::create::create_scalar;
137+
use sedona_testing::testers::ScalarUdfTester;
138+
139+
use super::*;
140+
141+
#[test]
142+
fn udf_metadata() {
143+
let udf: ScalarUDF = st_azimuth_udf().into();
144+
assert_eq!(udf.name(), "st_azimuth");
145+
assert!(udf.documentation().is_some());
146+
}
147+
148+
#[rstest]
149+
fn udf(
150+
#[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] start_type: SedonaType,
151+
#[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] end_type: SedonaType,
152+
) {
153+
let tester = ScalarUdfTester::new(
154+
st_azimuth_udf().into(),
155+
vec![start_type.clone(), end_type.clone()],
156+
);
157+
158+
assert_eq!(
159+
tester.return_type().unwrap(),
160+
SedonaType::Arrow(DataType::Float64)
161+
);
162+
163+
let start = create_scalar(Some("POINT (0 0)"), &start_type);
164+
let north = create_scalar(Some("POINT (0 1)"), &end_type);
165+
let east = create_scalar(Some("POINT (1 0)"), &end_type);
166+
let south = create_scalar(Some("POINT (0 -1)"), &end_type);
167+
let west = create_scalar(Some("POINT (-1 0)"), &end_type);
168+
let same = create_scalar(Some("POINT (0 0)"), &end_type);
169+
let empty = create_scalar(Some("POINT EMPTY"), &end_type);
170+
171+
let result = tester
172+
.invoke_scalar_scalar(start.clone(), north.clone())
173+
.unwrap();
174+
assert!(matches!(
175+
result,
176+
ScalarValue::Float64(Some(val)) if (val - 0.0).abs() < 1e-12
177+
));
178+
179+
let result = tester
180+
.invoke_scalar_scalar(start.clone(), east.clone())
181+
.unwrap();
182+
assert!(matches!(
183+
result,
184+
ScalarValue::Float64(Some(val)) if (val - std::f64::consts::FRAC_PI_2).abs() < 1e-12
185+
));
186+
187+
let result = tester
188+
.invoke_scalar_scalar(start.clone(), south.clone())
189+
.unwrap();
190+
assert!(matches!(
191+
result,
192+
ScalarValue::Float64(Some(val)) if (val - std::f64::consts::PI).abs() < 1e-12
193+
));
194+
195+
let result = tester
196+
.invoke_scalar_scalar(start.clone(), west.clone())
197+
.unwrap();
198+
assert!(matches!(
199+
result,
200+
ScalarValue::Float64(Some(val)) if (val - (3.0 * std::f64::consts::FRAC_PI_2)).abs() < 1e-12
201+
));
202+
203+
// If two points are the same, return NULL
204+
let result = tester
205+
.invoke_scalar_scalar(start.clone(), same.clone())
206+
.unwrap();
207+
assert!(result.is_null());
208+
209+
// If either one of the points is empty, return NULL
210+
let result = tester.invoke_scalar_scalar(start.clone(), empty.clone());
211+
assert!(
212+
result.is_err()
213+
&& result
214+
.unwrap_err()
215+
.to_string()
216+
.contains("ST_Azimuth expects both arguments to be non-empty POINT geometries")
217+
);
218+
219+
// If either one of the points is NULL, return NULL
220+
let result = tester
221+
.invoke_scalar_scalar(ScalarValue::Null, north.clone())
222+
.unwrap();
223+
assert!(result.is_null());
224+
}
225+
}

0 commit comments

Comments
 (0)