diff --git a/docs/background/05_task_chains.md b/docs/background/05_task_chains.md index 4b3b1df6..26596801 100644 --- a/docs/background/05_task_chains.md +++ b/docs/background/05_task_chains.md @@ -29,3 +29,10 @@ Tasks may be posted to a job queue (see `JobQueueTask`) and run by a job queue r ## Compound Task A compound task is `GraphAsTask` that contains a group of tasks (in DAG format) chained together to look like a single task. + +## Streaming Between Tasks + +- Tasks can stream partial results to dependants without waiting for `execute()` to finish by declaring stream-capable outputs via the static `streaming()` descriptor. +- Ports marked with readiness `first-chunk` allow downstream tasks to begin work as soon as the first chunk is emitted. Ports with readiness `final` defer dependants until streaming completes. +- `IExecuteContext` now exposes `pushChunk`, `closeStream`, and `attachStreamController` helpers so tasks can enqueue chunks directly or adapt custom `ReadableStream` producers. +- Dataflows track streaming state and expose async iterables so consumers can react to chunk updates while still receiving the final aggregated output when the stream ends. diff --git a/packages/task-graph/README.md b/packages/task-graph/README.md index f44a35fa..a27e3713 100644 --- a/packages/task-graph/README.md +++ b/packages/task-graph/README.md @@ -688,6 +688,61 @@ try { ## Advanced Patterns +### Streaming Outputs + +Tasks can stream partial results to downstream consumers. Declare stream-capable ports with `static streaming()` and use the execution context helpers to emit chunks. + +```typescript +import { + Task, + type IExecuteContext, + type DataPortSchema, + type TaskStreamingDescriptor, + createStringAccumulator, +} from "@podley/task-graph"; +import { Type } from "@sinclair/typebox"; + +class StreamingTextTask extends Task<{ chunks: string[] }, { output: string }> { + static readonly type = "StreamingTextTask"; + + static inputSchema() { + return Type.Object({ + chunks: Type.Array(Type.String({ description: "Chunk to emit" })), + }); + } + + static outputSchema() { + return Type.Object({ + output: Type.String({ description: "Concatenated output" }), + }); + } + + static streaming(): TaskStreamingDescriptor { + return { + outputs: { + output: { + chunkSchema: Type.String({ + description: "Live chunk emitted while processing", + }) as DataPortSchema, + readiness: "first-chunk", + accumulator: createStringAccumulator(), + }, + }, + }; + } + + async execute(input: { chunks: string[] }, context: IExecuteContext) { + for (const chunk of input.chunks) { + await context.pushChunk("output", chunk); + } + await context.closeStream("output"); + return { output: input.chunks.join("") }; + } +} +``` + +`readiness: "first-chunk"` allows dependent tasks to begin as soon as the first chunk is available. Use `"final"` to defer dependants until the stream closes. The execution context also exposes `attachStreamController` so existing `ReadableStream` producers can be bridged into the task graph without rewriting streaming logic. + ### Array Tasks (Parallel Processing) ```typescript diff --git a/packages/task-graph/src/common.ts b/packages/task-graph/src/common.ts index 36edb4b3..6fce25ea 100644 --- a/packages/task-graph/src/common.ts +++ b/packages/task-graph/src/common.ts @@ -12,6 +12,7 @@ export * from "./task/ITask"; export * from "./task/TaskEvents"; export * from "./task/TaskJSON"; export * from "./task/Task"; +export * from "./task/TaskStream"; export * from "./task/GraphAsTask"; export * from "./task/GraphAsTaskRunner"; export * from "./task/TaskRegistry"; diff --git a/packages/task-graph/src/task-graph/Dataflow.ts b/packages/task-graph/src/task-graph/Dataflow.ts index e14e349d..789c7d94 100644 --- a/packages/task-graph/src/task-graph/Dataflow.ts +++ b/packages/task-graph/src/task-graph/Dataflow.ts @@ -9,7 +9,13 @@ import { areSemanticallyCompatible, EventEmitter } from "@podley/util"; import { Type } from "@sinclair/typebox"; import { TaskError } from "../task/TaskError"; import { DataflowJson } from "../task/TaskJSON"; -import { Provenance, TaskIdType, TaskOutput, TaskStatus } from "../task/TaskTypes"; +import { + Provenance, + TaskIdType, + TaskOutput, + TaskStatus, + type TaskStreamPortDescriptor, +} from "../task/TaskTypes"; import { DataflowEventListener, DataflowEventListeners, @@ -23,6 +29,26 @@ export type DataflowIdType = `${string}[${string}] ==> ${string}[${string}]`; export const DATAFLOW_ALL_PORTS = "*"; export const DATAFLOW_ERROR_PORT = "[error]"; +interface StreamListener { + index: number; + pending: + | { + resolve: (result: IteratorResult) => void; + reject: (error: unknown) => void; + } + | null; + closed: boolean; +} + +interface DataflowStreamState { + descriptor: TaskStreamPortDescriptor; + history: unknown[]; + listeners: Set; + closed: boolean; + error: TaskError | null; + readinessReached: boolean; +} + /** * Represents a data flow between two tasks, indicating how one task's output is used as input for another task */ @@ -53,12 +79,14 @@ export class Dataflow { public provenance: Provenance = {}; public status: TaskStatus = TaskStatus.PENDING; public error: TaskError | undefined; + private streamState: DataflowStreamState | null = null; public reset() { this.status = TaskStatus.PENDING; this.error = undefined; this.value = undefined; this.provenance = {}; + this.streamState = null; this.emit("reset"); this.emit("status", this.status); } @@ -70,6 +98,9 @@ export class Dataflow { case TaskStatus.PROCESSING: this.emit("start"); break; + case TaskStatus.STREAMING: + this.emit("stream_start"); + break; case TaskStatus.COMPLETED: this.emit("complete"); break; @@ -110,6 +141,184 @@ export class Dataflow { } } + public beginStream( + descriptor: TaskStreamPortDescriptor, + provenance?: Provenance + ): void { + if (!this.streamState) { + this.streamState = { + descriptor, + history: [], + listeners: new Set(), + closed: false, + error: null, + readinessReached: false, + }; + } + if (provenance) { + this.provenance = provenance; + } + this.setStatus(TaskStatus.STREAMING); + } + + public pushStreamChunk( + chunk: unknown, + aggregate: unknown, + provenance?: Provenance + ): boolean { + const state = this.ensureActiveStream(); + state.history.push(chunk); + let readinessTriggered = false; + if (!state.readinessReached && state.descriptor.readiness === "first-chunk") { + state.readinessReached = true; + readinessTriggered = true; + } + if (provenance) { + this.provenance = provenance; + } + this.value = aggregate; + this.emit("stream_chunk", chunk, aggregate); + this.flushStreamListeners(); + return readinessTriggered; + } + + public endStream(finalValue: unknown, provenance?: Provenance): boolean { + const state = this.ensureActiveStream(); + state.closed = true; + let readinessTriggered = false; + if (state.descriptor.readiness === "final") { + state.readinessReached = true; + readinessTriggered = true; + } + if (provenance) { + this.provenance = provenance; + } + this.value = finalValue; + this.emit("stream_end", finalValue); + this.flushStreamListeners(); + return readinessTriggered; + } + + public failStream(error: TaskError): void { + const state = this.ensureActiveStream(); + state.error = error; + state.closed = true; + this.flushStreamListeners(); + this.emit("error", error); + } + + public streamIterator(): AsyncIterableIterator { + const state = this.ensureActiveStream(); + const listener: StreamListener = { + index: 0, + pending: null, + closed: false, + }; + state.listeners.add(listener); + const iterator: AsyncIterableIterator = { + next: () => this.resolveStreamListener(listener), + return: () => { + this.cleanupStreamListener(listener, false); + return Promise.resolve({ value: undefined, done: true }); + }, + throw: (error?: unknown) => { + this.cleanupStreamListener(listener, false); + return Promise.reject(error); + }, + [Symbol.asyncIterator](): AsyncIterableIterator { + return this; + }, + }; + this.flushStreamListener(listener); + return iterator; + } + + public hasActiveStream(): boolean { + return this.streamState !== null && !this.streamState.closed; + } + + public streamReadinessReached(): boolean { + return this.streamState?.readinessReached ?? false; + } + + public getStreamDescriptor(): TaskStreamPortDescriptor | null { + return this.streamState?.descriptor ?? null; + } + + private ensureActiveStream(): DataflowStreamState { + if (!this.streamState) { + throw new TaskError("Streaming state has not been initialised for this dataflow."); + } + return this.streamState; + } + + private resolveStreamListener( + listener: StreamListener + ): Promise> { + return new Promise((resolve, reject) => { + listener.pending = { resolve, reject }; + this.flushStreamListener(listener); + }); + } + + private flushStreamListeners(): void { + if (!this.streamState) return; + for (const listener of Array.from(this.streamState.listeners)) { + this.flushStreamListener(listener); + } + } + + private flushStreamListener(listener: StreamListener): void { + const state = this.streamState; + if (!state) return; + if (listener.closed) { + if (listener.pending) { + listener.pending.resolve({ value: undefined, done: true }); + listener.pending = null; + } + state.listeners.delete(listener); + return; + } + if (!listener.pending) { + return; + } + if (listener.index < state.history.length) { + const value = state.history[listener.index++]; + const { resolve } = listener.pending; + listener.pending = null; + resolve({ value, done: false }); + return; + } + if (state.error) { + const { reject } = listener.pending; + listener.pending = null; + listener.closed = true; + reject(state.error); + state.listeners.delete(listener); + return; + } + if (state.closed) { + const { resolve } = listener.pending; + listener.pending = null; + listener.closed = true; + resolve({ value: undefined, done: true }); + state.listeners.delete(listener); + } + } + + private cleanupStreamListener(listener: StreamListener, settle: boolean): void { + const state = this.streamState; + if (!state) return; + if (listener.pending) { + if (settle) { + listener.pending.resolve({ value: undefined, done: true }); + } + listener.pending = null; + } + listener.closed = true; + state.listeners.delete(listener); + } + toJSON(): DataflowJson { return { sourceTaskId: this.sourceTaskId, diff --git a/packages/task-graph/src/task-graph/DataflowEvents.ts b/packages/task-graph/src/task-graph/DataflowEvents.ts index 4c9283ca..5fe62aa1 100644 --- a/packages/task-graph/src/task-graph/DataflowEvents.ts +++ b/packages/task-graph/src/task-graph/DataflowEvents.ts @@ -20,6 +20,15 @@ export type DataflowEventListeners = { /** Fired when a source task completes successfully */ complete: () => void; + /** Fired when a dataflow begins streaming */ + stream_start: () => void; + + /** Fired when a streaming chunk traverses the dataflow */ + stream_chunk: (chunk: unknown, aggregate: unknown) => void; + + /** Fired when streaming ends */ + stream_end: (aggregate: unknown) => void; + /** Fired when a source task is skipped */ skipped: () => void; diff --git a/packages/task-graph/src/task-graph/TaskGraphRunner.ts b/packages/task-graph/src/task-graph/TaskGraphRunner.ts index cf7d5c81..46ed30d5 100644 --- a/packages/task-graph/src/task-graph/TaskGraphRunner.ts +++ b/packages/task-graph/src/task-graph/TaskGraphRunner.ts @@ -15,7 +15,13 @@ import { import { TASK_OUTPUT_REPOSITORY, TaskOutputRepository } from "../storage/TaskOutputRepository"; import { ITask } from "../task/ITask"; import { TaskAbortedError, TaskConfigurationError, TaskError } from "../task/TaskError"; -import { Provenance, TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes"; +import { + Provenance, + TaskInput, + TaskOutput, + TaskStatus, + type TaskStreamPortDescriptor, +} from "../task/TaskTypes"; import { DATAFLOW_ALL_PORTS } from "./Dataflow"; import { TaskGraph, TaskGraphRunConfig } from "./TaskGraph"; import { DependencyBasedScheduler, TopologicalScheduler } from "./TaskGraphScheduler"; @@ -427,6 +433,61 @@ export class TaskGraphRunner { }); } + private handleStreamStartForTask( + task: ITask, + portId: string, + descriptor: TaskStreamPortDescriptor, + provenance: Provenance + ): void { + for (const dataflow of this.getOutgoingDataflowsForPort(task, portId)) { + dataflow.beginStream(descriptor, provenance); + } + } + + private handleStreamChunkForTask( + task: ITask, + portId: string, + chunk: unknown, + aggregate: unknown, + provenance: Provenance + ): void { + for (const dataflow of this.getOutgoingDataflowsForPort(task, portId)) { + const readinessTriggered = dataflow.pushStreamChunk(chunk, aggregate, provenance); + if (readinessTriggered) { + this.processScheduler.onDataflowReady(dataflow.targetTaskId); + } + } + } + + private handleStreamEndForTask( + task: ITask, + portId: string, + aggregate: unknown, + provenance: Provenance + ): void { + for (const dataflow of this.getOutgoingDataflowsForPort(task, portId)) { + const readinessTriggered = dataflow.endStream(aggregate, provenance); + if (readinessTriggered) { + this.processScheduler.onDataflowReady(dataflow.targetTaskId); + } + } + } + + private handleStreamErrorForTask(task: ITask, portId: string, error: TaskError): void { + for (const dataflow of this.getOutgoingDataflowsForPort(task, portId)) { + dataflow.failStream(error); + } + } + + private getOutgoingDataflowsForPort(task: ITask, portId: string) { + return this.graph + .getTargetDataflows(task.config.id) + .filter( + (dataflow) => + dataflow.sourceTaskPortId === portId || dataflow.sourceTaskPortId === DATAFLOW_ALL_PORTS + ); + } + /** * Runs a task with provenance input * @param task The task to run @@ -447,12 +508,24 @@ export class TaskGraphRunner { this.provenanceInput.set(task.config.id, nodeProvenance); this.copyInputFromEdgesToNode(task); - const results = await task.runner.run(input, { - nodeProvenance, - outputCache: this.outputCache, - updateProgress: async (task: ITask, progress: number, message?: string, ...args: any[]) => - await this.handleProgress(task, progress, message, ...args), - }); + const results = await task.runner.run(input, { + nodeProvenance, + outputCache: this.outputCache, + updateProgress: async (task: ITask, progress: number, message?: string, ...args: any[]) => + await this.handleProgress(task, progress, message, ...args), + onStreamStart: async (_task, portId, descriptor) => { + this.handleStreamStartForTask(task, portId, descriptor, nodeProvenance); + }, + onStreamChunk: async (_task, portId, chunk, aggregate) => { + this.handleStreamChunkForTask(task, portId, chunk, aggregate, nodeProvenance); + }, + onStreamEnd: async (_task, portId, aggregate) => { + this.handleStreamEndForTask(task, portId, aggregate, nodeProvenance); + }, + onStreamError: async (_task, portId, error) => { + this.handleStreamErrorForTask(task, portId, error); + }, + }); await this.pushOutputFromNodeToEdges(task, results, nodeProvenance); diff --git a/packages/task-graph/src/task-graph/TaskGraphScheduler.ts b/packages/task-graph/src/task-graph/TaskGraphScheduler.ts index aa30c000..6ef2ba02 100644 --- a/packages/task-graph/src/task-graph/TaskGraphScheduler.ts +++ b/packages/task-graph/src/task-graph/TaskGraphScheduler.ts @@ -24,6 +24,12 @@ export interface ITaskGraphScheduler { */ onTaskCompleted(taskId: unknown): void; + /** + * Notifies the scheduler that a dataflow became ready for consumption + * @param taskId The ID of the task that may now be ready + */ + onDataflowReady(taskId: unknown): void; + /** * Resets the scheduler state */ @@ -54,6 +60,10 @@ export class TopologicalScheduler implements ITaskGraphScheduler { // Topological scheduler doesn't need to track individual task completion } + onDataflowReady(_taskId: unknown): void { + // Topological scheduler does not react to streaming readiness + } + reset(): void { this.sortedNodes = this.dag.topologicallySortedNodes(); this.currentIndex = 0; @@ -76,10 +86,28 @@ export class DependencyBasedScheduler implements ITaskGraphScheduler { } private isTaskReady(task: ITask): boolean { - const dependencies = this.dag - .getSourceDataflows(task.config.id) - .map((dataflow) => dataflow.sourceTaskId); - return dependencies.every((dep) => this.completedTasks.has(dep)); + const dataflows = this.dag.getSourceDataflows(task.config.id); + if (dataflows.length === 0) { + return true; + } + return dataflows.every((dataflow) => { + const descriptor = dataflow.getStreamDescriptor(); + if (descriptor) { + if (descriptor.readiness === "first-chunk") { + if (dataflow.streamReadinessReached()) { + return true; + } + return this.completedTasks.has(dataflow.sourceTaskId); + } + if (descriptor.readiness === "final") { + if (dataflow.streamReadinessReached()) { + return true; + } + return this.completedTasks.has(dataflow.sourceTaskId); + } + } + return this.completedTasks.has(dataflow.sourceTaskId); + }); } private async waitForNextTask(): Promise { @@ -127,6 +155,27 @@ export class DependencyBasedScheduler implements ITaskGraphScheduler { } } + onDataflowReady(taskId: unknown): void { + if (!this.nextResolver) return; + const candidate = Array.from(this.pendingTasks).find( + (task) => task.config.id === taskId && this.isTaskReady(task) + ); + if (candidate) { + this.pendingTasks.delete(candidate); + const resolver = this.nextResolver; + this.nextResolver = null; + resolver(candidate); + return; + } + const readyTask = Array.from(this.pendingTasks).find((task) => this.isTaskReady(task)); + if (readyTask) { + this.pendingTasks.delete(readyTask); + const resolver = this.nextResolver; + this.nextResolver = null; + resolver(readyTask); + } + } + reset(): void { this.completedTasks.clear(); this.pendingTasks = new Set(this.dag.topologicallySortedNodes()); diff --git a/packages/task-graph/src/task/ITask.ts b/packages/task-graph/src/task/ITask.ts index 85ca7da3..74e3cad1 100644 --- a/packages/task-graph/src/task/ITask.ts +++ b/packages/task-graph/src/task/ITask.ts @@ -21,7 +21,15 @@ import type { import type { JsonTaskItem, TaskGraphItemJson } from "./TaskJSON"; import { TaskRunner } from "./TaskRunner"; import type { DataPortSchema } from "./TaskSchema"; -import type { Provenance, TaskConfig, TaskInput, TaskOutput, TaskStatus } from "./TaskTypes"; +import type { + Provenance, + TaskConfig, + TaskInput, + TaskOutput, + TaskStatus, + TaskStreamPortDescriptor, + TaskStreamingDescriptor, +} from "./TaskTypes"; /** * Context for task execution @@ -31,6 +39,12 @@ export interface IExecuteContext { nodeProvenance: Provenance; updateProgress: (progress: number, message?: string, ...args: any[]) => Promise; own: (i: T) => T; + pushChunk: (portId: string, chunk: unknown) => Promise; + closeStream: (portId: string) => Promise; + attachStreamController: ( + portId: string, + controller: ReadableStreamDefaultController + ) => ReadableStreamDefaultController; } export type IExecuteReactiveContext = Pick; @@ -47,6 +61,19 @@ export interface IRunConfig { message?: string, ...args: any[] ) => Promise; + onStreamStart?: ( + task: ITask, + portId: string, + descriptor: TaskStreamPortDescriptor + ) => Promise | void; + onStreamChunk?: ( + task: ITask, + portId: string, + chunk: unknown, + aggregate: unknown + ) => Promise | void; + onStreamEnd?: (task: ITask, portId: string, aggregate: unknown) => Promise | void; + onStreamError?: (task: ITask, portId: string, error: TaskError) => Promise | void; } /** @@ -62,6 +89,7 @@ export interface ITaskStaticProperties { readonly cacheable: boolean; readonly inputSchema: () => DataPortSchema; readonly outputSchema: () => DataPortSchema; + readonly streaming: () => TaskStreamingDescriptor; } /** @@ -109,6 +137,7 @@ export interface ITaskIO { get type(): string; // gets local access for static type property get category(): string; // gets local access for static category property get title(): string; // gets local access for static title property + streaming(): TaskStreamingDescriptor; setDefaults(defaults: Record): void; resetInputData(): void; diff --git a/packages/task-graph/src/task/Task.ts b/packages/task-graph/src/task/Task.ts index c58b5d15..5a357473 100644 --- a/packages/task-graph/src/task/Task.ts +++ b/packages/task-graph/src/task/Task.ts @@ -27,6 +27,7 @@ import { type TaskInput, type TaskOutput, type TaskTypeName, + type TaskStreamingDescriptor, } from "./TaskTypes"; /** @@ -84,6 +85,14 @@ export class Task< return Type.Object({}) as DataPortSchema; } + /** + * Streaming metadata for this task + * Returns descriptors for stream-capable outputs + */ + public static streaming(): TaskStreamingDescriptor { + return { outputs: {} }; + } + // ======================================================================== // Task Execution Methods - Core logic provided by subclasses // ======================================================================== @@ -199,6 +208,10 @@ export class Task< return (this.constructor as typeof Task).outputSchema(); } + public streaming(): TaskStreamingDescriptor { + return (this.constructor as typeof Task).streaming(); + } + public get type(): TaskTypeName { return (this.constructor as typeof Task).type; } diff --git a/packages/task-graph/src/task/TaskEvents.ts b/packages/task-graph/src/task/TaskEvents.ts index 9abcf0a8..0dc8f0aa 100644 --- a/packages/task-graph/src/task/TaskEvents.ts +++ b/packages/task-graph/src/task/TaskEvents.ts @@ -23,6 +23,15 @@ export type TaskEventListeners = { /** Fired when a task completes successfully */ complete: () => void; + /** Fired when a task begins streaming on a specific output port */ + stream_start: (portId: string) => void; + + /** Fired when a streaming chunk is produced */ + stream_chunk: (portId: string, chunk: unknown, aggregate: unknown) => void; + + /** Fired when a task finishes streaming on a specific output port */ + stream_end: (portId: string, aggregate: unknown) => void; + /** Fired when a task is aborted */ abort: (error: TaskAbortedError) => void; diff --git a/packages/task-graph/src/task/TaskRunner.ts b/packages/task-graph/src/task/TaskRunner.ts index 18e6397b..104ca6c8 100644 --- a/packages/task-graph/src/task/TaskRunner.ts +++ b/packages/task-graph/src/task/TaskRunner.ts @@ -11,13 +11,39 @@ import { ITaskGraph } from "../task-graph/ITaskGraph"; import { IWorkflow } from "../task-graph/IWorkflow"; import { TaskGraph } from "../task-graph/TaskGraph"; import { ensureTask, type Taskish } from "../task-graph/Conversions"; -import { IRunConfig, ITask } from "./ITask"; +import { IRunConfig, type IExecuteContext, ITask } from "./ITask"; import { ITaskRunner } from "./ITaskRunner"; import { Task } from "./Task"; -import { TaskAbortedError, TaskError, TaskFailedError, TaskInvalidInputError } from "./TaskError"; -import { Provenance, TaskConfig, TaskInput, TaskOutput, TaskStatus } from "./TaskTypes"; +import { + TaskAbortedError, + TaskConfigurationError, + TaskError, + TaskFailedError, + TaskInvalidInputError, +} from "./TaskError"; +import { + Provenance, + TaskConfig, + TaskInput, + TaskOutput, + TaskStatus, + type TaskStreamPortDescriptor, + type TaskStream, +} from "./TaskTypes"; import { GraphAsTask } from "../task/GraphAsTask"; import { Workflow } from "../task-graph/Workflow"; +import { isTaskStream, toAsyncIterable } from "./TaskStream"; + +interface TaskStreamRuntimeState { + readonly portId: string; + readonly descriptor: TaskStreamPortDescriptor; + aggregate: unknown; + started: boolean; + done: boolean; + contextManaged: boolean; + returnedStream: boolean; + chunkCount: number; +} /** * Responsible for running tasks @@ -29,21 +55,26 @@ export class TaskRunner< Config extends TaskConfig = TaskConfig, > implements ITaskRunner { - /** - * Whether the task is currently running - */ - protected running = false; - protected reactiveRunning = false; - - /** - * Provenance information for the task - */ - protected nodeProvenance: Provenance = {}; - - /** - * The task to run - */ - public readonly task: ITask; + /** + * Whether the task is currently running + */ + protected running = false; + protected reactiveRunning = false; + + /** + * Provenance information for the task + */ + protected nodeProvenance: Provenance = {}; + + /** + * The task to run + */ + public readonly task: ITask; + + /** + * Streaming runtime state keyed by output port id + */ + protected streamStates: Map | null = null; /** * AbortController for cancelling task execution @@ -92,19 +123,19 @@ export class TaskRunner< const inputs: Input = this.task.runInputData as Input; let outputs: Output | undefined; - if (this.task.cacheable) { - outputs = (await this.outputCache?.getOutput(this.task.type, inputs)) as Output; - if (outputs) { - this.task.runOutputData = await this.executeTaskReactive(inputs, outputs); + if (this.task.cacheable) { + outputs = (await this.outputCache?.getOutput(this.task.type, inputs)) as Output; + if (outputs) { + this.task.runOutputData = await this.executeTaskReactive(inputs, outputs); + } } - } - if (!outputs) { - outputs = await this.executeTask(inputs); - if (this.task.cacheable && outputs !== undefined) { - await this.outputCache?.saveOutput(this.task.type, inputs, outputs); + if (!outputs) { + outputs = await this.executeTask(inputs, config); + if (this.task.cacheable && outputs !== undefined) { + await this.outputCache?.saveOutput(this.task.type, inputs, outputs); + } + this.task.runOutputData = outputs ?? ({} as Output); } - this.task.runOutputData = outputs ?? ({} as Output); - } await this.handleComplete(); @@ -171,15 +202,287 @@ export class TaskRunner< /** * Protected method to execute a task by delegating back to the task itself. + * Handles streaming-aware execution, collecting chunks and final output. */ - protected async executeTask(input: Input): Promise { - const result = await this.task.execute(input, { + protected async executeTask(input: Input, config: IRunConfig): Promise { + const streamStates = this.prepareStreamStates(); + const streamPromises: Promise[] = []; + const context = this.createExecuteContext(streamStates, config, streamPromises); + + const rawResult = await this.task.execute(input, context); + const normalizedResult = this.normalizeExecuteResult(rawResult); + + await this.consumeStreamsFromResult( + normalizedResult, + streamStates, + config, + streamPromises + ); + + await this.finalizeContextDrivenStreams(streamStates, config); + + await Promise.all(streamPromises); + + this.applyFinalStreamResults(normalizedResult, streamStates); + + this.streamStates = null; + + return await this.executeTaskReactive(input, normalizedResult as Output); + } + + protected prepareStreamStates(): Map { + const descriptor = this.task.streaming(); + const states = new Map(); + for (const [portId, portDescriptor] of Object.entries(descriptor.outputs)) { + states.set(portId, { + portId, + descriptor: portDescriptor, + aggregate: portDescriptor.accumulator.initial(), + started: false, + done: false, + contextManaged: false, + returnedStream: false, + chunkCount: 0, + }); + } + this.streamStates = states; + return states; + } + + protected createExecuteContext( + streamStates: Map, + config: IRunConfig, + streamPromises: Promise[] + ): IExecuteContext { + const pushChunk = (portId: string, chunk: unknown) => { + const state = this.ensureStreamState(portId, streamStates); + state.contextManaged = true; + const pending = this.pushStreamChunk(state, chunk, config); + const tracked = pending.catch(async (error) => { + await this.handleStreamError(state, error, config); + throw error; + }); + streamPromises.push(tracked); + return tracked; + }; + + const closeStream = (portId: string) => { + const state = this.ensureStreamState(portId, streamStates); + const pending = this.closeStreamState(state, config); + const tracked = pending.catch(async (error) => { + await this.handleStreamError(state, error, config); + throw error; + }); + streamPromises.push(tracked); + return tracked; + }; + + const attachStreamController = ( + portId: string, + controller: ReadableStreamDefaultController + ) => { + const state = this.ensureStreamState(portId, streamStates); + const originalEnqueue = controller.enqueue.bind(controller); + controller.enqueue = ((chunk: Chunk) => { + originalEnqueue(chunk); + void pushChunk(portId, chunk); + }) as typeof controller.enqueue; + + const originalClose = controller.close.bind(controller); + controller.close = (() => { + originalClose(); + void closeStream(portId); + }) as typeof controller.close; + + const originalError = controller.error.bind(controller); + controller.error = ((reason?: unknown) => { + originalError(reason); + const trackedError = this.handleStreamError( + state, + reason ?? new Error("Stream error"), + config + ).catch((error) => { + throw error; + }); + streamPromises.push(trackedError); + }) as typeof controller.error; + + return controller; + }; + + return { signal: this.abortController!.signal, updateProgress: this.handleProgress.bind(this), nodeProvenance: this.nodeProvenance, own: this.own, - }); - return await this.executeTaskReactive(input, result || ({} as Output)); + pushChunk, + closeStream, + attachStreamController, + }; + } + + protected normalizeExecuteResult( + result: Output | undefined + ): Record { + if (result && typeof result === "object") { + return { ...(result as Record) }; + } + return {}; + } + + protected async consumeStreamsFromResult( + result: Record, + streamStates: Map, + config: IRunConfig, + streamPromises: Promise[] + ): Promise { + for (const [portId, state] of streamStates.entries()) { + if (result[portId] === undefined) { + continue; + } + const value = result[portId]; + if (isTaskStream(value)) { + state.returnedStream = true; + delete result[portId]; + const promise = this.consumeStream(state, value as TaskStream, config); + streamPromises.push(promise); + } else { + state.aggregate = value; + state.done = true; + this.task.runOutputData = { + ...this.task.runOutputData, + [state.portId]: state.aggregate, + }; + } + } + } + + protected async finalizeContextDrivenStreams( + streamStates: Map, + config: IRunConfig + ): Promise { + for (const state of streamStates.values()) { + if (state.contextManaged && !state.done) { + await this.closeStreamState(state, config); + } + } + } + + protected applyFinalStreamResults( + result: Record, + streamStates: Map + ): void { + for (const state of streamStates.values()) { + if (!state.done) { + continue; + } + result[state.portId] = state.aggregate; + } + } + + protected ensureStreamState( + portId: string, + streamStates: Map + ): TaskStreamRuntimeState { + const state = streamStates.get(portId); + if (!state) { + throw new TaskConfigurationError( + `Task "${this.task.type}" attempted to stream on undeclared output port "${portId}".` + ); + } + return state; + } + + protected async startStream( + state: TaskStreamRuntimeState, + config: IRunConfig + ): Promise { + if (state.started) return; + state.started = true; + if (this.task.status !== TaskStatus.STREAMING) { + this.task.status = TaskStatus.STREAMING; + this.task.emit("status", this.task.status); + } + this.task.emit("stream_start", state.portId); + if (config.onStreamStart) { + await config.onStreamStart(this.task, state.portId, state.descriptor); + } + } + + protected async pushStreamChunk( + state: TaskStreamRuntimeState, + chunk: unknown, + config: IRunConfig + ): Promise { + if (this.abortController?.signal.aborted) { + throw new TaskAbortedError(); + } + await this.startStream(state, config); + state.aggregate = state.descriptor.accumulator.accumulate(state.aggregate, chunk); + state.chunkCount += 1; + this.task.runOutputData = { + ...this.task.runOutputData, + [state.portId]: state.aggregate, + }; + this.task.emit("stream_chunk", state.portId, chunk, state.aggregate); + if (config.onStreamChunk) { + await config.onStreamChunk(this.task, state.portId, chunk, state.aggregate); + } + } + + protected async closeStreamState( + state: TaskStreamRuntimeState, + config: IRunConfig + ): Promise { + if (state.done) return; + state.done = true; + state.aggregate = state.descriptor.accumulator.complete(state.aggregate); + this.task.runOutputData = { + ...this.task.runOutputData, + [state.portId]: state.aggregate, + }; + this.task.emit("stream_end", state.portId, state.aggregate); + if (config.onStreamEnd) { + await config.onStreamEnd(this.task, state.portId, state.aggregate); + } + } + + protected async handleStreamError( + state: TaskStreamRuntimeState, + error: unknown, + config: IRunConfig + ): Promise { + state.done = true; + if (error instanceof TaskAbortedError) { + throw error; + } + const taskError = + error instanceof TaskError + ? error + : new TaskFailedError( + error instanceof Error ? error.message : String(error ?? "Stream error") + ); + if (config.onStreamError) { + await config.onStreamError(this.task, state.portId, taskError); + } + throw taskError; + } + + protected consumeStream( + state: TaskStreamRuntimeState, + stream: TaskStream, + config: IRunConfig + ): Promise { + return (async () => { + try { + for await (const chunk of toAsyncIterable(stream)) { + await this.pushStreamChunk(state, chunk, config); + } + await this.closeStreamState(state, config); + } catch (error) { + await this.handleStreamError(state, error, config); + } + })(); } /** @@ -206,6 +509,7 @@ export class TaskRunner< this.nodeProvenance = {}; this.running = true; + this.streamStates = null; this.task.startedAt = new Date(); this.task.progress = 0; @@ -254,6 +558,7 @@ export class TaskRunner< this.task.status = TaskStatus.ABORTING; this.task.progress = 100; this.task.error = new TaskAbortedError(); + this.streamStates = null; this.task.emit("abort", this.task.error); this.task.emit("status", this.task.status); } @@ -273,6 +578,7 @@ export class TaskRunner< this.task.status = TaskStatus.COMPLETED; this.abortController = undefined; this.nodeProvenance = {}; + this.streamStates = null; this.task.emit("complete"); this.task.emit("status", this.task.status); @@ -289,6 +595,7 @@ export class TaskRunner< this.task.completedAt = new Date(); this.abortController = undefined; this.nodeProvenance = {}; + this.streamStates = null; this.task.emit("skipped"); this.task.emit("status", this.task.status); } @@ -316,6 +623,7 @@ export class TaskRunner< err instanceof TaskError ? err : new TaskFailedError(err?.message || "Task failed"); this.abortController = undefined; this.nodeProvenance = {}; + this.streamStates = null; this.task.emit("error", this.task.error); this.task.emit("status", this.task.status); } diff --git a/packages/task-graph/src/task/TaskSchema.ts b/packages/task-graph/src/task/TaskSchema.ts index a4aa378d..56fa0607 100644 --- a/packages/task-graph/src/task/TaskSchema.ts +++ b/packages/task-graph/src/task/TaskSchema.ts @@ -15,6 +15,7 @@ import { Type, } from "@sinclair/typebox"; import type { JSONSchema7 } from "json-schema"; +import type { TaskStreamReadiness } from "./TaskStream"; export function TypeReplicateArray( type: T, @@ -56,5 +57,47 @@ export interface DataPortSchema readonly anyOf?: readonly DataPortSchema[]; readonly oneOf?: readonly DataPortSchema[]; readonly not?: DataPortSchema; + readonly ["x-stream"]?: DataPortStreamAnnotation; readonly [K: `x-${string}`]: unknown; } + +export const DATA_PORT_STREAM_METADATA_KEY = "x-stream" as const; + +export interface DataPortStreamAnnotation { + readonly streaming: true; + readonly readiness: TaskStreamReadiness; + readonly chunkSchema: DataPortSchema | null; +} + +export function createStreamAnnotation( + readiness: TaskStreamReadiness, + chunkSchema: DataPortSchema | null = null +): DataPortStreamAnnotation { + return { + streaming: true, + readiness, + chunkSchema, + }; +} + +export function withStreamAnnotation( + schema: DataPortSchema, + annotation: DataPortStreamAnnotation +): DataPortSchema { + return { + ...schema, + [DATA_PORT_STREAM_METADATA_KEY]: annotation, + }; +} + +export function getStreamAnnotation(schema: DataPortSchema): DataPortStreamAnnotation | null { + const annotation = schema[DATA_PORT_STREAM_METADATA_KEY]; + if ( + annotation && + typeof annotation === "object" && + (annotation as DataPortStreamAnnotation).streaming === true + ) { + return annotation as DataPortStreamAnnotation; + } + return null; +} diff --git a/packages/task-graph/src/task/TaskStream.ts b/packages/task-graph/src/task/TaskStream.ts new file mode 100644 index 00000000..dab8caee --- /dev/null +++ b/packages/task-graph/src/task/TaskStream.ts @@ -0,0 +1,146 @@ +// ******************************************************************************* +// * PODLEY.AI: Your Agentic AI library * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import type { DataPortSchema } from "./TaskSchema"; + +/** + * Describes a streaming-compatible value exposed by a task. + * Tasks may either yield values through AsyncIterables or ReadableStreams. + */ +export type TaskStream = AsyncIterable | ReadableStream; + +/** + * Defines when a downstream consumer can begin execution relative to a streaming output. + * - `first-chunk`: consumers may start when the first chunk is available. + * - `final`: consumers must wait for the stream to finish. + */ +export type TaskStreamReadiness = "first-chunk" | "final"; + +/** + * Aggregator used to project a stream of chunks into a final output value. + * Implementations must be pure; they receive the current aggregate state for every chunk + * and are expected to return a new state instance. + */ +export interface TaskStreamAccumulator { + readonly initial: () => Aggregate; + readonly accumulate: (current: Aggregate, chunk: Chunk) => Aggregate; + readonly complete: (current: Aggregate) => Aggregate; +} + +/** + * Metadata describing how a specific output port behaves when streaming. + */ +export interface TaskStreamPortDescriptor { + readonly chunkSchema: DataPortSchema | null; + readonly readiness: TaskStreamReadiness; + readonly accumulator: TaskStreamAccumulator; +} + +/** + * Metadata describing all stream-capable outputs for a task. + */ +export interface TaskStreamingDescriptor { + readonly outputs: Readonly>>; +} + +/** + * Type guard for ReadableStream values. + */ +export function isReadableStream( + value: unknown +): value is ReadableStream { + return ( + typeof value === "object" && + value !== null && + typeof (value as { getReader?: unknown }).getReader === "function" + ); +} + +/** + * Type guard for AsyncIterable values. + */ +export function isAsyncIterable(value: unknown): value is AsyncIterable { + return ( + value !== null && + typeof value === "object" && + typeof (value as { [Symbol.asyncIterator]?: unknown })[Symbol.asyncIterator] === "function" + ); +} + +/** + * Type guard that matches either ReadableStream or AsyncIterable values. + */ +export function isTaskStream(value: unknown): value is TaskStream { + return isReadableStream(value) || isAsyncIterable(value); +} + +/** + * Converts a ReadableStream into an AsyncIterableIterator. + */ +export async function* readableStreamToAsyncIterable( + stream: ReadableStream +): AsyncIterableIterator { + const reader = stream.getReader(); + try { + while (true) { + const result = await reader.read(); + if (result.done) { + return; + } + yield result.value; + } + } finally { + reader.releaseLock(); + } +} + +/** + * Ensures a streaming value exposes the AsyncIterable protocol. + */ +export function toAsyncIterable(stream: TaskStream): AsyncIterableIterator { + if (isAsyncIterable(stream)) { + return (async function* iterate() { + for await (const chunk of stream) { + yield chunk; + } + })(); + } + return readableStreamToAsyncIterable(stream); +} + +/** + * Creates a simple accumulator that collects stream chunks into an array. + */ +export function createArrayAccumulator(): TaskStreamAccumulator { + return { + initial: () => [], + accumulate: (current, chunk) => [...current, chunk], + complete: (current) => current, + }; +} + +/** + * Creates an accumulator that concatenates string chunks. + */ +export function createStringAccumulator(): TaskStreamAccumulator { + return { + initial: () => "", + accumulate: (current, chunk) => `${current}${chunk}`, + complete: (current) => current, + }; +} + +/** + * Creates an accumulator that always retains the most recent chunk. + */ +export function createLatestValueAccumulator(): TaskStreamAccumulator { + return { + initial: () => null, + accumulate: (_, chunk) => chunk, + complete: (current) => current, + }; +} diff --git a/packages/task-graph/src/task/TaskTypes.ts b/packages/task-graph/src/task/TaskTypes.ts index f2b8ea1d..b270de36 100644 --- a/packages/task-graph/src/task/TaskTypes.ts +++ b/packages/task-graph/src/task/TaskTypes.ts @@ -7,6 +7,13 @@ import { TaskOutputRepository } from "../storage/TaskOutputRepository"; import type { Task } from "./Task"; +export type { + TaskStream, + TaskStreamAccumulator, + TaskStreamPortDescriptor, + TaskStreamReadiness, + TaskStreamingDescriptor, +} from "./TaskStream"; /** * Enum representing the possible states of a task @@ -24,6 +31,8 @@ export enum TaskStatus { SKIPPED = "SKIPPED", /** Task is currently running */ PROCESSING = "PROCESSING", + /** Task is emitting streaming output */ + STREAMING = "STREAMING", /** Task has completed successfully */ COMPLETED = "COMPLETED", /** Task is in the process of being aborted */ diff --git a/packages/test/src/test/task-graph/StreamingTaskGraph.test.ts b/packages/test/src/test/task-graph/StreamingTaskGraph.test.ts new file mode 100644 index 00000000..b676949e --- /dev/null +++ b/packages/test/src/test/task-graph/StreamingTaskGraph.test.ts @@ -0,0 +1,201 @@ +// ******************************************************************************* +// * PODLEY.AI: Your Agentic AI library * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { describe, expect, it } from "bun:test"; +import { + Task, + TaskGraph, + TaskGraphRunner, + Dataflow, + type IExecuteContext, + type TaskConfig, + type TaskStreamingDescriptor, + type DataPortSchema, + createStringAccumulator, +} from "@podley/task-graph"; +import { Type, type TObject } from "@sinclair/typebox"; +import { sleep } from "@podley/util"; + +type ProducerInput = { + chunks: string[]; + delay: number; +}; + +type ProducerOutput = { + output: string; +}; + +abstract class BaseStreamingProducer extends Task { + static readiness: "first-chunk" | "final" = "first-chunk"; + protected readonly events: string[]; + + constructor(events: string[], input: ProducerInput, config?: TaskConfig) { + super(input, config); + this.events = events; + } + + static inputSchema(): TObject { + return Type.Object({ + chunks: Type.Array( + Type.String({ + description: "Chunk to emit", + }) + ), + delay: Type.Number({ + description: "Delay between chunks (ms)", + default: 0, + }), + }); + } + + static outputSchema(): TObject { + return Type.Object({ + output: Type.String({ + description: "Concatenated result", + }), + }); + } + + static streaming(): TaskStreamingDescriptor { + return { + outputs: { + output: { + chunkSchema: Type.String({ + description: "Streamed chunk", + }) as DataPortSchema, + readiness: this.readiness, + accumulator: createStringAccumulator(), + }, + }, + }; + } + + async execute(input: ProducerInput, context: IExecuteContext): Promise { + let final = ""; + for (const chunk of input.chunks) { + this.events.push(`chunk:${chunk}`); + await context.pushChunk("output", chunk); + final += chunk; + if (input.delay > 0) { + await sleep(input.delay); + } + } + await context.closeStream("output"); + this.events.push("producer-complete"); + return { output: final }; + } +} + +describe("Streaming task orchestration", () => { + it("allows dependants to start once the first chunk arrives", async () => { + class FirstChunkProducer extends BaseStreamingProducer { + static override readiness: "first-chunk" | "final" = "first-chunk"; + } + + class RecordingConsumer extends Task<{ input: string }, { length: number }> { + static readonly type = "RecordingConsumer"; + private readonly events: string[]; + + constructor(events: string[], config?: TaskConfig) { + super({}, config); + this.events = events; + } + + static inputSchema(): TObject { + return Type.Object({ + input: Type.String({ + description: "Aggregated streaming input", + default: "", + }), + }); + } + + static outputSchema(): TObject { + return Type.Object({ + length: Type.Number({ description: "Length of received string" }), + }); + } + + async execute(input: { input: string }): Promise<{ length: number }> { + this.events.push("consumer-start"); + return { length: input.input.length }; + } + } + + const events: string[] = []; + const graph = new TaskGraph(); + const producer = new FirstChunkProducer(events, { chunks: ["A", "B", "C"], delay: 20 }, { id: "producer" }); + const consumer = new RecordingConsumer(events, { id: "consumer" }); + + graph.addTasks([producer, consumer]); + graph.addDataflow(new Dataflow("producer", "output", "consumer", "input")); + + const runner = new TaskGraphRunner(graph); + const results = await runner.runGraph(); + + expect(events.indexOf("consumer-start")).toBeGreaterThan(events.indexOf("chunk:A")); + expect(events.indexOf("consumer-start")).toBeLessThan(events.indexOf("producer-complete")); + + const producerResult = results.find((entry) => entry.id === "producer"); + expect(producerResult?.data.output).toBe("ABC"); + + const consumerResult = results.find((entry) => entry.id === "consumer"); + expect(consumerResult?.data.length).toBe(3); + }); + + it("waits for stream completion when readiness is final", async () => { + class FinalProducer extends BaseStreamingProducer { + static override readiness: "first-chunk" | "final" = "final"; + } + + class CapturingConsumer extends Task<{ input: string }, { seen: string }> { + static readonly type = "CapturingConsumer"; + private readonly events: string[]; + + constructor(events: string[], config?: TaskConfig) { + super({}, config); + this.events = events; + } + + static inputSchema(): TObject { + return Type.Object({ + input: Type.String({ + description: "Aggregated streaming input", + default: "", + }), + }); + } + + static outputSchema(): TObject { + return Type.Object({ + seen: Type.String({ description: "Observed string" }), + }); + } + + async execute(input: { input: string }): Promise<{ seen: string }> { + this.events.push("consumer-start"); + return { seen: input.input }; + } + } + + const events: string[] = []; + const graph = new TaskGraph(); + const producer = new FinalProducer(events, { chunks: ["X", "Y"], delay: 20 }, { id: "producer" }); + const consumer = new CapturingConsumer(events, { id: "consumer" }); + + graph.addTasks([producer, consumer]); + graph.addDataflow(new Dataflow("producer", "output", "consumer", "input")); + + const runner = new TaskGraphRunner(graph); + const results = await runner.runGraph(); + + expect(events.indexOf("consumer-start")).toBeGreaterThan(events.indexOf("producer-complete")); + + const consumerResult = results.find((entry) => entry.id === "consumer"); + expect(consumerResult?.data.seen).toBe("XY"); + }); +});