Skip to content

Commit 8058657

Browse files
authored
Documentation updates for 4.2 (#205)
* Updating docs for 4.2. * Migrating StripProvenance over to use the ProvenanceSerialization interface. * Updating the configuration tutorial. * Initial readme updates for 4.2 * Adding the start of the onnx export tutorial. * Adding first draft of release notes for 4.2. * Fixing the JEP 290 filter. * Changing Model.castModel so it's not static. * Adding a reproducibility tutorial and updating the irises and onnx export tutorials. * Updating gitignore file. * Updating reproducibility tutorial with a bigger diff example. * Updating the v4.2 release notes. * Adding TF-Java PR number. * Updating docs for the HDBSCAN implementation. * Updating 4.2 release notes. * Adding Tribuo v4.1.1 release notes. * Finishing ONNX export tutorial. * Updating ONNX export tutorial. * Docs updates after rebase. * CastModel fix in OCIModelCLI. * Javadoc updates after rebase. * Fixing the circular PR reference in the 4.2 release notes. * Fixing an accidental deletion in the external models tutorial.
1 parent de2e5d4 commit 8058657

34 files changed

+3313
-129
lines changed

.gitignore

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,32 @@ bin/
1616
.*.swp
1717

1818
# Other files
19-
*.jar
20-
*.class
2119
*.er
2220
*.log
2321
*.bck
2422
*.so
23+
*.patch
24+
25+
# Binaries
26+
*.jar
27+
*.class
28+
29+
# Archives
30+
*.gz
31+
*.zip
2532

2633
# Serialised models
2734
*.ser
2835

2936
# Temporary stuff
3037
junk/*
3138
.DS_Store
39+
.ipynb_checkpoints
40+
41+
# Profiling files
42+
*.jfr
43+
*.iprof
44+
*.jfc
45+
46+
# Tutorial files
47+
tutorials/*.svm

Common/LibSVM/src/main/java/org/tribuo/common/libsvm/LibSVMTrainer.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ protected LibSVMTrainer() {}
131131
/**
132132
* Constructs a LibSVMTrainer from the parameters.
133133
* @param parameters The SVM parameters.
134+
* @param seed The RNG seed.
134135
*/
135136
protected LibSVMTrainer(SVMParameters<T> parameters, long seed) {
136137
this.parameters = parameters.getParameters();

Core/src/main/java/org/tribuo/Model.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,20 +300,21 @@ public String toString() {
300300

301301
/**
302302
* Casts the model to the specified output type, assuming it is valid.
303-
* <p>
304303
* If it's not valid, throws {@link ClassCastException}.
305-
* @param inputModel The model to cast.
304+
* <p>
305+
* This method is intended for use on a deserialized model to restore it's
306+
* generic type in a safe way.
306307
* @param outputType The output type to cast to.
307-
* @param <T> The output type.
308+
* @param <U> The output type.
308309
* @return The model cast to the correct value.
309310
*/
310-
public static <T extends Output<T>> Model<T> castModel(Model<?> inputModel, Class<T> outputType) {
311-
if (inputModel.validate(outputType)) {
311+
public <U extends Output<U>> Model<U> castModel(Class<U> outputType) {
312+
if (validate(outputType)) {
312313
@SuppressWarnings("unchecked") // guarded by validate
313-
Model<T> castedModel = (Model<T>) inputModel;
314+
Model<U> castedModel = (Model<U>) this;
314315
return castedModel;
315316
} else {
316-
throw new ClassCastException("Attempted to cast model to " + outputType.getName() + " which is not valid for model " + inputModel.toString());
317+
throw new ClassCastException("Attempted to cast model to " + outputType.getName() + " which is not valid for model " + this.toString());
317318
}
318319
}
319320

Core/src/main/java/org/tribuo/ensemble/BaggingTrainer.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
* "The Elements of Statistical Learning"
5050
* Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
5151
* </pre>
52+
* @param <T> The prediction type.
5253
*/
5354
public class BaggingTrainer<T extends Output<T>> implements Trainer<T> {
5455

@@ -177,6 +178,7 @@ public EnsembleModel<T> train(Dataset<T> examples, Map<String, Provenance> runPr
177178
* @param labelIDs The output domain.
178179
* @param randInt A random int from an rng instance
179180
* @param runProvenance Provenance for this instance.
181+
* @param invocationCount The invocation count for the inner trainer.
180182
* @return The trained ensemble member.
181183
*/
182184
protected Model<T> trainSingleModel(Dataset<T> examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, int randInt, Map<String,Provenance> runProvenance, int invocationCount) {

Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,10 @@ default ONNXNode exportCombiner(ONNXNode input) {
7979
* will be required to provide ONNX support.
8080
* @param input the node to be ensembled according to this implementation.
8181
* @param weight The node of weights for ensembling.
82+
* @param <U> The type of the weights input reference.
8283
* @return The leaf node of the graph of operations added to ensemble input.
8384
*/
84-
default <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight) {
85+
default <U extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, U weight) {
8586
Logger.getLogger(this.getClass().getName()).severe("Tried to export an ensemble combiner to ONNX format, but this is not implemented.");
8687
throw new IllegalStateException("This ensemble cannot be exported as the combiner '" + this.getClass() + "' uses the default implementation of EnsembleCombiner.exportCombiner.");
8788
}

Data/src/main/java/org/tribuo/data/sql/SQLDBConfig.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ private SQLDBConfig() {}
7272
/**
7373
* Constructs a SQL database configuration.
7474
* <p>
75-
* Note it is recommended that wallet based connections are used rather than this constructor using {@link SQLDBConfig(String,Map)}.
75+
* Note it is recommended that wallet based connections are used rather than this constructor using {@link #SQLDBConfig(String,Map)}.
7676
* @param connectionString The connection string.
7777
* @param username The username.
7878
* @param password The password.
@@ -87,7 +87,7 @@ public SQLDBConfig(String connectionString, String username, String password, Ma
8787
/**
8888
* Constructs a SQL database configuration.
8989
* <p>
90-
* Note it is recommended that wallet based connections are used rather than this constructor using {@link SQLDBConfig(String,Map)}.
90+
* Note it is recommended that wallet based connections are used rather than this constructor using {@link #SQLDBConfig(String,Map)}.
9191
* @param host The host to connect to.
9292
* @param port The port to connect on.
9393
* @param db The db name.

Interop/OCI/src/main/java/org/tribuo/interop/oci/OCIModel.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ public static ConfigFileAuthenticationDetailsProvider makeAuthProvider(Path conf
309309
* @param configFile The OCI configuration file, if null use the default file.
310310
* @param endpointURL The endpoint URL.
311311
* @param outputConverter The converter for the specified output type.
312+
* @param <T> The output type.
312313
* @return An OCIModel ready to score new inputs.
313314
*/
314315
public static <T extends Output<T>> OCIModel<T> createOCIModel(OutputFactory<T> factory,
@@ -332,6 +333,7 @@ public static <T extends Output<T>> OCIModel<T> createOCIModel(OutputFactory<T>
332333
* @param profileName The profile name in the OCI configuration file, if null uses the default profile.
333334
* @param endpointURL The endpoint URL.
334335
* @param outputConverter The converter for the specified output type.
336+
* @param <T> The output type.
335337
* @return An OCIModel ready to score new inputs.
336338
*/
337339
public static <T extends Output<T>> OCIModel<T> createOCIModel(OutputFactory<T> factory,

Interop/OCI/src/main/java/org/tribuo/interop/oci/OCIModelCLI.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ private static void createModelAndDeploy(OCIModelOptions options) throws IOExcep
6464
// Load the Tribuo model
6565
Model<Label> model;
6666
try (ObjectInputStream ois = new ObjectInputStream(Files.newInputStream(options.modelPath))) {
67-
model = Model.castModel((Model<?>) ois.readObject(),Label.class);
67+
model = ((Model<?>)ois.readObject()).castModel(Label.class);
6868
}
6969
if (!(model instanceof ONNXExportable)) {
7070
throw new IllegalArgumentException("Model not ONNXExportable, received " + model.toString());

Interop/OCI/src/main/java/org/tribuo/interop/oci/OCIUtil.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ public static <T extends Output<T>, U extends Model<T> & ONNXExportable> String
373373
/**
374374
* Creates the OCI DS model artifact zip file.
375375
* @param onnxFile The ONNX file to create.
376+
* @param config The model artifact configuration.
376377
* @return The path referring to the zip file.
377378
* @throws IOException If the file could not be created or the ONNX file could not be read.
378379
*/

Json/src/main/java/org/tribuo/json/StripProvenance.java

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,15 @@
1616

1717
package org.tribuo.json;
1818

19-
import com.fasterxml.jackson.databind.ObjectMapper;
20-
import com.fasterxml.jackson.databind.SerializationFeature;
2119
import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
2220
import com.oracle.labs.mlrg.olcut.config.Option;
2321
import com.oracle.labs.mlrg.olcut.config.Options;
2422
import com.oracle.labs.mlrg.olcut.config.UsageException;
25-
import com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceModule;
23+
import com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceSerialization;
2624
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
2725
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
2826
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
2927
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
30-
import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance;
3128
import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
3229
import com.oracle.labs.mlrg.olcut.util.IOUtil;
3330
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
@@ -315,11 +312,8 @@ public static <T extends Output<T>> void main(String[] args) {
315312
ModelProvenance oldProvenance = input.getProvenance();
316313

317314
logger.info("Marshalling provenance and creating JSON.");
318-
List<ObjectMarshalledProvenance> list = ProvenanceUtil.marshalProvenance(oldProvenance);
319-
ObjectMapper mapper = new ObjectMapper();
320-
mapper.registerModule(new JsonProvenanceModule());
321-
mapper.enable(SerializationFeature.INDENT_OUTPUT);
322-
String jsonResult = mapper.writeValueAsString(list);
315+
JsonProvenanceSerialization jsonProvenanceSerialization = new JsonProvenanceSerialization(true);
316+
String jsonResult = jsonProvenanceSerialization.marshalAndSerialize(oldProvenance);
323317

324318
logger.info("Hashing JSON file");
325319
MessageDigest digest = o.hashType.getDigest();
@@ -340,8 +334,7 @@ public static <T extends Output<T>> void main(String[] args) {
340334

341335
ModelProvenance newProvenance = tuple.provenance;
342336
logger.info("Marshalling provenance and creating JSON.");
343-
List<ObjectMarshalledProvenance> newList = ProvenanceUtil.marshalProvenance(newProvenance);
344-
String newJsonResult = mapper.writeValueAsString(newList);
337+
String newJsonResult = jsonProvenanceSerialization.marshalAndSerialize(newProvenance);
345338

346339
logger.info("Old provenance = \n" + jsonResult);
347340
logger.info("New provenance = \n" + newJsonResult);

0 commit comments

Comments
 (0)