diff --git a/modules/nextflow/src/main/groovy/nextflow/trace/TraceRecord.groovy b/modules/nextflow/src/main/groovy/nextflow/trace/TraceRecord.groovy index 06e37b541c..361eb53a57 100644 --- a/modules/nextflow/src/main/groovy/nextflow/trace/TraceRecord.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/trace/TraceRecord.groovy @@ -121,6 +121,7 @@ class TraceRecord implements Serializable { transient private String executorName transient private CloudMachineInfo machineInfo transient private ContainerMeta containerMeta + transient private Integer numSpotInterruptions /** * Convert the given value to a string @@ -611,6 +612,14 @@ class TraceRecord implements Serializable { this.machineInfo = value } + Integer getNumSpotInterruptions() { + return numSpotInterruptions + } + + void setNumSpotInterruptions(Integer numSpotInterruptions) { + this.numSpotInterruptions = numSpotInterruptions + } + ContainerMeta getContainerMeta() { return containerMeta } diff --git a/modules/nextflow/src/test/groovy/nextflow/trace/TraceRecordTest.groovy b/modules/nextflow/src/test/groovy/nextflow/trace/TraceRecordTest.groovy index 528d93afd2..f827747a42 100644 --- a/modules/nextflow/src/test/groovy/nextflow/trace/TraceRecordTest.groovy +++ b/modules/nextflow/src/test/groovy/nextflow/trace/TraceRecordTest.groovy @@ -344,4 +344,29 @@ class TraceRecordTest extends Specification { then: thrown(NoSuchFileException) } + + def 'should manage numSpotInterruptions and not persist it across serialization'() { + given: + def rec = new TraceRecord() + + expect: + rec.getNumSpotInterruptions() == null + and: + rec.numSpotInterruptions == null + + when: + rec.setNumSpotInterruptions(3) + + then: + rec.getNumSpotInterruptions() == 3 + rec.numSpotInterruptions == 3 + + when: + def buf = rec.serialize() + def rec2 = TraceRecord.deserialize(buf) + + then: + rec2.getNumSpotInterruptions() == null + } + } diff --git a/plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsBatchTaskHandler.groovy b/plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsBatchTaskHandler.groovy index 6cada9f836..24a1903762 100644 --- a/plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsBatchTaskHandler.groovy +++ b/plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsBatchTaskHandler.groovy @@ -917,10 +917,48 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler> false + 2 * handler.isCompleted() >> false 1 * handler.getMachineInfo() >> new CloudMachineInfo('x1.large', 'us-east-1b', PriceModel.spot) and: @@ -919,6 +919,48 @@ class AwsBatchTaskHandlerTest extends Specification { trace.machineInfo.priceModel == PriceModel.spot } + def 'should create the trace record when job is completed with spot interruptions' () { + given: + def exec = Mock(Executor) { getName() >> 'awsbatch' } + def processor = Mock(TaskProcessor) + processor.getExecutor() >> exec + processor.getName() >> 'foo' + processor.getConfig() >> new ProcessConfig(Mock(BaseScript)) + def task = Mock(TaskRun) + task.getProcessor() >> processor + task.getConfig() >> GroovyMock(TaskConfig) + def proxy = Mock(AwsBatchProxy) + def handler = Spy(AwsBatchTaskHandler) + handler.@client = proxy + handler.task = task + handler.@jobId = 'xyz-123' + handler.setStatus(TaskStatus.COMPLETED) + + def attempt1 = GroovyMock(AttemptDetail) + def attempt2 = GroovyMock(AttemptDetail) + attempt1.statusReason() >> 'Host EC2 (instance i-123) terminated.' + attempt1.container() >> null + attempt2.statusReason() >> 'Essential container in task exited' + attempt2.container() >> null + def job = JobDetail.builder().attempts([attempt1, attempt2]).build() + + // Stub BEFORE calling the method + handler.isCompleted() >> true + handler.getMachineInfo() >> new CloudMachineInfo('x1.large', 'us-east-1b', PriceModel.spot) + handler.describeJob('xyz-123') >> job + + when: + def trace = handler.getTraceRecord() + + then: + trace.native_id == 'xyz-123' + trace.executorName == 'awsbatch' + trace.machineInfo.type == 'x1.large' + trace.machineInfo.zone == 'us-east-1b' + trace.machineInfo.priceModel == PriceModel.spot + trace.numSpotInterruptions == 1 + } + def 'should render submit command' () { given: def executor = Spy(AwsBatchExecutor) @@ -1243,4 +1285,56 @@ class AwsBatchTaskHandlerTest extends Specification { handler.task.error.message == 'Unknown termination' } + + def 'should return zero spot interruptions when no attempts or non-spot terminations exist'() { + given: + def handler = Spy(AwsBatchTaskHandler) + def attempt1 = GroovyMock(AttemptDetail) { + statusReason() >> 'Essential container in task exited' + } + def attempt2 = GroovyMock(AttemptDetail) { + statusReason() >> 'Some other reason' + } + + when: + def resultNoAttempts = handler.getNumSpotInterruptions('job-123') + then: + 1 * handler.isCompleted() >> true + 1 * handler.describeJob('job-123') >> JobDetail.builder().attempts([]).build() + resultNoAttempts == 0 + + when: + def resultNonSpot = handler.getNumSpotInterruptions('job-456') + then: + 1 * handler.isCompleted() >> true + 1 * handler.describeJob('job-456') >> JobDetail.builder().attempts([attempt1, attempt2]).build() + resultNonSpot == 0 + } + + def 'should return null when job cannot be processed'() { + given: + def handler = Spy(AwsBatchTaskHandler) + + when: + def resultNotCompleted = handler.getNumSpotInterruptions('job-123') + then: + 1 * handler.isCompleted() >> false + 0 * handler.describeJob(_) + resultNotCompleted == null + + when: + def resultNullJobId = handler.getNumSpotInterruptions(null) + then: + 0 * handler.isCompleted() + 0 * handler.describeJob(_) + resultNullJobId == null + + when: + def resultException = handler.getNumSpotInterruptions('job-789') + then: + 1 * handler.isCompleted() >> true + 1 * handler.describeJob('job-789') >> { throw new RuntimeException("Error") } + resultException == null + } + } diff --git a/plugins/nf-google/src/main/nextflow/cloud/google/batch/GoogleBatchTaskHandler.groovy b/plugins/nf-google/src/main/nextflow/cloud/google/batch/GoogleBatchTaskHandler.groovy index d13f6583c8..30ec45afd5 100644 --- a/plugins/nf-google/src/main/nextflow/cloud/google/batch/GoogleBatchTaskHandler.groovy +++ b/plugins/nf-google/src/main/nextflow/cloud/google/batch/GoogleBatchTaskHandler.groovy @@ -674,12 +674,52 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask { return machineInfo } + /** + * Count the number of spot instance reclamations for this task by examining + * the task status events and checking for preemption exit codes + * + * @param jobId The Google Batch Job Id + * @return The number of times this task was retried due to spot instance reclamation + */ + + protected Integer getNumSpotInterruptions(String jobId) { + if (!jobId || !taskId || !isCompleted()) { + return null + } + + try { + final status = client.getTaskStatus(jobId, taskId) + + if (!status) + return null + + // valid status but no events present means no interruptions occurred + if (!status?.statusEventsList) + return 0 + + int count = 0 + for (def event : status.statusEventsList) { + // Google Batch uses exit code 50001 for spot preemption + // Check if the event has a task execution with exit code 50001 + if (event.hasTaskExecution() && event.taskExecution.exitCode == 50001) { + count++ + } + } + return count + + } catch (Exception e) { + log.debug "[GOOGLE BATCH] Unable to count spot interruptions for job=$jobId task=$taskId - ${e.message}" + return null + } + } + @Override TraceRecord getTraceRecord() { def result = super.getTraceRecord() if( jobId && uid ) result.put('native_id', "$jobId/$taskId/$uid") result.machineInfo = getMachineInfo() + result.numSpotInterruptions = getNumSpotInterruptions(jobId) return result } diff --git a/plugins/nf-google/src/test/nextflow/cloud/google/batch/GoogleBatchTaskHandlerTest.groovy b/plugins/nf-google/src/test/nextflow/cloud/google/batch/GoogleBatchTaskHandlerTest.groovy index 44b061b73d..cf16b497ba 100644 --- a/plugins/nf-google/src/test/nextflow/cloud/google/batch/GoogleBatchTaskHandlerTest.groovy +++ b/plugins/nf-google/src/test/nextflow/cloud/google/batch/GoogleBatchTaskHandlerTest.groovy @@ -388,6 +388,49 @@ class GoogleBatchTaskHandlerTest extends Specification { trace.executorName == 'google-batch' } + def 'should create the trace record when job is completed with spot interruptions' () { + given: + def exec = Mock(Executor) { getName() >> 'google-batch' } + def processor = Mock(TaskProcessor) { + getExecutor() >> exec + getName() >> 'foo' + getConfig() >> new ProcessConfig(Mock(BaseScript)) + } + and: + def task = Mock(TaskRun) + task.getProcessor() >> processor + task.getConfig() >> new TaskConfig() + and: + def client = Mock(BatchClient) + def handler = Spy(GoogleBatchTaskHandler) + handler.task = task + handler.@client = client + handler.@jobId = 'xyz-123' + handler.@taskId = '0' + handler.@uid = '789' + + def event1 = StatusEvent.newBuilder() + .setTaskExecution(TaskExecution.newBuilder().setExitCode(50001).build()) + .build() + def event2 = StatusEvent.newBuilder() + .setTaskExecution(TaskExecution.newBuilder().setExitCode(0).build()) + .build() + def taskStatus = TaskStatus.newBuilder() + .addStatusEvents(event1) + .addStatusEvents(event2) + .build() + + when: + def trace = handler.getTraceRecord() + then: + 2 * handler.isCompleted() >> true + 1 * client.getTaskStatus('xyz-123', '0') >> taskStatus + and: + trace.native_id == 'xyz-123/0/789' + trace.executorName == 'google-batch' + trace.numSpotInterruptions == 1 + } + def 'should create submit request with fusion enabled' () { given: def WORK_DIR = CloudStorageFileSystem.forBucket('foo').getPath('/scratch') @@ -863,4 +906,83 @@ class GoogleBatchTaskHandlerTest extends Specification { 2 | true | true | 2 } + def 'should return zero when no status events exist'() { + given: + def handler = Spy(GoogleBatchTaskHandler) + handler.@taskId = 'task-123' + handler.@client = Mock(BatchClient) { + getTaskStatus('job-123', 'task-123') >> TaskStatus.newBuilder().build() + } + + when: + def result = handler.getNumSpotInterruptions('job-123') + + then: + handler.isCompleted() >> true + result == 0 + } + + def 'should count spot interruptions correctly from status events'() { + given: + def handler = Spy(GoogleBatchTaskHandler) + handler.@taskId = 'task-123' + handler.@client = Mock(BatchClient) { + getTaskStatus('job-123', 'task-123') >> TaskStatus.newBuilder() + .addStatusEvents( + StatusEvent.newBuilder() + .setTaskExecution( + TaskExecution.newBuilder().setExitCode(0).build() + ).build()) + .addStatusEvents( + StatusEvent.newBuilder() + .setTaskExecution( + TaskExecution.newBuilder().setExitCode(50001).build() + ).build()) + .addStatusEvents( + StatusEvent.newBuilder() + .setTaskExecution( + TaskExecution.newBuilder().setExitCode(50001).build() + ).build() + ).build() + } + + when: + def result = handler.getNumSpotInterruptions('job-123') + + then: + handler.isCompleted() >> true + result == 2 + } + + def 'should return null when jobId is null or task is incomplete'() { + given: + def handler = Spy(GoogleBatchTaskHandler) + handler.@taskId = 'task-123' + + when: + def resultNullJobId = handler.getNumSpotInterruptions(null) + def resultIncompleteTask = handler.getNumSpotInterruptions('job-123') + + then: + handler.isCompleted() >> false + resultNullJobId == null + resultIncompleteTask == null + } + + def 'should return null when an exception occurs while fetching task status'() { + given: + def handler = Spy(GoogleBatchTaskHandler) + handler.@taskId = 'task-123' + handler.@client = Mock(BatchClient) { + getTaskStatus('job-123', 'task-123') >> { throw new RuntimeException("Error") } + } + + when: + def result = handler.getNumSpotInterruptions('job-123') + + then: + handler.isCompleted() >> true + result == null + } + } diff --git a/plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerClient.groovy b/plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerClient.groovy index d070b758c8..4ffcc150d5 100644 --- a/plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerClient.groovy +++ b/plugins/nf-tower/src/main/io/seqera/tower/plugin/TowerClient.groovy @@ -660,6 +660,7 @@ class TowerClient implements TraceObserverV2 { record.cloudZone = trace.getMachineInfo()?.zone record.machineType = trace.getMachineInfo()?.type record.priceModel = trace.getMachineInfo()?.priceModel?.toString() + record.numSpotInterruptions = trace.getNumSpotInterruptions() return record } diff --git a/plugins/nf-tower/src/test/io/seqera/tower/plugin/TowerClientTest.groovy b/plugins/nf-tower/src/test/io/seqera/tower/plugin/TowerClientTest.groovy index 8b69f963dc..c133e3d897 100644 --- a/plugins/nf-tower/src/test/io/seqera/tower/plugin/TowerClientTest.groovy +++ b/plugins/nf-tower/src/test/io/seqera/tower/plugin/TowerClientTest.groovy @@ -534,4 +534,30 @@ class TowerClientTest extends Specification { request.method() == 'POST' request.uri().toString() == 'http://example.com/test' } + + def 'should include numSpotInterruptions in task map'() { + given: + def client = Spy(new TowerClient()) + client.getWorkflowProgress(true) >> new WorkflowProgress() + + def now = System.currentTimeMillis() + def trace = new TraceRecord([ + taskId: 42, + process: 'foo', + workdir: "/work/dir", + cpus: 1, + submit: now-2000, + start: now-1000, + complete: now + ]) + trace.setNumSpotInterruptions(3) + + when: + def req = client.makeTasksReq([trace]) + + then: + req.tasks.size() == 1 + req.tasks[0].numSpotInterruptions == 3 + } + }