Skip to content

Commit

Permalink
Merge pull request #44 from petebankhead/batch-performance
Browse files Browse the repository at this point in the history
Support custom batch sizes
  • Loading branch information
petebankhead authored Nov 10, 2023
2 parents d9b2c72 + 03dea46 commit c88101e
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 58 deletions.
68 changes: 34 additions & 34 deletions src/main/java/qupath/ext/wsinfer/TileLoader.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,43 +110,43 @@ public int getNumWorkers() {
}

private TileBatch nextBatch() {
List<Image> inputs = new ArrayList<>();
List<PathObject> pathObjectBatch = new ArrayList<>();
while (!pathObjects.isEmpty() && inputs.size() < maxBatchSize) {
PathObject pathObject = pathObjects.poll();
if (pathObject == null) {
break;
}
List<Image> inputs = new ArrayList<>();
List<PathObject> pathObjectBatch = new ArrayList<>();
while (!pathObjects.isEmpty() && inputs.size() < maxBatchSize) {
PathObject pathObject = pathObjects.poll();
if (pathObject == null) {
break;
}

ROI roi = pathObject.getROI();
int x = (int) Math.round(roi.getCentroidX() - width / 2.0);
int y = (int) Math.round(roi.getCentroidY() - height / 2.0);
try {
BufferedImage img;
if (x < 0 || y < 0 || x + width >= server.getWidth() || y + height >= server.getHeight()) {
// Handle out-of-bounds coordinates
// This reuses code from DnnTools.readPatch, but is not ideal since it uses a trip through OpenCV
var mat = DnnTools.readPatch(server, roi, downsample, width, height);
img = OpenCVTools.matToBufferedImage(mat);
mat.close();
logger.warn("Detected out-of-bounds tile request - results may be influenced by padding ({}, {}, {}, {})", x, y, width, height);
} else {
// Handle normal case of within-bounds coordinates
img = server.readRegion(downsample, x, y, width, height);
if (resizeWidth > 0 && resizeHeight > 0)
img = BufferedImageTools.resize(img, resizeWidth, resizeHeight, true);
}
Image input = BufferedImageFactory.getInstance().fromImage(img);
pathObjectBatch.add(pathObject);
inputs.add(input);
} catch (IOException e) {
logger.error("Failed to read tile: {}", e.getMessage(), e);
ROI roi = pathObject.getROI();
int x = (int) Math.round(roi.getCentroidX() - width / 2.0);
int y = (int) Math.round(roi.getCentroidY() - height / 2.0);
try {
BufferedImage img;
if (x < 0 || y < 0 || x + width >= server.getWidth() || y + height >= server.getHeight()) {
// Handle out-of-bounds coordinates
// This reuses code from DnnTools.readPatch, but is not ideal since it uses a trip through OpenCV
var mat = DnnTools.readPatch(server, roi, downsample, width, height);
img = OpenCVTools.matToBufferedImage(mat);
mat.close();
logger.warn("Detected out-of-bounds tile request - results may be influenced by padding ({}, {}, {}, {})", x, y, width, height);
} else {
// Handle normal case of within-bounds coordinates
img = server.readRegion(downsample, x, y, width, height);
if (resizeWidth > 0 && resizeHeight > 0)
img = BufferedImageTools.resize(img, resizeWidth, resizeHeight, true);
}
Image input = BufferedImageFactory.getInstance().fromImage(img);
pathObjectBatch.add(pathObject);
inputs.add(input);
} catch (IOException e) {
logger.error("Failed to read tile: {}", e.getMessage(), e);
}
if (inputs.isEmpty())
return new TileBatch();
else
return new TileBatch(inputs, pathObjectBatch);
}
if (inputs.isEmpty())
return new TileBatch();
else
return new TileBatch(inputs, pathObjectBatch);
}

class TileWorker implements Runnable {
Expand Down
41 changes: 25 additions & 16 deletions src/main/java/qupath/ext/wsinfer/WSInfer.java
Original file line number Diff line number Diff line change
Expand Up @@ -205,27 +205,27 @@ public static void runInference(ImageData<BufferedImage> imageData, WSInferModel
// Number of workers who will be busy fetching tiles for us while we're busy inferring
int nWorkers = Math.max(1, WSInferPrefs.numWorkersProperty().getValue());

// Make a guess at a batch size currently... *must* be 1 for MPS
// FIXME: Make batch size adjustable when using a GPU (or CPU?)
int batchSize = isMPS(device) || !device.isGpu() ? 1 : 4;
if (device == Device.cpu())
batchSize = 4;
// Set batch size
// Previously, this *had* to be 1 for MPS - but since DJL 0.24.0 that doesn't seem necessary any more
int batchSize = Math.max(1, WSInferPrefs.batchSizeProperty().getValue());

// Number of tiles each worker should prefetch
int numPrefetch = (int)Math.max(2, Math.ceil((double)batchSize * 2 / nWorkers));

var tileLoader = TileLoader.builder()
.batchSize(batchSize)
.numWorkers(nWorkers)
.numPrefetch(numPrefetch)
.server(server)
.tileSize(width, height)
.downsample(downsample)
.tiles(tiles)
.resizeTile(resize, resize)
.build();

double completedTiles = 0;
double totalTiles = tiles.size();
progressListener.updateProgress(
String.format(resources.getString("ui.processing-progress"), Math.round(completedTiles), Math.round(totalTiles)),
completedTiles/totalTiles);
int completedTiles = 0;
int totalTiles = tiles.size();
updateProgressForTiles(progressListener, completedTiles, totalTiles, startTime);

try (Predictor<Image, Classifications> predictor = model.newPredictor()) {
var batchQueue = tileLoader.getBatchQueue();
Expand Down Expand Up @@ -257,16 +257,12 @@ public static void runInference(ImageData<BufferedImage> imageData, WSInferModel
pathObject.setPathClass(PathClass.fromString(name));
}
completedTiles += inputs.size();
progressListener.updateProgress(
String.format(resources.getString("ui.processing-progress"), Math.round(completedTiles), Math.round(totalTiles)),
completedTiles/totalTiles);
updateProgressForTiles(progressListener, completedTiles, totalTiles, startTime);
}
}
long endTime = System.currentTimeMillis();
long duration = endTime - startTime;
progressListener.updateProgress(
String.format(resources.getString("ui.processing-completed"), Math.round(completedTiles), Math.round(totalTiles)),
1.0);
updateProgressForTiles(progressListener, completedTiles, totalTiles, startTime);

imageData.getHierarchy().fireObjectClassificationsChangedEvent(WSInfer.class, tiles);
long durationSeconds = duration/1000;
Expand All @@ -283,6 +279,19 @@ public static void runInference(ImageData<BufferedImage> imageData, WSInferModel
}
}

private static void updateProgressForTiles(ProgressListener progress, int completedTiles, int totalTiles, long startTime) {
double timeSeconds = (System.currentTimeMillis() - startTime) / 1000.0;
if (completedTiles == totalTiles)
progress.updateProgress(
String.format(resources.getString("ui.processing-completed"), completedTiles, totalTiles, completedTiles/timeSeconds),
(double)completedTiles/totalTiles);
else {
progress.updateProgress(
String.format(resources.getString("ui.processing-progress"), completedTiles, totalTiles, completedTiles / timeSeconds),
(double)completedTiles / totalTiles);
}
}

/**
* Check if a specified device corresponds to using the Metal Performance Shaders (MPS) backend (Apple Silicon)
* @param device
Expand Down
15 changes: 14 additions & 1 deletion src/main/java/qupath/ext/wsinfer/ui/WSInferController.java
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ public class WSInferController {
@FXML
private Spinner<Integer> spinnerNumWorkers;
@FXML
private Spinner<Integer> spinnerBatchSize;
@FXML
private TextField tfModelDirectory;
@FXML
private TextField localModelDirectory;
Expand Down Expand Up @@ -157,6 +159,7 @@ private void initialize() {
configureAvailableDevices();
configureModelDirectory();
configureNumWorkers();
configureBatchSize();

configureMessageLabel();
configureRunInferenceButton();
Expand Down Expand Up @@ -278,6 +281,10 @@ private void configureNumWorkers() {
spinnerNumWorkers.getValueFactory().valueProperty().bindBidirectional(WSInferPrefs.numWorkersProperty());
}

private void configureBatchSize() {
spinnerBatchSize.getValueFactory().valueProperty().bindBidirectional(WSInferPrefs.batchSizeProperty());
}

/**
* Try to run inference on the current image using the current model and parameters.
*/
Expand Down Expand Up @@ -493,7 +500,7 @@ private static class WSInferTask extends Task<Void> {

private final ImageData<BufferedImage> imageData;
private final WSInferModel model;
private final ProgressListener progressListener;
private final WSInferProgressDialog progressListener;

private WSInferTask(ImageData<BufferedImage> imageData, WSInferModel model) {
this.imageData = imageData;
Expand All @@ -504,6 +511,12 @@ private WSInferTask(ImageData<BufferedImage> imageData, WSInferModel model) {
e.consume();
}
});
this.stateProperty().addListener(this::handleStateChange);
}

private void handleStateChange(ObservableValue<? extends Worker.State> value, Worker.State oldValue, Worker.State newValue) {
if (progressListener != null && newValue == Worker.State.CANCELLED)
progressListener.cancel();
}

private String getDialogTitle() {
Expand Down
15 changes: 14 additions & 1 deletion src/main/java/qupath/ext/wsinfer/ui/WSInferPrefs.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,14 @@ public class WSInferPrefs {

private static final Property<Integer> numWorkersProperty = PathPrefs.createPersistentPreference(
"wsinfer.numWorkers",
1
Math.min(4, Runtime.getRuntime().availableProcessors())
).asObject();

private static final Property<Integer> batchSizeProperty = PathPrefs.createPersistentPreference(
"wsinfer.batchSize",
4
).asObject();

private static StringProperty localDirectoryProperty = PathPrefs.createPersistentPreference(
"wsinfer.localDirectory",
null);
Expand All @@ -68,6 +74,13 @@ public static Property<Integer> numWorkersProperty() {
return numWorkersProperty;
}

/**
* Integer storing the batch size for inference.
*/
public static Property<Integer> batchSizeProperty() {
return batchSizeProperty;
}

/**
* String storing the preferred directory for user-supplied model folders.
*/
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/qupath/ext/wsinfer/ui/WSInferProgressDialog.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class WSInferProgressDialog extends AnchorPane implements ProgressListener {
@FXML
private Button btnCancel;

private boolean isCancelled = false;

public WSInferProgressDialog(Window owner, EventHandler<ActionEvent> cancelHandler) {
URL url = getClass().getResource("progress_dialog.fxml");
ResourceBundle resources = ResourceBundle.getBundle("qupath.ext.wsinfer.ui.strings");
Expand All @@ -76,6 +78,8 @@ public WSInferProgressDialog(Window owner, EventHandler<ActionEvent> cancelHandl

@Override
public void updateProgress(String message, Double progress) {
if (isCancelled)
return;
if (Platform.isFxApplicationThread()) {
if (message != null)
progressLabel.setText(message);
Expand All @@ -92,4 +96,15 @@ public void updateProgress(String message, Double progress) {
}
}

/**
* Immediately cancel and hide the progress dialog.
* Subsequent calls to updateProgress will be ignored.
*/
public void cancel() {
isCancelled = true;
stage.hide();
progressLabel.setText("Cancelled");
progressBar.setProgress(1.0);
}

}
6 changes: 4 additions & 2 deletions src/main/resources/qupath/ext/wsinfer/ui/strings.properties
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ ui.options.localModelDirectory = User model directory:
ui.options.localModelDirectory.tooltip = Choose the directory where user-created models can be loaded from
ui.options.pworkers = Number of parallel tile loaders:
ui.options.pworkers.tooltip = Choose the desired number of threads used to request tiles for inference
ui.options.batchSize = Batch size:
ui.options.batchSize.tooltip = Choose the batch size for inference

# Model directories
ui.model-directory.choose-directory = Choose directory
Expand All @@ -67,8 +69,8 @@ ui.model-directory.found-n-local-models = %d local models found
## Other Windows
# Processing Window and progress pop-ups
ui.processing = Processing tiles
ui.processing-progress = Processing %d/%d tiles
ui.processing-completed = Completed %d/%d tiles
ui.processing-progress = Processing %d/%d tiles (%.1f per second)
ui.processing-completed = Completed %d/%d tiles (%.1f per second)
ui.cancel = Cancel
ui.popup.fetching = Downloading model: %s
ui.popup.available = Model available: %s
Expand Down
21 changes: 17 additions & 4 deletions src/main/resources/qupath/ext/wsinfer/ui/wsinfer_control.fxml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
</TitledPane>

<!-- Hardware Pane************************************************************-->
<TitledPane fx:id="pane3" animated="false" maxHeight="Infinity" text="%ui.options.pane" VBox.vgrow="NEVER">
<TitledPane fx:id="pane3" animated="false" expanded="false" maxHeight="Infinity" text="%ui.options.pane" VBox.vgrow="NEVER">
<VBox alignment="TOP_CENTER" spacing="7.5" styleClass="standard-padding">
<children>
<HBox alignment="CENTER" styleClass="standard-spacing">
Expand All @@ -163,13 +163,26 @@
</ChoiceBox>
</children>
</HBox>
<HBox alignment="CENTER" styleClass="standard-spacing">
<children>
<Label styleClass="regular" text="%ui.options.batchSize" />
<Spinner fx:id="spinnerBatchSize" prefWidth="75.0">
<tooltip>
<Tooltip text="%ui.options.batchSize.tooltip" />
</tooltip>
<valueFactory>
<SpinnerValueFactory.IntegerSpinnerValueFactory initialValue="1" max="512" min="1" />
</valueFactory>
</Spinner>
</children>
</HBox>
<HBox alignment="CENTER" styleClass="standard-spacing">
<children>
<Label styleClass="regular" text="%ui.options.pworkers" />
<Spinner fx:id="spinnerNumWorkers" prefWidth="75.0">
<tooltip><Tooltip text="%ui.options.pworkers.tooltip" /></tooltip>
<valueFactory>
<SpinnerValueFactory.IntegerSpinnerValueFactory initialValue="1" max="8" min="1" />
<SpinnerValueFactory.IntegerSpinnerValueFactory initialValue="1" max="128" min="1" />
</valueFactory>
</Spinner>
</children>
Expand All @@ -178,7 +191,7 @@
<VBox alignment="CENTER" styleClass="standard-spacing">
<VBox alignment="CENTER" styleClass="standard-spacing">
<children>
<Label styleClass="regular" text="%ui.options.directory" labelFor="$tfModelDirectory" />
<Label styleClass="regular" text="%ui.options.directory" />
<HBox styleClass="standard-spacing">
<children>
<TextField fx:id="tfModelDirectory" HBox.hgrow="ALWAYS">
Expand All @@ -200,7 +213,7 @@
</VBox>
<VBox alignment="CENTER" styleClass="standard-spacing">
<children>
<Label styleClass="regular" text="%ui.options.localModelDirectory" labelFor="$localModelDirectory" />
<Label styleClass="regular" text="%ui.options.localModelDirectory" />
<HBox styleClass="standard-spacing">
<children>
<TextField fx:id="localModelDirectory" HBox.hgrow="ALWAYS">
Expand Down

0 comments on commit c88101e

Please sign in to comment.