Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ plugins {
id 'io.nextflow.nextflow-plugin' version '1.0.0-beta.9'
}

version = '0.2.0'
version = '0.3.0'

nextflowPlugin {
nextflowVersion = '25.04.0'
Expand Down
125 changes: 92 additions & 33 deletions src/main/groovy/seqeralabs/plugin/NIMTaskHandler.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package seqeralabs.plugin
import groovy.json.JsonBuilder
import groovy.json.JsonSlurper
import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import nextflow.processor.TaskHandler
import nextflow.processor.TaskRun
import nextflow.processor.TaskStatus
Expand All @@ -34,6 +35,7 @@ import static java.nio.file.StandardOpenOption.*
* Task handler for NIM tasks
*/
@CompileStatic
@Slf4j
class NIMTaskHandler extends TaskHandler {

private final NIMExecutor executor
Expand Down Expand Up @@ -68,12 +70,12 @@ class NIMTaskHandler extends TaskHandler {
}

/**
* Log a message to stdout and .command.out
* Log an informational message to Nextflow log and .command.out/.command.log files
*/
private void logOut(String message) {
def taskId = task.name ?: task.hashLog?.take(8) ?: "unknown"
def prefixedMessage = "[NIM:${taskId}] ${message}"
println(prefixedMessage)
log.info("NIM task ${taskId}: ${message}")

if (outWriter) {
outWriter.println(message) // File logs don't need prefix
outWriter.flush()
Expand All @@ -85,12 +87,12 @@ class NIMTaskHandler extends TaskHandler {
}

/**
* Log a message to stderr and .command.err
* Log an error message to Nextflow log and .command.err/.command.log files
*/
private void logErr(String message) {
def taskId = task.name ?: task.hashLog?.take(8) ?: "unknown"
def prefixedMessage = "[NIM:${taskId}] ERROR: ${message}"
System.err.println(prefixedMessage)
log.error("NIM task ${taskId}: ${message}")

if (errWriter) {
errWriter.println(message) // File logs don't need prefix
errWriter.flush()
Expand All @@ -102,9 +104,12 @@ class NIMTaskHandler extends TaskHandler {
}

/**
* Log a message only to .command.log (for internal tracking)
* Log a debug message to Nextflow log and .command.log file (for internal tracking)
*/
private void logDebug(String message) {
def taskId = task.name ?: task.hashLog?.take(8) ?: "unknown"
log.debug("NIM task ${taskId}: ${message}")

if (logWriter) {
logWriter.println("DEBUG: ${message}")
logWriter.flush()
Expand Down Expand Up @@ -148,7 +153,7 @@ class NIMTaskHandler extends TaskHandler {
}

/**
* Get the main output filename (typically PDB file)
* Get the main output filename (PDB for structures, FASTA for sequences, etc.)
* @param serviceName The NIM service name
* @return The resolved output filename
*/
Expand All @@ -159,7 +164,15 @@ class NIMTaskHandler extends TaskHandler {
filename = getTaskExtValue('pdbFile', null)
}
if (!filename) {
filename = 'output.pdb' // Default
// Use service-specific default extensions
switch (serviceName) {
case 'proteinmpnn':
filename = 'output.fasta'
break
default:
filename = 'output.pdb' // Default for structure-generating services
break
}
}
return getOutputFilename('outputFile', filename as String, serviceName)
}
Expand Down Expand Up @@ -341,16 +354,16 @@ class NIMTaskHandler extends TaskHandler {
return
}

logOut("Using endpoint: ${endpoint}")
logOut("Executing NIM task: ${serviceName}")
logDebug("Using endpoint: ${endpoint}")

try {
// Build request body based on service type and task parameters
def requestData = buildRequestData(serviceName as String, pdbData)
def requestBody = new JsonBuilder(requestData).toString()
logOut("Request body first 500 chars: ${requestBody.take(500)}")
logDebug("Request body first 500 chars: ${requestBody.take(500)}")
logDebug("Full request body: ${requestBody}")

// Use Java HTTP client with proper SSL configuration
def httpClient = executor.httpClient
def request = HttpRequest.newBuilder()
Expand All @@ -362,53 +375,54 @@ class NIMTaskHandler extends TaskHandler {
.header("nvcf-poll-seconds", "300")
.POST(HttpRequest.BodyPublishers.ofString(requestBody))
.build()
logOut("Executing HTTP request to NIM API...")

logDebug("Executing HTTP request to NIM API...")
def response = httpClient.send(request, HttpResponse.BodyHandlers.ofString())

def statusCode = response.statusCode()
def responseBody = response.body()

logOut("HTTP status code: ${statusCode}")
logOut("Response body: ${responseBody}")
logDebug("Response body: ${responseBody}")

// Save results to work directory regardless of status for debugging
def resultFilename = getResultFilename(serviceName as String)
def resultFile = task.workDir.resolve(resultFilename)
resultFile.text = responseBody
logOut("Saved API response to: ${resultFilename}")
logDebug("Saved API response to: ${resultFilename}")

if (statusCode == 200 || statusCode == 202) {
logOut("NIM task completed successfully")

// Process the response and create output files
try {
processApiResponse(serviceName as String, responseBody)
logOut("Output files created successfully")
logDebug("Output files created successfully")
} catch (Exception e) {
logErr("Error processing API response: ${e.message}")
// Continue as success since API call worked, just output processing failed
}

// Set exit status first, then create files
exitStatus = 0
task.exitStatus = 0 // Critical: Set the task's exit status for TaskPollingMonitor

// Create expected Nextflow files for proper task completion
createNextflowFiles("NIM task completed successfully")

completed = true
} else if (statusCode == 422) {
logOut("NIM API validation error (422): ${responseBody}")
// For integration tests, we'll treat validation errors as "completed"
logOut("NIM API validation error (422) - treated as success for testing")
logDebug("Validation error response: ${responseBody}")
// For integration tests, we'll treat validation errors as "completed"
// since they indicate the API is working but data is invalid
exitStatus = 0 // Consider this success for testing purposes
task.exitStatus = 0 // Critical: Set the task's exit status for TaskPollingMonitor
createNextflowFiles("NIM API validation error (422) - treated as success for testing")
completed = true
} else {
logErr("NIM API request failed with status: ${statusCode}")
logErr("Response: ${responseBody}")
logDebug("Failed response body: ${responseBody}")
exitStatus = 1
task.exitStatus = 1 // Critical: Set the task's exit status for TaskPollingMonitor
createNextflowFiles("NIM API request failed with status: ${statusCode}")
Expand Down Expand Up @@ -475,7 +489,7 @@ class NIMTaskHandler extends TaskHandler {
def jsonSlurper = new JsonSlurper()
def responseData = jsonSlurper.parseText(responseBody) as Map

logOut("Processing API response for service: ${serviceName}")
logDebug("Processing API response for service: ${serviceName}")

switch (serviceName) {
case 'rfdiffusion':
Expand All @@ -485,6 +499,9 @@ class NIMTaskHandler extends TaskHandler {
case 'openfold':
processProteinFoldingResponse(responseData, serviceName)
break
case 'proteinmpnn':
processProteinMPNNResponse(responseData, serviceName)
break
default:
// For unknown services, try to extract common output formats
processGenericResponse(responseData, serviceName)
Expand All @@ -510,7 +527,7 @@ class NIMTaskHandler extends TaskHandler {
def outputFile = task.workDir.resolve(outputFilename)
outputFile.text = outputPdb
logOut("Created RFDiffusion output file: ${outputFilename}")
logOut("Output PDB size: ${outputPdb.length()} characters")
logDebug("Output PDB size: ${outputPdb.length()} characters")
} else if (responseData.containsKey('error')) {
logOut("RFDiffusion API returned error: ${responseData.error}")
// Still create an empty output file so the process doesn't fail
Expand Down Expand Up @@ -550,7 +567,7 @@ class NIMTaskHandler extends TaskHandler {
def outputFile = task.workDir.resolve(outputFilename)
outputFile.text = outputPdb
logOut("Created protein folding output file: ${outputFilename}")
logOut("Output PDB size: ${outputPdb.length()} characters")
logDebug("Output PDB size: ${outputPdb.length()} characters")
} else if (responseData.containsKey('error')) {
logOut("Protein folding API returned error: ${responseData.error}")
def outputFile = task.workDir.resolve(outputFilename)
Expand All @@ -562,7 +579,49 @@ class NIMTaskHandler extends TaskHandler {
outputFile.text = "# Protein folding response did not contain expected structure field\n"
}
}


/**
* Process ProteinMPNN API response
* @param responseData Parsed JSON response data
* @param serviceName The service name for dynamic filename resolution
*/
private void processProteinMPNNResponse(Map<String, Object> responseData, String serviceName) {
def outputFilename = getMainOutputFilename(serviceName)

if (responseData.containsKey('mfasta')) {
def mfastaContent = responseData.mfasta as String
def outputFile = task.workDir.resolve(outputFilename)
outputFile.text = mfastaContent
logOut("Created ProteinMPNN FASTA output file: ${outputFilename}")
logDebug("Output FASTA size: ${mfastaContent.length()} characters")

// Also save additional data fields if they exist
if (responseData.containsKey('scores')) {
def scoresFilename = getOutputFilename('scoresFile', "${task.name}_${serviceName}_scores.json", serviceName)
def scoresFile = task.workDir.resolve(scoresFilename)
scoresFile.text = new JsonBuilder(responseData.scores).toPrettyString()
logDebug("Created ProteinMPNN scores file: ${scoresFilename}")
}

if (responseData.containsKey('probs')) {
def probsFilename = getOutputFilename('probsFile', "${task.name}_${serviceName}_probs.json", serviceName)
def probsFile = task.workDir.resolve(probsFilename)
probsFile.text = new JsonBuilder(responseData.probs).toPrettyString()
logDebug("Created ProteinMPNN probabilities file: ${probsFilename}")
}

} else if (responseData.containsKey('error')) {
logOut("ProteinMPNN API returned error: ${responseData.error}")
def outputFile = task.workDir.resolve(outputFilename)
outputFile.text = "# ProteinMPNN API Error: ${responseData.error}\n"
} else {
logOut("Warning: ProteinMPNN response does not contain expected 'mfasta' field")
logOut("Available fields: ${responseData.keySet()}")
def outputFile = task.workDir.resolve(outputFilename)
outputFile.text = "# ProteinMPNN response did not contain mfasta field\n"
}
}

/**
* Process generic API response for unknown services
* @param responseData Parsed JSON response data
Expand Down Expand Up @@ -591,7 +650,7 @@ class NIMTaskHandler extends TaskHandler {
def outputFile = task.workDir.resolve(outputFilename)
outputFile.text = outputData as String
logOut("Created generic output file from field '${outputField}': ${outputFilename}")
logOut("Output size: ${(outputData as String).length()} characters")
logDebug("Output size: ${(outputData as String).length()} characters")
} else if (responseData.containsKey('error')) {
logOut("Generic API returned error: ${responseData.error}")
def outputFile = task.workDir.resolve(outputFilename)
Expand All @@ -603,7 +662,7 @@ class NIMTaskHandler extends TaskHandler {
def debugFilename = getOutputFilename('debugFile', 'output.json', serviceName)
def outputFile = task.workDir.resolve(debugFilename)
outputFile.text = new JsonBuilder(responseData).toPrettyString()
logOut("Created debug output file: ${debugFilename}")
logDebug("Created debug output file: ${debugFilename}")
}
}

Expand Down
Loading