Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifs to use custom models with omnipose version 1.0.6 #13

Open
wants to merge 1 commit into
base: v8
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ public List< String > toCmdLine( final String imagesDir, final boolean is3D, fin
* Cellpose command line arguments.
*/

cmd.add( "--verbose" );

// Target dir.
cmd.add( "--dir" );
cmd.add( imagesDir );
Expand Down
95 changes: 77 additions & 18 deletions src/main/java/fiji/plugin/trackmate/cellpose/CellposeDetector.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import fiji.plugin.trackmate.detection.DetectionUtils;
import fiji.plugin.trackmate.detection.LabelImageDetectorFactory;
import fiji.plugin.trackmate.detection.SpotGlobalDetector;
import fiji.plugin.trackmate.omnipose.advanced.AdvancedOmniposeSettings;
import fiji.plugin.trackmate.omnipose.OmniposeSettings;
import fiji.plugin.trackmate.util.TMUtils;
import ij.IJ;
import ij.ImagePlus;
Expand All @@ -47,6 +47,9 @@
import ij.plugin.Duplicator;
import ij.process.ImageConverter;
import ij.process.StackConverter;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.stream.Collectors;
import net.imagej.ImgPlus;
import net.imagej.axis.Axes;
import net.imglib2.Interval;
Expand Down Expand Up @@ -202,7 +205,7 @@ public boolean process()
{
results = executors.invokeAll( processes );
for ( final Future< String > future : results )
resultDirs.add( future.get() );
resultDirs.add( future.get() );
}
catch ( final InterruptedException | ExecutionException e )
{
Expand Down Expand Up @@ -558,8 +561,8 @@ public String call() throws Exception
for ( final ImagePlus imp : imps )
{
final String name = imp.getShortTitle() + ".tif";
// If we are running an advanced omnipose detector, just save the segmentation channel as tmp image
if (cellposeSettings instanceof AdvancedOmniposeSettings) {
// If we are running an omnipose detector, just save the segmentation channel as tmp image
if (cellposeSettings instanceof OmniposeSettings) {
ImagePlus chanImp = new Duplicator().run(imp, cellposeSettings.chan, cellposeSettings.chan, 0, 0, 0, 0);
IJ.saveAsTiff( chanImp, Paths.get( tmpDir.toString(), name ).toString() );
} else {
Expand All @@ -572,19 +575,75 @@ public String call() throws Exception
*/

try
{
final List< String > cmd = cellposeSettings.toCmdLine( tmpDir.toString(), is3D, anisotropy );
logger.setStatus( "Running " + cellposeSettings.getExecutableName() );
logger.log( "Running " + cellposeSettings.getExecutableName() + " with args:\n" );
logger.log( String.join( " ", cmd ) );
logger.log( "\n" );
final ProcessBuilder pb = new ProcessBuilder( cmd );
pb.redirectOutput( ProcessBuilder.Redirect.INHERIT );
pb.redirectError( ProcessBuilder.Redirect.INHERIT );

process = pb.start();
process.waitFor();
}
{
if (cellposeSettings instanceof OmniposeSettings)
{
final List< String > cmdOmni = new ArrayList<>( cellposeSettings.toCmdLine( tmpDir.toString(), is3D, anisotropy ) );

int nClasses = 2; // 2 by default in custom models
cmdOmni.add( "--nclasses" );
cmdOmni.add( String.valueOf( nClasses ) );

logger.setStatus( "Running " + cellposeSettings.getExecutableName() );
logger.log( "Running " + cellposeSettings.getExecutableName() + " with args:\n" );
logger.log( String.join( " ", cmdOmni ) );
logger.log( "\n" );

final ProcessBuilder pbOmni = new ProcessBuilder( cmdOmni );
// pbOmni.redirectOutput( ProcessBuilder.Redirect.INHERIT );
// pbOmni.redirectError( ProcessBuilder.Redirect.INHERIT );
pbOmni.redirectErrorStream(true);
process = pbOmni.start();

BufferedReader reader = new BufferedReader( new InputStreamReader(process.getInputStream()));
// reader.lines().forEach(line -> System.out.println(line));
String pythonErrorOutput = reader.lines().collect(Collectors.joining());
if (pythonErrorOutput.contains( "size mismatch for output.2.bias:" ) )
{
if( pythonErrorOutput.contains( "copying a param with shape torch.Size([4]) from checkpoint" ) )
{
nClasses = 3;
logger.log( "Regarding the model loaded, --nclasses argument should be set to " + String.valueOf( nClasses ) + "\n" );
}
else
{
nClasses = 4;
logger.log( "Regarding the model loaded, --nclasses argument should be set to " + String.valueOf( nClasses ) + "\n" );
}

cmdOmni.remove(cmdOmni.size() - 1);
cmdOmni.add(String.valueOf( nClasses ));

logger.log( "Running " + cellposeSettings.getExecutableName() + " with args:\n" );
logger.log( String.join( " ", cmdOmni ) );
logger.log( "\n" );
final ProcessBuilder updatedPbOmni = new ProcessBuilder( cmdOmni );
updatedPbOmni.redirectOutput( ProcessBuilder.Redirect.INHERIT );
updatedPbOmni.redirectError( ProcessBuilder.Redirect.INHERIT );

process = updatedPbOmni.start();
process.waitFor();
}
else if (pythonErrorOutput.contains("pretrained model has incorrect path") )
{
logger.log( "Pretrained model has incorrect path \n" );
}
}
else
{
final List< String > cmd = cellposeSettings.toCmdLine( tmpDir.toString(), is3D, anisotropy );
logger.setStatus( "Running " + cellposeSettings.getExecutableName() );
logger.log( "Running " + cellposeSettings.getExecutableName() + " with args:\n" );
logger.log( String.join( " ", cmd ) );
logger.log( "\n" );
final ProcessBuilder pb = new ProcessBuilder( cmd );
pb.redirectOutput( ProcessBuilder.Redirect.INHERIT );
pb.redirectError( ProcessBuilder.Redirect.INHERIT );

process = pb.start();
process.waitFor();
}
}
catch ( final IOException e )
{
final String msg = e.getMessage();
Expand All @@ -593,7 +652,7 @@ public String call() throws Exception
errorMessage = baseErrorMessage + "Problem running " + cellposeSettings.getExecutableName() + ":\n"
+ "The executable does not have the file permission to run.\n"
+ "Please see https://github.com/MouseLand/cellpose#run-cellpose-without-local-python-installation for more information.\n";
}
}
else
{
errorMessage = baseErrorMessage + "Problem running " + cellposeSettings.getExecutableName() + ":\n" + e.getMessage();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import fiji.plugin.trackmate.cellpose.AbstractCellposeSettings;

public class OmniposeSettings extends AbstractCellposeSettings
{
{

public OmniposeSettings(
final String omniposePythonPath,
Expand All @@ -48,7 +48,11 @@ public List< String > toCmdLine( final String imagesDir, final boolean is3D, fin
{
final List< String > cmd = new ArrayList<>( super.toCmdLine( imagesDir, is3D, anisotropy ) );
// omnipose executable adds it anyway, but let's make sure.
cmd.add( "--omni" );
cmd.add( "--omni" );

cmd.add( "--nchan" );
cmd.add( String.valueOf( 1 ) ); // Only segment with --nchan to 1, and save only the channel to segment in temp directory

return Collections.unmodifiableList( cmd );
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import static fiji.plugin.trackmate.cellpose.advanced.AdvancedCellposeDetectorFactory.KEY_CELL_PROB_THRESHOLD;
import static fiji.plugin.trackmate.cellpose.advanced.AdvancedCellposeDetectorFactory.KEY_FLOW_THRESHOLD;
import static fiji.plugin.trackmate.gui.Fonts.SMALL_FONT;
import static fiji.plugin.trackmate.omnipose.advanced.AdvancedOmniposeDetectorFactory.KEY_NB_CLASSES;

import java.awt.GridBagConstraints;
import java.awt.Insets;
import java.util.ArrayList;
Expand All @@ -31,7 +29,6 @@ public class AdvancedOmniposeDetectorConfigurationPanel extends OmniposeDetector

private static final String TITLE = AdvancedOmniposeDetectorFactory.NAME;;

protected final JComboBox< String > cmbboxNbClasses;

private final StyleElements.BoundedDoubleElement flowThresholdEl = new StyleElements.BoundedDoubleElement( "Flow threshold", 0.0, 3.0 )
{
Expand Down Expand Up @@ -75,36 +72,6 @@ public AdvancedOmniposeDetectorConfigurationPanel( final Settings settings, fina

int gridy = 12;

/*
* Add model number of output classes.
*/

final JLabel lblNbClasses = new JLabel( "N. output classes:" );
lblNbClasses.setFont( SMALL_FONT );
final GridBagConstraints gbcLblNbClasses = new GridBagConstraints();
gbcLblNbClasses.anchor = GridBagConstraints.EAST;
gbcLblNbClasses.insets = new Insets( 0, 5, 5, 5 );
gbcLblNbClasses.gridx = 0;
gbcLblNbClasses.gridy = gridy;
mainPanel.add( lblNbClasses, gbcLblNbClasses );

final int nbClassesMin = 2;
final int nbClassesMax = 4;

final List< String > lNbClasses = new ArrayList< String >();
for ( int nc = nbClassesMin; nc <= nbClassesMax; nc++ )
lNbClasses.add( "" + nc );

cmbboxNbClasses = new JComboBox<>( new Vector<>( lNbClasses ) );
cmbboxNbClasses.setFont( SMALL_FONT );
final GridBagConstraints gbcSpinner = new GridBagConstraints();
gbcSpinner.fill = GridBagConstraints.HORIZONTAL;
gbcSpinner.gridwidth = 2;
gbcSpinner.insets = new Insets( 0, 5, 5, 5 );
gbcSpinner.gridx = 1;
gbcSpinner.gridy = gridy;
mainPanel.add( cmbboxNbClasses, gbcSpinner );

/*
* Add flow threshold.
*/
Expand Down Expand Up @@ -171,7 +138,6 @@ public void setSettings( final Map< String, Object > settings )
flowThresholdEl.update();
cellProbThresholdEl.set( ( double ) settings.get( KEY_CELL_PROB_THRESHOLD ) );
cellProbThresholdEl.update();
cmbboxNbClasses.setSelectedIndex( ( int ) settings.get( KEY_NB_CLASSES ) );
}

@Override
Expand All @@ -180,7 +146,6 @@ public Map< String, Object > getSettings()
final Map< String, Object > settings = super.getSettings();
settings.put( KEY_FLOW_THRESHOLD, flowThresholdEl.get() );
settings.put( KEY_CELL_PROB_THRESHOLD, cellProbThresholdEl.get() );
settings.put( KEY_NB_CLASSES, cmbboxNbClasses.getSelectedIndex() + 2 );
return settings;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,7 @@ public class AdvancedOmniposeDetectorFactory< T extends RealType< T > & NativeTy
+ "<p>"
+ "Documentation for this module "
+ "<a href=\"https://imagej.net/plugins/trackmate/trackmate-advanced-omnipose\">on the ImageJ Wiki</a>."
+ "</html>";

/**
* The key to the parameter that store the output number of classes.
*/
public static final String KEY_NB_CLASSES = "NB_CLASSES";

public static final Integer DEFAULT_NB_CLASSES = Integer.valueOf( 0 );


+ "</html>";
/*
* METHODS
*/
Expand Down Expand Up @@ -120,7 +111,6 @@ public SpotGlobalDetector< T > getDetector( final Interval interval )

final double flowThreshold = ( Double ) settings.get( KEY_FLOW_THRESHOLD );
final double cellProbThreshold = ( Double ) settings.get( KEY_CELL_PROB_THRESHOLD );
final int nbClasses = (Integer) settings.get( KEY_NB_CLASSES );

final AdvancedOmniposeSettings cellposeSettings = AdvancedOmniposeSettings
.create()
Expand All @@ -134,7 +124,6 @@ public SpotGlobalDetector< T > getDetector( final Interval interval )
.simplifyContours( simplifyContours )
.flowThreshold( flowThreshold )
.cellProbThreshold( cellProbThreshold )
.nbClasses(nbClasses)
.get();

// Logger.
Expand All @@ -152,7 +141,6 @@ public boolean marshall( final Map< String, Object > settings, final Element ele
final StringBuilder errorHolder = new StringBuilder();
boolean ok = writeAttribute( settings, element, KEY_FLOW_THRESHOLD, Double.class, errorHolder );
ok = ok && writeAttribute( settings, element, KEY_CELL_PROB_THRESHOLD, Double.class, errorHolder );
ok = ok && writeAttribute( settings, element, KEY_NB_CLASSES, Integer.class, errorHolder );
if ( !ok )
errorMessage = errorHolder.toString();
return ok;
Expand All @@ -173,7 +161,6 @@ public boolean unmarshall( final Element element, final Map< String, Object > se
ok = ok && readBooleanAttribute( element, settings, KEY_SIMPLIFY_CONTOURS, errorHolder );
ok = ok && readDoubleAttribute( element, settings, KEY_FLOW_THRESHOLD, errorHolder );
ok = ok && readDoubleAttribute( element, settings, KEY_CELL_PROB_THRESHOLD, errorHolder );
ok = ok && readDoubleAttribute( element, settings, KEY_NB_CLASSES, errorHolder );

// Read model.
final String str = element.getAttributeValue( KEY_OMNIPOSE_MODEL );
Expand All @@ -199,7 +186,6 @@ public Map< String, Object > getDefaultSettings()
final Map< String, Object > settings = super.getDefaultSettings();
settings.put( KEY_FLOW_THRESHOLD, DEFAULT_FLOW_THRESHOLD );
settings.put( KEY_CELL_PROB_THRESHOLD, DEFAULT_CELL_PROB_THRESHOLD );
settings.put( KEY_NB_CLASSES, DEFAULT_NB_CLASSES );
return settings;
}

Expand All @@ -218,7 +204,6 @@ public boolean checkSettings( final Map< String, Object > settings )
ok = ok & checkParameter( settings, KEY_SIMPLIFY_CONTOURS, Boolean.class, errorHolder );
ok = ok & checkParameter( settings, KEY_FLOW_THRESHOLD, Double.class, errorHolder );
ok = ok & checkParameter( settings, KEY_CELL_PROB_THRESHOLD, Double.class, errorHolder );
ok = ok & checkParameter( settings, KEY_NB_CLASSES, Integer.class, errorHolder );

// If we have a logger, test it is of the right class.
final Object loggerObj = settings.get( KEY_LOGGER );
Expand All @@ -241,8 +226,7 @@ public boolean checkSettings( final Map< String, Object > settings )
KEY_OMNIPOSE_CUSTOM_MODEL_FILEPATH,
KEY_LOGGER,
KEY_FLOW_THRESHOLD,
KEY_CELL_PROB_THRESHOLD,
KEY_NB_CLASSES);
KEY_CELL_PROB_THRESHOLD);
ok = ok & checkMapKeys( settings, mandatoryKeys, optionalKeys, errorHolder );
if ( !ok )
errorMessage = errorHolder.toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ public class AdvancedOmniposeSettings extends OmniposeSettings
private final double flowThreshold;

private final double cellProbThreshold;

private final int nbClasses;

public AdvancedOmniposeSettings(
final String omniposePythonPath,
Expand All @@ -24,13 +22,11 @@ public AdvancedOmniposeSettings(
final boolean useGPU,
final boolean simplifyContours,
final double flowThreshold,
final double cellProbThreshold,
final int nbClasses)
final double cellProbThreshold)
{
super( omniposePythonPath, model, customModelPath, chan, chan2, diameter, useGPU, simplifyContours );
this.flowThreshold = flowThreshold;
this.cellProbThreshold = cellProbThreshold;
this.nbClasses = nbClasses;
}

@Override
Expand All @@ -45,10 +41,7 @@ public List< String > toCmdLine( final String imagesDir, final boolean is3D, fin
*/
cmd.add( "--mask_threshold" );
cmd.add( String.valueOf( cellProbThreshold ) );

cmd.add( "--nclasses" );
cmd.add( String.valueOf( nbClasses ) );


return Collections.unmodifiableList( cmd );
}

Expand All @@ -63,8 +56,6 @@ public static final class Builder extends OmniposeSettings.Builder
private double flowThreshold = 0.4;

private double cellProbThreshold = 0.0;

private int nbClasses = 4;

public Builder flowThreshold( final double flowThreshold )
{
Expand All @@ -77,13 +68,7 @@ public Builder cellProbThreshold( final double cellProbThreshold )
this.cellProbThreshold = cellProbThreshold;
return this;
}

public Builder nbClasses( final int nbClasses )
{
this.nbClasses = nbClasses;
return this;
}


@Override
public Builder channel1( final int ch )
{
Expand Down Expand Up @@ -153,8 +138,7 @@ public AdvancedOmniposeSettings get()
useGPU,
simplifyContours,
flowThreshold,
cellProbThreshold,
nbClasses);
cellProbThreshold);
}
}
}