diff --git a/src/main/java/fiji/plugin/trackmate/cellpose/AbstractCellposeSettings.java b/src/main/java/fiji/plugin/trackmate/cellpose/AbstractCellposeSettings.java index 883cdd3d..9ae15963 100644 --- a/src/main/java/fiji/plugin/trackmate/cellpose/AbstractCellposeSettings.java +++ b/src/main/java/fiji/plugin/trackmate/cellpose/AbstractCellposeSettings.java @@ -114,8 +114,6 @@ public List< String > toCmdLine( final String imagesDir, final boolean is3D, fin * Cellpose command line arguments. */ - cmd.add( "--verbose" ); - // Target dir. cmd.add( "--dir" ); cmd.add( imagesDir ); diff --git a/src/main/java/fiji/plugin/trackmate/cellpose/CellposeDetector.java b/src/main/java/fiji/plugin/trackmate/cellpose/CellposeDetector.java index 5ee3da96..6a658934 100644 --- a/src/main/java/fiji/plugin/trackmate/cellpose/CellposeDetector.java +++ b/src/main/java/fiji/plugin/trackmate/cellpose/CellposeDetector.java @@ -38,7 +38,7 @@ import fiji.plugin.trackmate.detection.DetectionUtils; import fiji.plugin.trackmate.detection.LabelImageDetectorFactory; import fiji.plugin.trackmate.detection.SpotGlobalDetector; -import fiji.plugin.trackmate.omnipose.advanced.AdvancedOmniposeSettings; +import fiji.plugin.trackmate.omnipose.OmniposeSettings; import fiji.plugin.trackmate.util.TMUtils; import ij.IJ; import ij.ImagePlus; @@ -47,6 +47,9 @@ import ij.plugin.Duplicator; import ij.process.ImageConverter; import ij.process.StackConverter; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.util.stream.Collectors; import net.imagej.ImgPlus; import net.imagej.axis.Axes; import net.imglib2.Interval; @@ -202,7 +205,7 @@ public boolean process() { results = executors.invokeAll( processes ); for ( final Future< String > future : results ) - resultDirs.add( future.get() ); + resultDirs.add( future.get() ); } catch ( final InterruptedException | ExecutionException e ) { @@ -558,8 +561,8 @@ public String call() throws Exception for ( final ImagePlus imp : imps ) { final String name = imp.getShortTitle() + ".tif"; - // If we are running an advanced omnipose detector, just save the segmentation channel as tmp image - if (cellposeSettings instanceof AdvancedOmniposeSettings) { + // If we are running an omnipose detector, just save the segmentation channel as tmp image + if (cellposeSettings instanceof OmniposeSettings) { ImagePlus chanImp = new Duplicator().run(imp, cellposeSettings.chan, cellposeSettings.chan, 0, 0, 0, 0); IJ.saveAsTiff( chanImp, Paths.get( tmpDir.toString(), name ).toString() ); } else { @@ -572,19 +575,75 @@ public String call() throws Exception */ try - { - final List< String > cmd = cellposeSettings.toCmdLine( tmpDir.toString(), is3D, anisotropy ); - logger.setStatus( "Running " + cellposeSettings.getExecutableName() ); - logger.log( "Running " + cellposeSettings.getExecutableName() + " with args:\n" ); - logger.log( String.join( " ", cmd ) ); - logger.log( "\n" ); - final ProcessBuilder pb = new ProcessBuilder( cmd ); - pb.redirectOutput( ProcessBuilder.Redirect.INHERIT ); - pb.redirectError( ProcessBuilder.Redirect.INHERIT ); - - process = pb.start(); - process.waitFor(); - } + { + if (cellposeSettings instanceof OmniposeSettings) + { + final List< String > cmdOmni = new ArrayList<>( cellposeSettings.toCmdLine( tmpDir.toString(), is3D, anisotropy ) ); + + int nClasses = 2; // 2 by default in custom models + cmdOmni.add( "--nclasses" ); + cmdOmni.add( String.valueOf( nClasses ) ); + + logger.setStatus( "Running " + cellposeSettings.getExecutableName() ); + logger.log( "Running " + cellposeSettings.getExecutableName() + " with args:\n" ); + logger.log( String.join( " ", cmdOmni ) ); + logger.log( "\n" ); + + final ProcessBuilder pbOmni = new ProcessBuilder( cmdOmni ); +// pbOmni.redirectOutput( ProcessBuilder.Redirect.INHERIT ); +// pbOmni.redirectError( ProcessBuilder.Redirect.INHERIT ); + pbOmni.redirectErrorStream(true); + process = pbOmni.start(); + + BufferedReader reader = new BufferedReader( new InputStreamReader(process.getInputStream())); +// reader.lines().forEach(line -> System.out.println(line)); + String pythonErrorOutput = reader.lines().collect(Collectors.joining()); + if (pythonErrorOutput.contains( "size mismatch for output.2.bias:" ) ) + { + if( pythonErrorOutput.contains( "copying a param with shape torch.Size([4]) from checkpoint" ) ) + { + nClasses = 3; + logger.log( "Regarding the model loaded, --nclasses argument should be set to " + String.valueOf( nClasses ) + "\n" ); + } + else + { + nClasses = 4; + logger.log( "Regarding the model loaded, --nclasses argument should be set to " + String.valueOf( nClasses ) + "\n" ); + } + + cmdOmni.remove(cmdOmni.size() - 1); + cmdOmni.add(String.valueOf( nClasses )); + + logger.log( "Running " + cellposeSettings.getExecutableName() + " with args:\n" ); + logger.log( String.join( " ", cmdOmni ) ); + logger.log( "\n" ); + final ProcessBuilder updatedPbOmni = new ProcessBuilder( cmdOmni ); + updatedPbOmni.redirectOutput( ProcessBuilder.Redirect.INHERIT ); + updatedPbOmni.redirectError( ProcessBuilder.Redirect.INHERIT ); + + process = updatedPbOmni.start(); + process.waitFor(); + } + else if (pythonErrorOutput.contains("pretrained model has incorrect path") ) + { + logger.log( "Pretrained model has incorrect path \n" ); + } + } + else + { + final List< String > cmd = cellposeSettings.toCmdLine( tmpDir.toString(), is3D, anisotropy ); + logger.setStatus( "Running " + cellposeSettings.getExecutableName() ); + logger.log( "Running " + cellposeSettings.getExecutableName() + " with args:\n" ); + logger.log( String.join( " ", cmd ) ); + logger.log( "\n" ); + final ProcessBuilder pb = new ProcessBuilder( cmd ); + pb.redirectOutput( ProcessBuilder.Redirect.INHERIT ); + pb.redirectError( ProcessBuilder.Redirect.INHERIT ); + + process = pb.start(); + process.waitFor(); + } + } catch ( final IOException e ) { final String msg = e.getMessage(); @@ -593,7 +652,7 @@ public String call() throws Exception errorMessage = baseErrorMessage + "Problem running " + cellposeSettings.getExecutableName() + ":\n" + "The executable does not have the file permission to run.\n" + "Please see https://github.com/MouseLand/cellpose#run-cellpose-without-local-python-installation for more information.\n"; - } + } else { errorMessage = baseErrorMessage + "Problem running " + cellposeSettings.getExecutableName() + ":\n" + e.getMessage(); diff --git a/src/main/java/fiji/plugin/trackmate/omnipose/OmniposeSettings.java b/src/main/java/fiji/plugin/trackmate/omnipose/OmniposeSettings.java index 98a58b05..38316b9c 100644 --- a/src/main/java/fiji/plugin/trackmate/omnipose/OmniposeSettings.java +++ b/src/main/java/fiji/plugin/trackmate/omnipose/OmniposeSettings.java @@ -28,7 +28,7 @@ import fiji.plugin.trackmate.cellpose.AbstractCellposeSettings; public class OmniposeSettings extends AbstractCellposeSettings -{ +{ public OmniposeSettings( final String omniposePythonPath, @@ -48,7 +48,11 @@ public List< String > toCmdLine( final String imagesDir, final boolean is3D, fin { final List< String > cmd = new ArrayList<>( super.toCmdLine( imagesDir, is3D, anisotropy ) ); // omnipose executable adds it anyway, but let's make sure. - cmd.add( "--omni" ); + cmd.add( "--omni" ); + + cmd.add( "--nchan" ); + cmd.add( String.valueOf( 1 ) ); // Only segment with --nchan to 1, and save only the channel to segment in temp directory + return Collections.unmodifiableList( cmd ); } diff --git a/src/main/java/fiji/plugin/trackmate/omnipose/advanced/AdvancedOmniposeDetectorConfigurationPanel.java b/src/main/java/fiji/plugin/trackmate/omnipose/advanced/AdvancedOmniposeDetectorConfigurationPanel.java index 43ff2fdd..d213d452 100644 --- a/src/main/java/fiji/plugin/trackmate/omnipose/advanced/AdvancedOmniposeDetectorConfigurationPanel.java +++ b/src/main/java/fiji/plugin/trackmate/omnipose/advanced/AdvancedOmniposeDetectorConfigurationPanel.java @@ -3,8 +3,6 @@ import static fiji.plugin.trackmate.cellpose.advanced.AdvancedCellposeDetectorFactory.KEY_CELL_PROB_THRESHOLD; import static fiji.plugin.trackmate.cellpose.advanced.AdvancedCellposeDetectorFactory.KEY_FLOW_THRESHOLD; import static fiji.plugin.trackmate.gui.Fonts.SMALL_FONT; -import static fiji.plugin.trackmate.omnipose.advanced.AdvancedOmniposeDetectorFactory.KEY_NB_CLASSES; - import java.awt.GridBagConstraints; import java.awt.Insets; import java.util.ArrayList; @@ -31,7 +29,6 @@ public class AdvancedOmniposeDetectorConfigurationPanel extends OmniposeDetector private static final String TITLE = AdvancedOmniposeDetectorFactory.NAME;; - protected final JComboBox< String > cmbboxNbClasses; private final StyleElements.BoundedDoubleElement flowThresholdEl = new StyleElements.BoundedDoubleElement( "Flow threshold", 0.0, 3.0 ) { @@ -75,36 +72,6 @@ public AdvancedOmniposeDetectorConfigurationPanel( final Settings settings, fina int gridy = 12; - /* - * Add model number of output classes. - */ - - final JLabel lblNbClasses = new JLabel( "N. output classes:" ); - lblNbClasses.setFont( SMALL_FONT ); - final GridBagConstraints gbcLblNbClasses = new GridBagConstraints(); - gbcLblNbClasses.anchor = GridBagConstraints.EAST; - gbcLblNbClasses.insets = new Insets( 0, 5, 5, 5 ); - gbcLblNbClasses.gridx = 0; - gbcLblNbClasses.gridy = gridy; - mainPanel.add( lblNbClasses, gbcLblNbClasses ); - - final int nbClassesMin = 2; - final int nbClassesMax = 4; - - final List< String > lNbClasses = new ArrayList< String >(); - for ( int nc = nbClassesMin; nc <= nbClassesMax; nc++ ) - lNbClasses.add( "" + nc ); - - cmbboxNbClasses = new JComboBox<>( new Vector<>( lNbClasses ) ); - cmbboxNbClasses.setFont( SMALL_FONT ); - final GridBagConstraints gbcSpinner = new GridBagConstraints(); - gbcSpinner.fill = GridBagConstraints.HORIZONTAL; - gbcSpinner.gridwidth = 2; - gbcSpinner.insets = new Insets( 0, 5, 5, 5 ); - gbcSpinner.gridx = 1; - gbcSpinner.gridy = gridy; - mainPanel.add( cmbboxNbClasses, gbcSpinner ); - /* * Add flow threshold. */ @@ -171,7 +138,6 @@ public void setSettings( final Map< String, Object > settings ) flowThresholdEl.update(); cellProbThresholdEl.set( ( double ) settings.get( KEY_CELL_PROB_THRESHOLD ) ); cellProbThresholdEl.update(); - cmbboxNbClasses.setSelectedIndex( ( int ) settings.get( KEY_NB_CLASSES ) ); } @Override @@ -180,7 +146,6 @@ public Map< String, Object > getSettings() final Map< String, Object > settings = super.getSettings(); settings.put( KEY_FLOW_THRESHOLD, flowThresholdEl.get() ); settings.put( KEY_CELL_PROB_THRESHOLD, cellProbThresholdEl.get() ); - settings.put( KEY_NB_CLASSES, cmbboxNbClasses.getSelectedIndex() + 2 ); return settings; } } diff --git a/src/main/java/fiji/plugin/trackmate/omnipose/advanced/AdvancedOmniposeDetectorFactory.java b/src/main/java/fiji/plugin/trackmate/omnipose/advanced/AdvancedOmniposeDetectorFactory.java index 43880258..586958f8 100644 --- a/src/main/java/fiji/plugin/trackmate/omnipose/advanced/AdvancedOmniposeDetectorFactory.java +++ b/src/main/java/fiji/plugin/trackmate/omnipose/advanced/AdvancedOmniposeDetectorFactory.java @@ -83,16 +83,7 @@ public class AdvancedOmniposeDetectorFactory< T extends RealType< T > & NativeTy + "

" + "Documentation for this module " + "on the ImageJ Wiki." - + ""; - - /** - * The key to the parameter that store the output number of classes. - */ - public static final String KEY_NB_CLASSES = "NB_CLASSES"; - - public static final Integer DEFAULT_NB_CLASSES = Integer.valueOf( 0 ); - - + + ""; /* * METHODS */ @@ -120,7 +111,6 @@ public SpotGlobalDetector< T > getDetector( final Interval interval ) final double flowThreshold = ( Double ) settings.get( KEY_FLOW_THRESHOLD ); final double cellProbThreshold = ( Double ) settings.get( KEY_CELL_PROB_THRESHOLD ); - final int nbClasses = (Integer) settings.get( KEY_NB_CLASSES ); final AdvancedOmniposeSettings cellposeSettings = AdvancedOmniposeSettings .create() @@ -134,7 +124,6 @@ public SpotGlobalDetector< T > getDetector( final Interval interval ) .simplifyContours( simplifyContours ) .flowThreshold( flowThreshold ) .cellProbThreshold( cellProbThreshold ) - .nbClasses(nbClasses) .get(); // Logger. @@ -152,7 +141,6 @@ public boolean marshall( final Map< String, Object > settings, final Element ele final StringBuilder errorHolder = new StringBuilder(); boolean ok = writeAttribute( settings, element, KEY_FLOW_THRESHOLD, Double.class, errorHolder ); ok = ok && writeAttribute( settings, element, KEY_CELL_PROB_THRESHOLD, Double.class, errorHolder ); - ok = ok && writeAttribute( settings, element, KEY_NB_CLASSES, Integer.class, errorHolder ); if ( !ok ) errorMessage = errorHolder.toString(); return ok; @@ -173,7 +161,6 @@ public boolean unmarshall( final Element element, final Map< String, Object > se ok = ok && readBooleanAttribute( element, settings, KEY_SIMPLIFY_CONTOURS, errorHolder ); ok = ok && readDoubleAttribute( element, settings, KEY_FLOW_THRESHOLD, errorHolder ); ok = ok && readDoubleAttribute( element, settings, KEY_CELL_PROB_THRESHOLD, errorHolder ); - ok = ok && readDoubleAttribute( element, settings, KEY_NB_CLASSES, errorHolder ); // Read model. final String str = element.getAttributeValue( KEY_OMNIPOSE_MODEL ); @@ -199,7 +186,6 @@ public Map< String, Object > getDefaultSettings() final Map< String, Object > settings = super.getDefaultSettings(); settings.put( KEY_FLOW_THRESHOLD, DEFAULT_FLOW_THRESHOLD ); settings.put( KEY_CELL_PROB_THRESHOLD, DEFAULT_CELL_PROB_THRESHOLD ); - settings.put( KEY_NB_CLASSES, DEFAULT_NB_CLASSES ); return settings; } @@ -218,7 +204,6 @@ public boolean checkSettings( final Map< String, Object > settings ) ok = ok & checkParameter( settings, KEY_SIMPLIFY_CONTOURS, Boolean.class, errorHolder ); ok = ok & checkParameter( settings, KEY_FLOW_THRESHOLD, Double.class, errorHolder ); ok = ok & checkParameter( settings, KEY_CELL_PROB_THRESHOLD, Double.class, errorHolder ); - ok = ok & checkParameter( settings, KEY_NB_CLASSES, Integer.class, errorHolder ); // If we have a logger, test it is of the right class. final Object loggerObj = settings.get( KEY_LOGGER ); @@ -241,8 +226,7 @@ public boolean checkSettings( final Map< String, Object > settings ) KEY_OMNIPOSE_CUSTOM_MODEL_FILEPATH, KEY_LOGGER, KEY_FLOW_THRESHOLD, - KEY_CELL_PROB_THRESHOLD, - KEY_NB_CLASSES); + KEY_CELL_PROB_THRESHOLD); ok = ok & checkMapKeys( settings, mandatoryKeys, optionalKeys, errorHolder ); if ( !ok ) errorMessage = errorHolder.toString(); diff --git a/src/main/java/fiji/plugin/trackmate/omnipose/advanced/AdvancedOmniposeSettings.java b/src/main/java/fiji/plugin/trackmate/omnipose/advanced/AdvancedOmniposeSettings.java index 5d238ed6..6da86cf0 100644 --- a/src/main/java/fiji/plugin/trackmate/omnipose/advanced/AdvancedOmniposeSettings.java +++ b/src/main/java/fiji/plugin/trackmate/omnipose/advanced/AdvancedOmniposeSettings.java @@ -11,8 +11,6 @@ public class AdvancedOmniposeSettings extends OmniposeSettings private final double flowThreshold; private final double cellProbThreshold; - - private final int nbClasses; public AdvancedOmniposeSettings( final String omniposePythonPath, @@ -24,13 +22,11 @@ public AdvancedOmniposeSettings( final boolean useGPU, final boolean simplifyContours, final double flowThreshold, - final double cellProbThreshold, - final int nbClasses) + final double cellProbThreshold) { super( omniposePythonPath, model, customModelPath, chan, chan2, diameter, useGPU, simplifyContours ); this.flowThreshold = flowThreshold; this.cellProbThreshold = cellProbThreshold; - this.nbClasses = nbClasses; } @Override @@ -45,10 +41,7 @@ public List< String > toCmdLine( final String imagesDir, final boolean is3D, fin */ cmd.add( "--mask_threshold" ); cmd.add( String.valueOf( cellProbThreshold ) ); - - cmd.add( "--nclasses" ); - cmd.add( String.valueOf( nbClasses ) ); - + return Collections.unmodifiableList( cmd ); } @@ -63,8 +56,6 @@ public static final class Builder extends OmniposeSettings.Builder private double flowThreshold = 0.4; private double cellProbThreshold = 0.0; - - private int nbClasses = 4; public Builder flowThreshold( final double flowThreshold ) { @@ -77,13 +68,7 @@ public Builder cellProbThreshold( final double cellProbThreshold ) this.cellProbThreshold = cellProbThreshold; return this; } - - public Builder nbClasses( final int nbClasses ) - { - this.nbClasses = nbClasses; - return this; - } - + @Override public Builder channel1( final int ch ) { @@ -153,8 +138,7 @@ public AdvancedOmniposeSettings get() useGPU, simplifyContours, flowThreshold, - cellProbThreshold, - nbClasses); + cellProbThreshold); } } }