Skip to content

Commit

Permalink
Bulk Load CDK: Remove interface from scope provider and migrate tests (
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Jan 9, 2025
1 parent 4265fab commit d3eae1a
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ interface DestinationTaskLauncher : TaskLauncher {
justification = "arguments are guaranteed to be non-null by Kotlin's type system"
)
class DefaultDestinationTaskLauncher(
private val taskScopeProvider: TaskScopeProvider<WrappedTask<ScopedTask>>,
private val taskScopeProvider: TaskScopeProvider,
private val catalog: DestinationCatalog,
private val config: DestinationConfiguration,
private val syncManager: SyncManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

package io.airbyte.cdk.load.task

import io.airbyte.cdk.load.util.CloseableCoroutine

interface Task {
suspend fun execute()
}
Expand All @@ -21,24 +19,3 @@ interface TaskLauncher {
*/
suspend fun run()
}

/**
* Wraps tasks with exception handling. It should perform all necessary exception handling, then
* execute the provided callback.
*/
interface TaskExceptionHandler<T : Task, U : Task> {
// Wrap a task with exception handling.
suspend fun withExceptionHandling(task: T): U

// Set a callback that will be invoked when any exception handling is done.
suspend fun setCallback(callback: suspend () -> Unit)
}

/** Provides the scope(s) in which tasks run. */
interface TaskScopeProvider<T : Task> : CloseableCoroutine {
/** Launch a task in the correct scope. */
suspend fun launch(task: T)

/** Unliked [close], may attempt to fail gracefully, but should guarantee return. */
suspend fun kill()
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ interface WrappedTask<T : Task> : Task {

@Singleton
@Secondary
class DestinationTaskScopeProvider(config: DestinationConfiguration) :
TaskScopeProvider<WrappedTask<ScopedTask>> {
class TaskScopeProvider(config: DestinationConfiguration) {
private val log = KotlinLogging.logger {}

private val timeoutMs = config.gracefulCancellationTimeoutMs
Expand All @@ -81,7 +80,7 @@ class DestinationTaskScopeProvider(config: DestinationConfiguration) :

private val failFastScope = ControlScope("input", Job(), Dispatchers.IO)

override suspend fun launch(task: WrappedTask<ScopedTask>) {
suspend fun launch(task: WrappedTask<ScopedTask>) {
val scope =
when (task.innerTask) {
is InternalScope -> internalScope
Expand All @@ -97,7 +96,7 @@ class DestinationTaskScopeProvider(config: DestinationConfiguration) :
}
}

override suspend fun close() {
suspend fun close() {
// Under normal operation, all tasks should be complete
// (except things like force flush, which loop). So
// - it's safe to force cancel the internal tasks
Expand Down Expand Up @@ -126,7 +125,7 @@ class DestinationTaskScopeProvider(config: DestinationConfiguration) :
internalScope.job.cancel()
}

override suspend fun kill() {
suspend fun kill() {
log.info { "Killing task scopes" }
// Terminate tasks which should be immediately terminated
failFastScope.job.cancel()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ import org.junit.jupiter.api.Test
"MockScopeProvider",
]
)
class DestinationTaskLauncherTest<T : ScopedTask> {
@Inject lateinit var mockScopeProvider: MockScopeProvider
class DestinationTaskLauncherTest {
@Inject lateinit var taskScopeProvider: TaskScopeProvider
@Inject lateinit var taskLauncher: DestinationTaskLauncher
@Inject lateinit var syncManager: SyncManager

Expand Down Expand Up @@ -447,42 +447,6 @@ class DestinationTaskLauncherTest<T : ScopedTask> {
teardownTaskFactory.hasRun.receive()
}

@Test
fun testHandleTeardownComplete() = runTest {
// This should close the scope provider.
launch {
taskLauncher.run()
Assertions.assertTrue(mockScopeProvider.didClose)
}
taskLauncher.handleTeardownComplete()
}

@Test
fun testHandleCallbackWithFailure() = runTest {
launch {
taskLauncher.run()
Assertions.assertTrue(mockScopeProvider.didKill)
}
taskLauncher.handleTeardownComplete(success = false)
}

@Test
fun `test exceptions in tasks throw`(catalog: DestinationCatalog) = runTest {
mockSpillToDiskTaskFactory.forceFailure.getAndSet(true)

val job = launch { taskLauncher.run() }
taskLauncher.handleTeardownComplete()
job.join()

mockFailStreamTaskFactory.didRunFor.close()

Assertions.assertEquals(
catalog.streams.map { it.descriptor }.toSet(),
mockFailStreamTaskFactory.didRunFor.toList().toSet(),
"FailStreamTask was run for each stream"
)
}

@Test
fun `test sync failure after stream failure`() = runTest {
val job = launch { taskLauncher.run() }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test

class DestinationTaskLauncherUTest {
private val taskScopeProvider: TaskScopeProvider<WrappedTask<ScopedTask>> =
mockk(relaxed = true)
private val taskScopeProvider: TaskScopeProvider = mockk(relaxed = true)
private val catalog: DestinationCatalog = mockk(relaxed = true)
private val syncManager: SyncManager = mockk(relaxed = true)

Expand Down Expand Up @@ -179,4 +178,52 @@ class DestinationTaskLauncherUTest {
)
coVerify(exactly = 1) { closeStreamTaskFactory.make(any(), any()) }
}

@Test
fun `task successful completion triggers scope close`() = runTest {
// This should close the scope provider.
val taskLauncher = getDefaultDestinationTaskLauncher(false)
launch {
taskLauncher.run()
coVerify { taskScopeProvider.close() }
}
taskLauncher.handleTeardownComplete()
}

@Test
fun `test completion with failure triggers scope kill`() = runTest {
val taskLauncher = getDefaultDestinationTaskLauncher(false)
launch {
taskLauncher.run()
coVerify { taskScopeProvider.kill() }
}
taskLauncher.handleTeardownComplete(success = false)
}

@Test
fun `test exceptions in tasks throw`() = runTest {
coEvery { spillToDiskTaskFactory.make(any(), any()) } answers
{
val task = mockk<SpillToDiskTask>(relaxed = true)
coEvery { task.execute() } throws Exception("spill to disk task failed")
task
}
coEvery { taskScopeProvider.launch(any()) } coAnswers
{
val task = firstArg<Task>()
task.execute()
}

val taskLauncher = getDefaultDestinationTaskLauncher(false)
val job = launch { taskLauncher.run() }
taskLauncher.handleTeardownComplete()
job.join()
coVerify {
failStreamTaskFactory.make(
any(),
any(),
match { it.namespace == "namespace" && it.name == "name" }
)
}
}
}

This file was deleted.

0 comments on commit d3eae1a

Please sign in to comment.