Skip to content

Commit

Permalink
fix: Create a new standardMachineType runtime attribute instead of re…
Browse files Browse the repository at this point in the history
…using cpuPlatform
  • Loading branch information
javiergaitan committed Sep 17, 2024
1 parent 8b20d18 commit 4ba50c3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,12 @@ class GcpBatchRequestFactoryImpl()(implicit gcsTransferConfiguration: GcsTransfe
val machineType = GcpBatchMachineConstraints.machineType(runtimeAttributes.memory,
runtimeAttributes.cpu,
cpuPlatformOption = runtimeAttributes.cpuPlatform,
standardMachineTypeOption = runtimeAttributes.standardMachineType,
googleLegacyMachineSelection = false,
jobLogger = jobLogger
)
// Don't set cpuPlatform if the field was used to specify standard machine type when submitting the GCP Batch request
val cpuPlatformBatchRequest = if (GcpBatchMachineConstraints.isStandardMachineType(cpuPlatform)) "" else cpuPlatform
val instancePolicy =
createInstancePolicy(cpuPlatform = cpuPlatformBatchRequest, spotModel, accelerators, allDisks, machineType = machineType)
createInstancePolicy(cpuPlatform = cpuPlatform, spotModel, accelerators, allDisks, machineType = machineType)
val locationPolicy = LocationPolicy.newBuilder.addAllowedLocations(zones).build
val allocationPolicy =
createAllocationPolicy(data, locationPolicy, instancePolicy.build, networkPolicy, gcpSa, accelerators)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ final case class GcpBatchRuntimeAttributes(cpu: Int Refined Positive,
continueOnReturnCode: ContinueOnReturnCode,
noAddress: Boolean,
useDockerImageCache: Option[Boolean],
checkpointFilename: Option[String]
checkpointFilename: Option[String],
standardMachineType: Option[String]
)

object GcpBatchRuntimeAttributes {
Expand Down Expand Up @@ -85,6 +86,8 @@ object GcpBatchRuntimeAttributes {
UseDockerImageCacheKey
).optional

val StandardMachineTypeKey = "StandardMachineType"

val CheckpointFileKey = "checkpointFile"
private val checkpointFileValidationInstance = new StringRuntimeAttributesValidation(CheckpointFileKey).optional

Expand All @@ -98,6 +101,8 @@ object GcpBatchRuntimeAttributes {
)
private def cpuPlatformValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[String] =
cpuPlatformValidationInstance
private def standardMachineTypeValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[String] =
new StringRuntimeAttributesValidation(StandardMachineTypeKey).optional
private def gpuTypeValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[GpuType] =
GpuTypeValidation.optional

Expand Down Expand Up @@ -171,7 +176,8 @@ object GcpBatchRuntimeAttributes {
bootDiskSizeValidation(runtimeConfig),
useDockerImageCacheValidation(runtimeConfig),
checkpointFileValidationInstance,
dockerValidation
dockerValidation,
standardMachineTypeValidation(runtimeConfig)
)
}

Expand Down Expand Up @@ -228,6 +234,10 @@ object GcpBatchRuntimeAttributes {
useDockerImageCacheValidation(runtimeAttrsConfig).key,
validatedRuntimeAttributes
)
val standardMachineType: Option[String] = RuntimeAttributesValidation.extractOption(
standardMachineTypeValidation(runtimeAttrsConfig).key,
validatedRuntimeAttributes
)

new GcpBatchRuntimeAttributes(
cpu = cpu,
Expand All @@ -243,7 +253,8 @@ object GcpBatchRuntimeAttributes {
continueOnReturnCode = continueOnReturnCode,
noAddress = noAddress,
useDockerImageCache = useDockerImageCache,
checkpointFilename = checkpointFileName
checkpointFilename = checkpointFileName,
standardMachineType = standardMachineType
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,19 @@ import cromwell.backend.google.batch.models.{
import cromwell.core.logging.JobLogger
import eu.timepit.refined.api.Refined
import eu.timepit.refined.numeric.Positive
import scala.util.matching.Regex
import wdl4s.parser.MemoryUnit
import wom.format.MemorySize

object GcpBatchMachineConstraints {
private val machineTypePattern: Regex = """^\w{2}\w?-\w+-\w+$""".r

def machineType(memory: MemorySize,
cpu: Int Refined Positive,
cpuPlatformOption: Option[String],
standardMachineTypeOption: Option[String],
googleLegacyMachineSelection: Boolean,
jobLogger: JobLogger
): String =
if (isStandardMachineType(cpuPlatformOption.getOrElse(""))) {
StandardMachineType(cpuPlatformOption.getOrElse("")).machineType
if (standardMachineTypeOption.exists(_.trim.nonEmpty)) {
StandardMachineType(standardMachineTypeOption.get).machineType
} else if (googleLegacyMachineSelection) {
s"predefined-$cpu-${memory.to(MemoryUnit.MB).amount.intValue()}"
} else {
Expand All @@ -39,6 +37,4 @@ object GcpBatchMachineConstraints {
}
customMachineType.machineType(memory, cpu, jobLogger)
}

def isStandardMachineType(machineType: String): Boolean = machineTypePattern.matches(machineType)
}

0 comments on commit 4ba50c3

Please sign in to comment.