Skip to content

Commit eecd816

Browse files
Add spot interruption tracking to trace records (#6606)
Track and report spot/preemptible instance interruptions for cloud batch executors. Changes: - Add `numSpotInterruptions` transient field to TraceRecord - AWS Batch: detect spot interruptions by checking status reason pattern "Host EC2*" - Google Batch: detect spot preemptions via exit code 50001 in status events - Tower plugin: send numSpotInterruptions to Seqera Platform telemetry This enables workflow optimization and cost analysis by tracking how often tasks are retried due to spot instance reclamation.
1 parent 3c3e9f5 commit eecd816

File tree

8 files changed

+358
-3
lines changed

8 files changed

+358
-3
lines changed

modules/nextflow/src/main/groovy/nextflow/trace/TraceRecord.groovy

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ class TraceRecord implements Serializable {
121121
transient private String executorName
122122
transient private CloudMachineInfo machineInfo
123123
transient private ContainerMeta containerMeta
124+
transient private Integer numSpotInterruptions
124125

125126
/**
126127
* Convert the given value to a string
@@ -611,6 +612,14 @@ class TraceRecord implements Serializable {
611612
this.machineInfo = value
612613
}
613614

615+
Integer getNumSpotInterruptions() {
616+
return numSpotInterruptions
617+
}
618+
619+
void setNumSpotInterruptions(Integer numSpotInterruptions) {
620+
this.numSpotInterruptions = numSpotInterruptions
621+
}
622+
614623
ContainerMeta getContainerMeta() {
615624
return containerMeta
616625
}

modules/nextflow/src/test/groovy/nextflow/trace/TraceRecordTest.groovy

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,4 +344,29 @@ class TraceRecordTest extends Specification {
344344
then:
345345
thrown(NoSuchFileException)
346346
}
347+
348+
def 'should manage numSpotInterruptions and not persist it across serialization'() {
349+
given:
350+
def rec = new TraceRecord()
351+
352+
expect:
353+
rec.getNumSpotInterruptions() == null
354+
and:
355+
rec.numSpotInterruptions == null
356+
357+
when:
358+
rec.setNumSpotInterruptions(3)
359+
360+
then:
361+
rec.getNumSpotInterruptions() == 3
362+
rec.numSpotInterruptions == 3
363+
364+
when:
365+
def buf = rec.serialize()
366+
def rec2 = TraceRecord.deserialize(buf)
367+
368+
then:
369+
rec2.getNumSpotInterruptions() == null
370+
}
371+
347372
}

plugins/nf-amazon/src/main/nextflow/cloud/aws/batch/AwsBatchTaskHandler.groovy

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,10 +917,48 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
917917
return machineInfo
918918
}
919919

920+
/**
921+
* Count the number of spot instance reclamations for this job by examining
922+
* the job attempts and checking for EC2 spot interruption status reasons
923+
*
924+
* @param jobId The AWS Batch Job Id
925+
* @return The number of times this job was retried due to spot instance reclamation
926+
*/
927+
protected Integer getNumSpotInterruptions(String jobId) {
928+
if (!jobId || !isCompleted())
929+
return null
930+
931+
try {
932+
def job = describeJob(jobId)
933+
if (!job)
934+
return null
935+
if (!job.attempts())
936+
return 0
937+
938+
int count = 0
939+
for (def attempt : job.attempts()) {
940+
// Check attempt-level statusReason
941+
def attemptReason = attempt.statusReason()
942+
// AWS Batch uses "Host EC2 (instance i-xxx) terminated." pattern for spot interruptions
943+
// Using startsWith to match the pattern regardless of instance ID
944+
if (attemptReason && attemptReason.startsWith('Host EC2')) {
945+
count++
946+
}
947+
}
948+
log.trace "Job $jobId had $count spot interruptions"
949+
return count
950+
}
951+
catch (Exception e) {
952+
log.debug "[AWS BATCH] Unable to count spot interruptions for job=$jobId - ${e.message}"
953+
return null
954+
}
955+
}
956+
920957
TraceRecord getTraceRecord() {
921958
def result = super.getTraceRecord()
922959
result.put('native_id', jobId)
923960
result.machineInfo = getMachineInfo()
961+
result.numSpotInterruptions = getNumSpotInterruptions(jobId)
924962
return result
925963
}
926964

plugins/nf-amazon/src/test/nextflow/cloud/aws/batch/AwsBatchTaskHandlerTest.groovy

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
package nextflow.cloud.aws.batch
1818

19-
import software.amazon.awssdk.services.batch.model.JobStatus
20-
2119
import java.nio.file.Path
2220
import java.time.Instant
2321

@@ -45,6 +43,7 @@ import nextflow.script.ProcessConfig
4543
import nextflow.util.CacheHelper
4644
import nextflow.util.MemoryUnit
4745
import software.amazon.awssdk.services.batch.BatchClient
46+
import software.amazon.awssdk.services.batch.model.AttemptDetail
4847
import software.amazon.awssdk.services.batch.model.ContainerDetail
4948
import software.amazon.awssdk.services.batch.model.DescribeJobDefinitionsRequest
5049
import software.amazon.awssdk.services.batch.model.DescribeJobDefinitionsResponse
@@ -54,6 +53,7 @@ import software.amazon.awssdk.services.batch.model.EvaluateOnExit
5453
import software.amazon.awssdk.services.batch.model.JobDefinition
5554
import software.amazon.awssdk.services.batch.model.JobDefinitionType
5655
import software.amazon.awssdk.services.batch.model.JobDetail
56+
import software.amazon.awssdk.services.batch.model.JobStatus
5757
import software.amazon.awssdk.services.batch.model.KeyValuePair
5858
import software.amazon.awssdk.services.batch.model.PlatformCapability
5959
import software.amazon.awssdk.services.batch.model.RegisterJobDefinitionResponse
@@ -908,7 +908,7 @@ class AwsBatchTaskHandlerTest extends Specification {
908908
when:
909909
def trace = handler.getTraceRecord()
910910
then:
911-
1 * handler.isCompleted() >> false
911+
2 * handler.isCompleted() >> false
912912
1 * handler.getMachineInfo() >> new CloudMachineInfo('x1.large', 'us-east-1b', PriceModel.spot)
913913

914914
and:
@@ -919,6 +919,48 @@ class AwsBatchTaskHandlerTest extends Specification {
919919
trace.machineInfo.priceModel == PriceModel.spot
920920
}
921921

922+
def 'should create the trace record when job is completed with spot interruptions' () {
923+
given:
924+
def exec = Mock(Executor) { getName() >> 'awsbatch' }
925+
def processor = Mock(TaskProcessor)
926+
processor.getExecutor() >> exec
927+
processor.getName() >> 'foo'
928+
processor.getConfig() >> new ProcessConfig(Mock(BaseScript))
929+
def task = Mock(TaskRun)
930+
task.getProcessor() >> processor
931+
task.getConfig() >> GroovyMock(TaskConfig)
932+
def proxy = Mock(AwsBatchProxy)
933+
def handler = Spy(AwsBatchTaskHandler)
934+
handler.@client = proxy
935+
handler.task = task
936+
handler.@jobId = 'xyz-123'
937+
handler.setStatus(TaskStatus.COMPLETED)
938+
939+
def attempt1 = GroovyMock(AttemptDetail)
940+
def attempt2 = GroovyMock(AttemptDetail)
941+
attempt1.statusReason() >> 'Host EC2 (instance i-123) terminated.'
942+
attempt1.container() >> null
943+
attempt2.statusReason() >> 'Essential container in task exited'
944+
attempt2.container() >> null
945+
def job = JobDetail.builder().attempts([attempt1, attempt2]).build()
946+
947+
// Stub BEFORE calling the method
948+
handler.isCompleted() >> true
949+
handler.getMachineInfo() >> new CloudMachineInfo('x1.large', 'us-east-1b', PriceModel.spot)
950+
handler.describeJob('xyz-123') >> job
951+
952+
when:
953+
def trace = handler.getTraceRecord()
954+
955+
then:
956+
trace.native_id == 'xyz-123'
957+
trace.executorName == 'awsbatch'
958+
trace.machineInfo.type == 'x1.large'
959+
trace.machineInfo.zone == 'us-east-1b'
960+
trace.machineInfo.priceModel == PriceModel.spot
961+
trace.numSpotInterruptions == 1
962+
}
963+
922964
def 'should render submit command' () {
923965
given:
924966
def executor = Spy(AwsBatchExecutor)
@@ -1243,4 +1285,56 @@ class AwsBatchTaskHandlerTest extends Specification {
12431285
handler.task.error.message == 'Unknown termination'
12441286

12451287
}
1288+
1289+
def 'should return zero spot interruptions when no attempts or non-spot terminations exist'() {
1290+
given:
1291+
def handler = Spy(AwsBatchTaskHandler)
1292+
def attempt1 = GroovyMock(AttemptDetail) {
1293+
statusReason() >> 'Essential container in task exited'
1294+
}
1295+
def attempt2 = GroovyMock(AttemptDetail) {
1296+
statusReason() >> 'Some other reason'
1297+
}
1298+
1299+
when:
1300+
def resultNoAttempts = handler.getNumSpotInterruptions('job-123')
1301+
then:
1302+
1 * handler.isCompleted() >> true
1303+
1 * handler.describeJob('job-123') >> JobDetail.builder().attempts([]).build()
1304+
resultNoAttempts == 0
1305+
1306+
when:
1307+
def resultNonSpot = handler.getNumSpotInterruptions('job-456')
1308+
then:
1309+
1 * handler.isCompleted() >> true
1310+
1 * handler.describeJob('job-456') >> JobDetail.builder().attempts([attempt1, attempt2]).build()
1311+
resultNonSpot == 0
1312+
}
1313+
1314+
def 'should return null when job cannot be processed'() {
1315+
given:
1316+
def handler = Spy(AwsBatchTaskHandler)
1317+
1318+
when:
1319+
def resultNotCompleted = handler.getNumSpotInterruptions('job-123')
1320+
then:
1321+
1 * handler.isCompleted() >> false
1322+
0 * handler.describeJob(_)
1323+
resultNotCompleted == null
1324+
1325+
when:
1326+
def resultNullJobId = handler.getNumSpotInterruptions(null)
1327+
then:
1328+
0 * handler.isCompleted()
1329+
0 * handler.describeJob(_)
1330+
resultNullJobId == null
1331+
1332+
when:
1333+
def resultException = handler.getNumSpotInterruptions('job-789')
1334+
then:
1335+
1 * handler.isCompleted() >> true
1336+
1 * handler.describeJob('job-789') >> { throw new RuntimeException("Error") }
1337+
resultException == null
1338+
}
1339+
12461340
}

plugins/nf-google/src/main/nextflow/cloud/google/batch/GoogleBatchTaskHandler.groovy

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,12 +674,52 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask {
674674
return machineInfo
675675
}
676676

677+
/**
678+
* Count the number of spot instance reclamations for this task by examining
679+
* the task status events and checking for preemption exit codes
680+
*
681+
* @param jobId The Google Batch Job Id
682+
* @return The number of times this task was retried due to spot instance reclamation
683+
*/
684+
685+
protected Integer getNumSpotInterruptions(String jobId) {
686+
if (!jobId || !taskId || !isCompleted()) {
687+
return null
688+
}
689+
690+
try {
691+
final status = client.getTaskStatus(jobId, taskId)
692+
693+
if (!status)
694+
return null
695+
696+
// valid status but no events present means no interruptions occurred
697+
if (!status?.statusEventsList)
698+
return 0
699+
700+
int count = 0
701+
for (def event : status.statusEventsList) {
702+
// Google Batch uses exit code 50001 for spot preemption
703+
// Check if the event has a task execution with exit code 50001
704+
if (event.hasTaskExecution() && event.taskExecution.exitCode == 50001) {
705+
count++
706+
}
707+
}
708+
return count
709+
710+
} catch (Exception e) {
711+
log.debug "[GOOGLE BATCH] Unable to count spot interruptions for job=$jobId task=$taskId - ${e.message}"
712+
return null
713+
}
714+
}
715+
677716
@Override
678717
TraceRecord getTraceRecord() {
679718
def result = super.getTraceRecord()
680719
if( jobId && uid )
681720
result.put('native_id', "$jobId/$taskId/$uid")
682721
result.machineInfo = getMachineInfo()
722+
result.numSpotInterruptions = getNumSpotInterruptions(jobId)
683723
return result
684724
}
685725

0 commit comments

Comments
 (0)