Skip to content

Commit

Permalink
Unit-test
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Kwok <[email protected]>
  • Loading branch information
andy-k-improving committed Feb 3, 2025
1 parent 3836f31 commit 0fcd688
Show file tree
Hide file tree
Showing 2 changed files with 309 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,34 +234,46 @@ private ExprValue computeWma(ArrayList<ExprValue> receivedValues) {
} else if (type == ExprCoreType.DATE) {
return ExprValueUtils.dateValue(
ExprValueUtils.timestampValue(Instant.ofEpochMilli(
calculateWmaInLong(receivedValues))).dateValue());
calculateWmaInTs(receivedValues))).dateValue());

} else if ( type == ExprCoreType.TIME) {
return ExprValueUtils.timeValue(
LocalTime.MIN.plus(calculateWmaInLong(receivedValues), MILLIS));
LocalTime.MIN.plus(calculateWmaInTime(receivedValues), MILLIS));

} else if (type == ExprCoreType.TIMESTAMP) {
return ExprValueUtils.timestampValue(Instant.ofEpochMilli(
calculateWmaInLong(receivedValues)));
calculateWmaInTs(receivedValues)));
}
return null;
}

private double calculateWmaInDouble (ArrayList<ExprValue> receivedValues) {
double sum = 0D;
int totalWeight = (receivedValues.size()*(receivedValues.size()+1)) / 2;
for (int i=0 ; i<receivedValues.size() ; i++) {
sum += receivedValues.get(i).doubleValue() / (i+1);
sum += receivedValues.get(i).doubleValue() * ((i + 1D) / totalWeight);
}
return sum / receivedValues.size();
return sum;
}

private long calculateWmaInLong (ArrayList<ExprValue> receivedValues) {
private long calculateWmaInTs (ArrayList<ExprValue> receivedValues) {
long sum = 0L;
int totalWeight = (receivedValues.size()*(receivedValues.size()+1)) / 2;
for (int i=0 ; i<receivedValues.size() ; i++) {
sum += receivedValues.get(i).longValue() / (i+1);
sum += (long) (receivedValues.get(i).timestampValue().toEpochMilli() * ((i + 1D) / totalWeight));
}
return sum / receivedValues.size();
return sum;
}

private long calculateWmaInTime (ArrayList<ExprValue> receivedValues) {
long sum = 0L;
int totalWeight = (receivedValues.size()*(receivedValues.size()+1)) / 2;
for (int i=0 ; i<receivedValues.size() ; i++) {
sum += (long) (MILLIS.between(LocalTime.MIN, receivedValues.get(i).timeValue()) * ((i + 1D) / totalWeight));
}
return sum;
}

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA;
import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.WMA;

import com.google.common.collect.ImmutableMap;
import java.time.Instant;
Expand Down Expand Up @@ -395,4 +396,292 @@ public void calculates_simple_moving_average_timestamp() {
plan.next());
assertFalse(plan.hasNext());
}

@Test
public void calculates_weighted_moving_average_one_field_one_sample() {
when(inputPlan.hasNext()).thenReturn(true, false);
when(inputPlan.next())
.thenReturn(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", WMA),
ExprCoreType.DOUBLE)));

plan.open();
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)),
plan.next());
}

@Test
public void calculates_weighted_moving_average_one_field_two_samples() {
when(inputPlan.hasNext()).thenReturn(true, true, false);
when(inputPlan.next())
.thenReturn(
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA),
ExprCoreType.DOUBLE)));

plan.open();
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next());
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)),
plan.next());
assertFalse(plan.hasNext());
}

@Test
public void calculates_weighted_moving_average_one_field_two_samples_three_rows() {
when(inputPlan.hasNext()).thenReturn(true, true, true, false);
when(inputPlan.next())
.thenReturn(
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA),
ExprCoreType.DOUBLE)));

plan.open();
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next());
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)),
plan.next());
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 199.99999999999997)),
plan.next());
assertFalse(plan.hasNext());
}

