Skip to content

Commit

Permalink
Implement foldIndexed
Browse files Browse the repository at this point in the history
+ Adding basic implementation of a `TriFunction`, which Java does not have
  • Loading branch information
tginsberg committed Jan 25, 2025
1 parent 7f4fcaa commit 6090e1a
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
+ Implement `everyNth()` to get every `n`<sup>th</sup> element from the stream
+ Implement `uniquelyOccurring()` to emit stream elements that occur a single time
+ Implement `takeUntil()` to take from a stream until a predicate is met, including the first element that matches the predicate
+ Implement `foldIndexed()` to perform a traditional fold along with the index of each element

### 0.7.0
+ Use greedy integrators where possible (Fixes #57)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ implementation("com.ginsberg:gatherers4j:0.8.0")
| `dropLast(n)` | Keep all but the last `n` elements of the stream |
| `everyNth(n)` | Limit the stream to every `n`<sup>th</sup> element |
| `filterWithIndex(predicate)` | Filter the stream with the given `predicate`, which takes an `element` and its `index` |
| `foldIndexed(fn)` | Perform a fold over the input stream where each element is included along with its index |
| `grouping()` | Group consecutive identical elements into lists |
| `groupingBy(fn)` | Group consecutive elements that are identical according to `fn` into lists |
| `interleave(iterable)` | Creates a stream of alternating objects from the input stream and the argument iterable |
Expand Down
70 changes: 70 additions & 0 deletions src/main/java/com/ginsberg/gatherers4j/FoldIndexedGatherer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright 2025 Todd Ginsberg
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.ginsberg.gatherers4j;

import org.jspecify.annotations.Nullable;

import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Gatherer;

import static com.ginsberg.gatherers4j.GathererUtils.mustNotBeNull;

public class FoldIndexedGatherer<INPUT extends @Nullable Object, OUTPUT extends @Nullable Object>
implements Gatherer<INPUT, FoldIndexedGatherer.State<OUTPUT>, OUTPUT> {

private final TriFunction<Long, OUTPUT, INPUT, OUTPUT> foldFunction;
private final Supplier<OUTPUT> initialValue;

FoldIndexedGatherer(
final Supplier<OUTPUT> initialValue,
final TriFunction<Long, OUTPUT, INPUT, OUTPUT> foldFunction
) {
mustNotBeNull(initialValue, "Initial value supplier must not be null");
mustNotBeNull(foldFunction, "Fold function must not be null");
this.foldFunction = foldFunction;
this.initialValue = initialValue;
}

@Override
public Supplier<State<OUTPUT>> initializer() {
return () -> new State<>(initialValue.get());
}

@Override
public Integrator<State<OUTPUT>, INPUT, OUTPUT> integrator() {
return Integrator.ofGreedy((state, element, downstream) -> {
state.carriedValue = foldFunction.apply(state.index++, state.carriedValue, element);
return !downstream.isRejecting();
});
}

@Override
public BiConsumer<State<OUTPUT>, Downstream<? super OUTPUT>> finisher() {
return (outputState, downstream) -> downstream.push(outputState.carriedValue);
}

public static class State<OUTPUT> {
@Nullable
OUTPUT carriedValue;
long index;

private State(@Nullable final OUTPUT initialValue) {
carriedValue = initialValue;
}
}
}
47 changes: 31 additions & 16 deletions src/main/java/com/ginsberg/gatherers4j/Gatherers4j.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.function.BiPredicate;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.random.RandomGenerator;
import java.util.stream.Stream;

Expand Down Expand Up @@ -89,7 +90,7 @@ public abstract class Gatherers4j {

/// Keep every nth element of the stream.
///
/// @param count The number of the elements to keep, must be at least 2
/// @param count The number of the elements to keep, must be at least 2
/// @param <INPUT> Type of elements in both the input and output streams
/// @return A non-null `EveryNthGatherer`
public static <INPUT extends @Nullable Object> EveryNthGatherer<INPUT> everyNth(final int count) {
Expand All @@ -100,7 +101,7 @@ public abstract class Gatherers4j {
/// and its index.
///
/// @param predicate A non-null `BiPredicate<Long,INPUT>` where the `Long` is the zero-based index of the element
/// being filtered, and the `INPUT` is the element itself.
/// being filtered, and the `INPUT` is the element itself.
/// @param <INPUT> Type of elements in the input stream
/// @return A non-null `FilteringWithIndexGatherer`
public static <INPUT extends @Nullable Object> FilteringWithIndexGatherer<INPUT> filterWithIndex(
Expand All @@ -109,6 +110,20 @@ public abstract class Gatherers4j {
return new FilteringWithIndexGatherer<>(predicate);
}

/// Perform a fold over every element in the input stream along with its index
///
/// @param <INPUT> Type of elements in the input stream
/// @param <OUTPUT> Type elements are folded to (the carry value)
/// @param initialValue Initial value of the fold
/// @param foldFunction Function that performs the fold given an element, its index, and the carry value
/// @return A non-null FoldIndexedGatherer
public static <INPUT extends @Nullable Object, OUTPUT extends @Nullable Object> FoldIndexedGatherer<INPUT, OUTPUT> foldIndexed(
final Supplier<OUTPUT> initialValue,
final TriFunction<Long, OUTPUT, INPUT, OUTPUT> foldFunction
) {
return new FoldIndexedGatherer<>(initialValue, foldFunction);
}

/// Turn a `Stream<INPUT>` into a `Stream<List<INPUT>>` where consecutive
/// equal elements, where equality is measured by `Object.equals(Object)`.
///
Expand Down Expand Up @@ -180,7 +195,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {
///
/// @param windowSize The trailing number of elements to multiply, must be greater than 1.
/// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used
/// in the moving product calculation
/// in the moving product calculation
/// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction`
/// @return A non-null `BigDecimalMovingProductGatherer`
public static <INPUT extends @Nullable Object> BigDecimalMovingProductGatherer<INPUT> movingProductBy(
Expand All @@ -204,7 +219,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {
///
/// @param windowSize The trailing number of elements to multiply, must be greater than 1.
/// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used
/// in the moving sum calculation
/// in the moving sum calculation
/// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction`
/// @return A non-null `BigDecimalMovingSumGatherer`
public static <INPUT extends @Nullable Object> BigDecimalMovingSumGatherer<INPUT> movingSumBy(
Expand Down Expand Up @@ -286,7 +301,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {
/// objects mapped from a `Stream<BigDecimal>` via a `mappingFunction`.
///
/// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used
/// in the standard deviation calculation
/// in the standard deviation calculation
/// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction`
/// @return A non-null `BigDecimalStandardDeviationGatherer`
public static <INPUT extends @Nullable Object> BigDecimalStandardDeviationGatherer<INPUT> runningPopulationStandardDeviationBy(
Expand All @@ -309,7 +324,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {
/// from a `Stream<INPUT>` via a `mappingFunction`.
///
/// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used
/// in the product calculation
/// in the product calculation
/// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction`
/// @return A non-null `BigDecimalProductGatherer`
public static <INPUT extends @Nullable Object> BigDecimalProductGatherer<INPUT> runningProductBy(
Expand All @@ -332,7 +347,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {
/// from a `Stream<INPUT>` via a `mappingFunction`.
///
/// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used
/// in the standard deviation calculation
/// in the standard deviation calculation
/// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction`
/// @return A non-null `BigDecimalStandardDeviationGatherer`
public static <INPUT extends @Nullable Object> BigDecimalStandardDeviationGatherer<INPUT> runningSampleStandardDeviationBy(
Expand All @@ -355,7 +370,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {
/// from a `Stream<INPUT>` via a `mappingFunction`.
///
/// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used
/// in the running sum calculation
/// in the running sum calculation
/// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction`
/// @return A non-null `BigDecimalSumGatherer`
public static <INPUT extends @Nullable Object> BigDecimalSumGatherer<INPUT> runningSumBy(
Expand Down Expand Up @@ -398,7 +413,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {
/// the given function. This is useful when paired with the `withOriginal` function.
///
/// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used
/// in the running average calculation
/// in the running average calculation
/// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction`
/// @return A non-null `BigDecimalSimpleAverageGatherer`
public static <INPUT extends @Nullable Object> BigDecimalSimpleAverageGatherer<INPUT> simpleRunningAverageBy(
Expand All @@ -421,7 +436,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {
///
/// @param windowSize The number of elements to average, must be greater than 1.
/// @param mappingFunction A function to map `<INPUT>` objects to `BigDecimal`, the results of which will be used
/// in the moving average calculation
/// in the moving average calculation
/// @param <INPUT> Type of elements in the input stream, to be remapped to `BigDecimal` by the `mappingFunction`
/// @return A non-null `BigDecimalSimpleMovingAverageGatherer`
public static <INPUT extends @Nullable Object> BigDecimalSimpleMovingAverageGatherer<INPUT> simpleMovingAverageBy(
Expand All @@ -445,7 +460,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {
/// Ensure the input stream is greater than `size` elements long, and emit all elements if so.
/// If not, throw an `IllegalStateException`.
///
/// @param size The size the stream must be longer than
/// @param size The size the stream must be longer than
/// @param <INPUT> Type of elements in both the input and output streams
/// @return A non-null `SizeGatherer`
/// @throws IllegalStateException when the input stream is not exactly `size` elements long
Expand All @@ -456,7 +471,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {
/// Ensure the input stream is greater than or equal to `size` elements long, and emit all elements if so.
/// If not, throw an `IllegalStateException`.
///
/// @param size The minimum size of the stream
/// @param size The minimum size of the stream
/// @param <INPUT> Type of elements in both the input and output streams
/// @return A non-null `SizeGatherer`
/// @throws IllegalStateException when the input stream is not exactly `size` elements long
Expand All @@ -467,7 +482,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {
/// Ensure the input stream is less than `size` elements long, and emit all elements if so.
/// If not, throw an `IllegalStateException`.
///
/// @param size The size the stream must be shorter than
/// @param size The size the stream must be shorter than
/// @param <INPUT> Type of elements in both the input and output streams
/// @return A non-null `SizeGatherer`
/// @throws IllegalStateException when the input stream is not exactly `size` elements long
Expand All @@ -478,7 +493,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {
/// Ensure the input stream is less than or equal to `size` elements long, and emit all elements if so.
/// If not, throw an `IllegalStateException`.
///
/// @param size The maximum size the stream
/// @param size The maximum size the stream
/// @param <INPUT> Type of elements in both the input and output streams
/// @return A non-null `SizeGatherer`
/// @throws IllegalStateException when the input stream is not exactly `size` elements long
Expand All @@ -490,7 +505,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {
/// matches the `predicate`.
///
/// @param predicate A non-null predicate function
/// @param <INPUT> Type of elements in both the input and output streams
/// @param <INPUT> Type of elements in both the input and output streams
/// @return A non-null `TakeUntilGatherer`
public static <INPUT extends @Nullable Object> TakeUntilGatherer<INPUT> takeUntil(
final Predicate<INPUT> predicate
Expand All @@ -514,7 +529,7 @@ public static <INPUT> LastGatherer<INPUT> last(final int count) {

/// Emit only those elements that occur in the input stream a single time.
///
/// @param <INPUT> Type of elements in the input stream
/// @param <INPUT> Type of elements in the input stream
/// @return A non-null `UniquelyOccurringGatherer`
public static <INPUT extends @Nullable Object> UniquelyOccurringGatherer<INPUT> uniquelyOccurring() {
return new UniquelyOccurringGatherer<>();
Expand Down
35 changes: 35 additions & 0 deletions src/main/java/com/ginsberg/gatherers4j/TriFunction.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright 2025 Todd Ginsberg
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.ginsberg.gatherers4j;

import org.jspecify.annotations.Nullable;

@FunctionalInterface
public interface TriFunction<
A extends @Nullable Object,
B extends @Nullable Object,
C extends @Nullable Object,
R extends @Nullable Object> {

/// Applies this function to the given arguments
///
/// @param a the first function argument
/// @param b the second function argument
/// @param c the third function argument
/// @return the function result
R apply(A a, B b, C c);
}
Loading

0 comments on commit 6090e1a

Please sign in to comment.