Skip to content

Commit

Permalink
Merge pull request #138 from scijava/scijava-ops-engine/conversion-anys
Browse files Browse the repository at this point in the history
Improve Conversion with Anys & Other Types Goodness
  • Loading branch information
gselzer authored Mar 22, 2024
2 parents 5997158 + e04ea9f commit 8ac7025
Show file tree
Hide file tree
Showing 14 changed files with 203 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ private List<RichOp<?>> resolveOpDependencies(OpInfo info,
Map<TypeVariable<?>, Type> dependencyTypeVarAssigns = new HashMap<>();
for (var entry : typeVarAssigns.entrySet()) {
var value = entry.getValue();
if (!value.equals(Any.class) && !(value instanceof Any)) {
if (!Any.is(value)) {
dependencyTypeVarAssigns.put(entry.getKey(), entry.getValue());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,7 @@ private void captureTypeVarsFromCandidate(Type srcType, OpCandidate candidate,
var existing = map.get(key);
var replacement = assigns.get(key);
// Ignore bounds that are weaker than current bounds.
if (Types.isAssignable(existing, replacement) && !existing.equals(
Any.class) && !(existing instanceof Any))
{
if (Types.isAssignable(existing, replacement) && !Any.is(existing)) {
continue;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,14 @@ private static ConvertedOpInfo convert(OpEnvironment env, OpInfo info,
BaseOpHints.Conversion.FORBIDDEN, //
BaseOpHints.History.IGNORE //
);
final RichOp<Function<?, ?>> identity = identityOp(env);
final Map<TypeVariable<?>, Type> vars = new HashMap<>();
// Find input converters
Type[] fromArgs = request.getArgs();
List<Type> toArgs = inputTypesAgainst(info, Types.raw(reqType));
List<RichOp<Function<?, ?>>> preConverters = new ArrayList<>();
for (int i = 0; i < fromArgs.length; i++) {
var opt = findConverter(fromArgs[i], toArgs.get(i), vars, env, h);
preConverters.add(opt.orElse(identity));
preConverters.add(opt);
}

// Find output converter
Expand All @@ -137,7 +136,7 @@ private static ConvertedOpInfo convert(OpEnvironment env, OpInfo info,
return opt.get();
}
// Attempt 2: Computer with identity mutable output
opt = postprocessIdentity(info, request, preConverters, identity, env);
opt = postprocessIdentity(info, request, preConverters, env);
if (opt.isPresent()) {
return opt.get();
}
Expand Down Expand Up @@ -206,7 +205,9 @@ private static Optional<ConvertedOpInfo> postprocessFunction(OpInfo info,
info, //
request.getType(), //
preConverters, //
Arrays.asList(request.getArgs()), //
postConverter, //
request.getOutType(), //
null, //
env //
));
Expand All @@ -223,30 +224,33 @@ private static Optional<ConvertedOpInfo> postprocessFunction(OpInfo info,
* @param request the original {@link OpRequest}
* @param preConverters the {@link List} of {@link RichOp}s responsible for
* converting the inputs to the {@link ConvertedOpInfo}
* @param identity an Op that simply returns the input value
* @param env the {@link OpEnvironment} used to match Ops necessary to create
* the {@code ConvertedOpInfo}
* @return a {@link ConvertedOpInfo}, aligning {@code info} to {@code request}
* if that is possible
*/
private static Optional<ConvertedOpInfo> postprocessIdentity(OpInfo info,
OpRequest request, List<RichOp<Function<?, ?>>> preConverters,
RichOp<Function<?, ?>> identity, OpEnvironment env)
OpEnvironment env)
{
// This procedure only applies to Ops with mutable outputs
int ioIndex = mutableIndexOf(request.getType());
if (ioIndex == -1) {
return Optional.empty();
}
// And only applies when the mutable index was "converted" with identity
if (preConverters.get(ioIndex) == identity) {
if (Ops.info(preConverters.get(ioIndex)).names().contains(
"engine.identity"))
{
// In this case, we need neither a postprocessor nor a copier,
// because the mutable output was directly edited.
return Optional.of(new ConvertedOpInfo( //
info, //
request.getType(), //
preConverters, //
identity, //
Arrays.asList(request.getArgs()), //
null, //
request.getOutType(), //
null, //
env //
));
Expand Down Expand Up @@ -304,7 +308,9 @@ private static Optional<ConvertedOpInfo> postprocessConvertAndCopy(
info, //
request.getType(), //
preConverters, //
Arrays.asList(request.getArgs()), //
postConverter, //
request.getOutType(), //
copyOp, //
env //
));
Expand Down Expand Up @@ -365,7 +371,9 @@ private static Optional<ConvertedOpInfo> postprocessCopy(OpInfo info,
info, //
request.getType(), //
preConverters, //
Arrays.asList(request.getArgs()), //
postConverter, //
request.getOutType(), //
copyOp, //
env //
));
Expand All @@ -387,15 +395,9 @@ private static Optional<ConvertedOpInfo> postprocessCopy(OpInfo info,
* @param hints the {@link Hints} to use in matching
* @return a rich converter Op, if it is both necessary and can be found
*/
private static Optional<RichOp<Function<?, ?>>> findConverter(Type from,
Type to, Map<TypeVariable<?>, Type> vars, OpEnvironment env, Hints hints)
private static RichOp<Function<?, ?>> findConverter(Type from, Type to,
Map<TypeVariable<?>, Type> vars, OpEnvironment env, Hints hints)
{
// If the request argument can be assigned to the info parameter directly,
// we don't need to call a preconverter
if (Types.isAssignable(from, to)) {
return Optional.empty();
}
// If direct assignment fails, we need a preconverter
var source = Nil.of(from);
// If the op parameter type has type variables that have been mapped
// already, substitute those mappings in.
Expand All @@ -412,7 +414,7 @@ private static Optional<ConvertedOpInfo> postprocessCopy(OpInfo info,
// The resulting Op can give us further information about type variable
// mappings - let's find them
resolveTypes(from, preDest, rich, vars);
return Optional.of(Ops.rich(op));
return Ops.rich(op);
}

private static void resolveTypes(Type source, Type dest,
Expand Down Expand Up @@ -454,15 +456,6 @@ private static Nil<?> wildcardVacuousTypeVars(final Type t) {
return Nil.of(Types.mapVarToTypes(t, vars));
}

private static <T> RichOp<Function<?, ?>> identityOp(OpEnvironment env) {
Nil<T> t = new Nil<>() {};
var op = env.unary("engine.identity") //
.inType(t) //
.outType(t) //
.function();
return Ops.rich(op);
}

/**
* {@link Class}es of array types return "[]" when
* {@link Class#getSimpleName()} is called. Those characters are invalid in a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@
import java.lang.invoke.MethodHandles;
import java.lang.reflect.*;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand All @@ -56,6 +53,7 @@
import org.scijava.struct.Struct;
import org.scijava.struct.StructInstance;
import org.scijava.struct.Structs;
import org.scijava.types.Any;
import org.scijava.types.Nil;
import org.scijava.types.Types;
import org.scijava.types.inference.FunctionalInterfaces;
Expand Down Expand Up @@ -97,7 +95,9 @@ public class ConvertedOpInfo implements OpInfo {
private final OpInfo info;
private final OpEnvironment env;
final List<RichOp<Function<?, ?>>> preconverters;
final List<Type> inTypes;
final RichOp<Function<?, ?>> postconverter;
final Type outType;
final RichOp<Computers.Arity1<?, ?>> copyOp;
private final Type opType;
private final Struct struct;
Expand All @@ -113,8 +113,9 @@ public ConvertedOpInfo(OpInfo info,
info, //
generateOpType(info, preconverters, postconverter), //
preconverters, //
Arrays.asList(inTypes(info.inputTypes(), preconverters)), //
postconverter, //
copyOp, //
outType(info.outputType(), postconverter), copyOp, //
env //
);
}
Expand All @@ -133,22 +134,21 @@ public ConvertedOpInfo( //
OpInfo info, //
Type opType, //
List<RichOp<Function<?, ?>>> preconverters, //
List<Type> reqInputs, //
RichOp<Function<?, ?>> postconverter, //
Type reqOutput, //
final RichOp<Computers.Arity1<?, ?>> copyOp, //
OpEnvironment env //
) {
this.info = info;
this.opType = opType;
this.opType = mapAnys(opType, info);
this.preconverters = preconverters;
this.inTypes = reqInputs;
this.postconverter = postconverter;
this.outType = reqOutput;
this.copyOp = copyOp;
this.env = env;
this.struct = generateStruct( //
info, //
opType, //
preconverters, //
postconverter //
);
this.struct = generateStruct();
this.hints = info.declaredHints().plus( //
BaseOpHints.Conversion.FORBIDDEN, //
"converted" //
Expand All @@ -158,18 +158,9 @@ public ConvertedOpInfo( //
/**
* Helper method to generate the new {@link Struct}
*
* @param info the original {@link OpInfo}
* @param opType the new Op's {@link Type}
* @param preconverters the {@link Function}s responsible for converting Op
* parameters
* @param postconverter the {@link Function} responsible for converting the
* Op's return
* @return the {@link Struct} of the converted Op
*/
private static Struct generateStruct(OpInfo info, Type opType,
List<RichOp<Function<?, ?>>> preconverters,
RichOp<Function<?, ?>> postconverter)
{
private Struct generateStruct() {
List<Type> originalIns = new ArrayList<>(info.inputTypes());
List<Member<?>> ioMembers = new ArrayList<>(info.struct().members());
// If the mutable index differs between the declared Op type and the
Expand All @@ -180,17 +171,13 @@ private static Struct generateStruct(OpInfo info, Type opType,
originalIns.add(toIOIdx, originalIns.remove(fromIOIdx));
ioMembers.add(toIOIdx, ioMembers.remove(fromIOIdx));
}
// Determine the new input and output types of the Op
Type[] inTypes = inTypes(originalIns, preconverters);
Type outType = outType(info.outputType(), postconverter);

// Create the functional member types of the new OpInfo
int index = 0;
List<FunctionalMethodType> fmts = new ArrayList<>();
for (Member<?> m : ioMembers) {
if (m.getIOType() == ItemIO.NONE) continue;
Type newType = m.isInput() ? inTypes[index++] : m.isOutput() ? outType
: null;
Type newType = m.isInput() ? this.inTypes.get(index++) : m.isOutput()
? outType : null;
fmts.add(new FunctionalMethodType(newType, m.getIOType()));
}
// generate new struct
Expand Down Expand Up @@ -266,7 +253,7 @@ public StructInstance<?> createOpInstance(List<?> dependencies) {
}
return rich.asOpType();
}).collect(Collectors.toList()), //
this.postconverter.asOpType(), //
this.postconverter == null ? IGNORED : this.postconverter.asOpType(), //
this.copyOp == null ? null : this.copyOp.asOpType() //
);
return struct().createInstance(convertedOp);
Expand Down Expand Up @@ -380,6 +367,9 @@ private static Type outType( //
Type originalOutput, //
RichOp<Function<?, ?>> postconverter //
) {
if (postconverter == null) {
return originalOutput;
}
// Start by looking at the output type T of the postconverter
var type = postconverter.instance().getType();
var pType = (ParameterizedType) type;
Expand Down Expand Up @@ -410,6 +400,31 @@ public double priority() {
return priority;
}

/**
* Helper method that removes any {@link Any}s in the Op type, replacing them
* with the types declared by the original {@link OpInfo}
*
* @param opType the Op {@link Type}, which may have some {@link Any}s
* @param info the original {@link OpInfo} being wrapped
* @return {@code opType}, without any {@link Any}s.
*/
private static Type mapAnys(Type opType, OpInfo info) {
var raw = Types.parameterizeRaw(Types.raw(opType));
Map<TypeVariable<?>, Type> reqMap = new HashMap<>();
GenericAssignability.inferTypeVariables(new Type[] { raw }, new Type[] {
opType }, reqMap);
Map<TypeVariable<?>, Type> infoMap = new HashMap<>();
GenericAssignability.inferTypeVariables(new Type[] { raw }, new Type[] {
info.opType() }, infoMap);
for (var key : reqMap.keySet()) {
var val = reqMap.get(key);
if (Any.is(val)) {
reqMap.put(key, infoMap.get(key));
}
}
return Types.mapVarToTypes(raw, reqMap);
}

/**
* We define the priority of any {@link ConvertedOpInfo} as the sum of the
* following:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
package org.scijava.ops.engine.matcher.convert;

import org.scijava.function.Inplaces;
import org.scijava.ops.engine.BaseOpHints;
import org.scijava.ops.engine.BaseOpHints.Conversion;
import org.scijava.ops.spi.OpCollection;
import org.scijava.ops.spi.OpField;
Expand All @@ -44,21 +45,23 @@
* @author Gabriel Selzer
* @param <T>
*/
public class IdentityCollection<T> implements OpCollection {
public class IdentityCollection<T, U extends T> implements OpCollection {

/**
* @input t the object to be converted
* @output the converted object (since we are doing an identity conversion,
* this is just a reference to the input object).
*/
@OpHints(hints = { Conversion.FORBIDDEN })
@OpField(names = "engine.convert, engine.identity", priority = Priority.LAST)
public final Function<T, T> identity = (t) -> t;
@OpHints(hints = { Conversion.FORBIDDEN,
BaseOpHints.DependencyMatching.FORBIDDEN })
@OpField(names = "engine.convert, engine.identity", priority = Priority.FIRST)
public final Function<U, T> identity = (t) -> t;

/**
* @mutable t the object to be "mutated"
*/
@OpHints(hints = { Conversion.FORBIDDEN })
@OpField(names = "engine.identity", priority = Priority.LAST)
@OpField(names = "engine.identity", priority = Priority.FIRST)
public final Inplaces.Arity1<T> inplace = (t) -> {};

}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ static int checkGenericOutputsAssignability(Type[] froms, Type[] tos,
Type from = froms[i];
Type to = tos[i];

if (to instanceof Any || to.equals(Any.class)) continue;
if (Any.is(to)) continue;

if (from instanceof TypeVariable) {
TypeVarInfo typeVarInfo = typeBounds.get(from);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import java.lang.reflect.Type;
import java.util.function.BiFunction;
import java.util.function.Function;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
Expand Down Expand Up @@ -156,4 +157,16 @@ public void testConvertedOpInfo() {
Assertions.assertEquals(expImplName, actImplName);
}

@OpMethod(names = "test.anyConversion", type = Function.class)
public static Integer foo(Integer i1) {
return i1 * i1;
}

@Test
public void testConvertAnys() {
Double in = 2.0;
var out = ops.unary("test.anyConversion").input(in).apply();
Assertions.assertInstanceOf(Integer.class, out);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ private LiftNeighborhoodComputersToImg() {
/**
* @implNote op names='engine.adapt', priority='100.', type=Function
*/
public static <T, U, F extends RandomAccessibleInterval<T>>
Computers.Arity3<F, Shape, OutOfBoundsFactory<T, F>, RandomAccessibleInterval<U>>
public static <T, U, F extends RandomAccessibleInterval<T>, G extends F>
Computers.Arity3<G, Shape, OutOfBoundsFactory<T, F>, RandomAccessibleInterval<U>>
adapt1UsingShapeAndOOBF(Computers.Arity1<Neighborhood<T>, U> op)
{
return (in, shape, oobf, out) -> {
Expand Down
Loading

0 comments on commit 8ac7025

Please sign in to comment.