@Test
public void calculates_weighted_moving_average_multiple_computations() {
when(inputPlan.hasNext()).thenReturn(true, true, true, false);
when(inputPlan.next())
.thenReturn(
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20)));

var plan =
new TrendlineOperator(
inputPlan,
Arrays.asList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA),
ExprCoreType.DOUBLE),
Pair.of(
AstDSL.computation(2, AstDSL.field("time"), "time_alias", WMA),
ExprCoreType.DOUBLE)));

plan.open();
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next());
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of(
"distance", 200, "time", 20, "distance_alias", 166.66666666666663, "time_alias", 16.666666666666664)),
plan.next());
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of(
"distance", 200, "time", 20, "distance_alias", 199.99999999999997, "time_alias", 20.0)),
plan.next());
assertFalse(plan.hasNext());
}


@Test
public void calculates_weighted_moving_average_one_field_two_samples_three_rows_null_value() {
when(inputPlan.hasNext()).thenReturn(true, true, true, false);
when(inputPlan.next())
.thenReturn(
ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)),
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 300, "time", 10)));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA),
ExprCoreType.DOUBLE)));

plan.open();
assertTrue(plan.hasNext());
assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), plan.next());
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), plan.next());
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 266.66666666666663)),
plan.next());
assertFalse(plan.hasNext());
}

@Test
public void calculates_weighted_moving_average_date() {
when(inputPlan.hasNext()).thenReturn(true, true, true, false);
when(inputPlan.next())
.thenReturn(
ExprValueUtils.tupleValue(
ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))),
ExprValueUtils.tupleValue(
ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)))),
ExprValueUtils.tupleValue(
ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)))));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("date"), "date_alias", WMA),
ExprCoreType.DATE)));

plan.open();
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))),
plan.next());
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of(
"date",
ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)),
"date_alias",
ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(4)))),
plan.next());
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of(
"date",
ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)),
"date_alias",
ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(10)))),
plan.next());
assertFalse(plan.hasNext());
}

@Test
public void calculates_weighted_moving_average_time() {
when(inputPlan.hasNext()).thenReturn(true, true, true, false);
when(inputPlan.next())
.thenReturn(
ExprValueUtils.tupleValue(
ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))),
ExprValueUtils.tupleValue(
ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(6)))),
ExprValueUtils.tupleValue(
ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12)))));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("time"), "time_alias", WMA),
ExprCoreType.TIME)));

plan.open();
assertTrue(plan.hasNext());
assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", LocalTime.MIN)), plan.next());
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of(
"time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(4))),
plan.next());
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of(
"time", LocalTime.MIN.plusHours(12), "time_alias", LocalTime.MIN.plusHours(10))),
plan.next());
assertFalse(plan.hasNext());
}

@Test
public void calculates_weighted_moving_average_timestamp() {
when(inputPlan.hasNext()).thenReturn(true, true, true, false);
when(inputPlan.next())
.thenReturn(
ExprValueUtils.tupleValue(
ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))),
ExprValueUtils.tupleValue(
ImmutableMap.of(
"timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1000)))),
ExprValueUtils.tupleValue(
ImmutableMap.of(
"timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500)))));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("timestamp"), "timestamp_alias", WMA),
ExprCoreType.TIMESTAMP)));

plan.open();
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), plan.next());
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of(
"timestamp",
Instant.EPOCH.plusMillis(1000),
"timestamp_alias",
Instant.EPOCH.plusMillis(666))),
plan.next());
assertTrue(plan.hasNext());
assertEquals(
ExprValueUtils.tupleValue(
ImmutableMap.of(
"timestamp",
Instant.EPOCH.plusMillis(1500),
"timestamp_alias",
Instant.EPOCH.plusMillis(1333))),
plan.next());
assertFalse(plan.hasNext());
}

}

0 comments on commit 0fcd688

Please sign in to comment.