Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a91de85
Added number of reclamations in TraceRecord
munishchouhan Nov 21, 2025
f08d9e2
Merge branch 'master' into add-num-reclamations-trace
munishchouhan Nov 24, 2025
de2c09a
Merge branch 'master' into add-num-reclamations-trace
munishchouhan Nov 25, 2025
42c49eb
Merge branch 'master' into add-num-reclamations-trace
munishchouhan Nov 26, 2025
6e54f65
Merge branch 'master' into add-num-reclamations-trace
munishchouhan Dec 3, 2025
28e8081
added countSpotReclamations in nf-amazon
munishchouhan Dec 9, 2025
29175b3
added tests for countSpotReclamations
munishchouhan Dec 9, 2025
b53ecc2
reverted unwanted changes
munishchouhan Dec 9, 2025
0dc5cac
reverted unwanted changes
munishchouhan Dec 9, 2025
ff0426a
added countSpotReclamations in nf-google
munishchouhan Dec 9, 2025
d6e8eaa
Merge branch 'master' into add-num-reclamations-trace
munishchouhan Dec 10, 2025
8d23dd2
Merge branch 'master' into add-num-reclamations-trace
munishchouhan Dec 11, 2025
51f61a5
changes as per review
munishchouhan Dec 12, 2025
11237ed
changes as per review
munishchouhan Dec 12, 2025
4a1df3a
formating [ci skip]
munishchouhan Dec 12, 2025
b9fc623
removed unused code
munishchouhan Dec 12, 2025
9bc3cf1
Merge branch 'master' into add-num-reclamations-trace
munishchouhan Dec 12, 2025
8fe1e6c
chnaged GroovyMock(TaskConfig()) to new TaskConfig()
munishchouhan Dec 12, 2025
2e704af
Merge branch 'master' into add-num-reclamations-trace
munishchouhan Dec 12, 2025
b411d65
chnaged GroovyMock(TaskConfig()) to new TaskConfig()
munishchouhan Dec 12, 2025
927cbce
added getNumSpotInterruptions and changed name to numSpotInterruptions
munishchouhan Dec 15, 2025
3bdb604
fixed tests
munishchouhan Dec 15, 2025
6ffab69
fixed tests
munishchouhan Dec 15, 2025
9a58bd6
Merge branch 'master' into add-num-reclamations-trace
munishchouhan Dec 15, 2025
d8d0d8e
added if guard
munishchouhan Dec 15, 2025
6c6f153
updated if statements format
munishchouhan Dec 15, 2025
dcb8465
updated if statements format [ci skip]
munishchouhan Dec 15, 2025
a1d7506
Revert "updated if statements format [ci skip]"
munishchouhan Dec 15, 2025
f399c63
Revert "updated if statements format" [ci skip]
munishchouhan Dec 15, 2025
3352120
reverted unwanted changes
munishchouhan Dec 15, 2025
48fce52
reverted unwanted formatting [ci skip]
munishchouhan Dec 15, 2025
e0a2508
reverted unwanted formatting [ci skip]
munishchouhan Dec 15, 2025
e5c49f7
reverted unwanted formatting [ci skip]
munishchouhan Dec 15, 2025
71b155a
refactored
munishchouhan Dec 15, 2025
fd29de3
Merge branch 'master' into add-num-reclamations-trace
munishchouhan Dec 15, 2025
68aedb5
Merge branch 'master' into add-num-reclamations-trace
munishchouhan Dec 16, 2025
90fb949
Refactor getNumSpotInterruptions methods [ci fast]
pditommaso Dec 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class TraceRecord implements Serializable {
transient private String executorName
transient private CloudMachineInfo machineInfo
transient private ContainerMeta containerMeta
transient private Integer numSpotInterruptions

/**
* Convert the given value to a string
Expand Down Expand Up @@ -611,6 +612,14 @@ class TraceRecord implements Serializable {
this.machineInfo = value
}

Integer getNumSpotInterruptions() {
return numSpotInterruptions
}

void setNumSpotInterruptions(Integer numSpotInterruptions) {
this.numSpotInterruptions = numSpotInterruptions
}

ContainerMeta getContainerMeta() {
return containerMeta
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,4 +344,29 @@ class TraceRecordTest extends Specification {
then:
thrown(NoSuchFileException)
}

def 'should manage numSpotInterruptions and not persist it across serialization'() {
given:
def rec = new TraceRecord()

expect:
rec.getNumSpotInterruptions() == null
and:
rec.numSpotInterruptions == null

when:
rec.setNumSpotInterruptions(3)

then:
rec.getNumSpotInterruptions() == 3
rec.numSpotInterruptions == 3

when:
def buf = rec.serialize()
def rec2 = TraceRecord.deserialize(buf)

then:
rec2.getNumSpotInterruptions() == null
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -917,10 +917,48 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
return machineInfo
}

/**
* Count the number of spot instance reclamations for this job by examining
* the job attempts and checking for EC2 spot interruption status reasons
*
* @param jobId The AWS Batch Job Id
* @return The number of times this job was retried due to spot instance reclamation
*/
protected Integer getNumSpotInterruptions(String jobId) {
if (!jobId || !isCompleted())
return null

try {
def job = describeJob(jobId)
if (!job)
return null
if (!job.attempts())
return 0

int count = 0
for (def attempt : job.attempts()) {
// Check attempt-level statusReason
def attemptReason = attempt.statusReason()
// AWS Batch uses "Host EC2 (instance i-xxx) terminated." pattern for spot interruptions
// Using startsWith to match the pattern regardless of instance ID
if (attemptReason && attemptReason.startsWith('Host EC2')) {
count++
}
}
log.trace "Job $jobId had $count spot interruptions"
return count
}
catch (Exception e) {
log.debug "[AWS BATCH] Unable to count spot interruptions for job=$jobId - ${e.message}"
return null
}
}

TraceRecord getTraceRecord() {
def result = super.getTraceRecord()
result.put('native_id', jobId)
result.machineInfo = getMachineInfo()
result.numSpotInterruptions = getNumSpotInterruptions(jobId)
return result
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package nextflow.cloud.aws.batch

import software.amazon.awssdk.services.batch.model.JobStatus

import java.nio.file.Path
import java.time.Instant

Expand Down Expand Up @@ -45,6 +43,7 @@ import nextflow.script.ProcessConfig
import nextflow.util.CacheHelper
import nextflow.util.MemoryUnit
import software.amazon.awssdk.services.batch.BatchClient
import software.amazon.awssdk.services.batch.model.AttemptDetail
import software.amazon.awssdk.services.batch.model.ContainerDetail
import software.amazon.awssdk.services.batch.model.DescribeJobDefinitionsRequest
import software.amazon.awssdk.services.batch.model.DescribeJobDefinitionsResponse
Expand All @@ -54,6 +53,7 @@ import software.amazon.awssdk.services.batch.model.EvaluateOnExit
import software.amazon.awssdk.services.batch.model.JobDefinition
import software.amazon.awssdk.services.batch.model.JobDefinitionType
import software.amazon.awssdk.services.batch.model.JobDetail
import software.amazon.awssdk.services.batch.model.JobStatus
import software.amazon.awssdk.services.batch.model.KeyValuePair
import software.amazon.awssdk.services.batch.model.PlatformCapability
import software.amazon.awssdk.services.batch.model.RegisterJobDefinitionResponse
Expand Down Expand Up @@ -908,7 +908,7 @@ class AwsBatchTaskHandlerTest extends Specification {
when:
def trace = handler.getTraceRecord()
then:
1 * handler.isCompleted() >> false
2 * handler.isCompleted() >> false
1 * handler.getMachineInfo() >> new CloudMachineInfo('x1.large', 'us-east-1b', PriceModel.spot)

and:
Expand All @@ -919,6 +919,48 @@ class AwsBatchTaskHandlerTest extends Specification {
trace.machineInfo.priceModel == PriceModel.spot
}

def 'should create the trace record when job is completed with spot interruptions' () {
given:
def exec = Mock(Executor) { getName() >> 'awsbatch' }
def processor = Mock(TaskProcessor)
processor.getExecutor() >> exec
processor.getName() >> 'foo'
processor.getConfig() >> new ProcessConfig(Mock(BaseScript))
def task = Mock(TaskRun)
task.getProcessor() >> processor
task.getConfig() >> GroovyMock(TaskConfig)
def proxy = Mock(AwsBatchProxy)
def handler = Spy(AwsBatchTaskHandler)
handler.@client = proxy
handler.task = task
handler.@jobId = 'xyz-123'
handler.setStatus(TaskStatus.COMPLETED)

def attempt1 = GroovyMock(AttemptDetail)
def attempt2 = GroovyMock(AttemptDetail)
attempt1.statusReason() >> 'Host EC2 (instance i-123) terminated.'
attempt1.container() >> null
attempt2.statusReason() >> 'Essential container in task exited'
attempt2.container() >> null
def job = JobDetail.builder().attempts([attempt1, attempt2]).build()

// Stub BEFORE calling the method
handler.isCompleted() >> true
handler.getMachineInfo() >> new CloudMachineInfo('x1.large', 'us-east-1b', PriceModel.spot)
handler.describeJob('xyz-123') >> job

when:
def trace = handler.getTraceRecord()

then:
trace.native_id == 'xyz-123'
trace.executorName == 'awsbatch'
trace.machineInfo.type == 'x1.large'
trace.machineInfo.zone == 'us-east-1b'
trace.machineInfo.priceModel == PriceModel.spot
trace.numSpotInterruptions == 1
}

def 'should render submit command' () {
given:
def executor = Spy(AwsBatchExecutor)
Expand Down Expand Up @@ -1243,4 +1285,56 @@ class AwsBatchTaskHandlerTest extends Specification {
handler.task.error.message == 'Unknown termination'

}

def 'should return zero spot interruptions when no attempts or non-spot terminations exist'() {
given:
def handler = Spy(AwsBatchTaskHandler)
def attempt1 = GroovyMock(AttemptDetail) {
statusReason() >> 'Essential container in task exited'
}
def attempt2 = GroovyMock(AttemptDetail) {
statusReason() >> 'Some other reason'
}

when:
def resultNoAttempts = handler.getNumSpotInterruptions('job-123')
then:
1 * handler.isCompleted() >> true
1 * handler.describeJob('job-123') >> JobDetail.builder().attempts([]).build()
resultNoAttempts == 0

when:
def resultNonSpot = handler.getNumSpotInterruptions('job-456')
then:
1 * handler.isCompleted() >> true
1 * handler.describeJob('job-456') >> JobDetail.builder().attempts([attempt1, attempt2]).build()
resultNonSpot == 0
}

def 'should return null when job cannot be processed'() {
given:
def handler = Spy(AwsBatchTaskHandler)

when:
def resultNotCompleted = handler.getNumSpotInterruptions('job-123')
then:
1 * handler.isCompleted() >> false
0 * handler.describeJob(_)
resultNotCompleted == null

when:
def resultNullJobId = handler.getNumSpotInterruptions(null)
then:
0 * handler.isCompleted()
0 * handler.describeJob(_)
resultNullJobId == null

when:
def resultException = handler.getNumSpotInterruptions('job-789')
then:
1 * handler.isCompleted() >> true
1 * handler.describeJob('job-789') >> { throw new RuntimeException("Error") }
resultException == null
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -674,12 +674,52 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask {
return machineInfo
}

/**
* Count the number of spot instance reclamations for this task by examining
* the task status events and checking for preemption exit codes
*
* @param jobId The Google Batch Job Id
* @return The number of times this task was retried due to spot instance reclamation
*/

protected Integer getNumSpotInterruptions(String jobId) {
if (!jobId || !taskId || !isCompleted()) {
return null
}

try {
final status = client.getTaskStatus(jobId, taskId)

if (!status)
return null

// valid status but no events present means no interruptions occurred
if (!status?.statusEventsList)
return 0

int count = 0
for (def event : status.statusEventsList) {
// Google Batch uses exit code 50001 for spot preemption
// Check if the event has a task execution with exit code 50001
if (event.hasTaskExecution() && event.taskExecution.exitCode == 50001) {
count++
}
}
return count

} catch (Exception e) {
log.debug "[GOOGLE BATCH] Unable to count spot interruptions for job=$jobId task=$taskId - ${e.message}"
return null
}
}

@Override
TraceRecord getTraceRecord() {
def result = super.getTraceRecord()
if( jobId && uid )
result.put('native_id', "$jobId/$taskId/$uid")
result.machineInfo = getMachineInfo()
result.numSpotInterruptions = getNumSpotInterruptions(jobId)
return result
}

Expand Down
Loading
Loading