Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ class GoogleBatchScriptLauncher extends BashWrapperBuilder implements GoogleBatc

GoogleBatchScriptLauncher withConfig(GoogleOpts config) {
this.config = config
// Add logs bucket to mounted volumes if configured
if( config?.batch?.logsBucket ) {
final logsBucketName = config.batch.extractBucketName(config.batch.logsBucket)
if( logsBucketName ) {
buckets.add(logsBucketName)
}
}
return this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,14 +447,30 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask {
return Job.newBuilder()
.addTaskGroups(taskGroup)
.setAllocationPolicy(allocationPolicy)
.setLogsPolicy(
LogsPolicy.newBuilder()
.setDestination(LogsPolicy.Destination.CLOUD_LOGGING)
)
.setLogsPolicy(createLogsPolicy())
.putAllLabels(task.config.getResourceLabels())
.build()
}

/**
* Create the LogsPolicy based on configuration
* @return LogsPolicy configured for either PATH (GCS bucket) or CLOUD_LOGGING
*/
protected LogsPolicy createLogsPolicy() {
final logsBucket = executor.batchConfig.logsBucket
if( logsBucket ) {
final containerPath = executor.batchConfig.convertGcsPathToMountPath(logsBucket)
return LogsPolicy.newBuilder()
.setDestination(LogsPolicy.Destination.PATH)
.setLogsPath(containerPath)
.build()
} else {
return LogsPolicy.newBuilder()
.setDestination(LogsPolicy.Destination.CLOUD_LOGGING)
.build()
}
}

/**
* @return Retrieve the submitted task state
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ class BatchConfig implements ConfigScope {
""")
final boolean installGpuDrivers

@ConfigOption
@Description("""
The Google Cloud Storage bucket path where job logs should be stored, e.g. `gs://my-logs-bucket/logs`. When specified, Google Batch will write job logs to this location instead of Cloud Logging. The bucket must be accessible and writable by the service account.
""")
final String logsBucket

@ConfigOption
@Description("""
Max number of execution attempts of a job interrupted by a Compute Engine Spot reclaim event (default: `0`).
Expand Down Expand Up @@ -142,6 +148,7 @@ class BatchConfig implements ConfigScope {
cpuPlatform = opts.cpuPlatform
gcsfuseOptions = opts.gcsfuseOptions as List<String> ?: DEFAULT_GCSFUSE_OPTS
installGpuDrivers = opts.installGpuDrivers as boolean
logsBucket = validateLogsBucket(opts.logsBucket as String)
maxSpotAttempts = opts.maxSpotAttempts != null ? opts.maxSpotAttempts as int : DEFAULT_MAX_SPOT_ATTEMPTS
network = opts.network
networkTags = opts.networkTags as List<String> ?: Collections.emptyList()
Expand All @@ -155,4 +162,44 @@ class BatchConfig implements ConfigScope {

BatchRetryConfig getRetryConfig() { retry }

private static String validateLogsBucket(String bucket) {
if( !bucket )
return null

if( !bucket.startsWith('gs://') )
throw new IllegalArgumentException("Logs bucket path must start with 'gs://' - provided: $bucket")

if( bucket.length() <= 5 || bucket == 'gs://' )
throw new IllegalArgumentException("Invalid logs bucket path - provided: $bucket")

return bucket
}

/**
* Extract the bucket name from a GCS path
* @param gcsPath GCS path like "gs://bucket-name/path/to/logs"
* @return bucket name like "bucket-name"
*/
static String extractBucketName(String gcsPath) {
if( !gcsPath || !gcsPath.startsWith('gs://') )
return null

final pathWithoutProtocol = gcsPath.substring(5) // Remove "gs://"
final slashIndex = pathWithoutProtocol.indexOf('/')
return slashIndex > 0 ? pathWithoutProtocol.substring(0, slashIndex) : pathWithoutProtocol
}

/**
* Convert a GCS path to container mount path
* @param gcsPath GCS path like "gs://bucket-name/path/to/logs"
* @return container path like "/mnt/disks/bucket-name/path/to/logs"
*/
static String convertGcsPathToMountPath(String gcsPath) {
if( !gcsPath || !gcsPath.startsWith('gs://') )
return gcsPath

final pathWithoutProtocol = gcsPath.substring(5) // Remove "gs://"
return "/mnt/disks/${pathWithoutProtocol}"
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,62 @@ class GoogleBatchTaskHandlerTest extends Specification {
"SUCCEEDED" | JobStatus.State.FAILED | makeTaskStatus(TaskStatus.State.SUCCEEDED, "") // get from task status
}

def 'should create submit request with logs bucket PATH policy' () {
given:
def GCS_VOL = Volume.newBuilder().setGcs(GCS.newBuilder().setRemotePath('foo').build() ).build()
def WORK_DIR = CloudStorageFileSystem.forBucket('foo').getPath('/scratch')
def CONTAINER_IMAGE = 'ubuntu:22.1'
def LOGS_BUCKET = 'gs://my-logs-bucket/logs'

def session = Mock(Session) {
getBucketDir() >> CloudStorageFileSystem.forBucket('foo').getPath('/')
getConfig() >> [google: [project: 'test-project', batch: [logsBucket: LOGS_BUCKET]]]
}

def exec = Mock(GoogleBatchExecutor) {
getSession() >> session
getBatchConfig() >> new BatchConfig([logsBucket: LOGS_BUCKET])
getConfig() >> Mock(ExecutorConfig)
isFusionEnabled() >> false
}

def bean = new TaskBean(workDir: WORK_DIR, inputFiles: [:])
def task = Mock(TaskRun) {
toTaskBean() >> bean
getHashLog() >> 'abcd1234'
getWorkDir() >> WORK_DIR
getContainer() >> CONTAINER_IMAGE
getConfig() >> Mock(TaskConfig) {
getCpus() >> 2
getResourceLabels() >> [:]
}
}

def LOGS_VOL = Volume.newBuilder().setGcs(GCS.newBuilder().setRemotePath('my-logs-bucket').build()).setMountPath('/mnt/disks/my-logs-bucket').build()
def mounts = ['/mnt/disks/foo/scratch:/mnt/disks/foo/scratch:rw']
def volumes = [ GCS_VOL, LOGS_VOL ]
def launcher = new GoogleBatchLauncherSpecMock('bash .command.run', mounts, volumes)

def handler = Spy(new GoogleBatchTaskHandler(task, exec))

when:
def req = handler.newSubmitRequest(task, launcher)
then:
handler.fusionEnabled() >> false
handler.findBestMachineType(_, false) >> null

and:
req.getLogsPolicy().getDestination().toString() == 'PATH'
req.getLogsPolicy().getLogsPath() == '/mnt/disks/my-logs-bucket/logs'
and:
def taskGroup = req.getTaskGroups(0)
def taskVolumes = taskGroup.getTaskSpec().getVolumesList()
taskVolumes.size() >= 2 // At least work dir volume and logs bucket volume
def logsBucketVolume = taskVolumes.find { it.getGcs().getRemotePath() == 'my-logs-bucket' }
logsBucketVolume != null
logsBucketVolume.getMountPath() == '/mnt/disks/my-logs-bucket'
}

def makeTask(String name, TaskStatus.State state){
Task.newBuilder().setName(name)
.setStatus(TaskStatus.newBuilder().setState(state).build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,57 @@ class BatchConfigTest extends Specification {
config.bootDiskSize == MemoryUnit.of('100GB')
}

@Requires({System.getenv('GOOGLE_APPLICATION_CREDENTIALS')})
def 'should validate logs bucket config' () {
when:
def config = new BatchConfig([logsBucket: 'gs://my-logs-bucket/logs'])
then:
config.logsBucket == 'gs://my-logs-bucket/logs'

when:
config = new BatchConfig([:])
then:
config.logsBucket == null
}

@Requires({System.getenv('GOOGLE_APPLICATION_CREDENTIALS')})
def 'should reject invalid logs bucket paths' () {
when:
new BatchConfig([logsBucket: 'invalid-bucket'])
then:
def e = thrown(IllegalArgumentException)
e.message.contains("Logs bucket path must start with 'gs://'")

when:
new BatchConfig([logsBucket: 'gs://'])
then:
e = thrown(IllegalArgumentException)
e.message.contains("Invalid logs bucket path")

when:
new BatchConfig([logsBucket: 's3://bucket'])
then:
e = thrown(IllegalArgumentException)
e.message.contains("Logs bucket path must start with 'gs://'")
}

def 'should extract bucket name from GCS path' () {
expect:
BatchConfig.extractBucketName('gs://my-bucket') == 'my-bucket'
BatchConfig.extractBucketName('gs://my-bucket/logs') == 'my-bucket'
BatchConfig.extractBucketName('gs://my-bucket/path/to/logs') == 'my-bucket'
BatchConfig.extractBucketName('gs://') == ''
BatchConfig.extractBucketName('invalid-path') == null
BatchConfig.extractBucketName(null) == null
}

def 'should convert GCS path to mount path' () {
expect:
BatchConfig.convertGcsPathToMountPath('gs://my-bucket') == '/mnt/disks/my-bucket'
BatchConfig.convertGcsPathToMountPath('gs://my-bucket/logs') == '/mnt/disks/my-bucket/logs'
BatchConfig.convertGcsPathToMountPath('gs://my-bucket/path/to/logs') == '/mnt/disks/my-bucket/path/to/logs'
BatchConfig.convertGcsPathToMountPath('invalid-path') == 'invalid-path'
BatchConfig.convertGcsPathToMountPath(null) == null
}

}