diff --git a/.cursor/rules/testing.mdc b/.cursor/rules/testing.mdc index 0809cd15..2f373b87 100644 --- a/.cursor/rules/testing.mdc +++ b/.cursor/rules/testing.mdc @@ -4,4 +4,4 @@ alwaysApply: true Use bun, not jest, not npm, not node -To run tests: `bun test`. To run specific ones: `bun run `. +To run tests: `bun test`. To run specific ones: `bun test `. diff --git a/bun.lock b/bun.lock index 491d0497..fd123c85 100644 --- a/bun.lock +++ b/bun.lock @@ -257,8 +257,8 @@ "name": "@workglow/util", "version": "0.0.85", "dependencies": { + "@sroussey/json-schema-library": "^10.5.3", "@sroussey/json-schema-to-ts": "3.1.3", - "json-schema-library": "^10.5.1", }, }, }, @@ -613,7 +613,7 @@ "@rollup/rollup-win32-x64-msvc": ["@rollup/rollup-win32-x64-msvc@4.53.3", "", { "os": "win32", "cpu": "x64" }, "sha512-UhTd8u31dXadv0MopwGgNOBpUVROFKWVQgAg5N1ESyCz8AuBcMqm4AuTjrwgQKGDfoFuz02EuMRHQIw/frmYKQ=="], - "@sagold/json-pointer": ["@sagold/json-pointer@7.2.0", "", {}, "sha512-RZpwGl1yhNuzQVKOADJx65TrWL7T6HTGs2Rpv7KlbFY0CfbFWNAKsisvC/uGfchknCGJEnoxz9uPAdmgoAE3IA=="], + "@sagold/json-pointer": ["@sagold/json-pointer@7.2.1", "", {}, "sha512-8EX4r5Royl5M3qNPTh5W5njdOtRqbWgQfVv26DbzjGj2/55b60EqvbiqUIglzw7fADfY/Io6jDJxGNJHmC+g8g=="], "@sagold/json-query": ["@sagold/json-query@6.2.0", "", { "dependencies": { "@sagold/json-pointer": "^5.1.2", "ebnf": "^1.9.1" } }, "sha512-7bOIdUE6eHeoWtFm8TvHQHfTVSZuCs+3RpOKmZCDBIOrxpvF/rNFTeuvIyjHva/RR0yVS3kQtr+9TW72LQEZjA=="], @@ -621,6 +621,8 @@ "@sroussey/changesets-cli": ["@sroussey/changesets-cli@2.29.7", "", { "dependencies": { "@changesets/apply-release-plan": "^7.0.13", "@changesets/assemble-release-plan": "^6.0.9", "@changesets/changelog-git": "^0.2.1", "@changesets/config": "^3.1.1", "@changesets/errors": "^0.2.0", "@changesets/get-dependents-graph": "^2.1.3", "@changesets/get-release-plan": "^4.0.13", "@changesets/git": "^3.0.4", "@changesets/logger": "^0.1.1", "@changesets/pre": "^2.0.2", "@changesets/read": "^0.6.5", "@changesets/should-skip-package": "^0.1.2", "@changesets/types": "^6.1.0", "@changesets/write": "^0.4.0", "@inquirer/external-editor": "^1.0.0", "@manypkg/get-packages": "^1.1.3", "ansi-colors": "^4.1.3", "ci-info": "^3.7.0", "enquirer": "^2.4.1", "fs-extra": "^7.0.1", "mri": "^1.2.0", "p-limit": "^2.2.0", "package-manager-detector": "^0.2.0", "picocolors": "^1.1.0", "resolve-from": "^5.0.0", "semver": "^7.7.3", "spawndamnit": "^3.0.1", "term-size": "^2.1.0" }, "bin": { "changeset": "bin.js" } }, "sha512-y47qkQTbei/sDIk0//S1bBXnOxPkrmeXjHxljubE9xqS+k8h+RAt+J0RyzJxk3SUYgY7D2vYJMQw4s9865F34w=="], + "@sroussey/json-schema-library": ["@sroussey/json-schema-library@10.5.3", "", { "dependencies": { "@sagold/json-pointer": "^7.2.1", "@sagold/json-query": "^6.2.0", "deepmerge": "^4.3.1", "fast-copy": "^3.0.2", "fast-deep-equal": "^3.1.3", "smtp-address-parser": "1.0.10", "uri-js": "^4.4.1", "valid-url": "^1.0.9" } }, "sha512-B+4Q84gJk56qAuM/4UAOm/pmmy2p+YtLChl0JBhgz8Qk4J0lDxgJh0JtT/CX5mew7Eq5nv6t9zuqdFLM6yQoDQ=="], + "@sroussey/json-schema-to-ts": ["@sroussey/json-schema-to-ts@3.1.3", "", { "dependencies": { "ts-algebra": "^2.0.0" } }, "sha512-N4j/Mz1YkZHvQfStIvtS4DiQLltzzU84jFt6qoo0DsUHV+n3UDfduWlYQSwov8gS9iJliIJ4L4Vb15k5HVdLwg=="], "@sroussey/transformers": ["@sroussey/transformers@3.8.2", "", { "dependencies": { "@huggingface/jinja": "^0.5.3", "onnxruntime-node": "1.23.2", "onnxruntime-web": "1.23.2", "sharp": "^0.34.5" } }, "sha512-K9g7aGnUZ8xdBBhrt6rZIB1rnFY1H4VggGCsrJoEL6tvhm5/Z+VpAHGduvdNS/s2a3lTsJt+scxfTU4DQ0T5JA=="], @@ -1241,8 +1243,6 @@ "json-buffer": ["json-buffer@3.0.1", "", {}, "sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ=="], - "json-schema-library": ["json-schema-library@10.5.1", "", { "dependencies": { "@sagold/json-pointer": "^7.2.0", "@sagold/json-query": "^6.2.0", "deepmerge": "^4.3.1", "fast-copy": "^3.0.2", "fast-deep-equal": "^3.1.3", "smtp-address-parser": "1.0.10", "uri-js": "^4.4.1", "valid-url": "^1.0.9" } }, "sha512-QDKmtWbgHoxzZEBZ3XESZBQprpgfSlOezQC+wKukZJzNOlBc8nomWZxYBY4qFGKawmtWkLRmZUDW34WlKVhAug=="], - "json-schema-traverse": ["json-schema-traverse@0.4.1", "", {}, "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg=="], "json-stable-stringify-without-jsonify": ["json-stable-stringify-without-jsonify@1.0.1", "", {}, "sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw=="], diff --git a/docs/background/06_run_graph_orchestration.md b/docs/background/06_run_graph_orchestration.md index 200f9da6..ac809fb6 100644 --- a/docs/background/06_run_graph_orchestration.md +++ b/docs/background/06_run_graph_orchestration.md @@ -10,7 +10,7 @@ The editor DAG is defined by the end user and saved in the database (tasks and d ## Graph -The graph is a DAG. It is a list of tasks and a list of dataflows. The tasks are the nodes and the dataflows are the connections between task outputs and inputs, plus status and provenance. +The graph is a DAG. It is a list of tasks and a list of dataflows. The tasks are the nodes and the dataflows are the connections between task outputs and inputs, plus status. We expose events for graphs, tasks, and dataflows. A suspend/resume could be added for bulk creation. This helps keep a UI in sync as the graph runs. diff --git a/docs/background/07_intrumentation.md b/docs/background/07_intrumentation.md index 6579d775..4b3ace64 100644 --- a/docs/background/07_intrumentation.md +++ b/docs/background/07_intrumentation.md @@ -7,5 +7,5 @@ Instrumentation is the process of adding code to a program to collect data about Some of these tools cost money, so we need to track and estimate costs. - Tasks emit status/progress events (`TaskStatus`, progress percent) -- Dataflows emit start/complete/error events and carry provenance +- Dataflows emit start/complete/error events - Task graphs emit start/progress/complete/error events diff --git a/docs/developers/02_architecture.md b/docs/developers/02_architecture.md index 680605c4..49fb3bd0 100644 --- a/docs/developers/02_architecture.md +++ b/docs/developers/02_architecture.md @@ -242,11 +242,10 @@ classDiagram class TaskGraphRunner{ Map layers - Map provenanceInput TaskGraph dag TaskOutputRepository repository assignLayers(Task[] sortedNodes) - runGraph(TaskInput parentProvenance) TaskOutput + runGraph(TaskInput input) TaskOutput runGraphReactive() TaskOutput } @@ -255,7 +254,6 @@ classDiagram The TaskGraphRunner is responsible for executing tasks in a task graph. Key features include: - **Layer-based Execution**: Tasks are organized into layers based on dependencies, allowing parallel execution of independent tasks -- **Provenance Tracking**: Tracks the lineage and input data that led to each task's output - **Caching Support**: Can use a TaskOutputRepository to cache task outputs and avoid re-running tasks - **Reactive Mode**: Supports reactive execution where tasks can respond to input changes without full re-execution - **Smart Task Scheduling**: Automatically determines task execution order based on dependencies diff --git a/docs/developers/03_extending.md b/docs/developers/03_extending.md index 8f740c5d..13d24d45 100644 --- a/docs/developers/03_extending.md +++ b/docs/developers/03_extending.md @@ -6,6 +6,7 @@ This document covers how to write your own tasks. For a more practical guide to - [Tasks must have a `run()` method](#tasks-must-have-a-run-method) - [Define Inputs and Outputs](#define-inputs-and-outputs) - [Register the Task](#register-the-task) +- [Schema Format Annotations](#schema-format-annotations) - [Job Queues and LLM tasks](#job-queues-and-llm-tasks) - [Write a new Compound Task](#write-a-new-compound-task) - [Reactive Task UIs](#reactive-task-uis) @@ -117,7 +118,7 @@ To use the Task in Workflow, there are a few steps: ```ts export const simpleDebug = (input: DebugLogTaskInput) => { - return new SimpleDebugTask(input).run(); + return new SimpleDebugTask({} as DebugLogTaskInput, {}).run(input); }; declare module "@workglow/task-graph" { @@ -129,6 +130,103 @@ declare module "@workglow/task-graph" { Workflow.prototype.simpleDebug = CreateWorkflow(SimpleDebugTask); ``` +## Schema Format Annotations + +When defining task input schemas, you can use `format` annotations to enable automatic resolution of string identifiers to object instances. The TaskRunner inspects input schemas and resolves annotated string values before task execution. + +### Built-in Format Annotations + +The system supports several format annotations out of the box: + +| Format | Description | Helper Function | +| --------------------- | ----------------------------------- | -------------------------- | +| `model` | Any AI model configuration | — | +| `model:TaskName` | Model compatible with specific task | — | +| `repository:tabular` | Tabular data repository | `TypeTabularRepository()` | +| `repository:vector` | Vector storage repository | `TypeVectorRepository()` | +| `repository:document` | Document repository | `TypeDocumentRepository()` | + +### Example: Using Format Annotations + +```typescript +import { Task, type DataPortSchema } from "@workglow/task-graph"; +import { TypeTabularRepository } from "@workglow/storage"; +import { FromSchema } from "@workglow/util"; + +const MyTaskInputSchema = { + type: "object", + properties: { + // Model input - accepts string ID or ModelConfig object + model: { + title: "AI Model", + description: "Model for text generation", + format: "model:TextGenerationTask", + oneOf: [ + { type: "string", title: "Model ID" }, + { type: "object", title: "Model Config" }, + ], + }, + // Repository input - uses helper function + dataSource: TypeTabularRepository({ + title: "Data Source", + description: "Repository containing source data", + }), + // Regular string input (no resolution) + prompt: { type: "string", title: "Prompt" }, + }, + required: ["model", "dataSource", "prompt"], +} as const satisfies DataPortSchema; + +type MyTaskInput = FromSchema; + +export class MyTask extends Task { + static readonly type = "MyTask"; + static inputSchema = () => MyTaskInputSchema; + + async executeReactive(input: MyTaskInput) { + // By the time execute runs, model is a ModelConfig object + // and dataSource is an ITabularRepository instance + const { model, dataSource, prompt } = input; + // ... + } +} +``` + +### Creating Custom Format Resolvers + +You can extend the resolution system by registering custom resolvers: + +```typescript +import { registerInputResolver } from "@workglow/util"; + +// Register a resolver for "template:*" formats +registerInputResolver("template", async (id, format, registry) => { + const templateRepo = registry.get(TEMPLATE_REPOSITORY); + const template = await templateRepo.findById(id); + if (!template) { + throw new Error(`Template "${id}" not found`); + } + return template; +}); +``` + +Then use it in your schemas: + +```typescript +const inputSchema = { + type: "object", + properties: { + emailTemplate: { + type: "string", + format: "template:email", + title: "Email Template", + }, + }, +}; +``` + +When a task runs with `{ emailTemplate: "welcome-email" }`, the resolver automatically converts it to the template object before execution. + ## Job Queues and LLM tasks We separate any long running tasks as Jobs. Jobs could potentially be run anywhere, either locally in the same thread, in separate threads, or on a remote server. A job queue will manage these for a single provider (like OpenAI, or a local Transformers.js ONNX runtime), and handle backoff, retries, etc. @@ -148,3 +246,144 @@ Compound Tasks are not cached (though any or all of their children may be). ## Reactive Task UIs Tasks can be reactive at a certain level. This means that they can be triggered by changes in the data they depend on, without "running" the expensive job based task runs. This is useful for a UI node editor. For example, you change a color in one task and it is propagated downstream without incurring costs for re-running the entire graph. It is like a spreadsheet where changing a cell can trigger a recalculation of other cells. This is implemented via a `runReactive()` method that is called when the data changes. Typically, the `run()` will call `runReactive()` on itself at the end of the method. + +## AI and RAG Tasks + +The `@workglow/ai` package provides a comprehensive set of tasks for building RAG (Retrieval-Augmented Generation) pipelines. These tasks are designed to chain together in workflows without requiring external loops. + +### Document Processing Tasks + +| Task | Description | +| ------------------------- | ----------------------------------------------------- | +| `StructuralParserTask` | Parses markdown/text into hierarchical document trees | +| `TextChunkerTask` | Splits text into chunks with configurable strategies | +| `HierarchicalChunkerTask` | Token-aware chunking that respects document structure | +| `TopicSegmenterTask` | Segments text by topic using heuristics or embeddings | +| `DocumentEnricherTask` | Adds summaries and entities to document nodes | + +### Vector and Embedding Tasks + +| Task | Description | +| ----------------------- | ---------------------------------------------- | +| `TextEmbeddingTask` | Generates embeddings using configurable models | +| `ChunkToVectorTask` | Transforms chunks to vector store format | +| `VectorStoreUpsertTask` | Stores vectors in a repository | +| `VectorStoreSearchTask` | Searches vectors by similarity | +| `VectorQuantizeTask` | Quantizes vectors for storage efficiency | + +### Retrieval and Generation Tasks + +| Task | Description | +| ------------------------ | --------------------------------------------- | +| `QueryExpanderTask` | Expands queries for better retrieval coverage | +| `HybridSearchTask` | Combines vector and full-text search | +| `RerankerTask` | Reranks search results for relevance | +| `HierarchyJoinTask` | Enriches results with parent context | +| `ContextBuilderTask` | Builds context for LLM prompts | +| `RetrievalTask` | Orchestrates end-to-end retrieval | +| `TextQuestionAnswerTask` | Generates answers from context | +| `TextGenerationTask` | General text generation | + +### Chainable RAG Pipeline Example + +Tasks chain together through compatible input/output schemas: + +```typescript +import { Workflow } from "@workglow/task-graph"; +import { InMemoryVectorRepository } from "@workglow/storage"; + +const vectorRepo = new InMemoryVectorRepository(); +await vectorRepo.setupDatabase(); + +// Document ingestion pipeline +await new Workflow() + .structuralParser({ + text: markdownContent, + title: "My Document", + format: "markdown", + }) + .documentEnricher({ + generateSummaries: true, + extractEntities: true, + }) + .hierarchicalChunker({ + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }) + .textEmbedding({ + model: "Xenova/all-MiniLM-L6-v2", + }) + .chunkToVector() + .vectorStoreUpsert({ + repository: vectorRepo, + }) + .run(); +``` + +### Retrieval Pipeline Example + +```typescript +const answer = await new Workflow() + .textEmbedding({ + text: query, + model: "Xenova/all-MiniLM-L6-v2", + }) + .vectorStoreSearch({ + repository: vectorRepo, + topK: 10, + }) + .reranker({ + query, + topK: 5, + }) + .contextBuilder({ + format: "markdown", + maxLength: 2000, + }) + .textQuestionAnswer({ + question: query, + model: "Xenova/LaMini-Flan-T5-783M", + }) + .run(); +``` + +### Hierarchical Document Structure + +Documents are represented as trees with typed nodes: + +```typescript +type DocumentNode = + | DocumentRootNode // Root of document + | SectionNode // Headers, structural sections + | ParagraphNode // Text blocks + | SentenceNode // Fine-grained (optional) + | TopicNode; // Detected topic segments + +// Each node contains: +interface BaseNode { + nodeId: string; // Deterministic content-based ID + range: { start: number; end: number }; + text: string; + enrichment?: { + summary?: string; + entities?: Entity[]; + keywords?: string[]; + }; +} +``` + +### Task Data Flow + +Each task passes through what the next task needs: + +| Task | Passes Through | Adds | +| --------------------- | ----------------------- | ------------------------------------ | +| `structuralParser` | - | `docId`, `documentTree`, `nodeCount` | +| `documentEnricher` | `docId`, `documentTree` | `summaryCount`, `entityCount` | +| `hierarchicalChunker` | `docId` | `chunks`, `text[]`, `count` | +| `textEmbedding` | (implicit) | `vector[]` | +| `chunkToVector` | - | `ids[]`, `vectors[]`, `metadata[]` | +| `vectorStoreUpsert` | - | `count`, `ids` | + +This design eliminates the need for external loops - the entire pipeline chains together naturally. diff --git a/examples/cli/src/TaskCLI.ts b/examples/cli/src/TaskCLI.ts index 5b1df14b..a6177154 100644 --- a/examples/cli/src/TaskCLI.ts +++ b/examples/cli/src/TaskCLI.ts @@ -42,7 +42,7 @@ export function AddBaseCommands(program: Command) { ? (await getGlobalModelRepository().findByName(options.model))?.model_id : (await getGlobalModelRepository().findModelsByTask("TextEmbeddingTask"))?.map( (m) => m.model_id - ); + )?.[0]; if (!model) { program.error(`Unknown model ${options.model}`); @@ -67,7 +67,7 @@ export function AddBaseCommands(program: Command) { ? (await getGlobalModelRepository().findByName(options.model))?.model_id : (await getGlobalModelRepository().findModelsByTask("TextSummaryTask"))?.map( (m) => m.model_id - ); + )?.[0]; if (!model) { program.error(`Unknown model ${options.model}`); } else { @@ -92,7 +92,7 @@ export function AddBaseCommands(program: Command) { ? (await getGlobalModelRepository().findByName(options.model))?.model_id : (await getGlobalModelRepository().findModelsByTask("TextRewriterTask"))?.map( (m) => m.model_id - ); + )?.[0]; if (!model) { program.error(`Unknown model ${options.model}`); } else { diff --git a/packages/ai-provider/README.md b/packages/ai-provider/README.md index b652c88b..87d559a6 100644 --- a/packages/ai-provider/README.md +++ b/packages/ai-provider/README.md @@ -138,7 +138,7 @@ const task = new TextEmbeddingTask({ }); const result = await task.execute(); -// result.vector: TypedArray - Vector embedding +// result.vector: Vector - Vector embedding ``` **Text Translation:** diff --git a/packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts b/packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts index 3f48d98c..60c3598c 100644 --- a/packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts +++ b/packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts @@ -37,44 +37,45 @@ import { } from "@sroussey/transformers"; import type { AiProviderRunFn, - BackgroundRemovalTaskExecuteInput, - BackgroundRemovalTaskExecuteOutput, - DownloadModelTaskExecuteInput, - DownloadModelTaskExecuteOutput, - ImageClassificationTaskExecuteInput, - ImageClassificationTaskExecuteOutput, - ImageEmbeddingTaskExecuteInput, - ImageEmbeddingTaskExecuteOutput, - ImageSegmentationTaskExecuteInput, - ImageSegmentationTaskExecuteOutput, - ImageToTextTaskExecuteInput, - ImageToTextTaskExecuteOutput, - ObjectDetectionTaskExecuteInput, - ObjectDetectionTaskExecuteOutput, - TextClassificationTaskExecuteInput, - TextClassificationTaskExecuteOutput, - TextEmbeddingTaskExecuteInput, - TextEmbeddingTaskExecuteOutput, - TextFillMaskTaskExecuteInput, - TextFillMaskTaskExecuteOutput, - TextGenerationTaskExecuteInput, - TextGenerationTaskExecuteOutput, - TextLanguageDetectionTaskExecuteInput, - TextLanguageDetectionTaskExecuteOutput, - TextNamedEntityRecognitionTaskExecuteInput, - TextNamedEntityRecognitionTaskExecuteOutput, - TextQuestionAnswerTaskExecuteInput, - TextQuestionAnswerTaskExecuteOutput, - TextRewriterTaskExecuteInput, - TextRewriterTaskExecuteOutput, - TextSummaryTaskExecuteInput, - TextSummaryTaskExecuteOutput, - TextTranslationTaskExecuteInput, - TextTranslationTaskExecuteOutput, - TypedArray, - UnloadModelTaskExecuteInput, - UnloadModelTaskExecuteOutput, + BackgroundRemovalTaskInput, + BackgroundRemovalTaskOutput, + DownloadModelTaskRunInput, + DownloadModelTaskRunOutput, + ImageClassificationTaskInput, + ImageClassificationTaskOutput, + ImageEmbeddingTaskInput, + ImageEmbeddingTaskOutput, + ImageSegmentationTaskInput, + ImageSegmentationTaskOutput, + ImageToTextTaskInput, + ImageToTextTaskOutput, + ObjectDetectionTaskInput, + ObjectDetectionTaskOutput, + TextClassificationTaskInput, + TextClassificationTaskOutput, + TextEmbeddingTaskInput, + TextEmbeddingTaskOutput, + TextFillMaskTaskInput, + TextFillMaskTaskOutput, + TextGenerationTaskInput, + TextGenerationTaskOutput, + TextLanguageDetectionTaskInput, + TextLanguageDetectionTaskOutput, + TextNamedEntityRecognitionTaskInput, + TextNamedEntityRecognitionTaskOutput, + TextQuestionAnswerTaskInput, + TextQuestionAnswerTaskOutput, + TextRewriterTaskInput, + TextRewriterTaskOutput, + TextSummaryTaskInput, + TextSummaryTaskOutput, + TextTranslationTaskInput, + TextTranslationTaskOutput, + UnloadModelTaskRunInput, + UnloadModelTaskRunOutput, } from "@workglow/ai"; + +import { TypedArray } from "@workglow/util"; import { CallbackStatus } from "./HFT_CallbackStatus"; import { HTF_CACHE_NAME } from "./HFT_Constants"; import { HfTransformersOnnxModelConfig } from "./HFT_ModelSchema"; @@ -441,8 +442,8 @@ const getPipeline = async ( * This is shared between inline and worker implementations. */ export const HFT_Download: AiProviderRunFn< - DownloadModelTaskExecuteInput, - DownloadModelTaskExecuteOutput, + DownloadModelTaskRunInput, + DownloadModelTaskRunOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { // Download the model by creating a pipeline @@ -459,8 +460,8 @@ export const HFT_Download: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const HFT_Unload: AiProviderRunFn< - UnloadModelTaskExecuteInput, - UnloadModelTaskExecuteOutput, + UnloadModelTaskRunInput, + UnloadModelTaskRunOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { // Delete the pipeline from the in-memory map @@ -524,8 +525,8 @@ const deleteModelCache = async (model_path: string): Promise => { */ export const HFT_TextEmbedding: AiProviderRunFn< - TextEmbeddingTaskExecuteInput, - TextEmbeddingTaskExecuteOutput, + TextEmbeddingTaskInput, + TextEmbeddingTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const generateEmbedding: FeatureExtractionPipeline = await getPipeline(model!, onProgress, { @@ -555,8 +556,8 @@ export const HFT_TextEmbedding: AiProviderRunFn< }; export const HFT_TextClassification: AiProviderRunFn< - TextClassificationTaskExecuteInput, - TextClassificationTaskExecuteOutput, + TextClassificationTaskInput, + TextClassificationTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { if (model?.provider_config?.pipeline === "zero-shot-classification") { @@ -611,8 +612,8 @@ export const HFT_TextClassification: AiProviderRunFn< }; export const HFT_TextLanguageDetection: AiProviderRunFn< - TextLanguageDetectionTaskExecuteInput, - TextLanguageDetectionTaskExecuteOutput, + TextLanguageDetectionTaskInput, + TextLanguageDetectionTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const TextClassification: TextClassificationPipeline = await getPipeline(model!, onProgress, { @@ -641,8 +642,8 @@ export const HFT_TextLanguageDetection: AiProviderRunFn< }; export const HFT_TextNamedEntityRecognition: AiProviderRunFn< - TextNamedEntityRecognitionTaskExecuteInput, - TextNamedEntityRecognitionTaskExecuteOutput, + TextNamedEntityRecognitionTaskInput, + TextNamedEntityRecognitionTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const textNamedEntityRecognition: TokenClassificationPipeline = await getPipeline( @@ -672,8 +673,8 @@ export const HFT_TextNamedEntityRecognition: AiProviderRunFn< }; export const HFT_TextFillMask: AiProviderRunFn< - TextFillMaskTaskExecuteInput, - TextFillMaskTaskExecuteOutput, + TextFillMaskTaskInput, + TextFillMaskTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const unmasker: FillMaskPipeline = await getPipeline(model!, onProgress, { @@ -700,8 +701,8 @@ export const HFT_TextFillMask: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const HFT_TextGeneration: AiProviderRunFn< - TextGenerationTaskExecuteInput, - TextGenerationTaskExecuteOutput, + TextGenerationTaskInput, + TextGenerationTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const generateText: TextGenerationPipeline = await getPipeline(model!, onProgress, { @@ -733,8 +734,8 @@ export const HFT_TextGeneration: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const HFT_TextTranslation: AiProviderRunFn< - TextTranslationTaskExecuteInput, - TextTranslationTaskExecuteOutput, + TextTranslationTaskInput, + TextTranslationTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const translate: TranslationPipeline = await getPipeline(model!, onProgress, { @@ -749,12 +750,9 @@ export const HFT_TextTranslation: AiProviderRunFn< ...(signal ? { abort_signal: signal } : {}), } as any); - let translatedText: string | string[] = ""; - if (Array.isArray(result)) { - translatedText = result.map((r) => (r as TranslationSingle)?.translation_text || ""); - } else { - translatedText = (result as TranslationSingle)?.translation_text || ""; - } + const translatedText = Array.isArray(result) + ? (result[0] as TranslationSingle)?.translation_text || "" + : (result as TranslationSingle)?.translation_text || ""; return { text: translatedText, @@ -767,8 +765,8 @@ export const HFT_TextTranslation: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const HFT_TextRewriter: AiProviderRunFn< - TextRewriterTaskExecuteInput, - TextRewriterTaskExecuteOutput, + TextRewriterTaskInput, + TextRewriterTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const generateText: TextGenerationPipeline = await getPipeline(model!, onProgress, { @@ -807,8 +805,8 @@ export const HFT_TextRewriter: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const HFT_TextSummary: AiProviderRunFn< - TextSummaryTaskExecuteInput, - TextSummaryTaskExecuteOutput, + TextSummaryTaskInput, + TextSummaryTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const generateSummary: SummarizationPipeline = await getPipeline(model!, onProgress, { @@ -838,8 +836,8 @@ export const HFT_TextSummary: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const HFT_TextQuestionAnswer: AiProviderRunFn< - TextQuestionAnswerTaskExecuteInput, - TextQuestionAnswerTaskExecuteOutput, + TextQuestionAnswerTaskInput, + TextQuestionAnswerTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { // Get the question answering pipeline @@ -869,8 +867,8 @@ export const HFT_TextQuestionAnswer: AiProviderRunFn< * Core implementation for image segmentation using Hugging Face Transformers. */ export const HFT_ImageSegmentation: AiProviderRunFn< - ImageSegmentationTaskExecuteInput, - ImageSegmentationTaskExecuteOutput, + ImageSegmentationTaskInput, + ImageSegmentationTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const segmenter: ImageSegmentationPipeline = await getPipeline(model!, onProgress, { @@ -902,8 +900,8 @@ export const HFT_ImageSegmentation: AiProviderRunFn< * Core implementation for image to text using Hugging Face Transformers. */ export const HFT_ImageToText: AiProviderRunFn< - ImageToTextTaskExecuteInput, - ImageToTextTaskExecuteOutput, + ImageToTextTaskInput, + ImageToTextTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const captioner: ImageToTextPipeline = await getPipeline(model!, onProgress, { @@ -926,8 +924,8 @@ export const HFT_ImageToText: AiProviderRunFn< * Core implementation for background removal using Hugging Face Transformers. */ export const HFT_BackgroundRemoval: AiProviderRunFn< - BackgroundRemovalTaskExecuteInput, - BackgroundRemovalTaskExecuteOutput, + BackgroundRemovalTaskInput, + BackgroundRemovalTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const remover: BackgroundRemovalPipeline = await getPipeline(model!, onProgress, { @@ -949,8 +947,8 @@ export const HFT_BackgroundRemoval: AiProviderRunFn< * Core implementation for image embedding using Hugging Face Transformers. */ export const HFT_ImageEmbedding: AiProviderRunFn< - ImageEmbeddingTaskExecuteInput, - ImageEmbeddingTaskExecuteOutput, + ImageEmbeddingTaskInput, + ImageEmbeddingTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { const embedder: ImageFeatureExtractionPipeline = await getPipeline(model!, onProgress, { @@ -961,7 +959,7 @@ export const HFT_ImageEmbedding: AiProviderRunFn< return { vector: result.data as TypedArray, - }; + } as ImageEmbeddingTaskOutput; }; /** @@ -969,8 +967,8 @@ export const HFT_ImageEmbedding: AiProviderRunFn< * Auto-selects between regular and zero-shot classification. */ export const HFT_ImageClassification: AiProviderRunFn< - ImageClassificationTaskExecuteInput, - ImageClassificationTaskExecuteOutput, + ImageClassificationTaskInput, + ImageClassificationTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { if (model?.provider_config?.pipeline === "zero-shot-image-classification") { @@ -1024,8 +1022,8 @@ export const HFT_ImageClassification: AiProviderRunFn< * Auto-selects between regular and zero-shot detection. */ export const HFT_ObjectDetection: AiProviderRunFn< - ObjectDetectionTaskExecuteInput, - ObjectDetectionTaskExecuteOutput, + ObjectDetectionTaskInput, + ObjectDetectionTaskOutput, HfTransformersOnnxModelConfig > = async (input, model, onProgress, signal) => { if (model?.provider_config?.pipeline === "zero-shot-object-detection") { diff --git a/packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts b/packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts index d01ff205..6ce04d0c 100644 --- a/packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts +++ b/packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts @@ -23,34 +23,34 @@ import { } from "@mediapipe/tasks-vision"; import type { AiProviderRunFn, - DownloadModelTaskExecuteInput, - DownloadModelTaskExecuteOutput, - FaceDetectorTaskExecuteInput, - FaceDetectorTaskExecuteOutput, - FaceLandmarkerTaskExecuteInput, - FaceLandmarkerTaskExecuteOutput, - GestureRecognizerTaskExecuteInput, - GestureRecognizerTaskExecuteOutput, - HandLandmarkerTaskExecuteInput, - HandLandmarkerTaskExecuteOutput, - ImageClassificationTaskExecuteInput, - ImageClassificationTaskExecuteOutput, - ImageEmbeddingTaskExecuteInput, - ImageEmbeddingTaskExecuteOutput, - ImageSegmentationTaskExecuteInput, - ImageSegmentationTaskExecuteOutput, - ObjectDetectionTaskExecuteInput, - ObjectDetectionTaskExecuteOutput, - PoseLandmarkerTaskExecuteInput, - PoseLandmarkerTaskExecuteOutput, - TextClassificationTaskExecuteInput, - TextClassificationTaskExecuteOutput, - TextEmbeddingTaskExecuteInput, - TextEmbeddingTaskExecuteOutput, - TextLanguageDetectionTaskExecuteInput, - TextLanguageDetectionTaskExecuteOutput, - UnloadModelTaskExecuteInput, - UnloadModelTaskExecuteOutput, + DownloadModelTaskRunInput, + DownloadModelTaskRunOutput, + FaceDetectorTaskInput, + FaceDetectorTaskOutput, + FaceLandmarkerTaskInput, + FaceLandmarkerTaskOutput, + GestureRecognizerTaskInput, + GestureRecognizerTaskOutput, + HandLandmarkerTaskInput, + HandLandmarkerTaskOutput, + ImageClassificationTaskInput, + ImageClassificationTaskOutput, + ImageEmbeddingTaskInput, + ImageEmbeddingTaskOutput, + ImageSegmentationTaskInput, + ImageSegmentationTaskOutput, + ObjectDetectionTaskInput, + ObjectDetectionTaskOutput, + PoseLandmarkerTaskInput, + PoseLandmarkerTaskOutput, + TextClassificationTaskInput, + TextClassificationTaskOutput, + TextEmbeddingTaskInput, + TextEmbeddingTaskOutput, + TextLanguageDetectionTaskInput, + TextLanguageDetectionTaskOutput, + UnloadModelTaskRunInput, + UnloadModelTaskRunOutput, } from "@workglow/ai"; import { PermanentJobError } from "@workglow/job-queue"; import { TFMPModelConfig } from "./TFMP_ModelSchema"; @@ -262,8 +262,8 @@ const getModelTask = async ( * This is shared between inline and worker implementations. */ export const TFMP_Download: AiProviderRunFn< - DownloadModelTaskExecuteInput, - DownloadModelTaskExecuteOutput, + DownloadModelTaskRunInput, + DownloadModelTaskRunOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { let task: TaskInstance; @@ -327,8 +327,8 @@ export const TFMP_Download: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const TFMP_TextEmbedding: AiProviderRunFn< - TextEmbeddingTaskExecuteInput, - TextEmbeddingTaskExecuteOutput, + TextEmbeddingTaskInput, + TextEmbeddingTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const textEmbedder = await getModelTask(model!, {}, onProgress, signal, TextEmbedder); @@ -350,8 +350,8 @@ export const TFMP_TextEmbedding: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const TFMP_TextClassification: AiProviderRunFn< - TextClassificationTaskExecuteInput, - TextClassificationTaskExecuteOutput, + TextClassificationTaskInput, + TextClassificationTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const TextClassification = await getModelTask( @@ -387,8 +387,8 @@ export const TFMP_TextClassification: AiProviderRunFn< * This is shared between inline and worker implementations. */ export const TFMP_TextLanguageDetection: AiProviderRunFn< - TextLanguageDetectionTaskExecuteInput, - TextLanguageDetectionTaskExecuteOutput, + TextLanguageDetectionTaskInput, + TextLanguageDetectionTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const maxLanguages = input.maxLanguages === 0 ? -1 : input.maxLanguages; @@ -431,8 +431,8 @@ export const TFMP_TextLanguageDetection: AiProviderRunFn< * 3. If no other models are using the WASM fileset (count reaches 0), unloads the WASM */ export const TFMP_Unload: AiProviderRunFn< - UnloadModelTaskExecuteInput, - UnloadModelTaskExecuteOutput, + UnloadModelTaskRunInput, + UnloadModelTaskRunOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const model_path = model!.provider_config.model_path; @@ -471,8 +471,8 @@ export const TFMP_Unload: AiProviderRunFn< * Core implementation for image segmentation using MediaPipe. */ export const TFMP_ImageSegmentation: AiProviderRunFn< - ImageSegmentationTaskExecuteInput, - ImageSegmentationTaskExecuteOutput, + ImageSegmentationTaskInput, + ImageSegmentationTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const imageSegmenter = await getModelTask(model!, {}, onProgress, signal, ImageSegmenter); @@ -504,8 +504,8 @@ export const TFMP_ImageSegmentation: AiProviderRunFn< * Core implementation for image embedding using MediaPipe. */ export const TFMP_ImageEmbedding: AiProviderRunFn< - ImageEmbeddingTaskExecuteInput, - ImageEmbeddingTaskExecuteOutput, + ImageEmbeddingTaskInput, + ImageEmbeddingTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const imageEmbedder = await getModelTask(model!, {}, onProgress, signal, ImageEmbedder); @@ -519,15 +519,15 @@ export const TFMP_ImageEmbedding: AiProviderRunFn< return { vector: embedding, - }; + } as ImageEmbeddingTaskOutput; }; /** * Core implementation for image classification using MediaPipe. */ export const TFMP_ImageClassification: AiProviderRunFn< - ImageClassificationTaskExecuteInput, - ImageClassificationTaskExecuteOutput, + ImageClassificationTaskInput, + ImageClassificationTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const imageClassifier = await getModelTask( @@ -559,8 +559,8 @@ export const TFMP_ImageClassification: AiProviderRunFn< * Core implementation for object detection using MediaPipe. */ export const TFMP_ObjectDetection: AiProviderRunFn< - ObjectDetectionTaskExecuteInput, - ObjectDetectionTaskExecuteOutput, + ObjectDetectionTaskInput, + ObjectDetectionTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const objectDetector = await getModelTask( @@ -598,8 +598,8 @@ export const TFMP_ObjectDetection: AiProviderRunFn< * Core implementation for gesture recognition using MediaPipe. */ export const TFMP_GestureRecognizer: AiProviderRunFn< - GestureRecognizerTaskExecuteInput, - GestureRecognizerTaskExecuteOutput, + GestureRecognizerTaskInput, + GestureRecognizerTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const gestureRecognizer = await getModelTask( @@ -650,8 +650,8 @@ export const TFMP_GestureRecognizer: AiProviderRunFn< * Core implementation for hand landmark detection using MediaPipe. */ export const TFMP_HandLandmarker: AiProviderRunFn< - HandLandmarkerTaskExecuteInput, - HandLandmarkerTaskExecuteOutput, + HandLandmarkerTaskInput, + HandLandmarkerTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const handLandmarker = await getModelTask( @@ -698,8 +698,8 @@ export const TFMP_HandLandmarker: AiProviderRunFn< * Core implementation for face detection using MediaPipe. */ export const TFMP_FaceDetector: AiProviderRunFn< - FaceDetectorTaskExecuteInput, - FaceDetectorTaskExecuteOutput, + FaceDetectorTaskInput, + FaceDetectorTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const faceDetector = await getModelTask( @@ -743,8 +743,8 @@ export const TFMP_FaceDetector: AiProviderRunFn< * Core implementation for face landmark detection using MediaPipe. */ export const TFMP_FaceLandmarker: AiProviderRunFn< - FaceLandmarkerTaskExecuteInput, - FaceLandmarkerTaskExecuteOutput, + FaceLandmarkerTaskInput, + FaceLandmarkerTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const faceLandmarker = await getModelTask( @@ -799,8 +799,8 @@ export const TFMP_FaceLandmarker: AiProviderRunFn< * Core implementation for pose landmark detection using MediaPipe. */ export const TFMP_PoseLandmarker: AiProviderRunFn< - PoseLandmarkerTaskExecuteInput, - PoseLandmarkerTaskExecuteOutput, + PoseLandmarkerTaskInput, + PoseLandmarkerTaskOutput, TFMPModelConfig > = async (input, model, onProgress, signal) => { const poseLandmarker = await getModelTask( diff --git a/packages/ai/README.md b/packages/ai/README.md index d8118516..9323a849 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -216,25 +216,6 @@ const result = await task.run(); // Output: { similarity: 0.85 } ``` -### Document Processing Tasks - -#### DocumentSplitterTask - -Splits documents into smaller chunks for processing. - -```typescript -import { DocumentSplitterTask } from "@workglow/ai"; - -const task = new DocumentSplitterTask({ - document: "Very long document content...", - chunkSize: 1000, - chunkOverlap: 200, -}); - -const result = await task.run(); -// Output: { chunks: ["chunk1...", "chunk2...", "chunk3..."] } -``` - ### Model Management Tasks #### DownloadModelTask @@ -415,30 +396,140 @@ const result = await workflow console.log("Final similarity score:", result.similarity); ``` -## Document Processing +## RAG (Retrieval-Augmented Generation) Pipelines + +The AI package provides a comprehensive set of tasks for building RAG pipelines. These tasks chain together in workflows without requiring external loops. + +### Document Processing Tasks + +| Task | Description | +| ------------------------- | ----------------------------------------------------- | +| `StructuralParserTask` | Parses markdown/text into hierarchical document trees | +| `TextChunkerTask` | Splits text into chunks with configurable strategies | +| `HierarchicalChunkerTask` | Token-aware chunking that respects document structure | +| `TopicSegmenterTask` | Segments text by topic using heuristics or embeddings | +| `DocumentEnricherTask` | Adds summaries and entities to document nodes | + +### Vector and Storage Tasks -The package includes document processing capabilities: +| Task | Description | +| ----------------------- | ---------------------------------------- | +| `ChunkToVectorTask` | Transforms chunks to vector store format | +| `VectorStoreUpsertTask` | Stores vectors in a repository | +| `VectorStoreSearchTask` | Searches vectors by similarity | +| `VectorQuantizeTask` | Quantizes vectors for storage efficiency | + +### Retrieval and Generation Tasks + +| Task | Description | +| -------------------- | --------------------------------------------- | +| `QueryExpanderTask` | Expands queries for better retrieval coverage | +| `HybridSearchTask` | Combines vector and full-text search | +| `RerankerTask` | Reranks search results for relevance | +| `HierarchyJoinTask` | Enriches results with parent context | +| `ContextBuilderTask` | Builds context for LLM prompts | +| `RetrievalTask` | Orchestrates end-to-end retrieval | + +### Complete RAG Workflow Example ```typescript -import { Document, DocumentConverterMarkdown } from "@workglow/ai"; +import { Workflow } from "@workglow/task-graph"; +import { InMemoryVectorRepository } from "@workglow/storage"; -// Create a document -const doc = new Document("# My Document\n\nThis is content...", { title: "Sample Doc" }); +const vectorRepo = new InMemoryVectorRepository(); +await vectorRepo.setupDatabase(); -// Convert markdown to structured format -const converter = new DocumentConverterMarkdown(); -const processedDoc = await converter.convert(doc); +// Document ingestion - fully chainable, no loops required +await new Workflow() + .structuralParser({ + text: markdownContent, + title: "Documentation", + format: "markdown", + }) + .documentEnricher({ + generateSummaries: true, + extractEntities: true, + }) + .hierarchicalChunker({ + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }) + .textEmbedding({ + model: "Xenova/all-MiniLM-L6-v2", + }) + .chunkToVector() + .vectorStoreUpsert({ + repository: vectorRepo, + }) + .run(); -// Use with document splitter -const splitter = new DocumentSplitterTask({ - document: processedDoc.content, - chunkSize: 500, - chunkOverlap: 50, -}); +// Query pipeline +const answer = await new Workflow() + .queryExpander({ + query: "What is transfer learning?", + method: "multi-query", + numVariations: 3, + }) + .textEmbedding({ + model: "Xenova/all-MiniLM-L6-v2", + }) + .vectorStoreSearch({ + repository: vectorRepo, + topK: 10, + scoreThreshold: 0.5, + }) + .reranker({ + query: "What is transfer learning?", + topK: 5, + }) + .contextBuilder({ + format: "markdown", + maxLength: 2000, + }) + .textQuestionAnswer({ + question: "What is transfer learning?", + model: "Xenova/LaMini-Flan-T5-783M", + }) + .run(); +``` + +### Hierarchical Document Structure -const chunks = await splitter.run(); +Documents are represented as trees with typed nodes: + +```typescript +type DocumentNode = + | DocumentRootNode // Root of document + | SectionNode // Headers, structural sections + | ParagraphNode // Text blocks + | SentenceNode // Fine-grained (optional) + | TopicNode; // Detected topic segments ``` +Each node contains: + +- `nodeId` - Deterministic content-based ID +- `range` - Source character offsets +- `text` - Content +- `enrichment` - Summaries, entities, keywords (optional) +- `children` - Child nodes (for parent nodes) + +### Task Data Flow + +Each task passes through what the next task needs: + +| Task | Passes Through | Adds | +| --------------------- | ----------------------- | ------------------------------------ | +| `structuralParser` | - | `docId`, `documentTree`, `nodeCount` | +| `documentEnricher` | `docId`, `documentTree` | `summaryCount`, `entityCount` | +| `hierarchicalChunker` | `docId` | `chunks`, `text[]`, `count` | +| `textEmbedding` | (implicit) | `vector[]` | +| `chunkToVector` | - | `ids[]`, `vectors[]`, `metadata[]` | +| `vectorStoreUpsert` | - | `count`, `ids` | + +This design eliminates the need for external loops - the entire pipeline chains together naturally. + ## Error Handling AI tasks include comprehensive error handling: @@ -466,6 +557,46 @@ try { ## Advanced Configuration +### Model Input Resolution + +AI tasks accept model inputs as either string identifiers or direct `ModelConfig` objects. When a string is provided, the TaskRunner automatically resolves it to a `ModelConfig` before task execution using the `ModelRepository`. + +```typescript +import { TextGenerationTask } from "@workglow/ai"; + +// Using a model ID (resolved from ModelRepository) +const task1 = new TextGenerationTask({ + model: "onnx:Xenova/gpt2:q8", + prompt: "Generate text", +}); + +// Using a direct ModelConfig object +const task2 = new TextGenerationTask({ + model: { + model_id: "onnx:Xenova/gpt2:q8", + provider: "hf-transformers-onnx", + tasks: ["TextGenerationTask"], + title: "GPT-2", + provider_config: { pipeline: "text-generation" }, + }, + prompt: "Generate text", +}); + +// Both approaches work identically +``` + +This resolution is handled by the input resolver system, which inspects schema `format` annotations (like `"model"` or `"model:TextGenerationTask"`) to determine how string values should be resolved. + +### Supported Format Annotations + +| Format | Description | Resolver | +| --------------------- | ---------------------------------------- | -------------------------- | +| `model` | Any AI model configuration | ModelRepository | +| `model:TaskName` | Model compatible with specific task type | ModelRepository | +| `repository:tabular` | Tabular data repository | TabularRepositoryRegistry | +| `repository:vector` | Vector storage repository | VectorRepositoryRegistry | +| `repository:document` | Document repository | DocumentRepositoryRegistry | + ### Custom Model Validation Tasks automatically validate that specified models exist and are compatible: diff --git a/packages/ai/src/common.ts b/packages/ai/src/common.ts index 12dbd7b9..ac800c78 100644 --- a/packages/ai/src/common.ts +++ b/packages/ai/src/common.ts @@ -5,12 +5,19 @@ */ export * from "./job/AiJob"; + export * from "./model/InMemoryModelRepository"; export * from "./model/ModelRegistry"; export * from "./model/ModelRepository"; export * from "./model/ModelSchema"; + export * from "./provider/AiProviderRegistry"; + export * from "./source/Document"; -export * from "./source/DocumentConverterMarkdown"; -export * from "./source/DocumentConverterText"; +export * from "./source/DocumentNode"; +export * from "./source/DocumentRepository"; +export * from "./source/DocumentRepositoryRegistry"; +export * from "./source/DocumentSchema"; +export * from "./source/StructuralParser"; + export * from "./task"; diff --git a/packages/ai/src/model/ModelRegistry.ts b/packages/ai/src/model/ModelRegistry.ts index 0d162955..2a7a3deb 100644 --- a/packages/ai/src/model/ModelRegistry.ts +++ b/packages/ai/src/model/ModelRegistry.ts @@ -4,9 +4,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { createServiceToken, globalServiceRegistry } from "@workglow/util"; +import { + createServiceToken, + globalServiceRegistry, + registerInputResolver, + ServiceRegistry, +} from "@workglow/util"; import { InMemoryModelRepository } from "./InMemoryModelRepository"; import { ModelRepository } from "./ModelRepository"; +import type { ModelConfig } from "./ModelSchema"; /** * Service token for the global model repository @@ -32,8 +38,36 @@ export function getGlobalModelRepository(): ModelRepository { /** * Sets the global model repository instance - * @param pr The model repository instance to register + * @param repository The model repository instance to register */ -export function setGlobalModelRepository(pr: ModelRepository): void { - globalServiceRegistry.registerInstance(MODEL_REPOSITORY, pr); +export function setGlobalModelRepository(repository: ModelRepository): void { + globalServiceRegistry.registerInstance(MODEL_REPOSITORY, repository); } + +/** + * Resolves a model ID to a ModelConfig from the repository. + * Used by the input resolver system. + */ +async function resolveModelFromRegistry( + id: string, + format: string, + registry: ServiceRegistry +): Promise { + const modelRepo = registry.has(MODEL_REPOSITORY) + ? registry.get(MODEL_REPOSITORY) + : getGlobalModelRepository(); + + if (Array.isArray(id)) { + const results = await Promise.all(id.map((i) => modelRepo.findByName(i))); + return results.filter((model) => model !== undefined) as ModelConfig[]; + } + + const model = await modelRepo.findByName(id); + if (!model) { + throw new Error(`Model "${id}" not found in repository`); + } + return model; +} + +// Register the model resolver for format: "model" and "model:*" +registerInputResolver("model", resolveModelFromRegistry); diff --git a/packages/ai/src/source/Document.ts b/packages/ai/src/source/Document.ts index a58a1c0a..9d58ca8a 100644 --- a/packages/ai/src/source/Document.ts +++ b/packages/ai/src/source/Document.ts @@ -4,170 +4,78 @@ * SPDX-License-Identifier: Apache-2.0 */ -enum DocumentType { - DOCUMENT = "document", - SECTION = "section", - TEXT = "text", - IMAGE = "image", - TABLE = "table", -} - -const doc_variants = [ - "tree", - "flat", - "tree-paragraphs", - "flat-paragraphs", - "tree-sentences", - "flat-sentences", -] as const; -type DocVariant = (typeof doc_variants)[number]; -const doc_parsers = ["txt", "md"] as const; // | "html" | "pdf" | "csv"; -type DocParser = (typeof doc_parsers)[number]; - -export interface DocumentMetadata { - title: string; -} - -export interface DocumentSectionMetadata { - title: string; -} +import type { ChunkNode, DocumentMetadata, DocumentNode } from "./DocumentSchema"; /** - * Represents a document with its content and metadata. + * Document represents a hierarchical document with chunks + * + * Key features: + * - Single source-of-truth tree structure (root node) + * - Single set of chunks + * - Separate persistence for document structure vs vectors */ export class Document { - public metadata: DocumentMetadata; + public readonly docId: string; + public readonly metadata: DocumentMetadata; + public readonly root: DocumentNode; + private chunks: ChunkNode[]; - constructor(content?: ContentType, metadata: DocumentMetadata = { title: "" }) { - this.metadata = metadata; - if (content) { - if (Array.isArray(content)) { - for (const line of content) { - this.addContent(line); - } - } else { - this.addContent(content); - } - } - } - - public addContent(content: ContentTypeItem) { - if (typeof content === "string") { - this.addText(content); - } else if (content instanceof DocumentBaseFragment || content instanceof DocumentSection) { - this.fragments.push(content); - } else { - throw new Error("Unknown content type"); - } - } - - public addSection(content?: ContentType, metadata?: DocumentSectionMetadata): DocumentSection { - const section = new DocumentSection(this, content, metadata); - this.fragments.push(section); - return section; - } - - public addText(content: string): TextFragment { - const f = new TextFragment(content); - this.fragments.push(f); - return f; - } - public addImage(content: unknown): ImageFragment { - const f = new ImageFragment(content); - this.fragments.push(f); - return f; - } - public addTable(content: unknown): TableFragment { - const f = new TableFragment(content); - this.fragments.push(f); - return f; - } - - public fragments: Array = []; - - toJSON(): unknown { - return { - type: DocumentType.DOCUMENT, - metadata: this.metadata, - fragments: this.fragments.map((f) => f.toJSON()), - }; - } -} - -export class DocumentSection extends Document { constructor( - public parent: Document, - content?: ContentType, - metadata?: DocumentSectionMetadata + docId: string, + root: DocumentNode, + metadata: DocumentMetadata, + chunks: ChunkNode[] = [] ) { - super(content, metadata); - this.parent = parent; + this.docId = docId; + this.root = root; + this.metadata = metadata; + this.chunks = chunks || []; } - toJSON(): unknown { - return { - type: DocumentType.SECTION, - metadata: this.metadata, - fragments: this.fragments.map((f) => f.toJSON()), - }; + /** + * Set chunks for the document + */ + setChunks(chunks: ChunkNode[]): void { + this.chunks = chunks; } -} -interface DocumentFragmentMetadata {} - -export class DocumentBaseFragment { - metadata?: DocumentFragmentMetadata; - constructor(metadata?: DocumentFragmentMetadata) { - this.metadata = metadata; + /** + * Get all chunks + */ + getChunks(): ChunkNode[] { + return this.chunks; } -} -export class TextFragment extends DocumentBaseFragment { - content: string; - constructor(content: string, metadata?: DocumentFragmentMetadata) { - super(metadata); - this.content = content; - } - toJSON(): unknown { - return { - type: DocumentType.TEXT, - metadata: this.metadata, - content: this.content, - }; + /** + * Find chunks by nodeId + */ + findChunksByNodeId(nodeId: string): ChunkNode[] { + return this.chunks.filter((chunk) => chunk.nodePath.includes(nodeId)); } -} -export class TableFragment extends DocumentBaseFragment { - content: any; - constructor(content: any, metadata?: DocumentFragmentMetadata) { - super(metadata); - this.content = content; - } - toJSON(): unknown { + /** + * Serialize to JSON + */ + toJSON(): { + docId: string; + metadata: DocumentMetadata; + root: DocumentNode; + chunks: ChunkNode[]; + } { return { - type: DocumentType.TABLE, + docId: this.docId, metadata: this.metadata, - content: this.content, + root: this.root, + chunks: this.chunks, }; } -} -export class ImageFragment extends DocumentBaseFragment { - content: any; - constructor(content: any, metadata?: DocumentFragmentMetadata) { - super(metadata); - this.content = content; - } - toJSON(): unknown { - return { - type: DocumentType.IMAGE, - metadata: this.metadata, - content: this.content, - }; + /** + * Deserialize from JSON + */ + static fromJSON(json: string): Document { + const obj = JSON.parse(json); + const doc = new Document(obj.docId, obj.root, obj.metadata, obj.chunks); + return doc; } } - -export type DocumentFragment = TextFragment | TableFragment | ImageFragment; - -export type ContentTypeItem = string | DocumentFragment | DocumentSection; -export type ContentType = ContentTypeItem | ContentTypeItem[]; diff --git a/packages/ai/src/source/DocumentConverter.ts b/packages/ai/src/source/DocumentConverter.ts deleted file mode 100644 index b89ba6ca..00000000 --- a/packages/ai/src/source/DocumentConverter.ts +++ /dev/null @@ -1,18 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { Document, DocumentMetadata } from "./Document"; - -/** - * Abstract class for converting different types of content into a Document. - */ -export abstract class DocumentConverter { - public metadata: DocumentMetadata; - constructor(metadata: DocumentMetadata) { - this.metadata = metadata; - } - public abstract convert(): Document; -} diff --git a/packages/ai/src/source/DocumentConverterMarkdown.ts b/packages/ai/src/source/DocumentConverterMarkdown.ts deleted file mode 100644 index 55e88330..00000000 --- a/packages/ai/src/source/DocumentConverterMarkdown.ts +++ /dev/null @@ -1,120 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { Document, type DocumentMetadata, type DocumentSection } from "./Document"; -import { DocumentConverter } from "./DocumentConverter"; - -export class DocumentConverterMarkdown extends DocumentConverter { - constructor( - metadata: DocumentMetadata, - public markdown: string - ) { - super(metadata); - } - public convert(): Document { - const parser = new MarkdownParser(this.metadata.title); - const document = parser.parse(this.markdown); - return document; - } -} - -class MarkdownParser { - private document: Document; - private currentSection: Document | DocumentSection; - private textBuffer: string[] = []; // Buffer to accumulate text lines - - constructor(title: string) { - this.document = new Document(title); - this.currentSection = this.document; - } - - parse(markdown: string): Document { - const lines = markdown.split("\n"); - - lines.forEach((line, index) => { - if (this.isHeader(line)) { - this.flushTextBuffer(); - const { level, content } = this.parseHeader(line); - this.currentSection = - level === 1 ? this.document.addSection(content) : this.currentSection.addSection(content); - } else if (this.isTableStart(line)) { - this.flushTextBuffer(); - const tableLines = this.collectTableLines(lines, index); - this.currentSection.addTable(tableLines.join("\n")); - } else if (this.isImageInline(line)) { - this.parseLineWithPossibleImages(line); - } else { - this.textBuffer.push(line); // Accumulate text lines in the buffer - } - }); - - this.flushTextBuffer(); // Flush any remaining text in the buffer - return this.document; - } - - private flushTextBuffer() { - if (this.textBuffer.length > 0) { - const textContent = this.textBuffer.join("\n").trim(); - if (textContent) { - this.currentSection.addText(textContent); - } - this.textBuffer = []; // Clear the buffer after flushing - } - } - - private parseLineWithPossibleImages(line: string) { - // Split the line by image markdown, keeping the delimiter (image markdown) - const parts = line.split(/(!\[.*?\]\(.*?\))/).filter((part) => part !== ""); - parts.forEach((part) => { - if (this.isImage(part)) { - const { alt, src } = this.parseImage(part); - this.flushTextBuffer(); - this.currentSection.addImage({ alt, src }); - } else { - this.textBuffer.push(part); - } - }); - this.flushTextBuffer(); - } - - private isHeader(line: string): boolean { - return /^#{1,6}\s/.test(line); - } - - private parseHeader(line: string): { level: number; content: string } { - const match = line.match(/^(#{1,6})\s+(.*)$/); - return match ? { level: match[1].length, content: match[2] } : { level: 0, content: "" }; - } - - private isTableStart(line: string): boolean { - return line.trim().startsWith("|") && line.includes("|", line.indexOf("|") + 1); - } - - private collectTableLines(lines: string[], startIndex: number): string[] { - const tableLines = []; - for (let i = startIndex; i < lines.length && this.isTableLine(lines[i]); i++) { - tableLines.push(lines[i]); - } - return tableLines; - } - - private isTableLine(line: string): boolean { - return line.includes("|"); - } - - private isImageInline(line: string): boolean { - return line.includes("![") && line.includes("]("); - } - - private isImage(part: string): boolean { - return /^!\[.*\]\(.*\)$/.test(part); - } - - private parseImage(markdown: string): { alt: string; src: string } { - const match = markdown.match(/^!\[(.*)\]\((.*)\)$/); - return match ? { alt: match[1], src: match[2] } : { alt: "", src: "" }; - } -} diff --git a/packages/ai/src/source/DocumentConverterText.ts b/packages/ai/src/source/DocumentConverterText.ts deleted file mode 100644 index f9337bda..00000000 --- a/packages/ai/src/source/DocumentConverterText.ts +++ /dev/null @@ -1,20 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { Document, DocumentMetadata } from "./Document"; -import { DocumentConverter } from "./DocumentConverter"; - -export class DocumentConverterText extends DocumentConverter { - constructor( - metadata: DocumentMetadata, - public text: string - ) { - super(metadata); - } - public convert(): Document { - return new Document(this.text, this.metadata); - } -} diff --git a/packages/ai/src/source/DocumentNode.ts b/packages/ai/src/source/DocumentNode.ts new file mode 100644 index 00000000..5bf021a5 --- /dev/null +++ b/packages/ai/src/source/DocumentNode.ts @@ -0,0 +1,134 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { sha256 } from "@workglow/util"; + +import { + NodeKind, + type DocumentNode, + type DocumentRootNode, + type NodeKind as NodeKindType, + type NodeRange, + type SectionNode, + type TopicNode, +} from "./DocumentSchema"; + +/** + * Utility functions for ID generation + */ +export class NodeIdGenerator { + /** + * Generate docId from source URI and content hash + */ + static async generateDocId(sourceUri: string, content: string): Promise { + const contentHash = await sha256(content); + const combined = `${sourceUri}|${contentHash}`; + const hash = await sha256(combined); + return `doc_${hash.substring(0, 16)}`; + } + + /** + * Generate nodeId for structural nodes (document, section) + */ + static async generateStructuralNodeId( + docId: string, + kind: NodeKindType, + range: NodeRange + ): Promise { + const combined = `${docId}|${kind}|${range.startOffset}:${range.endOffset}`; + const hash = await sha256(combined); + return `node_${hash.substring(0, 16)}`; + } + + /** + * Generate nodeId for child nodes (paragraph, topic) + */ + static async generateChildNodeId(parentNodeId: string, ordinal: number): Promise { + const combined = `${parentNodeId}|${ordinal}`; + const hash = await sha256(combined); + return `node_${hash.substring(0, 16)}`; + } + + /** + * Generate chunkId + */ + static async generateChunkId( + docId: string, + leafNodeId: string, + chunkOrdinal: number + ): Promise { + const combined = `${docId}|${leafNodeId}|${chunkOrdinal}`; + const hash = await sha256(combined); + return `chunk_${hash.substring(0, 16)}`; + } +} + +/** + * Approximate token counting (v1) + */ +export function estimateTokens(text: string): number { + return Math.ceil(text.length / 4); +} + +/** + * Helper to check if a node has children + */ +export function hasChildren( + node: DocumentNode +): node is DocumentRootNode | SectionNode | TopicNode { + return ( + node.kind === NodeKind.DOCUMENT || + node.kind === NodeKind.SECTION || + node.kind === NodeKind.TOPIC + ); +} + +/** + * Helper to get all children of a node + */ +export function getChildren(node: DocumentNode): DocumentNode[] { + if (hasChildren(node)) { + return node.children; + } + return []; +} + +/** + * Traverse document tree depth-first + */ +export function* traverseDepthFirst(node: DocumentNode): Generator { + yield node; + if (hasChildren(node)) { + for (const child of node.children) { + yield* traverseDepthFirst(child); + } + } +} + +/** + * Get node path from root to target node + */ +export function getNodePath(root: DocumentNode, targetNodeId: string): string[] | undefined { + const path: string[] = []; + + function search(node: DocumentNode): boolean { + path.push(node.nodeId); + if (node.nodeId === targetNodeId) { + return true; + } + if (hasChildren(node)) { + for (const child of node.children) { + if (search(child)) { + return true; + } + } + } + path.pop(); + return false; + } + + return search(root) ? path : undefined; +} diff --git a/packages/ai/src/source/DocumentRepository.ts b/packages/ai/src/source/DocumentRepository.ts new file mode 100644 index 00000000..4570f3da --- /dev/null +++ b/packages/ai/src/source/DocumentRepository.ts @@ -0,0 +1,245 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + ITabularRepository, + IVectorRepository, + SearchResult, + VectorSearchOptions, +} from "@workglow/storage"; +import type { DataPortSchemaObject, FromSchema, TypedArray } from "@workglow/util"; +import { Document } from "./Document"; +import { ChunkNode, DocumentNode } from "./DocumentSchema"; + +/** + * Schema for storing documents in tabular storage + */ +export const DocumentStorageSchema = { + type: "object", + properties: { + docId: { + type: "string", + title: "Document ID", + description: "Unique identifier for the document", + }, + data: { + type: "string", + title: "Document Data", + description: "JSON-serialized document", + }, + }, + required: ["docId", "data"], + additionalProperties: false, +} as const satisfies DataPortSchemaObject; + +type DocumentStorageEntity = FromSchema; + +/** + * Document repository that uses TabularStorage for persistence and VectorStorage for search. + * This is a unified implementation that composes storage backends rather than using + * inheritance/interface patterns. + */ +export class DocumentRepository { + private tabularStorage: ITabularRepository< + typeof DocumentStorageSchema, + ["docId"], + DocumentStorageEntity, + { docId: string }, + Document + >; + private vectorStorage?: IVectorRepository; + + /** + * Creates a new DocumentRepository instance. + * + * @param tabularStorage - Pre-initialized tabular storage for document persistence + * @param vectorStorage - Pre-initialized vector storage for chunk similarity search + * + * @example + * ```typescript + * const tabularStorage = new InMemoryTabularRepository(DocumentStorageSchema, ["docId"]); + * await tabularStorage.setupDatabase(); + * + * const vectorStorage = new InMemoryVectorRepository(); + * await vectorStorage.setupDatabase(); + * + * const docRepo = new DocumentRepository(tabularStorage, vectorStorage); + * ``` + */ + constructor( + tabularStorage: ITabularRepository< + typeof DocumentStorageSchema, + ["docId"], + DocumentStorageEntity, + { docId: string }, + Document + >, + vectorStorage?: IVectorRepository + ) { + this.tabularStorage = tabularStorage; + this.vectorStorage = vectorStorage; + } + + /** + * Upsert a document + */ + async upsert(document: Document): Promise { + const serialized = JSON.stringify(document.toJSON ? document.toJSON() : document); + await this.tabularStorage.put({ + docId: document.docId, + data: serialized, + }); + } + + /** + * Get a document by ID + */ + async get(docId: string): Promise { + const entity = await this.tabularStorage.get({ docId }); + if (!entity) { + return undefined; + } + return Document.fromJSON(entity.data); + } + + /** + * Delete a document + */ + async delete(docId: string): Promise { + await this.tabularStorage.delete({ docId }); + } + + /** + * Get a specific node by ID + */ + async getNode(docId: string, nodeId: string): Promise { + const doc = await this.get(docId); + if (!doc) { + return undefined; + } + + // Traverse tree to find node + const traverse = (node: any): any => { + if (node.nodeId === nodeId) { + return node; + } + if (node.children && Array.isArray(node.children)) { + for (const child of node.children) { + const found = traverse(child); + if (found) return found; + } + } + return undefined; + }; + + return traverse(doc.root); + } + + /** + * Get ancestors of a node (from root to node) + */ + async getAncestors(docId: string, nodeId: string): Promise { + const doc = await this.get(docId); + if (!doc) { + return []; + } + + // Get path from root to target node + const path: string[] = []; + const findPath = (node: any): boolean => { + path.push(node.nodeId); + if (node.nodeId === nodeId) { + return true; + } + if (node.children && Array.isArray(node.children)) { + for (const child of node.children) { + if (findPath(child)) { + return true; + } + } + } + path.pop(); + return false; + }; + + if (!findPath(doc.root)) { + return []; + } + + // Collect nodes along the path + const ancestors: any[] = []; + let currentNode: any = doc.root; + ancestors.push(currentNode); + + for (let i = 1; i < path.length; i++) { + const targetId = path[i]; + if (currentNode.children && Array.isArray(currentNode.children)) { + const found = currentNode.children.find((child: any) => child.nodeId === targetId); + if (found) { + currentNode = found; + ancestors.push(currentNode); + } else { + break; + } + } else { + break; + } + } + + return ancestors; + } + + /** + * Get chunks for a document + */ + async getChunks(docId: string): Promise { + const doc = await this.get(docId); + if (!doc) { + return []; + } + return doc.getChunks(); + } + + /** + * Find chunks that contain a specific nodeId in their path + */ + async findChunksByNodeId(docId: string, nodeId: string): Promise { + const doc = await this.get(docId); + if (!doc) { + return []; + } + if (doc.findChunksByNodeId) { + return doc.findChunksByNodeId(nodeId); + } + // Fallback implementation + const chunks = doc.getChunks(); + return chunks.filter((chunk) => chunk.nodePath && chunk.nodePath.includes(nodeId)); + } + + /** + * List all document IDs + */ + async list(): Promise { + const entities = await this.tabularStorage.getAll(); + if (!entities) { + return []; + } + return entities.map((e) => e.docId); + } + + /** + * Search for similar vectors using the vector storage + * @param query - Query vector to search for + * @param options - Search options (topK, filter, scoreThreshold) + * @returns Array of search results sorted by similarity + */ + async search( + query: TypedArray, + options?: VectorSearchOptions + ): Promise[]> { + return this.vectorStorage?.similaritySearch(query, options) || []; + } +} diff --git a/packages/ai/src/source/DocumentRepositoryRegistry.ts b/packages/ai/src/source/DocumentRepositoryRegistry.ts new file mode 100644 index 00000000..bdaca417 --- /dev/null +++ b/packages/ai/src/source/DocumentRepositoryRegistry.ts @@ -0,0 +1,80 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + createServiceToken, + globalServiceRegistry, + registerInputResolver, + ServiceRegistry, +} from "@workglow/util"; +import type { DocumentRepository } from "./DocumentRepository"; + +/** + * Service token for the document repository registry + * Maps repository IDs to DocumentRepository instances + */ +export const DOCUMENT_REPOSITORIES = createServiceToken>( + "document.repositories" +); + +// Register default factory if not already registered +if (!globalServiceRegistry.has(DOCUMENT_REPOSITORIES)) { + globalServiceRegistry.register( + DOCUMENT_REPOSITORIES, + (): Map => new Map(), + true + ); +} + +/** + * Gets the global document repository registry + * @returns Map of document repository ID to instance + */ +export function getGlobalDocumentRepositories(): Map { + return globalServiceRegistry.get(DOCUMENT_REPOSITORIES); +} + +/** + * Registers a document repository globally by ID + * @param id The unique identifier for this repository + * @param repository The repository instance to register + */ +export function registerDocumentRepository(id: string, repository: DocumentRepository): void { + const repos = getGlobalDocumentRepositories(); + repos.set(id, repository); +} + +/** + * Gets a document repository by ID from the global registry + * @param id The repository identifier + * @returns The repository instance or undefined if not found + */ +export function getDocumentRepository(id: string): DocumentRepository | undefined { + return getGlobalDocumentRepositories().get(id); +} + +/** + * Resolves a repository ID to a DocumentRepository from the registry. + * Used by the input resolver system. + */ +async function resolveDocumentRepositoryFromRegistry( + id: string, + format: string, + registry: ServiceRegistry +): Promise { + const repos = registry.has(DOCUMENT_REPOSITORIES) + ? registry.get>(DOCUMENT_REPOSITORIES) + : getGlobalDocumentRepositories(); + + const repo = repos.get(id); + if (!repo) { + throw new Error(`Document repository "${id}" not found in registry`); + } + return repo; +} + +// Register the repository resolver for format: "repository:document" +registerInputResolver("repository:document", resolveDocumentRepositoryFromRegistry); diff --git a/packages/ai/src/source/DocumentSchema.ts b/packages/ai/src/source/DocumentSchema.ts new file mode 100644 index 00000000..8e5886af --- /dev/null +++ b/packages/ai/src/source/DocumentSchema.ts @@ -0,0 +1,630 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { DataPortSchema, FromSchema, JsonSchema } from "@workglow/util"; + +/** + * Node kind discriminator for hierarchical document structure + */ +export const NodeKind = { + DOCUMENT: "document", + SECTION: "section", + PARAGRAPH: "paragraph", + SENTENCE: "sentence", + TOPIC: "topic", +} as const; + +export type NodeKind = (typeof NodeKind)[keyof typeof NodeKind]; + +// ============================================================================= +// Schema Definitions +// ============================================================================= + +/** + * Schema for source range of a node (character offsets) + */ +export const NodeRangeSchema = { + type: "object", + properties: { + startOffset: { + type: "integer", + title: "Start Offset", + description: "Starting character offset", + }, + endOffset: { + type: "integer", + title: "End Offset", + description: "Ending character offset", + }, + }, + required: ["startOffset", "endOffset"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type NodeRange = FromSchema; + +/** + * Schema for named entity extracted from text + */ +export const EntitySchema = { + type: "object", + properties: { + text: { + type: "string", + title: "Text", + description: "Entity text", + }, + type: { + type: "string", + title: "Type", + description: "Entity type (e.g., PERSON, ORG, LOC)", + }, + score: { + type: "number", + title: "Score", + description: "Confidence score", + }, + }, + required: ["text", "type", "score"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type Entity = FromSchema; + +/** + * Schema for enrichment data attached to a node + */ +export const NodeEnrichmentSchema = { + type: "object", + properties: { + summary: { + type: "string", + title: "Summary", + description: "Summary of the node content", + }, + entities: { + type: "array", + items: EntitySchema, + title: "Entities", + description: "Named entities extracted from the node", + }, + keywords: { + type: "array", + items: { type: "string" }, + title: "Keywords", + description: "Keywords associated with the node", + }, + }, + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type NodeEnrichment = FromSchema; + +/** + * Schema for base document node fields (used for runtime validation) + * Note: Individual node types and DocumentNode union are defined as interfaces + * below because FromSchema cannot properly infer recursive discriminated unions. + */ +export const DocumentNodeBaseSchema = { + type: "object", + properties: { + nodeId: { + type: "string", + title: "Node ID", + description: "Unique identifier for this node", + }, + kind: { + type: "string", + enum: Object.values(NodeKind), + title: "Kind", + description: "Node type discriminator", + }, + range: NodeRangeSchema, + text: { + type: "string", + title: "Text", + description: "Text content of the node", + }, + enrichment: NodeEnrichmentSchema, + }, + required: ["nodeId", "kind", "range", "text"], + additionalProperties: true, +} as const satisfies DataPortSchema; + +/** + * Schema for document node (generic, for runtime validation) + * This is a simplified schema for task input/output validation. + * The actual TypeScript types use a proper discriminated union. + */ +export const DocumentNodeSchema = { + type: "object", + title: "Document Node", + description: "A node in the hierarchical document tree", + properties: { + ...DocumentNodeBaseSchema.properties, + level: { + type: "integer", + title: "Level", + description: "Header level for section nodes", + }, + title: { + type: "string", + title: "Title", + description: "Section title", + }, + children: { + type: "array", + title: "Children", + description: "Child nodes", + }, + }, + required: [...DocumentNodeBaseSchema.required], + additionalProperties: false, +} as const satisfies DataPortSchema; + +/** + * Schema for paragraph node + */ +export const ParagraphNodeSchema = { + type: "object", + properties: { + ...DocumentNodeBaseSchema.properties, + kind: { + type: "string", + const: NodeKind.PARAGRAPH, + title: "Kind", + description: "Node type discriminator", + }, + }, + required: [...DocumentNodeBaseSchema.required], + additionalProperties: false, +} as const satisfies DataPortSchema; + +/** + * Schema for sentence node + */ +export const SentenceNodeSchema = { + type: "object", + properties: { + ...DocumentNodeBaseSchema.properties, + kind: { + type: "string", + const: NodeKind.SENTENCE, + title: "Kind", + description: "Node type discriminator", + }, + }, + required: [...DocumentNodeBaseSchema.required], + additionalProperties: false, +} as const satisfies DataPortSchema; + +/** + * Schema for section node + */ +export const SectionNodeSchema = { + type: "object", + properties: { + ...DocumentNodeBaseSchema.properties, + kind: { + type: "string", + const: NodeKind.SECTION, + title: "Kind", + description: "Node type discriminator", + }, + level: { + type: "integer", + minimum: 1, + maximum: 6, + title: "Level", + description: "Header level (1-6 for markdown)", + }, + title: { + type: "string", + title: "Title", + description: "Section title", + }, + children: { + type: "array", + items: DocumentNodeSchema, + title: "Children", + description: "Child nodes", + }, + }, + required: [...DocumentNodeBaseSchema.required, "level", "title", "children"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +/** + * Schema for topic node + */ +export const TopicNodeSchema = { + type: "object", + properties: { + ...DocumentNodeBaseSchema.properties, + kind: { + type: "string", + const: NodeKind.TOPIC, + title: "Kind", + description: "Node type discriminator", + }, + children: { + type: "array", + items: DocumentNodeSchema, + title: "Children", + description: "Child nodes", + }, + }, + required: [...DocumentNodeBaseSchema.required, "children"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +/** + * Schema for document root node + */ +export const DocumentRootNodeSchema = { + type: "object", + properties: { + ...DocumentNodeBaseSchema.properties, + kind: { + type: "string", + const: NodeKind.DOCUMENT, + title: "Kind", + description: "Node type discriminator", + }, + title: { + type: "string", + title: "Title", + description: "Document title", + }, + children: { + type: "array", + items: DocumentNodeSchema, + title: "Children", + description: "Child nodes", + }, + }, + required: [...DocumentNodeBaseSchema.required, "title", "children"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +// ============================================================================= +// Manually-defined interfaces for recursive discriminated union types +// These provide better TypeScript inference than FromSchema for recursive types +// ============================================================================= + +/** + * Base document node fields + */ +interface DocumentNodeBase { + readonly nodeId: string; + readonly kind: NodeKind; + readonly range: NodeRange; + readonly text: string; + readonly enrichment?: NodeEnrichment; +} + +/** + * Document root node + */ +export interface DocumentRootNode extends DocumentNodeBase { + readonly kind: typeof NodeKind.DOCUMENT; + readonly title: string; + readonly children: DocumentNode[]; +} + +/** + * Section node (from markdown headers or structural divisions) + */ +export interface SectionNode extends DocumentNodeBase { + readonly kind: typeof NodeKind.SECTION; + readonly level: number; + readonly title: string; + readonly children: DocumentNode[]; +} + +/** + * Paragraph node + */ +export interface ParagraphNode extends DocumentNodeBase { + readonly kind: typeof NodeKind.PARAGRAPH; +} + +/** + * Sentence node (optional fine-grained segmentation) + */ +export interface SentenceNode extends DocumentNodeBase { + readonly kind: typeof NodeKind.SENTENCE; +} + +/** + * Topic segment node (from TopicSegmenter) + */ +export interface TopicNode extends DocumentNodeBase { + readonly kind: typeof NodeKind.TOPIC; + readonly children: DocumentNode[]; +} + +/** + * Discriminated union of all document node types + */ +export type DocumentNode = + | DocumentRootNode + | SectionNode + | ParagraphNode + | SentenceNode + | TopicNode; + +// ============================================================================= +// Token Budget and Chunk Schemas +// ============================================================================= + +/** + * Schema for token budget configuration + */ +export const TokenBudgetSchema = { + type: "object", + properties: { + maxTokensPerChunk: { + type: "integer", + title: "Max Tokens Per Chunk", + description: "Maximum tokens allowed per chunk", + }, + overlapTokens: { + type: "integer", + title: "Overlap Tokens", + description: "Number of tokens to overlap between chunks", + }, + reservedTokens: { + type: "integer", + title: "Reserved Tokens", + description: "Tokens reserved for metadata or context", + }, + }, + required: ["maxTokensPerChunk", "overlapTokens", "reservedTokens"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type TokenBudget = FromSchema; + +/** + * Schema for chunk enrichment + */ +export const ChunkEnrichmentSchema = { + type: "object", + properties: { + summary: { + type: "string", + title: "Summary", + description: "Summary of the chunk content", + }, + entities: { + type: "array", + items: EntitySchema, + title: "Entities", + description: "Named entities extracted from the chunk", + }, + }, + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type ChunkEnrichment = FromSchema; + +/** + * Schema for chunk node (output of HierarchicalChunker) + */ +export const ChunkNodeSchema = () => + ({ + type: "object", + properties: { + chunkId: { + type: "string", + title: "Chunk ID", + description: "Unique identifier for this chunk", + }, + docId: { + type: "string", + title: "Document ID", + description: "ID of the parent document", + }, + text: { + type: "string", + title: "Text", + description: "Text content of the chunk", + }, + nodePath: { + type: "array", + items: { type: "string" }, + title: "Node Path", + description: "Node IDs from root to leaf", + }, + depth: { + type: "integer", + title: "Depth", + description: "Depth in the document tree", + }, + enrichment: ChunkEnrichmentSchema, + }, + required: ["chunkId", "docId", "text", "nodePath", "depth"], + additionalProperties: false, + }) as const satisfies DataPortSchema; + +export type ChunkNode = FromSchema>; + +// ============================================================================= +// Chunk Metadata Schemas (for vector store) +// ============================================================================= + +/** + * Schema for chunk metadata stored in vector database + * This is the metadata output from ChunkToVectorTask + */ +export const ChunkMetadataSchema = { + type: "object", + properties: { + docId: { + type: "string", + title: "Document ID", + description: "ID of the parent document", + }, + chunkId: { + type: "string", + title: "Chunk ID", + description: "Unique identifier for this chunk", + }, + leafNodeId: { + type: "string", + title: "Leaf Node ID", + description: "ID of the leaf node this chunk belongs to", + }, + depth: { + type: "integer", + title: "Depth", + description: "Depth in the document tree", + }, + text: { + type: "string", + title: "Text", + description: "Text content of the chunk", + }, + nodePath: { + type: "array", + items: { type: "string" }, + title: "Node Path", + description: "Node IDs from root to leaf", + }, + summary: { + type: "string", + title: "Summary", + description: "Summary of the chunk content", + }, + entities: { + type: "array", + items: EntitySchema, + title: "Entities", + description: "Named entities extracted from the chunk", + }, + }, + required: ["docId", "chunkId", "leafNodeId", "depth", "text", "nodePath"], + additionalProperties: true, +} as const satisfies DataPortSchema; + +export type ChunkMetadata = FromSchema; + +/** + * Schema for chunk metadata array (for use in task schemas) + */ +export const ChunkMetadataArraySchema = { + type: "array", + items: ChunkMetadataSchema, + title: "Chunk Metadata", + description: "Metadata for each chunk", +} as const satisfies JsonSchema; + +/** + * Schema for enriched chunk metadata (after HierarchyJoinTask) + * Extends ChunkMetadata with hierarchy information from document repository + */ +export const EnrichedChunkMetadataSchema = { + type: "object", + properties: { + docId: { + type: "string", + title: "Document ID", + description: "ID of the parent document", + }, + chunkId: { + type: "string", + title: "Chunk ID", + description: "Unique identifier for this chunk", + }, + leafNodeId: { + type: "string", + title: "Leaf Node ID", + description: "ID of the leaf node this chunk belongs to", + }, + depth: { + type: "integer", + title: "Depth", + description: "Depth in the document tree", + }, + text: { + type: "string", + title: "Text", + description: "Text content of the chunk", + }, + nodePath: { + type: "array", + items: { type: "string" }, + title: "Node Path", + description: "Node IDs from root to leaf", + }, + summary: { + type: "string", + title: "Summary", + description: "Summary of the chunk content", + }, + entities: { + type: "array", + items: EntitySchema, + title: "Entities", + description: "Named entities (rolled up from hierarchy)", + }, + parentSummaries: { + type: "array", + items: { type: "string" }, + title: "Parent Summaries", + description: "Summaries from ancestor nodes", + }, + sectionTitles: { + type: "array", + items: { type: "string" }, + title: "Section Titles", + description: "Titles of ancestor section nodes", + }, + }, + required: ["docId", "chunkId", "leafNodeId", "depth", "text", "nodePath"], + additionalProperties: true, +} as const satisfies DataPortSchema; + +export type EnrichedChunkMetadata = FromSchema; + +/** + * Schema for enriched chunk metadata array (for use in task schemas) + */ +export const EnrichedChunkMetadataArraySchema = { + type: "array", + items: EnrichedChunkMetadataSchema, + title: "Enriched Metadata", + description: "Metadata enriched with hierarchy information", +} as const satisfies JsonSchema; + +/** + * Schema for document metadata + */ +export const DocumentMetadataSchema = { + type: "object", + properties: { + title: { + type: "string", + title: "Title", + description: "Document title", + }, + sourceUri: { + type: "string", + title: "Source URI", + description: "Original source URI of the document", + }, + createdAt: { + type: "string", + title: "Created At", + description: "ISO timestamp of creation", + }, + }, + required: ["title"], + additionalProperties: true, +} as const satisfies DataPortSchema; + +export type DocumentMetadata = FromSchema; diff --git a/packages/ai/src/source/MasterDocument.ts b/packages/ai/src/source/MasterDocument.ts deleted file mode 100644 index 4d2aef33..00000000 --- a/packages/ai/src/source/MasterDocument.ts +++ /dev/null @@ -1,50 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { Document, DocumentMetadata, TextFragment } from "./Document"; -import { DocumentConverter } from "./DocumentConverter"; - -/** - * MasterDocument represents a container for managing multiple versions/variants of a document. - * It maintains the original document and its transformed variants for different use cases. - * - * Key features: - * - Stores original document and metadata - * - Maintains a master version and variants - * - Automatically creates paragraph-split variant - * - * The paragraph variant splits text fragments by newlines while preserving other fragment types, - * which is useful for more granular text processing. - */ - -export class MasterDocument { - public metadata: DocumentMetadata; - public original: DocumentConverter; - public master: Document; - public variants: Document[] = []; - constructor(original: DocumentConverter, metadata: DocumentMetadata) { - this.metadata = Object.assign(original.metadata, metadata); - this.original = original; - this.master = original.convert(); - this.variants.push(paragraphVariant(this.master)); - } -} - -function paragraphVariant(doc: Document): Document { - const newdoc = new Document("", doc.metadata); - for (const node of doc.fragments) { - if (node instanceof TextFragment) { - const newnodes = node.content - .split("\n") - .filter((t) => t) - .map((paragraph) => new TextFragment(paragraph)); - newdoc.fragments.push(...newnodes); - } else { - newdoc.fragments.push(node); - } - } - return newdoc; -} diff --git a/packages/ai/src/source/StructuralParser.ts b/packages/ai/src/source/StructuralParser.ts new file mode 100644 index 00000000..7b439e75 --- /dev/null +++ b/packages/ai/src/source/StructuralParser.ts @@ -0,0 +1,254 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { NodeIdGenerator } from "./DocumentNode"; +import { + type DocumentRootNode, + NodeKind, + type ParagraphNode, + type SectionNode, +} from "./DocumentSchema"; + +/** + * Parse markdown into a hierarchical DocumentNode tree + */ +export class StructuralParser { + /** + * Parse markdown text into a hierarchical document tree + */ + static async parseMarkdown( + docId: string, + text: string, + title: string + ): Promise { + const lines = text.split("\n"); + let currentOffset = 0; + + const root: DocumentRootNode = { + nodeId: await NodeIdGenerator.generateStructuralNodeId(docId, NodeKind.DOCUMENT, { + startOffset: 0, + endOffset: text.length, + }), + kind: NodeKind.DOCUMENT, + range: { startOffset: 0, endOffset: text.length }, + text: title, + title, + children: [], + }; + + let currentParentStack: Array = [root]; + let textBuffer: string[] = []; + let textBufferStartOffset = 0; + + const flushTextBuffer = async () => { + if (textBuffer.length > 0) { + const content = textBuffer.join("\n").trim(); + if (content) { + const paragraphStartOffset = textBufferStartOffset; + const paragraphEndOffset = currentOffset; + + const paragraph: ParagraphNode = { + nodeId: await NodeIdGenerator.generateChildNodeId( + currentParentStack[currentParentStack.length - 1].nodeId, + currentParentStack[currentParentStack.length - 1].children.length + ), + kind: NodeKind.PARAGRAPH, + range: { + startOffset: paragraphStartOffset, + endOffset: paragraphEndOffset, + }, + text: content, + }; + + currentParentStack[currentParentStack.length - 1].children.push(paragraph); + } + textBuffer = []; + } + }; + + for (const line of lines) { + const lineLength = line.length + 1; // +1 for newline + + // Check if line is a header + const headerMatch = line.match(/^(#{1,6})\s+(.*)$/); + if (headerMatch) { + await flushTextBuffer(); + + const level = headerMatch[1].length; + const headerTitle = headerMatch[2]; + + // Pop stack until we find appropriate parent + while ( + currentParentStack.length > 1 && + currentParentStack[currentParentStack.length - 1].kind === NodeKind.SECTION && + (currentParentStack[currentParentStack.length - 1] as SectionNode).level >= level + ) { + const poppedSection = currentParentStack.pop() as SectionNode; + // Update endOffset of popped section + const updatedSection: SectionNode = { + ...poppedSection, + range: { + ...poppedSection.range, + endOffset: currentOffset, + }, + }; + // Replace in parent's children + const parent = currentParentStack[currentParentStack.length - 1]; + parent.children[parent.children.length - 1] = updatedSection; + } + + const sectionStartOffset = currentOffset; + const section: SectionNode = { + nodeId: await NodeIdGenerator.generateStructuralNodeId(docId, NodeKind.SECTION, { + startOffset: sectionStartOffset, + endOffset: text.length, // Will be updated when section closes + }), + kind: NodeKind.SECTION, + level, + title: headerTitle, + range: { + startOffset: sectionStartOffset, + endOffset: text.length, + }, + text: headerTitle, + children: [], + }; + + currentParentStack[currentParentStack.length - 1].children.push(section); + currentParentStack.push(section); + } else { + // Accumulate text + if (textBuffer.length === 0) { + textBufferStartOffset = currentOffset; + } + textBuffer.push(line); + } + + currentOffset += lineLength; + } + + await flushTextBuffer(); + + // Close any remaining sections + while (currentParentStack.length > 1) { + const section = currentParentStack.pop() as SectionNode; + const updatedSection: SectionNode = { + ...section, + range: { + ...section.range, + endOffset: text.length, + }, + }; + const parent = currentParentStack[currentParentStack.length - 1]; + parent.children[parent.children.length - 1] = updatedSection; + } + + return root; + } + + /** + * Parse plain text into a hierarchical document tree + * Splits by double newlines to create paragraphs + */ + static async parsePlainText( + docId: string, + text: string, + title: string + ): Promise { + const root: DocumentRootNode = { + nodeId: await NodeIdGenerator.generateStructuralNodeId(docId, NodeKind.DOCUMENT, { + startOffset: 0, + endOffset: text.length, + }), + kind: NodeKind.DOCUMENT, + range: { startOffset: 0, endOffset: text.length }, + text: title, + title, + children: [], + }; + + // Split by double newlines to get paragraphs while tracking offsets + const paragraphRegex = /\n\s*\n/g; + let lastIndex = 0; + let paragraphIndex = 0; + let match: RegExpExecArray | null; + + while ((match = paragraphRegex.exec(text)) !== null) { + const rawParagraph = text.slice(lastIndex, match.index); + const paragraphText = rawParagraph.trim(); + + if (paragraphText.length > 0) { + const trimmedRelativeStart = rawParagraph.indexOf(paragraphText); + const startOffset = lastIndex + trimmedRelativeStart; + const endOffset = startOffset + paragraphText.length; + + const paragraph: ParagraphNode = { + nodeId: await NodeIdGenerator.generateChildNodeId(root.nodeId, paragraphIndex), + kind: NodeKind.PARAGRAPH, + range: { + startOffset, + endOffset, + }, + text: paragraphText, + }; + + root.children.push(paragraph); + paragraphIndex++; + } + + lastIndex = paragraphRegex.lastIndex; + } + + // Handle trailing paragraph after the last double newline, if any + if (lastIndex < text.length) { + const rawParagraph = text.slice(lastIndex); + const paragraphText = rawParagraph.trim(); + + if (paragraphText.length > 0) { + const trimmedRelativeStart = rawParagraph.indexOf(paragraphText); + const startOffset = lastIndex + trimmedRelativeStart; + const endOffset = startOffset + paragraphText.length; + + const paragraph: ParagraphNode = { + nodeId: await NodeIdGenerator.generateChildNodeId(root.nodeId, paragraphIndex), + kind: NodeKind.PARAGRAPH, + range: { + startOffset, + endOffset, + }, + text: paragraphText, + }; + + root.children.push(paragraph); + } + } + return root; + } + + /** + * Auto-detect format and parse + */ + static parse( + docId: string, + text: string, + title: string, + format?: "markdown" | "text" + ): Promise { + if (format === "markdown" || (!format && this.looksLikeMarkdown(text))) { + return this.parseMarkdown(docId, text, title); + } + return this.parsePlainText(docId, text, title); + } + + /** + * Check if text contains markdown header patterns + * Looks for lines starting with 1-6 hash symbols followed by whitespace + */ + private static looksLikeMarkdown(text: string): boolean { + // Check for markdown header patterns: line starting with # followed by space + return /^#{1,6}\s/m.test(text); + } +} diff --git a/packages/ai/src/source/index.ts b/packages/ai/src/source/index.ts deleted file mode 100644 index 8b9992fd..00000000 --- a/packages/ai/src/source/index.ts +++ /dev/null @@ -1,10 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -export * from "./Document"; -export * from "./DocumentConverterMarkdown"; -export * from "./DocumentConverterText"; -export * from "./MasterDocument"; diff --git a/packages/ai/src/task/BackgroundRemovalTask.ts b/packages/ai/src/task/BackgroundRemovalTask.ts index 5be33802..b3e1a81f 100644 --- a/packages/ai/src/task/BackgroundRemovalTask.ts +++ b/packages/ai/src/task/BackgroundRemovalTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:BackgroundRemovalTask")); +const modelSchema = TypeModel("model:BackgroundRemovalTask"); const processedImageSchema = { type: "string", @@ -27,7 +22,7 @@ const processedImageSchema = { export const BackgroundRemovalInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, }, required: ["image", "model"], @@ -37,11 +32,7 @@ export const BackgroundRemovalInputSchema = { export const BackgroundRemovalOutputSchema = { type: "object", properties: { - image: { - oneOf: [processedImageSchema, { type: "array", items: processedImageSchema }], - title: processedImageSchema.title, - description: processedImageSchema.description, - }, + image: processedImageSchema, }, required: ["image"], additionalProperties: false, @@ -49,12 +40,6 @@ export const BackgroundRemovalOutputSchema = { export type BackgroundRemovalTaskInput = FromSchema; export type BackgroundRemovalTaskOutput = FromSchema; -export type BackgroundRemovalTaskExecuteInput = DeReplicateFromSchema< - typeof BackgroundRemovalInputSchema ->; -export type BackgroundRemovalTaskExecuteOutput = DeReplicateFromSchema< - typeof BackgroundRemovalOutputSchema ->; /** * Removes backgrounds from images using computer vision models @@ -88,7 +73,7 @@ export const backgroundRemoval = ( input: BackgroundRemovalTaskInput, config?: JobQueueTaskConfig ) => { - return new BackgroundRemovalTask(input, config).run(); + return new BackgroundRemovalTask({} as BackgroundRemovalTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/ChunkToVectorTask.ts b/packages/ai/src/task/ChunkToVectorTask.ts new file mode 100644 index 00000000..df4fcf4c --- /dev/null +++ b/packages/ai/src/task/ChunkToVectorTask.ts @@ -0,0 +1,179 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { + DataPortSchema, + FromSchema, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; +import { ChunkNodeSchema, type ChunkNode } from "../source/DocumentSchema"; + +const inputSchema = { + type: "object", + properties: { + docId: { + type: "string", + title: "Document ID", + description: "The document ID", + }, + chunks: { + type: "array", + items: ChunkNodeSchema(), + title: "Chunks", + description: "Array of chunk nodes", + }, + vectors: { + type: "array", + items: TypedArraySchema({ + title: "Vector", + description: "Vector embedding", + }), + title: "Vectors", + description: "Embeddings from TextEmbeddingTask", + }, + }, + required: [], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + ids: { + type: "array", + items: { type: "string" }, + title: "IDs", + description: "Chunk IDs for vector store", + }, + vectors: { + type: "array", + items: TypedArraySchema({ + title: "Vector", + description: "Vector embedding", + }), + title: "Vectors", + description: "Vector embeddings", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata for vector store", + }, + title: "Metadata", + description: "Metadata for each vector", + }, + texts: { + type: "array", + items: { type: "string" }, + title: "Texts", + description: "Chunk texts (for reference)", + }, + }, + required: ["ids", "vectors", "metadata", "texts"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type ChunkToVectorTaskInput = FromSchema; +export type ChunkToVectorTaskOutput = FromSchema; + +/** + * Task to transform chunk nodes and embeddings into vector store format + * Bridges HierarchicalChunker + TextEmbedding → VectorStoreUpsert + */ +export class ChunkToVectorTask extends Task< + ChunkToVectorTaskInput, + ChunkToVectorTaskOutput, + JobQueueTaskConfig +> { + public static type = "ChunkToVectorTask"; + public static category = "Document"; + public static title = "Chunk to Vector Transform"; + public static description = "Transform chunks and embeddings to vector store format"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: ChunkToVectorTaskInput, + context: IExecuteContext + ): Promise { + const { chunks, vectors } = input; + + const chunkArray = chunks as ChunkNode[]; + + if (!chunkArray || !vectors) { + throw new Error("Both chunks and vector are required"); + } + + if (chunkArray.length !== vectors.length) { + throw new Error(`Mismatch: ${chunkArray.length} chunks but ${vectors.length} vectors`); + } + + const ids: string[] = []; + const metadata: any[] = []; + const texts: string[] = []; + + for (let i = 0; i < chunkArray.length; i++) { + const chunk = chunkArray[i]; + + ids.push(chunk.chunkId); + texts.push(chunk.text); + + metadata.push({ + docId: chunk.docId, + chunkId: chunk.chunkId, + leafNodeId: chunk.nodePath[chunk.nodePath.length - 1], + depth: chunk.depth, + text: chunk.text, + nodePath: chunk.nodePath, + // Include enrichment if present + ...(chunk.enrichment || {}), + }); + } + + return { + ids, + vectors, + metadata, + texts, + }; + } +} + +TaskRegistry.registerTask(ChunkToVectorTask); + +export const chunkToVector = (input: ChunkToVectorTaskInput, config?: JobQueueTaskConfig) => { + return new ChunkToVectorTask({} as ChunkToVectorTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + chunkToVector: CreateWorkflow< + ChunkToVectorTaskInput, + ChunkToVectorTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.chunkToVector = CreateWorkflow(ChunkToVectorTask); diff --git a/packages/ai/src/task/ContextBuilderTask.ts b/packages/ai/src/task/ContextBuilderTask.ts new file mode 100644 index 00000000..19dee6dc --- /dev/null +++ b/packages/ai/src/task/ContextBuilderTask.ts @@ -0,0 +1,339 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; + +export const ContextFormat = { + SIMPLE: "simple", + NUMBERED: "numbered", + XML: "xml", + MARKDOWN: "markdown", + JSON: "json", +} as const; + +export type ContextFormat = (typeof ContextFormat)[keyof typeof ContextFormat]; + +const inputSchema = { + type: "object", + properties: { + chunks: { + type: "array", + items: { type: "string" }, + title: "Text Chunks", + description: "Retrieved text chunks to format", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata for each chunk", + }, + title: "Metadata", + description: "Metadata for each chunk (optional)", + }, + scores: { + type: "array", + items: { type: "number" }, + title: "Scores", + description: "Relevance scores for each chunk (optional)", + }, + format: { + type: "string", + enum: Object.values(ContextFormat), + title: "Format", + description: "Format for the context output", + default: ContextFormat.SIMPLE, + }, + maxLength: { + type: "number", + title: "Max Length", + description: "Maximum length of context in characters (0 = unlimited)", + minimum: 0, + default: 0, + }, + includeMetadata: { + type: "boolean", + title: "Include Metadata", + description: "Whether to include metadata in the context", + default: false, + }, + separator: { + type: "string", + title: "Separator", + description: "Separator between chunks", + default: "\n\n", + }, + }, + required: ["chunks"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + context: { + type: "string", + title: "Context", + description: "Formatted context string for LLM", + }, + chunksUsed: { + type: "number", + title: "Chunks Used", + description: "Number of chunks included in context", + }, + totalLength: { + type: "number", + title: "Total Length", + description: "Total length of context in characters", + }, + }, + required: ["context", "chunksUsed", "totalLength"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type ContextBuilderTaskInput = FromSchema; +export type ContextBuilderTaskOutput = FromSchema; + +/** + * Task for formatting retrieved chunks into context for LLM prompts. + * Supports various formatting styles and length constraints. + */ +export class ContextBuilderTask extends Task< + ContextBuilderTaskInput, + ContextBuilderTaskOutput, + JobQueueTaskConfig +> { + public static type = "ContextBuilderTask"; + public static category = "RAG"; + public static title = "Context Builder"; + public static description = "Format retrieved chunks into context for LLM prompts"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async executeReactive( + input: ContextBuilderTaskInput, + output: ContextBuilderTaskOutput + ): Promise { + const { + chunks, + metadata = [], + scores = [], + format = ContextFormat.SIMPLE, + maxLength = 0, + includeMetadata = false, + separator = "\n\n", + } = input; + + let context = ""; + let chunksUsed = 0; + + for (let i = 0; i < chunks.length; i++) { + const chunk = chunks[i]; + const meta = metadata[i]; + const score = scores[i]; + + let formattedChunk = this.formatChunk(chunk, meta, score, i, format, includeMetadata); + + // Check length constraint + if (maxLength > 0) { + const potentialLength = context.length + formattedChunk.length + separator.length; + if (potentialLength > maxLength) { + // Try to fit partial chunk if it's the first one + if (chunksUsed === 0) { + const available = maxLength - context.length; + if (available > 100) { + // Only include partial if we have reasonable space + formattedChunk = formattedChunk.substring(0, available - 3) + "..."; + context += formattedChunk; + chunksUsed++; + } + } + break; + } + } + + if (chunksUsed > 0) { + context += separator; + } + context += formattedChunk; + chunksUsed++; + } + + return { + context, + chunksUsed, + totalLength: context.length, + }; + } + + private formatChunk( + chunk: string, + metadata: any, + score: number | undefined, + index: number, + format: ContextFormat, + includeMetadata: boolean + ): string { + switch (format) { + case ContextFormat.NUMBERED: + return this.formatNumbered(chunk, metadata, score, index, includeMetadata); + case ContextFormat.XML: + return this.formatXML(chunk, metadata, score, index, includeMetadata); + case ContextFormat.MARKDOWN: + return this.formatMarkdown(chunk, metadata, score, index, includeMetadata); + case ContextFormat.JSON: + return this.formatJSON(chunk, metadata, score, index, includeMetadata); + case ContextFormat.SIMPLE: + default: + return chunk; + } + } + + private formatNumbered( + chunk: string, + metadata: any, + score: number | undefined, + index: number, + includeMetadata: boolean + ): string { + let result = `[${index + 1}] ${chunk}`; + if (includeMetadata && metadata) { + const metaStr = this.formatMetadataInline(metadata, score); + if (metaStr) { + result += ` ${metaStr}`; + } + } + return result; + } + + private formatXML( + chunk: string, + metadata: any, + score: number | undefined, + index: number, + includeMetadata: boolean + ): string { + let result = ``; + if (includeMetadata && (metadata || score !== undefined)) { + result += "\n "; + if (score !== undefined) { + result += `\n ${score.toFixed(4)}`; + } + if (metadata) { + for (const [key, value] of Object.entries(metadata)) { + result += `\n <${key}>${this.escapeXML(String(value))}`; + } + } + result += "\n "; + result += `\n ${this.escapeXML(chunk)}`; + result += "\n"; + } else { + result += `${this.escapeXML(chunk)}`; + } + return result; + } + + private formatMarkdown( + chunk: string, + metadata: any, + score: number | undefined, + index: number, + includeMetadata: boolean + ): string { + let result = `### Chunk ${index + 1}\n\n`; + if (includeMetadata && (metadata || score !== undefined)) { + result += "**Metadata:**\n"; + if (score !== undefined) { + result += `- Score: ${score.toFixed(4)}\n`; + } + if (metadata) { + for (const [key, value] of Object.entries(metadata)) { + result += `- ${key}: ${value}\n`; + } + } + result += "\n"; + } + result += chunk; + return result; + } + + private formatJSON( + chunk: string, + metadata: any, + score: number | undefined, + index: number, + includeMetadata: boolean + ): string { + const obj: any = { + index: index + 1, + content: chunk, + }; + if (includeMetadata) { + if (score !== undefined) { + obj.score = score; + } + if (metadata) { + obj.metadata = metadata; + } + } + return JSON.stringify(obj); + } + + private formatMetadataInline(metadata: any, score: number | undefined): string { + const parts: string[] = []; + if (score !== undefined) { + parts.push(`score=${score.toFixed(4)}`); + } + if (metadata) { + for (const [key, value] of Object.entries(metadata)) { + parts.push(`${key}=${value}`); + } + } + return parts.length > 0 ? `(${parts.join(", ")})` : ""; + } + + private escapeXML(str: string): string { + return str + .replace(/&/g, "&") + .replace(//g, ">") + .replace(/"/g, """) + .replace(/'/g, "'"); + } +} + +TaskRegistry.registerTask(ContextBuilderTask); + +export const contextBuilder = (input: ContextBuilderTaskInput, config?: JobQueueTaskConfig) => { + return new ContextBuilderTask({} as ContextBuilderTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + contextBuilder: CreateWorkflow< + ContextBuilderTaskInput, + ContextBuilderTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.contextBuilder = CreateWorkflow(ContextBuilderTask); diff --git a/packages/ai/src/task/DocumentEnricherTask.ts b/packages/ai/src/task/DocumentEnricherTask.ts new file mode 100644 index 00000000..1c9d5886 --- /dev/null +++ b/packages/ai/src/task/DocumentEnricherTask.ts @@ -0,0 +1,412 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; +import { TypeModel } from "../common"; +import { ModelConfig } from "../model/ModelSchema"; +import { getChildren, hasChildren } from "../source/DocumentNode"; +import { type DocumentNode, type Entity, type NodeEnrichment } from "../source/DocumentSchema"; +import { TextNamedEntityRecognitionTask } from "./TextNamedEntityRecognitionTask"; +import { TextSummaryTask } from "./TextSummaryTask"; + +const inputSchema = { + type: "object", + properties: { + docId: { + type: "string", + title: "Document ID", + description: "The document ID", + }, + documentTree: { + title: "Document Tree", + description: "The hierarchical document tree to enrich", + }, + generateSummaries: { + type: "boolean", + title: "Generate Summaries", + description: "Whether to generate summaries for sections", + default: true, + }, + extractEntities: { + type: "boolean", + title: "Extract Entities", + description: "Whether to extract named entities", + default: true, + }, + summaryModel: TypeModel("model:TextSummaryTask", { + title: "Summary Model", + description: "Model to use for summary generation (optional)", + }), + summaryThreshold: { + type: "number", + title: "Summary Threshold", + description: "Minimum combined text length (node + children) to warrant generating a summary", + default: 500, + }, + nerModel: TypeModel("model:TextNamedEntityRecognitionTask", { + title: "NER Model", + description: "Model to use for named entity recognition (optional)", + }), + }, + required: [], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + docId: { + type: "string", + title: "Document ID", + description: "The document ID (passed through)", + }, + documentTree: { + title: "Document Tree", + description: "The enriched document tree", + }, + summaryCount: { + type: "number", + title: "Summary Count", + description: "Number of summaries generated", + }, + entityCount: { + type: "number", + title: "Entity Count", + description: "Number of entities extracted", + }, + }, + required: ["docId", "documentTree", "summaryCount", "entityCount"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type DocumentEnricherTaskInput = FromSchema; +export type DocumentEnricherTaskOutput = FromSchema; + +/** + * Task for enriching document nodes with summaries and entities + * Uses bottom-up propagation to roll up child information to parents + */ +export class DocumentEnricherTask extends Task< + DocumentEnricherTaskInput, + DocumentEnricherTaskOutput, + JobQueueTaskConfig +> { + public static type = "DocumentEnricherTask"; + public static category = "Document"; + public static title = "Document Enricher"; + public static description = "Enrich document nodes with summaries and entities"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: DocumentEnricherTaskInput, + context: IExecuteContext + ): Promise { + const { + docId, + documentTree, + generateSummaries = true, + extractEntities = true, + summaryModel: summaryModelConfig, + summaryThreshold = 500, + nerModel: nerModelConfig, + } = input; + + const root = documentTree as DocumentNode; + const summaryModel = summaryModelConfig ? (summaryModelConfig as ModelConfig) : undefined; + const nerModel = nerModelConfig ? (nerModelConfig as ModelConfig) : undefined; + let summaryCount = 0; + let entityCount = 0; + + const extract = + extractEntities && nerModel + ? async (text: string) => { + const result = await context + .own(new TextNamedEntityRecognitionTask({ text, model: nerModel })) + .run(); + return result.entities.map((e) => ({ + type: e.entity, + text: e.word, + score: e.score, + })); + } + : undefined; + + // Bottom-up enrichment + const enrichedRoot = await this.enrichNode( + root, + context, + generateSummaries && summaryModel ? summaryModel : undefined, + summaryThreshold, + extract, + (count) => (summaryCount += count), + (count) => (entityCount += count) + ); + + return { + docId: docId as string, + documentTree: enrichedRoot, + summaryCount, + entityCount, + }; + } + + /** + * Enrich a node recursively (bottom-up) + */ + private async enrichNode( + node: DocumentNode, + context: IExecuteContext, + summaryModel: ModelConfig | undefined, + summaryThreshold: number, + extract: ((text: string) => Promise) | undefined, + onSummary: (count: number) => void, + onEntity: (count: number) => void + ): Promise { + // If node has children, enrich them first + let enrichedChildren: DocumentNode[] | undefined; + if (hasChildren(node)) { + const children = getChildren(node); + enrichedChildren = await Promise.all( + children.map((child) => + this.enrichNode( + child, + context, + summaryModel, + summaryThreshold, + extract, + onSummary, + onEntity + ) + ) + ); + } + + // Generate enrichment for this node + const enrichment: NodeEnrichment = {}; + + // Generate summary (for sections and documents) + if (summaryModel && (node.kind === "section" || node.kind === "document")) { + if (enrichedChildren && enrichedChildren.length > 0) { + // Summary of children + enrichment.summary = await this.generateSummary( + node, + enrichedChildren, + context, + summaryModel, + summaryThreshold + ); + } else { + // Leaf section summary + enrichment.summary = await this.generateLeafSummary( + node.text, + context, + summaryModel, + summaryThreshold + ); + } + if (enrichment.summary) { + onSummary(1); + } + } + + // Extract entities + if (extract) { + enrichment.entities = await this.extractEntities(node, enrichedChildren, extract); + if (enrichment.entities) { + onEntity(enrichment.entities.length); + } + } + + // Create enriched node + const enrichedNode: DocumentNode = { + ...node, + enrichment: Object.keys(enrichment).length > 0 ? enrichment : undefined, + }; + + if (enrichedChildren) { + (enrichedNode as any).children = enrichedChildren; + } + + return enrichedNode; + } + + /** + * Private method to summarize text using the TextSummaryTask + */ + private async summarize( + text: string, + context: IExecuteContext, + model: ModelConfig + ): Promise { + // TODO: Handle truncation of text if needed, based on model configuration + return (await context.own(new TextSummaryTask()).run({ text, model })).text; + } + + /** + * Generate summary for a node with children + */ + private async generateSummary( + node: DocumentNode, + children: DocumentNode[], + context: IExecuteContext, + model: ModelConfig, + threshold: number + ): Promise { + const textParts: string[] = []; + + // Include the node's own text + const nodeText = node.text?.trim(); + if (nodeText) { + textParts.push(nodeText); + } + + // Include children summaries/texts + const childTexts = children + .map((child) => { + if (child.enrichment?.summary) { + return child.enrichment.summary; + } + return child.text; + }) + .join(" ") + .trim(); + + if (childTexts) { + textParts.push(childTexts); + } + + const combinedText = textParts.join(" ").trim(); + if (!combinedText) { + return undefined; + } + + // Check if summary is warranted based on threshold + if (combinedText.length < threshold) { + return undefined; + } + + const summaryParts: string[] = []; + + // Summarize the node's own text first + if (nodeText) { + const nodeSummary = await this.summarize(nodeText, context, model); + if (nodeSummary) { + summaryParts.push(nodeSummary); + } + } + + // Include children summaries/texts + if (childTexts) { + summaryParts.push(childTexts); + } + + const combinedSummaries = summaryParts.join(" ").trim(); + if (!combinedSummaries) { + return undefined; + } + + const result = await this.summarize(combinedSummaries, context, model); + return result; + } + + /** + * Generate summary for a leaf node + */ + private async generateLeafSummary( + text: string, + context: IExecuteContext, + model: ModelConfig, + threshold: number + ): Promise { + const trimmedText = text.trim(); + if (!trimmedText) { + return undefined; + } + + // Check if summary is warranted based on threshold + if (trimmedText.length < threshold) { + return undefined; + } + + const result = await this.summarize(trimmedText, context, model); + return result; + } + + /** + * Extract and roll up entities from node and children + */ + private async extractEntities( + node: DocumentNode, + children: DocumentNode[] | undefined, + extract: ((text: string) => Promise) | undefined + ): Promise { + const entities: Entity[] = []; + + // Collect from children first + if (children) { + for (const child of children) { + if (child.enrichment?.entities) { + entities.push(...child.enrichment.entities); + } + } + } + + const text = node.text.trim(); + if (text && extract) { + const nodeEntities = await extract(text); + if (nodeEntities?.length) { + entities.push(...nodeEntities); + } + } + + // Deduplicate by text + const unique = new Map(); + for (const entity of entities) { + const key = `${entity.text}::${entity.type}`; + const existing = unique.get(key); + if (!existing || entity.score > existing.score) { + unique.set(key, entity); + } + } + + const result = Array.from(unique.values()); + return result.length > 0 ? result : undefined; + } +} + +TaskRegistry.registerTask(DocumentEnricherTask); + +export const documentEnricher = (input: DocumentEnricherTaskInput, config?: JobQueueTaskConfig) => { + return new DocumentEnricherTask({} as DocumentEnricherTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + documentEnricher: CreateWorkflow< + DocumentEnricherTaskInput, + DocumentEnricherTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.documentEnricher = CreateWorkflow(DocumentEnricherTask); diff --git a/packages/ai/src/task/DocumentSplitterTask.ts b/packages/ai/src/task/DocumentSplitterTask.ts deleted file mode 100644 index 864f7e1d..00000000 --- a/packages/ai/src/task/DocumentSplitterTask.ts +++ /dev/null @@ -1,98 +0,0 @@ -/** - * @license - * Copyright 2025 Steven Roussey - * SPDX-License-Identifier: Apache-2.0 - */ - -import { - CreateWorkflow, - JobQueueTaskConfig, - Task, - TaskRegistry, - Workflow, -} from "@workglow/task-graph"; -import { DataPortSchema, FromSchema } from "@workglow/util"; -import { Document, DocumentFragment } from "../source/Document"; - -const inputSchema = { - type: "object", - properties: { - parser: { - type: "string", - enum: ["txt", "md"], - title: "Document Kind", - description: "The kind of document (txt or md)", - }, - // file: Type.Instance(Document), - }, - required: ["parser"], - additionalProperties: false, -} as const satisfies DataPortSchema; - -const outputSchema = { - type: "object", - properties: { - texts: { - type: "array", - items: { type: "string" }, - title: "Text Chunks", - description: "The text chunks of the document", - }, - }, - required: ["texts"], - additionalProperties: false, -} as const satisfies DataPortSchema; - -export type DocumentSplitterTaskInput = FromSchema; -export type DocumentSplitterTaskOutput = FromSchema; - -export class DocumentSplitterTask extends Task< - DocumentSplitterTaskInput, - DocumentSplitterTaskOutput, - JobQueueTaskConfig -> { - public static type = "DocumentSplitterTask"; - public static category = "Document"; - public static title = "Document Splitter"; - public static description = "Splits documents into text chunks for processing"; - public static inputSchema(): DataPortSchema { - return inputSchema as DataPortSchema; - } - public static outputSchema(): DataPortSchema { - return outputSchema as DataPortSchema; - } - - flattenFragmentsToTexts(item: DocumentFragment | Document): string[] { - if (item instanceof Document) { - const texts: string[] = []; - item.fragments.forEach((fragment) => { - texts.push(...this.flattenFragmentsToTexts(fragment)); - }); - return texts; - } else { - return [item.content]; - } - } - - async executeReactive(): Promise { - return { texts: this.flattenFragmentsToTexts(this.runInputData.file) }; - } -} - -TaskRegistry.registerTask(DocumentSplitterTask); - -export const documentSplitter = (input: DocumentSplitterTaskInput) => { - return new DocumentSplitterTask(input).run(); -}; - -declare module "@workglow/task-graph" { - interface Workflow { - documentSplitter: CreateWorkflow< - DocumentSplitterTaskInput, - DocumentSplitterTaskOutput, - JobQueueTaskConfig - >; - } -} - -Workflow.prototype.documentSplitter = CreateWorkflow(DocumentSplitterTask); diff --git a/packages/ai/src/task/DownloadModelTask.ts b/packages/ai/src/task/DownloadModelTask.ts index 6586f9b9..6eb9a0fd 100644 --- a/packages/ai/src/task/DownloadModelTask.ts +++ b/packages/ai/src/task/DownloadModelTask.ts @@ -7,9 +7,9 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model")); +const modelSchema = TypeModel("model"); const DownloadModelInputSchema = { type: "object", @@ -31,10 +31,6 @@ const DownloadModelOutputSchema = { export type DownloadModelTaskRunInput = FromSchema; export type DownloadModelTaskRunOutput = FromSchema; -export type DownloadModelTaskExecuteInput = DeReplicateFromSchema; -export type DownloadModelTaskExecuteOutput = DeReplicateFromSchema< - typeof DownloadModelOutputSchema ->; /** * Download a model from a remote source and cache it locally. @@ -103,7 +99,7 @@ TaskRegistry.registerTask(DownloadModelTask); * @returns Promise resolving to the downloaded model(s) */ export const downloadModel = (input: DownloadModelTaskRunInput, config?: JobQueueTaskConfig) => { - return new DownloadModelTask(input, config).run(); + return new DownloadModelTask({} as DownloadModelTaskRunInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/FaceDetectorTask.ts b/packages/ai/src/task/FaceDetectorTask.ts index 7989a46f..465d8717 100644 --- a/packages/ai/src/task/FaceDetectorTask.ts +++ b/packages/ai/src/task/FaceDetectorTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:FaceDetectorTask")); +const modelSchema = TypeModel("model:FaceDetectorTask"); /** * A bounding box for face detection. @@ -99,7 +94,7 @@ const TypeFaceDetection = { export const FaceDetectorInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, minDetectionConfidence: { type: "number", @@ -142,8 +137,6 @@ export const FaceDetectorOutputSchema = { export type FaceDetectorTaskInput = FromSchema; export type FaceDetectorTaskOutput = FromSchema; -export type FaceDetectorTaskExecuteInput = DeReplicateFromSchema; -export type FaceDetectorTaskExecuteOutput = DeReplicateFromSchema; /** * Detects faces in images using MediaPipe Face Detector. @@ -176,7 +169,7 @@ TaskRegistry.registerTask(FaceDetectorTask); * @returns Promise resolving to the detected faces with bounding boxes and keypoints */ export const faceDetector = (input: FaceDetectorTaskInput, config?: JobQueueTaskConfig) => { - return new FaceDetectorTask(input, config).run(); + return new FaceDetectorTask({} as FaceDetectorTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/FaceLandmarkerTask.ts b/packages/ai/src/task/FaceLandmarkerTask.ts index 2dc151bd..961bc436 100644 --- a/packages/ai/src/task/FaceLandmarkerTask.ts +++ b/packages/ai/src/task/FaceLandmarkerTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:FaceLandmarkerTask")); +const modelSchema = TypeModel("model:FaceLandmarkerTask"); /** * A landmark point with x, y, z coordinates. @@ -102,7 +97,7 @@ const TypeFaceLandmarkerDetection = { export const FaceLandmarkerInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, numFaces: { type: "number", @@ -177,12 +172,6 @@ export const FaceLandmarkerOutputSchema = { export type FaceLandmarkerTaskInput = FromSchema; export type FaceLandmarkerTaskOutput = FromSchema; -export type FaceLandmarkerTaskExecuteInput = DeReplicateFromSchema< - typeof FaceLandmarkerInputSchema ->; -export type FaceLandmarkerTaskExecuteOutput = DeReplicateFromSchema< - typeof FaceLandmarkerOutputSchema ->; /** * Detects facial landmarks and expressions in images using MediaPipe Face Landmarker. @@ -216,7 +205,7 @@ TaskRegistry.registerTask(FaceLandmarkerTask); * @returns Promise resolving to the detected facial landmarks, blendshapes, and transformation matrices */ export const faceLandmarker = (input: FaceLandmarkerTaskInput, config?: JobQueueTaskConfig) => { - return new FaceLandmarkerTask(input, config).run(); + return new FaceLandmarkerTask({} as FaceLandmarkerTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/GestureRecognizerTask.ts b/packages/ai/src/task/GestureRecognizerTask.ts index 64868453..706b6c44 100644 --- a/packages/ai/src/task/GestureRecognizerTask.ts +++ b/packages/ai/src/task/GestureRecognizerTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:GestureRecognizerTask")); +const modelSchema = TypeModel("model:GestureRecognizerTask"); /** * A landmark point with x, y, z coordinates. @@ -122,7 +117,7 @@ const TypeHandGestureDetection = { export const GestureRecognizerInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, numHands: { type: "number", @@ -183,12 +178,6 @@ export const GestureRecognizerOutputSchema = { export type GestureRecognizerTaskInput = FromSchema; export type GestureRecognizerTaskOutput = FromSchema; -export type GestureRecognizerTaskExecuteInput = DeReplicateFromSchema< - typeof GestureRecognizerInputSchema ->; -export type GestureRecognizerTaskExecuteOutput = DeReplicateFromSchema< - typeof GestureRecognizerOutputSchema ->; /** * Recognizes hand gestures in images using MediaPipe Gesture Recognizer. @@ -225,7 +214,7 @@ export const gestureRecognizer = ( input: GestureRecognizerTaskInput, config?: JobQueueTaskConfig ) => { - return new GestureRecognizerTask(input, config).run(); + return new GestureRecognizerTask({} as GestureRecognizerTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/HandLandmarkerTask.ts b/packages/ai/src/task/HandLandmarkerTask.ts index 1d0beec8..739e92a1 100644 --- a/packages/ai/src/task/HandLandmarkerTask.ts +++ b/packages/ai/src/task/HandLandmarkerTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:HandLandmarkerTask")); +const modelSchema = TypeModel("model:HandLandmarkerTask"); /** * A landmark point with x, y, z coordinates. @@ -95,7 +90,7 @@ const TypeHandDetection = { export const HandLandmarkerInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, numHands: { type: "number", @@ -156,12 +151,6 @@ export const HandLandmarkerOutputSchema = { export type HandLandmarkerTaskInput = FromSchema; export type HandLandmarkerTaskOutput = FromSchema; -export type HandLandmarkerTaskExecuteInput = DeReplicateFromSchema< - typeof HandLandmarkerInputSchema ->; -export type HandLandmarkerTaskExecuteOutput = DeReplicateFromSchema< - typeof HandLandmarkerOutputSchema ->; /** * Detects hand landmarks in images using MediaPipe Hand Landmarker. @@ -194,7 +183,7 @@ TaskRegistry.registerTask(HandLandmarkerTask); * @returns Promise resolving to the detected hand landmarks and handedness */ export const handLandmarker = (input: HandLandmarkerTaskInput, config?: JobQueueTaskConfig) => { - return new HandLandmarkerTask(input, config).run(); + return new HandLandmarkerTask({} as HandLandmarkerTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/HierarchicalChunkerTask.ts b/packages/ai/src/task/HierarchicalChunkerTask.ts new file mode 100644 index 00000000..95b21baf --- /dev/null +++ b/packages/ai/src/task/HierarchicalChunkerTask.ts @@ -0,0 +1,303 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; + +import { estimateTokens, getChildren, hasChildren, NodeIdGenerator } from "../source/DocumentNode"; +import { + ChunkNodeSchema, + type ChunkNode, + type DocumentNode, + type TokenBudget, +} from "../source/DocumentSchema"; + +const inputSchema = { + type: "object", + properties: { + docId: { + type: "string", + title: "Document ID", + description: "The ID of the document", + }, + documentTree: { + title: "Document Tree", + description: "The hierarchical document tree to chunk", + }, + maxTokens: { + type: "number", + title: "Max Tokens", + description: "Maximum tokens per chunk", + minimum: 50, + default: 512, + }, + overlap: { + type: "number", + title: "Overlap", + description: "Overlap in tokens between chunks", + minimum: 0, + default: 50, + }, + reservedTokens: { + type: "number", + title: "Reserved Tokens", + description: "Reserved tokens for metadata/wrappers", + minimum: 0, + default: 10, + }, + strategy: { + type: "string", + enum: ["hierarchical", "flat", "sentence"], + title: "Chunking Strategy", + description: "Strategy for chunking", + default: "hierarchical", + }, + }, + required: [], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + docId: { + type: "string", + title: "Document ID", + description: "The document ID (passed through)", + }, + chunks: { + type: "array", + items: ChunkNodeSchema(), + title: "Chunks", + description: "Array of chunk nodes", + }, + text: { + type: "array", + items: { type: "string" }, + title: "Texts", + description: "Chunk texts (for TextEmbeddingTask)", + }, + count: { + type: "number", + title: "Count", + description: "Number of chunks generated", + }, + }, + required: ["docId", "chunks", "text", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type HierarchicalChunkerTaskInput = FromSchema; +export type HierarchicalChunkerTaskOutput = FromSchema; + +/** + * Task for hierarchical chunking that respects token budgets and document structure + */ +export class HierarchicalChunkerTask extends Task< + HierarchicalChunkerTaskInput, + HierarchicalChunkerTaskOutput, + JobQueueTaskConfig +> { + public static type = "HierarchicalChunkerTask"; + public static category = "Document"; + public static title = "Hierarchical Chunker"; + public static description = "Chunk documents hierarchically respecting token budgets"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: HierarchicalChunkerTaskInput, + context: IExecuteContext + ): Promise { + const { + docId, + documentTree, + maxTokens = 512, + overlap = 50, + reservedTokens = 10, + strategy = "hierarchical", + } = input; + + if (!docId) { + throw new Error("docId is required"); + } + if (!documentTree) { + throw new Error("documentTree is required"); + } + + const root = documentTree as DocumentNode; + const tokenBudget: TokenBudget = { + maxTokensPerChunk: maxTokens, + overlapTokens: overlap, + reservedTokens, + }; + + const chunks: ChunkNode[] = []; + + if (strategy === "hierarchical") { + await this.chunkHierarchically(root, [], docId, tokenBudget, chunks); + } else { + // Flat chunking: treat entire document as flat text + await this.chunkFlat(root, docId, tokenBudget, chunks); + } + + return { + docId, + chunks, + text: chunks.map((c) => c.text), + count: chunks.length, + }; + } + + /** + * Hierarchical chunking that respects document structure + */ + private async chunkHierarchically( + node: DocumentNode, + nodePath: string[], + docId: string, + tokenBudget: TokenBudget, + chunks: ChunkNode[] + ): Promise { + const currentPath = [...nodePath, node.nodeId]; + + // If node has no children, it's a leaf - chunk its text + if (!hasChildren(node)) { + await this.chunkText(node.text, currentPath, docId, tokenBudget, chunks, node.nodeId); + return; + } + + // For nodes with children, recursively chunk children + const children = getChildren(node); + for (const child of children) { + await this.chunkHierarchically(child, currentPath, docId, tokenBudget, chunks); + } + } + + /** + * Chunk a single text string + */ + private async chunkText( + text: string, + nodePath: string[], + docId: string, + tokenBudget: TokenBudget, + chunks: ChunkNode[], + leafNodeId: string + ): Promise { + const maxChars = (tokenBudget.maxTokensPerChunk - tokenBudget.reservedTokens) * 4; + const overlapChars = tokenBudget.overlapTokens * 4; + + if (estimateTokens(text) <= tokenBudget.maxTokensPerChunk - tokenBudget.reservedTokens) { + // Text fits in one chunk + const chunkId = await NodeIdGenerator.generateChunkId(docId, leafNodeId, 0); + chunks.push({ + chunkId, + docId, + text, + nodePath, + depth: nodePath.length, + }); + return; + } + + // Split into multiple chunks with overlap + let chunkOrdinal = 0; + let startOffset = 0; + + while (startOffset < text.length) { + const endOffset = Math.min(startOffset + maxChars, text.length); + const chunkText = text.substring(startOffset, endOffset); + + const chunkId = await NodeIdGenerator.generateChunkId(docId, leafNodeId, chunkOrdinal); + + chunks.push({ + chunkId, + docId, + text: chunkText, + nodePath, + depth: nodePath.length, + }); + + chunkOrdinal++; + startOffset += maxChars - overlapChars; + + // Prevent infinite loop + if (overlapChars >= maxChars) { + startOffset = endOffset; + } + } + } + + /** + * Flat chunking (ignores hierarchy) + */ + private async chunkFlat( + root: DocumentNode, + docId: string, + tokenBudget: TokenBudget, + chunks: ChunkNode[] + ): Promise { + // Collect all text from the tree + const allText = this.collectAllText(root); + await this.chunkText(allText, [root.nodeId], docId, tokenBudget, chunks, root.nodeId); + } + + /** + * Collect all text from a node and its descendants + */ + private collectAllText(node: DocumentNode): string { + const texts: string[] = []; + + const traverse = (n: DocumentNode) => { + if (!hasChildren(n)) { + texts.push(n.text); + } else { + for (const child of getChildren(n)) { + traverse(child); + } + } + }; + + traverse(node); + return texts.join("\n\n"); + } +} + +TaskRegistry.registerTask(HierarchicalChunkerTask); + +export const hierarchicalChunker = ( + input: HierarchicalChunkerTaskInput, + config?: JobQueueTaskConfig +) => { + return new HierarchicalChunkerTask({} as HierarchicalChunkerTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + hierarchicalChunker: CreateWorkflow< + HierarchicalChunkerTaskInput, + HierarchicalChunkerTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.hierarchicalChunker = CreateWorkflow(HierarchicalChunkerTask); diff --git a/packages/ai/src/task/HierarchyJoinTask.ts b/packages/ai/src/task/HierarchyJoinTask.ts new file mode 100644 index 00000000..010dabd8 --- /dev/null +++ b/packages/ai/src/task/HierarchyJoinTask.ts @@ -0,0 +1,248 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { DocumentRepository } from "../source/DocumentRepository"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; + +import { + type ChunkMetadata, + ChunkMetadataArraySchema, + EnrichedChunkMetadataArraySchema, +} from "../source/DocumentSchema"; + +const inputSchema = { + type: "object", + properties: { + documentRepository: { + title: "Document Repository", + description: "The document repository to query for hierarchy", + }, + chunks: { + type: "array", + items: { type: "string" }, + title: "Chunks", + description: "Retrieved text chunks", + }, + ids: { + type: "array", + items: { type: "string" }, + title: "Chunk IDs", + description: "IDs of retrieved chunks", + }, + metadata: ChunkMetadataArraySchema, + scores: { + type: "array", + items: { type: "number" }, + title: "Scores", + description: "Similarity scores for each result", + }, + includeParentSummaries: { + type: "boolean", + title: "Include Parent Summaries", + description: "Whether to include summaries from parent nodes", + default: true, + }, + includeEntities: { + type: "boolean", + title: "Include Entities", + description: "Whether to include entities from the node hierarchy", + default: true, + }, + }, + required: ["documentRepository", "chunks", "ids", "metadata", "scores"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + chunks: { + type: "array", + items: { type: "string" }, + title: "Chunks", + description: "Retrieved text chunks", + }, + ids: { + type: "array", + items: { type: "string" }, + title: "Chunk IDs", + description: "IDs of retrieved chunks", + }, + metadata: EnrichedChunkMetadataArraySchema, + scores: { + type: "array", + items: { type: "number" }, + title: "Scores", + description: "Similarity scores", + }, + count: { + type: "number", + title: "Count", + description: "Number of results", + }, + }, + required: ["chunks", "ids", "metadata", "scores", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type HierarchyJoinTaskInput = FromSchema; +export type HierarchyJoinTaskOutput = FromSchema; + +/** + * Task for enriching search results with hierarchy information + * Joins chunk IDs back to document repository to get parent summaries and entities + */ +export class HierarchyJoinTask extends Task< + HierarchyJoinTaskInput, + HierarchyJoinTaskOutput, + JobQueueTaskConfig +> { + public static type = "HierarchyJoinTask"; + public static category = "RAG"; + public static title = "Hierarchy Join"; + public static description = "Enrich search results with document hierarchy context"; + public static cacheable = false; // Has external dependency + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: HierarchyJoinTaskInput, + context: IExecuteContext + ): Promise { + const { + documentRepository, + chunks, + ids, + metadata, + scores, + includeParentSummaries = true, + includeEntities = true, + } = input; + + const repo = documentRepository as DocumentRepository; + const enrichedMetadata: any[] = []; + + for (let i = 0; i < ids.length; i++) { + const chunkId = ids[i]; + const originalMetadata: ChunkMetadata | undefined = metadata[i]; + + if (!originalMetadata) { + // Skip if metadata is missing + enrichedMetadata.push({} as ChunkMetadata); + continue; + } + + // Extract docId and nodeId from metadata + const docId = originalMetadata.docId; + const leafNodeId = originalMetadata.leafNodeId; + + if (!docId || !leafNodeId) { + // Can't enrich without IDs + enrichedMetadata.push(originalMetadata); + continue; + } + + try { + // Get ancestors from document repository + const ancestors = await repo.getAncestors(docId, leafNodeId); + + const enriched: any = { ...originalMetadata }; + + // Add parent summaries + if (includeParentSummaries && ancestors.length > 0) { + const parentSummaries: string[] = []; + const sectionTitles: string[] = []; + + for (const ancestor of ancestors) { + if (ancestor.enrichment?.summary) { + parentSummaries.push(ancestor.enrichment.summary); + } + if (ancestor.kind === "section" && (ancestor as any).title) { + sectionTitles.push((ancestor as any).title); + } + } + + if (parentSummaries.length > 0) { + enriched.parentSummaries = parentSummaries; + } + if (sectionTitles.length > 0) { + enriched.sectionTitles = sectionTitles; + } + } + + // Add entities (rolled up from ancestors) + if (includeEntities && ancestors.length > 0) { + const allEntities: any[] = []; + + for (const ancestor of ancestors) { + if (ancestor.enrichment?.entities) { + allEntities.push(...ancestor.enrichment.entities); + } + } + + // Deduplicate entities + const uniqueEntities = new Map(); + for (const entity of allEntities) { + const existing = uniqueEntities.get(entity.text); + if (!existing || entity.score > existing.score) { + uniqueEntities.set(entity.text, entity); + } + } + + if (uniqueEntities.size > 0) { + enriched.entities = Array.from(uniqueEntities.values()); + } + } + + enrichedMetadata.push(enriched); + } catch (error) { + // If join fails, keep original metadata + console.error(`Failed to join hierarchy for chunk ${chunkId}:`, error); + enrichedMetadata.push(originalMetadata); + } + } + + return { + chunks, + ids, + metadata: enrichedMetadata, + scores, + count: chunks.length, + }; + } +} + +TaskRegistry.registerTask(HierarchyJoinTask); + +export const hierarchyJoin = (input: HierarchyJoinTaskInput, config?: JobQueueTaskConfig) => { + return new HierarchyJoinTask({} as HierarchyJoinTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + hierarchyJoin: CreateWorkflow< + HierarchyJoinTaskInput, + HierarchyJoinTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.hierarchyJoin = CreateWorkflow(HierarchyJoinTask); diff --git a/packages/ai/src/task/HybridSearchTask.ts b/packages/ai/src/task/HybridSearchTask.ts new file mode 100644 index 00000000..4d27b76c --- /dev/null +++ b/packages/ai/src/task/HybridSearchTask.ts @@ -0,0 +1,233 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { IVectorRepository, TypeVectorRepository } from "@workglow/storage"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { + DataPortSchema, + FromSchema, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; + +const inputSchema = { + type: "object", + properties: { + repository: TypeVectorRepository({ + title: "Vector Repository", + description: "The vector repository instance to search in (must support hybridSearch)", + }), + queryVector: TypedArraySchema({ + title: "Query Vector", + description: "The query vector for semantic search", + }), + queryText: { + type: "string", + title: "Query Text", + description: "The query text for full-text search", + }, + topK: { + type: "number", + title: "Top K", + description: "Number of top results to return", + minimum: 1, + default: 10, + }, + filter: { + type: "object", + title: "Metadata Filter", + description: "Filter results by metadata fields", + }, + scoreThreshold: { + type: "number", + title: "Score Threshold", + description: "Minimum combined score threshold (0-1)", + minimum: 0, + maximum: 1, + default: 0, + }, + vectorWeight: { + type: "number", + title: "Vector Weight", + description: "Weight for vector similarity (0-1), remainder goes to text relevance", + minimum: 0, + maximum: 1, + default: 0.7, + }, + returnVectors: { + type: "boolean", + title: "Return Vectors", + description: "Whether to return vector embeddings in results", + default: false, + }, + }, + required: ["repository", "queryVector", "queryText"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + chunks: { + type: "array", + items: { type: "string" }, + title: "Text Chunks", + description: "Retrieved text chunks", + }, + ids: { + type: "array", + items: { type: "string" }, + title: "IDs", + description: "IDs of retrieved chunks", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata of retrieved chunk", + }, + title: "Metadata", + description: "Metadata of retrieved chunks", + }, + scores: { + type: "array", + items: { type: "number" }, + title: "Scores", + description: "Combined relevance scores for each result", + }, + vectors: { + type: "array", + items: TypedArraySchema({ + title: "Vector", + description: "Vector embedding", + }), + title: "Vectors", + description: "Vector embeddings (if returnVectors is true)", + }, + count: { + type: "number", + title: "Count", + description: "Number of results returned", + }, + }, + required: ["chunks", "ids", "metadata", "scores", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type HybridSearchTaskInput = FromSchema; +export type HybridSearchTaskOutput = FromSchema; + +/** + * Task for hybrid search combining vector similarity and full-text search. + * Requires a vector repository that supports hybridSearch (e.g., SeekDB, Postgres with pgvector). + * + * Hybrid search improves retrieval by combining: + * - Semantic similarity (vector search) - understands meaning + * - Keyword matching (full-text search) - finds exact terms + */ +export class HybridSearchTask extends Task< + HybridSearchTaskInput, + HybridSearchTaskOutput, + JobQueueTaskConfig +> { + public static type = "HybridSearchTask"; + public static category = "RAG"; + public static title = "Hybrid Search"; + public static description = "Combined vector + full-text search for improved retrieval"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: HybridSearchTaskInput, + context: IExecuteContext + ): Promise { + const { + repository, + queryVector, + queryText, + topK = 10, + filter, + scoreThreshold = 0, + vectorWeight = 0.7, + returnVectors = false, + } = input; + + // Repository is resolved by input resolver system before execution + const repo = repository as unknown as IVectorRepository; + + // Check if repository supports hybrid search + if (!repo.hybridSearch) { + throw new Error( + "Repository does not support hybrid search. Use SeekDbVectorRepository or PostgresVectorRepository." + ); + } + + // Convert to Float32Array for repository search (repo expects Float32Array by default) + const searchVector = + queryVector instanceof Float32Array ? queryVector : new Float32Array(queryVector); + + // Perform hybrid search + const results = await repo.hybridSearch(searchVector, { + textQuery: queryText, + topK, + filter, + scoreThreshold, + vectorWeight, + }); + + // Extract text chunks from metadata + const chunks = results.map((r) => { + const meta = r.metadata as any; + return meta.text || meta.content || meta.chunk || JSON.stringify(meta); + }); + + const output: HybridSearchTaskOutput = { + chunks, + ids: results.map((r) => r.id), + metadata: results.map((r) => r.metadata), + scores: results.map((r) => r.score), + count: results.length, + }; + + if (returnVectors) { + output.vectors = results.map((r) => r.vector); + } + + return output; + } +} + +TaskRegistry.registerTask(HybridSearchTask); + +export const hybridSearch = async ( + input: HybridSearchTaskInput, + config?: JobQueueTaskConfig +): Promise => { + return new HybridSearchTask({} as HybridSearchTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + hybridSearch: CreateWorkflow; + } +} + +Workflow.prototype.hybridSearch = CreateWorkflow(HybridSearchTask); diff --git a/packages/ai/src/task/ImageClassificationTask.ts b/packages/ai/src/task/ImageClassificationTask.ts index 9dd75954..858c40f3 100644 --- a/packages/ai/src/task/ImageClassificationTask.ts +++ b/packages/ai/src/task/ImageClassificationTask.ts @@ -6,21 +6,15 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeCategory, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeCategory, TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:ImageClassificationTask")); +const modelSchema = TypeModel("model:ImageClassificationTask"); export const ImageClassificationInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, categories: { type: "array", @@ -64,12 +58,6 @@ export const ImageClassificationOutputSchema = { export type ImageClassificationTaskInput = FromSchema; export type ImageClassificationTaskOutput = FromSchema; -export type ImageClassificationTaskExecuteInput = DeReplicateFromSchema< - typeof ImageClassificationInputSchema ->; -export type ImageClassificationTaskExecuteOutput = DeReplicateFromSchema< - typeof ImageClassificationOutputSchema ->; /** * Classifies images into categories using vision models. @@ -105,7 +93,7 @@ export const imageClassification = ( input: ImageClassificationTaskInput, config?: JobQueueTaskConfig ) => { - return new ImageClassificationTask(input, config).run(); + return new ImageClassificationTask({} as ImageClassificationTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/ImageEmbeddingTask.ts b/packages/ai/src/task/ImageEmbeddingTask.ts index 80194f52..94e0219c 100644 --- a/packages/ai/src/task/ImageEmbeddingTask.ts +++ b/packages/ai/src/task/ImageEmbeddingTask.ts @@ -5,17 +5,16 @@ */ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; -import { DataPortSchema, FromSchema } from "@workglow/util"; import { - DeReplicateFromSchema, + DataPortSchema, + FromSchema, TypedArraySchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; + TypedArraySchemaOptions, +} from "@workglow/util"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:ImageEmbeddingTask")); +const modelSchema = TypeModel("model:ImageEmbeddingTask"); const embeddingSchema = TypedArraySchema({ title: "Embedding", @@ -25,7 +24,7 @@ const embeddingSchema = TypedArraySchema({ export const ImageEmbeddingInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, }, required: ["image", "model"], @@ -35,23 +34,19 @@ export const ImageEmbeddingInputSchema = { export const ImageEmbeddingOutputSchema = { type: "object", properties: { - vector: { - oneOf: [embeddingSchema, { type: "array", items: embeddingSchema }], - title: "Embedding", - description: "The image embedding vector", - }, + vector: embeddingSchema, }, required: ["vector"], additionalProperties: false, } as const satisfies DataPortSchema; -export type ImageEmbeddingTaskInput = FromSchema; -export type ImageEmbeddingTaskOutput = FromSchema; -export type ImageEmbeddingTaskExecuteInput = DeReplicateFromSchema< - typeof ImageEmbeddingInputSchema +export type ImageEmbeddingTaskInput = FromSchema< + typeof ImageEmbeddingInputSchema, + TypedArraySchemaOptions >; -export type ImageEmbeddingTaskExecuteOutput = DeReplicateFromSchema< - typeof ImageEmbeddingOutputSchema +export type ImageEmbeddingTaskOutput = FromSchema< + typeof ImageEmbeddingOutputSchema, + TypedArraySchemaOptions >; /** @@ -83,7 +78,7 @@ TaskRegistry.registerTask(ImageEmbeddingTask); * @returns Promise resolving to the image embedding vector */ export const imageEmbedding = (input: ImageEmbeddingTaskInput, config?: JobQueueTaskConfig) => { - return new ImageEmbeddingTask(input, config).run(); + return new ImageEmbeddingTask({} as ImageEmbeddingTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/ImageSegmentationTask.ts b/packages/ai/src/task/ImageSegmentationTask.ts index 17204dbd..4c4e96ed 100644 --- a/packages/ai/src/task/ImageSegmentationTask.ts +++ b/packages/ai/src/task/ImageSegmentationTask.ts @@ -6,20 +6,15 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:ImageSegmentationTask")); +const modelSchema = TypeModel("model:ImageSegmentationTask"); export const ImageSegmentationInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, threshold: { type: "number", @@ -88,12 +83,6 @@ export const ImageSegmentationOutputSchema = { export type ImageSegmentationTaskInput = FromSchema; export type ImageSegmentationTaskOutput = FromSchema; -export type ImageSegmentationTaskExecuteInput = DeReplicateFromSchema< - typeof ImageSegmentationInputSchema ->; -export type ImageSegmentationTaskExecuteOutput = DeReplicateFromSchema< - typeof ImageSegmentationOutputSchema ->; /** * Segments images into regions using computer vision models @@ -128,7 +117,7 @@ export const imageSegmentation = ( input: ImageSegmentationTaskInput, config?: JobQueueTaskConfig ) => { - return new ImageSegmentationTask(input, config).run(); + return new ImageSegmentationTask({} as ImageSegmentationTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/ImageToTextTask.ts b/packages/ai/src/task/ImageToTextTask.ts index c7fff8bd..2cf6b919 100644 --- a/packages/ai/src/task/ImageToTextTask.ts +++ b/packages/ai/src/task/ImageToTextTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:ImageToTextTask")); +const modelSchema = TypeModel("model:ImageToTextTask"); const generatedTextSchema = { type: "string", @@ -25,7 +20,7 @@ const generatedTextSchema = { export const ImageToTextInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, maxTokens: { type: "number", @@ -55,8 +50,6 @@ export const ImageToTextOutputSchema = { export type ImageToTextTaskInput = FromSchema; export type ImageToTextTaskOutput = FromSchema; -export type ImageToTextTaskExecuteInput = DeReplicateFromSchema; -export type ImageToTextTaskExecuteOutput = DeReplicateFromSchema; /** * Generates text descriptions from images using vision-language models @@ -88,7 +81,7 @@ TaskRegistry.registerTask(ImageToTextTask); * @returns Promise resolving to the generated text description */ export const imageToText = (input: ImageToTextTaskInput, config?: JobQueueTaskConfig) => { - return new ImageToTextTask(input, config).run(); + return new ImageToTextTask({} as ImageToTextTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/ObjectDetectionTask.ts b/packages/ai/src/task/ObjectDetectionTask.ts index 78244e31..6132d1cd 100644 --- a/packages/ai/src/task/ObjectDetectionTask.ts +++ b/packages/ai/src/task/ObjectDetectionTask.ts @@ -6,16 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeBoundingBox, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeBoundingBox, TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:ObjectDetectionTask")); +const modelSchema = TypeModel("model:ObjectDetectionTask"); const detectionSchema = { type: "object", @@ -41,7 +35,7 @@ const detectionSchema = { export const ObjectDetectionInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, labels: { type: "array", @@ -85,12 +79,6 @@ export const ObjectDetectionOutputSchema = { export type ObjectDetectionTaskInput = FromSchema; export type ObjectDetectionTaskOutput = FromSchema; -export type ObjectDetectionTaskExecuteInput = DeReplicateFromSchema< - typeof ObjectDetectionInputSchema ->; -export type ObjectDetectionTaskExecuteOutput = DeReplicateFromSchema< - typeof ObjectDetectionOutputSchema ->; /** * Detects objects in images using vision models. @@ -123,7 +111,7 @@ TaskRegistry.registerTask(ObjectDetectionTask); * @returns Promise resolving to the detected objects with labels, scores, and bounding boxes */ export const objectDetection = (input: ObjectDetectionTaskInput, config?: JobQueueTaskConfig) => { - return new ObjectDetectionTask(input, config).run(); + return new ObjectDetectionTask({} as ObjectDetectionTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/PoseLandmarkerTask.ts b/packages/ai/src/task/PoseLandmarkerTask.ts index 1c596e47..8f0a45f3 100644 --- a/packages/ai/src/task/PoseLandmarkerTask.ts +++ b/packages/ai/src/task/PoseLandmarkerTask.ts @@ -6,15 +6,10 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; -import { - DeReplicateFromSchema, - TypeImageInput, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeImageInput, TypeModel } from "./base/AiTaskSchemas"; import { AiVisionTask } from "./base/AiVisionTask"; -const modelSchema = TypeReplicateArray(TypeModel("model:PoseLandmarkerTask")); +const modelSchema = TypeModel("model:PoseLandmarkerTask"); /** * A landmark point with x, y, z coordinates and visibility/presence scores. @@ -105,7 +100,7 @@ const TypePoseDetection = { export const PoseLandmarkerInputSchema = { type: "object", properties: { - image: TypeReplicateArray(TypeImageInput), + image: TypeImageInput, model: modelSchema, numPoses: { type: "number", @@ -173,12 +168,6 @@ export const PoseLandmarkerOutputSchema = { export type PoseLandmarkerTaskInput = FromSchema; export type PoseLandmarkerTaskOutput = FromSchema; -export type PoseLandmarkerTaskExecuteInput = DeReplicateFromSchema< - typeof PoseLandmarkerInputSchema ->; -export type PoseLandmarkerTaskExecuteOutput = DeReplicateFromSchema< - typeof PoseLandmarkerOutputSchema ->; /** * Detects pose landmarks in images using MediaPipe Pose Landmarker. @@ -211,7 +200,7 @@ TaskRegistry.registerTask(PoseLandmarkerTask); * @returns Promise resolving to the detected pose landmarks and optional segmentation masks */ export const poseLandmarker = (input: PoseLandmarkerTaskInput, config?: JobQueueTaskConfig) => { - return new PoseLandmarkerTask(input, config).run(); + return new PoseLandmarkerTask({} as PoseLandmarkerTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/QueryExpanderTask.ts b/packages/ai/src/task/QueryExpanderTask.ts new file mode 100644 index 00000000..b3804b19 --- /dev/null +++ b/packages/ai/src/task/QueryExpanderTask.ts @@ -0,0 +1,318 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; + +export const QueryExpansionMethod = { + MULTI_QUERY: "multi-query", + HYDE: "hyde", + SYNONYMS: "synonyms", + PARAPHRASE: "paraphrase", +} as const; + +export type QueryExpansionMethod = (typeof QueryExpansionMethod)[keyof typeof QueryExpansionMethod]; + +const inputSchema = { + type: "object", + properties: { + query: { + type: "string", + title: "Query", + description: "The original query to expand", + }, + method: { + type: "string", + enum: Object.values(QueryExpansionMethod), + title: "Expansion Method", + description: "Method to use for query expansion", + default: QueryExpansionMethod.MULTI_QUERY, + }, + numVariations: { + type: "number", + title: "Number of Variations", + description: "Number of query variations to generate", + minimum: 1, + maximum: 10, + default: 3, + }, + model: { + type: "string", + title: "Model", + description: "LLM model to use for expansion (for HyDE and paraphrase methods)", + }, + }, + required: ["query"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + queries: { + type: "array", + items: { type: "string" }, + title: "Expanded Queries", + description: "Generated query variations", + }, + originalQuery: { + type: "string", + title: "Original Query", + description: "The original input query", + }, + method: { + type: "string", + title: "Method Used", + description: "The expansion method that was used", + }, + count: { + type: "number", + title: "Count", + description: "Number of queries generated", + }, + }, + required: ["queries", "originalQuery", "method", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type QueryExpanderTaskInput = FromSchema; +export type QueryExpanderTaskOutput = FromSchema; + +/** + * Task for expanding queries to improve retrieval coverage. + * Supports multiple expansion methods including multi-query, HyDE, and paraphrasing. + * + * Note: HyDE and paraphrase methods require an LLM model. + * For now, this implements simple rule-based expansion. + */ +export class QueryExpanderTask extends Task< + QueryExpanderTaskInput, + QueryExpanderTaskOutput, + JobQueueTaskConfig +> { + public static type = "QueryExpanderTask"; + public static category = "RAG"; + public static title = "Query Expander"; + public static description = "Expand queries to improve retrieval coverage"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: QueryExpanderTaskInput, + context: IExecuteContext + ): Promise { + const { query, method = QueryExpansionMethod.MULTI_QUERY, numVariations = 3 } = input; + + let queries: string[]; + + switch (method) { + case QueryExpansionMethod.HYDE: + queries = this.hydeExpansion(query, numVariations); + break; + case QueryExpansionMethod.SYNONYMS: + queries = this.synonymExpansion(query, numVariations); + break; + case QueryExpansionMethod.PARAPHRASE: + queries = this.paraphraseExpansion(query, numVariations); + break; + case QueryExpansionMethod.MULTI_QUERY: + default: + queries = this.multiQueryExpansion(query, numVariations); + break; + } + + // Always include original query + if (!queries.includes(query)) { + queries.unshift(query); + } + + return { + queries, + originalQuery: query, + method, + count: queries.length, + }; + } + + /** + * Multi-query expansion: Generate variations by rephrasing the question + */ + private multiQueryExpansion(query: string, numVariations: number): string[] { + const queries: string[] = [query]; + + // Simple rule-based variations + const variations: string[] = []; + + // Question word variations + if (query.toLowerCase().startsWith("what")) { + variations.push(query.replace(/^what/i, "Which")); + variations.push(query.replace(/^what/i, "Can you explain")); + } else if (query.toLowerCase().startsWith("how")) { + variations.push(query.replace(/^how/i, "What is the method to")); + variations.push(query.replace(/^how/i, "In what way")); + } else if (query.toLowerCase().startsWith("why")) { + variations.push(query.replace(/^why/i, "What is the reason")); + variations.push(query.replace(/^why/i, "For what purpose")); + } else if (query.toLowerCase().startsWith("where")) { + variations.push(query.replace(/^where/i, "In which location")); + variations.push(query.replace(/^where/i, "At what place")); + } + + // Add "Tell me about" variation + if (!query.toLowerCase().startsWith("tell me")) { + variations.push(`Tell me about ${query.toLowerCase()}`); + } + + // Add "Explain" variation + if (!query.toLowerCase().startsWith("explain")) { + variations.push(`Explain ${query.toLowerCase()}`); + } + + // Take up to numVariations + for (let i = 0; i < Math.min(numVariations - 1, variations.length); i++) { + if (variations[i] && !queries.includes(variations[i])) { + queries.push(variations[i]); + } + } + + return queries; + } + + /** + * HyDE (Hypothetical Document Embeddings): Generate hypothetical answers + */ + private hydeExpansion(query: string, numVariations: number): string[] { + // TODO: in a real implementation, this would call a model to generate hypothetical answer templates + const queries: string[] = [query]; + + const templates = [ + `The answer to "${query}" is that`, + `Regarding ${query}, it is important to note that`, + `${query} can be explained by the fact that`, + `In response to ${query}, one should consider that`, + ]; + + for (let i = 0; i < Math.min(numVariations - 1, templates.length); i++) { + queries.push(templates[i]); + } + + return queries; + } + + /** + * Synonym expansion: Replace keywords with synonyms + */ + private synonymExpansion(query: string, numVariations: number): string[] { + const queries: string[] = [query]; + + // Simple synonym dictionary (in production, use a proper thesaurus) + const synonyms: Record = { + find: ["locate", "discover", "search for"], + create: ["make", "build", "generate"], + delete: ["remove", "erase", "eliminate"], + update: ["modify", "change", "edit"], + show: ["display", "present", "reveal"], + explain: ["describe", "clarify", "elaborate"], + help: ["assist", "aid", "support"], + problem: ["issue", "challenge", "difficulty"], + solution: ["answer", "resolution", "fix"], + method: ["approach", "technique", "way"], + }; + + const words = query.toLowerCase().split(/\s+/); + let variationsGenerated = 0; + + for (const [word, syns] of Object.entries(synonyms)) { + if (variationsGenerated >= numVariations - 1) break; + + const wordIndex = words.indexOf(word); + if (wordIndex !== -1) { + for (const syn of syns) { + if (variationsGenerated >= numVariations - 1) break; + + const newWords = [...words]; + newWords[wordIndex] = syn; + const newQuery = newWords.join(" "); + + // Preserve original capitalization pattern + const capitalizedQuery = this.preserveCapitalization(query, newQuery); + if (!queries.includes(capitalizedQuery)) { + queries.push(capitalizedQuery); + variationsGenerated++; + } + } + } + } + + return queries; + } + + /** + * Paraphrase expansion: Rephrase the query + * TODO: This should use an LLM for better paraphrasing + */ + private paraphraseExpansion(query: string, numVariations: number): string[] { + const queries: string[] = [query]; + + // Simple paraphrase templates + const paraphrases: string[] = []; + + // Add context + paraphrases.push(`I need information about ${query.toLowerCase()}`); + paraphrases.push(`Can you help me understand ${query.toLowerCase()}`); + paraphrases.push(`I'm looking for details on ${query.toLowerCase()}`); + + for (let i = 0; i < Math.min(numVariations - 1, paraphrases.length); i++) { + if (!queries.includes(paraphrases[i])) { + queries.push(paraphrases[i]); + } + } + + return queries; + } + + /** + * Preserve capitalization pattern from original to new query + */ + private preserveCapitalization(original: string, modified: string): string { + if (original[0] === original[0].toUpperCase()) { + return modified.charAt(0).toUpperCase() + modified.slice(1); + } + return modified; + } +} + +TaskRegistry.registerTask(QueryExpanderTask); + +export const queryExpander = (input: QueryExpanderTaskInput, config?: JobQueueTaskConfig) => { + return new QueryExpanderTask({} as QueryExpanderTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + queryExpander: CreateWorkflow< + QueryExpanderTaskInput, + QueryExpanderTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.queryExpander = CreateWorkflow(QueryExpanderTask); diff --git a/packages/ai/src/task/RerankerTask.ts b/packages/ai/src/task/RerankerTask.ts new file mode 100644 index 00000000..bcd5d8c0 --- /dev/null +++ b/packages/ai/src/task/RerankerTask.ts @@ -0,0 +1,341 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; +import { TextClassificationTask } from "./TextClassificationTask"; + +const inputSchema = { + type: "object", + properties: { + query: { + type: "string", + title: "Query", + description: "The query to rerank results against", + }, + chunks: { + type: "array", + items: { type: "string" }, + title: "Text Chunks", + description: "Retrieved text chunks to rerank", + }, + scores: { + type: "array", + items: { type: "number" }, + title: "Initial Scores", + description: "Initial retrieval scores (optional)", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata for each chunk", + }, + title: "Metadata", + description: "Metadata for each chunk (optional)", + }, + topK: { + type: "number", + title: "Top K", + description: "Number of top results to return after reranking", + minimum: 1, + }, + method: { + type: "string", + enum: ["cross-encoder", "reciprocal-rank-fusion", "simple"], + title: "Reranking Method", + description: "Method to use for reranking", + default: "simple", + }, + model: { + type: "string", + title: "Reranker Model", + description: "Cross-encoder model to use for reranking", + }, + }, + required: ["query", "chunks"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + chunks: { + type: "array", + items: { type: "string" }, + title: "Reranked Chunks", + description: "Chunks reordered by relevance", + }, + scores: { + type: "array", + items: { type: "number" }, + title: "Reranked Scores", + description: "New relevance scores", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata for each chunk", + }, + title: "Metadata", + description: "Metadata for reranked chunks", + }, + originalIndices: { + type: "array", + items: { type: "number" }, + title: "Original Indices", + description: "Original indices of reranked chunks", + }, + count: { + type: "number", + title: "Count", + description: "Number of results returned", + }, + }, + required: ["chunks", "scores", "originalIndices", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type RerankerTaskInput = FromSchema; +export type RerankerTaskOutput = FromSchema; + +interface RankedItem { + chunk: string; + score: number; + metadata?: any; + originalIndex: number; +} + +/** + * Task for reranking retrieved chunks to improve relevance. + * Supports multiple reranking methods including cross-encoder models. + * + * Note: Cross-encoder reranking requires a model to be loaded. + * For now, this implements simple heuristic-based reranking. + */ +export class RerankerTask extends Task { + public static type = "RerankerTask"; + public static category = "RAG"; + public static title = "Reranker"; + public static description = "Rerank retrieved chunks to improve relevance"; + public static cacheable = true; + private resolvedCrossEncoderModel?: string | null; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute(input: RerankerTaskInput, context: IExecuteContext): Promise { + const { query, chunks, scores = [], metadata = [], topK, method = "simple", model } = input; + + let rankedItems: RankedItem[]; + + switch (method) { + case "cross-encoder": + rankedItems = await this.crossEncoderRerank( + query, + chunks, + scores, + metadata, + model, + context + ); + break; + case "reciprocal-rank-fusion": + rankedItems = this.reciprocalRankFusion(chunks, scores, metadata); + break; + case "simple": + default: + rankedItems = this.simpleRerank(query, chunks, scores, metadata); + break; + } + + // Apply topK if specified + if (topK && topK < rankedItems.length) { + rankedItems = rankedItems.slice(0, topK); + } + + return { + chunks: rankedItems.map((item) => item.chunk), + scores: rankedItems.map((item) => item.score), + metadata: rankedItems.map((item) => item.metadata), + originalIndices: rankedItems.map((item) => item.originalIndex), + count: rankedItems.length, + }; + } + + private async crossEncoderRerank( + query: string, + chunks: string[], + scores: number[], + metadata: any[], + model: string | undefined, + context: IExecuteContext + ): Promise { + if (chunks.length === 0) { + return []; + } + + if (!model) { + throw new Error( + "No cross-encoder model found. Please provide a model or register a TextClassificationTask model." + ); + } + + const items = await Promise.all( + chunks.map(async (chunk, index) => { + const pairText = `${query} [SEP] ${chunk}`; + const task = context.own( + new TextClassificationTask({ text: pairText, model: model, maxCategories: 2 }) + ); + const result = await task.run(); + const crossScore = this.extractCrossEncoderScore(result.categories); + return { + chunk, + score: Number.isFinite(crossScore) ? crossScore : scores[index] || 0, + metadata: metadata[index], + originalIndex: index, + }; + }) + ); + + items.sort((a, b) => b.score - a.score); + return items; + } + + private extractCrossEncoderScore( + categories: Array<{ label: string; score: number }> | undefined + ): number { + if (!categories || categories.length === 0) { + return 0; + } + const preferred = categories.find((category) => + /^(label_1|positive|relevant|yes|true)$/i.test(category.label) + ); + if (preferred) { + return preferred.score; + } + let best = categories[0].score; + for (let i = 1; i < categories.length; i++) { + if (categories[i].score > best) { + best = categories[i].score; + } + } + return best; + } + + /** + * Simple heuristic-based reranking using keyword matching and position + */ + private simpleRerank( + query: string, + chunks: string[], + scores: number[], + metadata: any[] + ): RankedItem[] { + const queryLower = query.toLowerCase(); + const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); + + const items: RankedItem[] = chunks.map((chunk, index) => { + const chunkLower = chunk.toLowerCase(); + const initialScore = scores[index] || 0; + + // Calculate keyword match score + let keywordScore = 0; + let exactMatchBonus = 0; + + for (const word of queryWords) { + // Count occurrences + const regex = new RegExp(word, "gi"); + const matches = chunkLower.match(regex); + if (matches) { + keywordScore += matches.length; + } + } + + // Bonus for exact query match + if (chunkLower.includes(queryLower)) { + exactMatchBonus = 0.5; + } + + // Normalize keyword score + const normalizedKeywordScore = Math.min(keywordScore / (queryWords.length * 3), 1); + + // Position penalty (prefer earlier results, but not too heavily) + const positionPenalty = Math.log(index + 1) / 10; + + // Combined score + const combinedScore = + initialScore * 0.4 + normalizedKeywordScore * 0.4 + exactMatchBonus * 0.2 - positionPenalty; + + return { + chunk, + score: combinedScore, + metadata: metadata[index], + originalIndex: index, + }; + }); + + // Sort by score descending + items.sort((a, b) => b.score - a.score); + + return items; + } + + /** + * Reciprocal Rank Fusion for combining multiple rankings + * Useful when you have multiple retrieval methods + */ + private reciprocalRankFusion(chunks: string[], scores: number[], metadata: any[]): RankedItem[] { + const k = 60; // RRF constant + + const items: RankedItem[] = chunks.map((chunk, index) => { + // RRF score = 1 / (k + rank) + // Here we use the initial ranking (index) as the rank + const rrfScore = 1 / (k + index + 1); + + return { + chunk, + score: rrfScore, + metadata: metadata[index], + originalIndex: index, + }; + }); + + // Sort by RRF score descending + items.sort((a, b) => b.score - a.score); + + return items; + } +} + +TaskRegistry.registerTask(RerankerTask); + +export const reranker = (input: RerankerTaskInput, config?: JobQueueTaskConfig) => { + return new RerankerTask({} as RerankerTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + reranker: CreateWorkflow; + } +} + +Workflow.prototype.reranker = CreateWorkflow(RerankerTask); diff --git a/packages/ai/src/task/RetrievalTask.ts b/packages/ai/src/task/RetrievalTask.ts new file mode 100644 index 00000000..2dcc6763 --- /dev/null +++ b/packages/ai/src/task/RetrievalTask.ts @@ -0,0 +1,243 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { IVectorRepository, TypeVectorRepository } from "@workglow/storage"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { + DataPortSchema, + FromSchema, + TypedArray, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; +import { TypeModel } from "./base/AiTaskSchemas"; +import { TextEmbeddingTask } from "./TextEmbeddingTask"; + +const inputSchema = { + type: "object", + properties: { + repository: TypeVectorRepository({ + title: "Vector Repository", + description: "The vector repository instance to search in", + }), + query: { + oneOf: [ + { type: "string" }, + TypedArraySchema({ + title: "Query Vector", + description: "Pre-computed query vector", + }), + ], + title: "Query", + description: "Query string or pre-computed query vector", + }, + model: TypeModel("model:TextEmbeddingTask", { + title: "Model", + description: + "Text embedding model to use for query embedding (required when query is a string)", + }), + topK: { + type: "number", + title: "Top K", + description: "Number of top results to return", + minimum: 1, + default: 5, + }, + filter: { + type: "object", + title: "Metadata Filter", + description: "Filter results by metadata fields", + }, + scoreThreshold: { + type: "number", + title: "Score Threshold", + description: "Minimum similarity score threshold (0-1)", + minimum: 0, + maximum: 1, + default: 0, + }, + returnVectors: { + type: "boolean", + title: "Return Vectors", + description: "Whether to return vector embeddings in results", + default: false, + }, + }, + required: ["repository", "query"], + if: { + properties: { + query: { type: "string" }, + }, + }, + then: { + required: ["repository", "query", "model"], + }, + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + chunks: { + type: "array", + items: { type: "string" }, + title: "Text Chunks", + description: "Retrieved text chunks", + }, + ids: { + type: "array", + items: { type: "string" }, + title: "IDs", + description: "IDs of retrieved chunks", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata of retrieved chunk", + }, + title: "Metadata", + description: "Metadata of retrieved chunks", + }, + scores: { + type: "array", + items: { type: "number" }, + title: "Scores", + description: "Similarity scores for each result", + }, + vectors: { + type: "array", + items: TypedArraySchema({ + title: "Vector", + description: "Vector embedding", + }), + title: "Vectors", + description: "Vector embeddings (if returnVectors is true)", + }, + count: { + type: "number", + title: "Count", + description: "Number of results returned", + }, + }, + required: ["chunks", "ids", "metadata", "scores", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type RetrievalTaskInput = FromSchema; +export type RetrievalTaskOutput = FromSchema; + +/** + * End-to-end retrieval task that combines embedding generation (if needed) and vector search. + * Simplifies the RAG pipeline by handling the full retrieval process. + */ +export class RetrievalTask extends Task< + RetrievalTaskInput, + RetrievalTaskOutput, + JobQueueTaskConfig +> { + public static type = "RetrievalTask"; + public static category = "RAG"; + public static title = "Retrieval"; + public static description = "End-to-end retrieval: embed query and search for similar chunks"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute(input: RetrievalTaskInput, context: IExecuteContext): Promise { + const { + repository, + query, + topK = 5, + filter, + model, + scoreThreshold = 0, + returnVectors = false, + } = input; + + // Repository is resolved by input resolver system before execution + const repo = repository as unknown as IVectorRepository; + + // Determine query vector + let queryVector: TypedArray; + if (typeof query === "string") { + // If query is a string, model must be provided (enforced by schema) + if (!model) { + throw new Error( + "Model is required when query is a string. Please provide a model with format 'model:TextEmbeddingTask'." + ); + } + const embeddingTask = context.own(new TextEmbeddingTask({ text: query, model })); + const embeddingResult = await embeddingTask.run(); + queryVector = Array.isArray(embeddingResult.vector) + ? embeddingResult.vector[0] + : embeddingResult.vector; + } else { + // Query is already a vector + queryVector = query as TypedArray; + } + + // Convert to Float32Array for repository search (repo expects Float32Array by default) + const searchVector = + queryVector instanceof Float32Array ? queryVector : new Float32Array(queryVector); + + // Search vector repository + const results = await repo.similaritySearch(searchVector, { + topK, + filter, + scoreThreshold, + }); + + // Extract text chunks from metadata + // Assumes metadata has a 'text' or 'content' field + const chunks = results.map((r) => { + const meta = r.metadata as any; + return meta.text || meta.content || meta.chunk || JSON.stringify(meta); + }); + + const output: RetrievalTaskOutput = { + chunks, + ids: results.map((r) => r.id), + metadata: results.map((r) => r.metadata), + scores: results.map((r) => r.score), + count: results.length, + }; + + if (returnVectors) { + output.vectors = results.map((r) => r.vector); + } + + return output; + } +} + +TaskRegistry.registerTask(RetrievalTask); + +export const retrieval = (input: RetrievalTaskInput, config?: JobQueueTaskConfig) => { + return new RetrievalTask({} as RetrievalTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + retrieval: CreateWorkflow; + } +} + +Workflow.prototype.retrieval = CreateWorkflow(RetrievalTask); diff --git a/packages/ai/src/task/StructuralParserTask.ts b/packages/ai/src/task/StructuralParserTask.ts new file mode 100644 index 00000000..2ab766ba --- /dev/null +++ b/packages/ai/src/task/StructuralParserTask.ts @@ -0,0 +1,161 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; +import { NodeIdGenerator } from "../source/DocumentNode"; +import { DocumentNode } from "../source/DocumentSchema"; +import { StructuralParser } from "../source/StructuralParser"; + +const inputSchema = { + type: "object", + properties: { + text: { + type: "string", + title: "Text", + description: "The text content to parse", + }, + title: { + type: "string", + title: "Title", + description: "Document title", + }, + format: { + type: "string", + enum: ["markdown", "text", "auto"], + title: "Format", + description: "Document format (auto-detects if not specified)", + default: "auto", + }, + sourceUri: { + type: "string", + title: "Source URI", + description: "Source identifier for document ID generation", + }, + docId: { + type: "string", + title: "Document ID", + description: "Pre-generated document ID (optional)", + }, + }, + required: ["text", "title"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + docId: { + type: "string", + title: "Document ID", + description: "Generated or provided document ID", + }, + documentTree: { + title: "Document Tree", + description: "Parsed hierarchical document tree", + }, + nodeCount: { + type: "number", + title: "Node Count", + description: "Total number of nodes in the tree", + }, + }, + required: ["docId", "documentTree", "nodeCount"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type StructuralParserTaskInput = FromSchema; +export type StructuralParserTaskOutput = FromSchema; + +/** + * Task for parsing documents into hierarchical tree structure + * Supports markdown and plain text with automatic format detection + */ +export class StructuralParserTask extends Task< + StructuralParserTaskInput, + StructuralParserTaskOutput, + JobQueueTaskConfig +> { + public static type = "StructuralParserTask"; + public static category = "Document"; + public static title = "Structural Parser"; + public static description = "Parse documents into hierarchical tree structure"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: StructuralParserTaskInput, + context: IExecuteContext + ): Promise { + const { text, title, format = "auto", sourceUri, docId: providedDocId } = input; + + // Generate or use provided docId + const docId = + providedDocId || (await NodeIdGenerator.generateDocId(sourceUri || "document", text)); + + // Parse based on format + let documentTree: DocumentNode; + if (format === "markdown") { + documentTree = await StructuralParser.parseMarkdown(docId, text, title); + } else if (format === "text") { + documentTree = await StructuralParser.parsePlainText(docId, text, title); + } else { + // Auto-detect + documentTree = await StructuralParser.parse(docId, text, title); + } + + // Count nodes + const nodeCount = this.countNodes(documentTree); + + return { + docId, + documentTree, + nodeCount, + }; + } + + private countNodes(node: any): number { + let count = 1; + if (node.children && Array.isArray(node.children)) { + for (const child of node.children) { + count += this.countNodes(child); + } + } + return count; + } +} + +TaskRegistry.registerTask(StructuralParserTask); + +export const structuralParser = (input: StructuralParserTaskInput, config?: JobQueueTaskConfig) => { + return new StructuralParserTask({} as StructuralParserTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + structuralParser: CreateWorkflow< + StructuralParserTaskInput, + StructuralParserTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.structuralParser = CreateWorkflow(StructuralParserTask); diff --git a/packages/ai/src/task/TextChunkerTask.ts b/packages/ai/src/task/TextChunkerTask.ts new file mode 100644 index 00000000..99cce8f9 --- /dev/null +++ b/packages/ai/src/task/TextChunkerTask.ts @@ -0,0 +1,358 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; + +export const ChunkingStrategy = { + FIXED: "fixed", + SENTENCE: "sentence", + PARAGRAPH: "paragraph", + SEMANTIC: "semantic", +} as const; + +export type ChunkingStrategy = (typeof ChunkingStrategy)[keyof typeof ChunkingStrategy]; + +const inputSchema = { + type: "object", + properties: { + text: { + type: "string", + title: "Text", + description: "The text to chunk", + }, + chunkSize: { + type: "number", + title: "Chunk Size", + description: "Maximum size of each chunk in characters", + minimum: 1, + default: 512, + }, + chunkOverlap: { + type: "number", + title: "Chunk Overlap", + description: "Number of characters to overlap between chunks", + minimum: 0, + default: 50, + }, + strategy: { + type: "string", + enum: Object.values(ChunkingStrategy), + title: "Chunking Strategy", + description: "Strategy to use for chunking text", + default: ChunkingStrategy.FIXED, + }, + }, + required: ["text"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + chunks: { + type: "array", + items: { type: "string" }, + title: "Text Chunks", + description: "The chunked text segments", + }, + metadata: { + type: "array", + items: { + type: "object", + properties: { + index: { type: "number" }, + startChar: { type: "number" }, + endChar: { type: "number" }, + length: { type: "number" }, + }, + additionalProperties: false, + }, + title: "Chunk Metadata", + description: "Metadata for each chunk", + }, + }, + required: ["chunks", "metadata"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type TextChunkerTaskInput = FromSchema; +export type TextChunkerTaskOutput = FromSchema; + +interface ChunkMetadata { + index: number; + startChar: number; + endChar: number; + length: number; +} + +/** + * Task for chunking text into smaller segments with configurable strategies. + * Supports fixed-size, sentence-based, paragraph-based, and semantic chunking. + */ +export class TextChunkerTask extends Task< + TextChunkerTaskInput, + TextChunkerTaskOutput, + JobQueueTaskConfig +> { + public static type = "TextChunkerTask"; + public static category = "Document"; + public static title = "Text Chunker"; + public static description = + "Splits text into chunks using various strategies (fixed, sentence, paragraph)"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: TextChunkerTaskInput, + context: IExecuteContext + ): Promise { + const { text, chunkSize = 512, chunkOverlap = 50, strategy = ChunkingStrategy.FIXED } = input; + + let chunks: string[]; + let metadata: ChunkMetadata[]; + + switch (strategy) { + case ChunkingStrategy.SENTENCE: + ({ chunks, metadata } = this.chunkBySentence(text, chunkSize, chunkOverlap)); + break; + case ChunkingStrategy.PARAGRAPH: + ({ chunks, metadata } = this.chunkByParagraph(text, chunkSize, chunkOverlap)); + break; + case ChunkingStrategy.SEMANTIC: + // For now, semantic is the same as sentence-based + // TODO: Implement true semantic chunking with embeddings + ({ chunks, metadata } = this.chunkBySentence(text, chunkSize, chunkOverlap)); + break; + case ChunkingStrategy.FIXED: + default: + ({ chunks, metadata } = this.chunkFixed(text, chunkSize, chunkOverlap)); + break; + } + + return { chunks, metadata }; + } + + /** + * Fixed-size chunking with overlap + */ + private chunkFixed( + text: string, + chunkSize: number, + chunkOverlap: number + ): { chunks: string[]; metadata: ChunkMetadata[] } { + const chunks: string[] = []; + const metadata: ChunkMetadata[] = []; + let startChar = 0; + let index = 0; + + while (startChar < text.length) { + const endChar = Math.min(startChar + chunkSize, text.length); + const chunk = text.substring(startChar, endChar); + chunks.push(chunk); + metadata.push({ + index, + startChar, + endChar, + length: chunk.length, + }); + + // Move forward by chunkSize - chunkOverlap, but at least 1 character to prevent infinite loop + const step = Math.max(1, chunkSize - chunkOverlap); + startChar += step; + index++; + } + + return { chunks, metadata }; + } + + /** + * Sentence-based chunking that respects sentence boundaries + */ + private chunkBySentence( + text: string, + chunkSize: number, + chunkOverlap: number + ): { chunks: string[]; metadata: ChunkMetadata[] } { + // Split by sentence boundaries (., !, ?, followed by space or newline) + const sentenceRegex = /[.!?]+[\s\n]+/g; + const sentences: string[] = []; + const sentenceStarts: number[] = []; + let lastIndex = 0; + let match: RegExpExecArray | null; + + while ((match = sentenceRegex.exec(text)) !== null) { + const sentence = text.substring(lastIndex, match.index + match[0].length); + sentences.push(sentence); + sentenceStarts.push(lastIndex); + lastIndex = match.index + match[0].length; + } + + // Add remaining text as last sentence + if (lastIndex < text.length) { + sentences.push(text.substring(lastIndex)); + sentenceStarts.push(lastIndex); + } + + // Group sentences into chunks + const chunks: string[] = []; + const metadata: ChunkMetadata[] = []; + let currentChunk = ""; + let currentStartChar = 0; + let index = 0; + + for (let i = 0; i < sentences.length; i++) { + const sentence = sentences[i]; + const sentenceStart = sentenceStarts[i]; + + // If adding this sentence would exceed chunkSize, save current chunk + if (currentChunk.length > 0 && currentChunk.length + sentence.length > chunkSize) { + chunks.push(currentChunk.trim()); + metadata.push({ + index, + startChar: currentStartChar, + endChar: currentStartChar + currentChunk.length, + length: currentChunk.trim().length, + }); + index++; + + // Start new chunk with overlap + if (chunkOverlap > 0) { + // Find sentences to include in overlap + let overlapText = ""; + let j = i - 1; + while (j >= 0 && overlapText.length < chunkOverlap) { + overlapText = sentences[j] + overlapText; + j--; + } + currentChunk = overlapText + sentence; + currentStartChar = sentenceStarts[Math.max(0, j + 1)]; + } else { + currentChunk = sentence; + currentStartChar = sentenceStart; + } + } else { + if (currentChunk.length === 0) { + currentStartChar = sentenceStart; + } + currentChunk += sentence; + } + } + + // Add final chunk + if (currentChunk.length > 0) { + chunks.push(currentChunk.trim()); + metadata.push({ + index, + startChar: currentStartChar, + endChar: currentStartChar + currentChunk.length, + length: currentChunk.trim().length, + }); + } + + return { chunks, metadata }; + } + + /** + * Paragraph-based chunking that respects paragraph boundaries + */ + private chunkByParagraph( + text: string, + chunkSize: number, + chunkOverlap: number + ): { chunks: string[]; metadata: ChunkMetadata[] } { + // Split by paragraph boundaries (double newline or more) + const paragraphs = text.split(/\n\s*\n/).filter((p) => p.trim().length > 0); + const chunks: string[] = []; + const metadata: ChunkMetadata[] = []; + let currentChunk = ""; + let currentStartChar = 0; + let index = 0; + let charPosition = 0; + + for (let i = 0; i < paragraphs.length; i++) { + const paragraph = paragraphs[i].trim(); + const paragraphStart = text.indexOf(paragraph, charPosition); + charPosition = paragraphStart + paragraph.length; + + // If adding this paragraph would exceed chunkSize, save current chunk + if (currentChunk.length > 0 && currentChunk.length + paragraph.length + 2 > chunkSize) { + chunks.push(currentChunk.trim()); + metadata.push({ + index, + startChar: currentStartChar, + endChar: currentStartChar + currentChunk.length, + length: currentChunk.trim().length, + }); + index++; + + // Start new chunk with overlap + if (chunkOverlap > 0 && i > 0) { + // Include previous paragraph(s) for overlap + let overlapText = ""; + let j = i - 1; + while (j >= 0 && overlapText.length < chunkOverlap) { + overlapText = paragraphs[j].trim() + "\n\n" + overlapText; + j--; + } + currentChunk = overlapText + paragraph; + currentStartChar = paragraphStart - overlapText.length; + } else { + currentChunk = paragraph; + currentStartChar = paragraphStart; + } + } else { + if (currentChunk.length === 0) { + currentStartChar = paragraphStart; + currentChunk = paragraph; + } else { + currentChunk += "\n\n" + paragraph; + } + } + } + + // Add final chunk + if (currentChunk.length > 0) { + chunks.push(currentChunk.trim()); + metadata.push({ + index, + startChar: currentStartChar, + endChar: currentStartChar + currentChunk.length, + length: currentChunk.trim().length, + }); + } + + return { chunks, metadata }; + } +} + +TaskRegistry.registerTask(TextChunkerTask); + +export const textChunker = (input: TextChunkerTaskInput, config?: JobQueueTaskConfig) => { + return new TextChunkerTask({} as TextChunkerTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + textChunker: CreateWorkflow; + } +} + +Workflow.prototype.textChunker = CreateWorkflow(TextChunkerTask); diff --git a/packages/ai/src/task/TextClassificationTask.ts b/packages/ai/src/task/TextClassificationTask.ts index 170d8686..0b7c6215 100644 --- a/packages/ai/src/task/TextClassificationTask.ts +++ b/packages/ai/src/task/TextClassificationTask.ts @@ -7,18 +7,18 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextClassificationTask")); +const modelSchema = TypeModel("model:TextClassificationTask"); export const TextClassificationInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text to classify", - }), + }, candidateLabels: { type: "array", items: { @@ -75,12 +75,6 @@ export const TextClassificationOutputSchema = { export type TextClassificationTaskInput = FromSchema; export type TextClassificationTaskOutput = FromSchema; -export type TextClassificationTaskExecuteInput = DeReplicateFromSchema< - typeof TextClassificationInputSchema ->; -export type TextClassificationTaskExecuteOutput = DeReplicateFromSchema< - typeof TextClassificationOutputSchema ->; /** * Classifies text into categories using language models. @@ -115,7 +109,7 @@ export const textClassification = ( input: TextClassificationTaskInput, config?: JobQueueTaskConfig ) => { - return new TextClassificationTask(input, config).run(); + return new TextClassificationTask({} as TextClassificationTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextEmbeddingTask.ts b/packages/ai/src/task/TextEmbeddingTask.ts index 9f50e9a7..252d003e 100644 --- a/packages/ai/src/task/TextEmbeddingTask.ts +++ b/packages/ai/src/task/TextEmbeddingTask.ts @@ -5,26 +5,25 @@ */ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; -import { DataPortSchema, FromSchema } from "@workglow/util"; -import { AiTask } from "./base/AiTask"; import { - DeReplicateFromSchema, + DataPortSchema, + FromSchema, TypedArraySchema, TypedArraySchemaOptions, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +} from "@workglow/util"; +import { AiTask } from "./base/AiTask"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextEmbeddingTask")); +const modelSchema = TypeModel("model:TextEmbeddingTask"); export const TextEmbeddingInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text to embed", - }), + }, model: modelSchema, }, required: ["text", "model"], @@ -34,12 +33,10 @@ export const TextEmbeddingInputSchema = { export const TextEmbeddingOutputSchema = { type: "object", properties: { - vector: TypeReplicateArray( - TypedArraySchema({ - title: "Vector", - description: "The vector embedding of the text", - }) - ), + vector: TypedArraySchema({ + title: "Vector", + description: "The vector embedding of the text", + }), }, required: ["vector"], additionalProperties: false, @@ -53,10 +50,6 @@ export type TextEmbeddingTaskOutput = FromSchema< typeof TextEmbeddingOutputSchema, TypedArraySchemaOptions >; -export type TextEmbeddingTaskExecuteInput = DeReplicateFromSchema; -export type TextEmbeddingTaskExecuteOutput = DeReplicateFromSchema< - typeof TextEmbeddingOutputSchema ->; /** * A task that generates vector embeddings for text using a specified embedding model. @@ -86,7 +79,7 @@ TaskRegistry.registerTask(TextEmbeddingTask); * @returns Promise resolving to the generated embeddings */ export const textEmbedding = async (input: TextEmbeddingTaskInput, config?: JobQueueTaskConfig) => { - return new TextEmbeddingTask(input, config).run(); + return new TextEmbeddingTask({} as TextEmbeddingTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextFillMaskTask.ts b/packages/ai/src/task/TextFillMaskTask.ts index dbec9052..a308c0c9 100644 --- a/packages/ai/src/task/TextFillMaskTask.ts +++ b/packages/ai/src/task/TextFillMaskTask.ts @@ -7,18 +7,18 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextFillMaskTask")); +const modelSchema = TypeModel("model:TextFillMaskTask"); export const TextFillMaskInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text with a mask token to fill", - }), + }, model: modelSchema, }, required: ["text", "model"], @@ -62,8 +62,6 @@ export const TextFillMaskOutputSchema = { export type TextFillMaskTaskInput = FromSchema; export type TextFillMaskTaskOutput = FromSchema; -export type TextFillMaskTaskExecuteInput = DeReplicateFromSchema; -export type TextFillMaskTaskExecuteOutput = DeReplicateFromSchema; /** * Fills masked tokens in text using language models @@ -90,7 +88,7 @@ TaskRegistry.registerTask(TextFillMaskTask); * @returns Promise resolving to the predicted tokens with scores and complete sequences */ export const textFillMask = (input: TextFillMaskTaskInput, config?: JobQueueTaskConfig) => { - return new TextFillMaskTask(input, config).run(); + return new TextFillMaskTask({} as TextFillMaskTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextGenerationTask.ts b/packages/ai/src/task/TextGenerationTask.ts index 03f1eb00..0bea2a0c 100644 --- a/packages/ai/src/task/TextGenerationTask.ts +++ b/packages/ai/src/task/TextGenerationTask.ts @@ -7,7 +7,7 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; const generatedTextSchema = { type: "string", @@ -15,17 +15,17 @@ const generatedTextSchema = { description: "The generated text", } as const; -const modelSchema = TypeReplicateArray(TypeModel("model:TextGenerationTask")); +const modelSchema = TypeModel("model:TextGenerationTask"); export const TextGenerationInputSchema = { type: "object", properties: { model: modelSchema, - prompt: TypeReplicateArray({ + prompt: { type: "string", title: "Prompt", description: "The prompt to generate text from", - }), + }, maxTokens: { type: "number", title: "Max Tokens", @@ -74,11 +74,7 @@ export const TextGenerationInputSchema = { export const TextGenerationOutputSchema = { type: "object", properties: { - text: { - oneOf: [generatedTextSchema, { type: "array", items: generatedTextSchema }], - title: generatedTextSchema.title, - description: generatedTextSchema.description, - }, + text: generatedTextSchema, }, required: ["text"], additionalProperties: false, @@ -86,12 +82,6 @@ export const TextGenerationOutputSchema = { export type TextGenerationTaskInput = FromSchema; export type TextGenerationTaskOutput = FromSchema; -export type TextGenerationTaskExecuteInput = DeReplicateFromSchema< - typeof TextGenerationInputSchema ->; -export type TextGenerationTaskExecuteOutput = DeReplicateFromSchema< - typeof TextGenerationOutputSchema ->; export class TextGenerationTask extends AiTask< TextGenerationTaskInput, @@ -116,7 +106,7 @@ TaskRegistry.registerTask(TextGenerationTask); * Task for generating text using a language model */ export const textGeneration = (input: TextGenerationTaskInput, config?: JobQueueTaskConfig) => { - return new TextGenerationTask(input, config).run(); + return new TextGenerationTask({} as TextGenerationTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextLanguageDetectionTask.ts b/packages/ai/src/task/TextLanguageDetectionTask.ts index 59cdb8b9..c12c6c34 100644 --- a/packages/ai/src/task/TextLanguageDetectionTask.ts +++ b/packages/ai/src/task/TextLanguageDetectionTask.ts @@ -7,18 +7,18 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextLanguageDetectionTask")); +const modelSchema = TypeModel("model:TextLanguageDetectionTask"); export const TextLanguageDetectionInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text to detect the language of", - }), + }, maxLanguages: { type: "number", minimum: 0, @@ -100,12 +100,6 @@ export const TextLanguageDetectionOutputSchema = { export type TextLanguageDetectionTaskInput = FromSchema; export type TextLanguageDetectionTaskOutput = FromSchema; -export type TextLanguageDetectionTaskExecuteInput = DeReplicateFromSchema< - typeof TextLanguageDetectionInputSchema ->; -export type TextLanguageDetectionTaskExecuteOutput = DeReplicateFromSchema< - typeof TextLanguageDetectionOutputSchema ->; /** * Detects the language of text using language models @@ -138,7 +132,7 @@ export const textLanguageDetection = ( input: TextLanguageDetectionTaskInput, config?: JobQueueTaskConfig ) => { - return new TextLanguageDetectionTask(input, config).run(); + return new TextLanguageDetectionTask({} as TextLanguageDetectionTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextNamedEntityRecognitionTask.ts b/packages/ai/src/task/TextNamedEntityRecognitionTask.ts index 1a91ba95..b6b42dc7 100644 --- a/packages/ai/src/task/TextNamedEntityRecognitionTask.ts +++ b/packages/ai/src/task/TextNamedEntityRecognitionTask.ts @@ -7,18 +7,18 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextNamedEntityRecognitionTask")); +const modelSchema = TypeModel("model:TextNamedEntityRecognitionTask"); export const TextNamedEntityRecognitionInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text to extract named entities from", - }), + }, blockList: { type: "array", items: { @@ -76,12 +76,6 @@ export type TextNamedEntityRecognitionTaskInput = FromSchema< export type TextNamedEntityRecognitionTaskOutput = FromSchema< typeof TextNamedEntityRecognitionOutputSchema >; -export type TextNamedEntityRecognitionTaskExecuteInput = DeReplicateFromSchema< - typeof TextNamedEntityRecognitionInputSchema ->; -export type TextNamedEntityRecognitionTaskExecuteOutput = DeReplicateFromSchema< - typeof TextNamedEntityRecognitionOutputSchema ->; /** * Extracts named entities from text using language models @@ -114,7 +108,7 @@ export const textNamedEntityRecognition = ( input: TextNamedEntityRecognitionTaskInput, config?: JobQueueTaskConfig ) => { - return new TextNamedEntityRecognitionTask(input, config).run(); + return new TextNamedEntityRecognitionTask({} as TextNamedEntityRecognitionTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextQuestionAnswerTask.ts b/packages/ai/src/task/TextQuestionAnswerTask.ts index c2100ee2..f928eb98 100644 --- a/packages/ai/src/task/TextQuestionAnswerTask.ts +++ b/packages/ai/src/task/TextQuestionAnswerTask.ts @@ -7,7 +7,7 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; const contextSchema = { type: "string", @@ -27,13 +27,13 @@ const textSchema = { description: "The generated text", } as const; -const modelSchema = TypeReplicateArray(TypeModel("model:TextQuestionAnswerTask")); +const modelSchema = TypeModel("model:TextQuestionAnswerTask"); export const TextQuestionAnswerInputSchema = { type: "object", properties: { - context: TypeReplicateArray(contextSchema), - question: TypeReplicateArray(questionSchema), + context: contextSchema, + question: questionSchema, model: modelSchema, }, required: ["context", "question", "model"], @@ -43,11 +43,7 @@ export const TextQuestionAnswerInputSchema = { export const TextQuestionAnswerOutputSchema = { type: "object", properties: { - text: { - oneOf: [textSchema, { type: "array", items: textSchema }], - title: textSchema.title, - description: textSchema.description, - }, + text: textSchema, }, required: ["text"], additionalProperties: false, @@ -55,12 +51,6 @@ export const TextQuestionAnswerOutputSchema = { export type TextQuestionAnswerTaskInput = FromSchema; export type TextQuestionAnswerTaskOutput = FromSchema; -export type TextQuestionAnswerTaskExecuteInput = DeReplicateFromSchema< - typeof TextQuestionAnswerInputSchema ->; -export type TextQuestionAnswerTaskExecuteOutput = DeReplicateFromSchema< - typeof TextQuestionAnswerOutputSchema ->; /** * This is a special case of text generation that takes a context and a question @@ -94,7 +84,7 @@ export const textQuestionAnswer = ( input: TextQuestionAnswerTaskInput, config?: JobQueueTaskConfig ) => { - return new TextQuestionAnswerTask(input, config).run(); + return new TextQuestionAnswerTask({} as TextQuestionAnswerTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextRewriterTask.ts b/packages/ai/src/task/TextRewriterTask.ts index 7d031d91..275207c8 100644 --- a/packages/ai/src/task/TextRewriterTask.ts +++ b/packages/ai/src/task/TextRewriterTask.ts @@ -7,23 +7,23 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextRewriterTask")); +const modelSchema = TypeModel("model:TextRewriterTask"); export const TextRewriterInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text to rewrite", - }), - prompt: TypeReplicateArray({ + }, + prompt: { type: "string", title: "Prompt", description: "The prompt to direct the rewriting", - }), + }, model: modelSchema, }, required: ["text", "prompt", "model"], @@ -45,8 +45,6 @@ export const TextRewriterOutputSchema = { export type TextRewriterTaskInput = FromSchema; export type TextRewriterTaskOutput = FromSchema; -export type TextRewriterTaskExecuteInput = DeReplicateFromSchema; -export type TextRewriterTaskExecuteOutput = DeReplicateFromSchema; /** * This is a special case of text generation that takes a prompt and text to rewrite @@ -73,7 +71,7 @@ TaskRegistry.registerTask(TextRewriterTask); * @returns Promise resolving to the rewritten text output(s) */ export const textRewriter = (input: TextRewriterTaskInput, config?: JobQueueTaskConfig) => { - return new TextRewriterTask(input, config).run(); + return new TextRewriterTask({} as TextRewriterTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextSummaryTask.ts b/packages/ai/src/task/TextSummaryTask.ts index 675643a1..fdff1f52 100644 --- a/packages/ai/src/task/TextSummaryTask.ts +++ b/packages/ai/src/task/TextSummaryTask.ts @@ -7,18 +7,18 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextSummaryTask")); +const modelSchema = TypeModel("model:TextSummaryTask"); export const TextSummaryInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text to summarize", - }), + }, model: modelSchema, }, required: ["text", "model"], @@ -40,8 +40,6 @@ export const TextSummaryOutputSchema = { export type TextSummaryTaskInput = FromSchema; export type TextSummaryTaskOutput = FromSchema; -export type TextSummaryTaskExecuteInput = DeReplicateFromSchema; -export type TextSummaryTaskExecuteOutput = DeReplicateFromSchema; /** * This summarizes a piece of text @@ -70,7 +68,7 @@ TaskRegistry.registerTask(TextSummaryTask); * @returns Promise resolving to the summarized text output(s) */ export const textSummary = async (input: TextSummaryTaskInput, config?: JobQueueTaskConfig) => { - return new TextSummaryTask(input, config).run(); + return new TextSummaryTask({} as TextSummaryTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TextTranslationTask.ts b/packages/ai/src/task/TextTranslationTask.ts index 988c682e..f5a3c8f1 100644 --- a/packages/ai/src/task/TextTranslationTask.ts +++ b/packages/ai/src/task/TextTranslationTask.ts @@ -7,14 +7,9 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { - DeReplicateFromSchema, - TypeLanguage, - TypeModel, - TypeReplicateArray, -} from "./base/AiTaskSchemas"; +import { TypeLanguage, TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model:TextTranslationTask")); +const modelSchema = TypeModel("model:TextTranslationTask"); const translationTextSchema = { type: "string", @@ -25,27 +20,23 @@ const translationTextSchema = { export const TextTranslationInputSchema = { type: "object", properties: { - text: TypeReplicateArray({ + text: { type: "string", title: "Text", description: "The text to translate", + }, + source_lang: TypeLanguage({ + title: "Source Language", + description: "The source language", + minLength: 2, + maxLength: 2, + }), + target_lang: TypeLanguage({ + title: "Target Language", + description: "The target language", + minLength: 2, + maxLength: 2, }), - source_lang: TypeReplicateArray( - TypeLanguage({ - title: "Source Language", - description: "The source language", - minLength: 2, - maxLength: 2, - }) - ), - target_lang: TypeReplicateArray( - TypeLanguage({ - title: "Target Language", - description: "The target language", - minLength: 2, - maxLength: 2, - }) - ), model: modelSchema, }, required: ["text", "source_lang", "target_lang", "model"], @@ -55,11 +46,7 @@ export const TextTranslationInputSchema = { export const TextTranslationOutputSchema = { type: "object", properties: { - text: { - oneOf: [translationTextSchema, { type: "array", items: translationTextSchema }], - title: translationTextSchema.title, - description: translationTextSchema.description, - }, + text: translationTextSchema, target_lang: TypeLanguage({ title: "Output Language", description: "The output language", @@ -73,12 +60,6 @@ export const TextTranslationOutputSchema = { export type TextTranslationTaskInput = FromSchema; export type TextTranslationTaskOutput = FromSchema; -export type TextTranslationTaskExecuteInput = DeReplicateFromSchema< - typeof TextTranslationInputSchema ->; -export type TextTranslationTaskExecuteOutput = DeReplicateFromSchema< - typeof TextTranslationOutputSchema ->; /** * This translates text from one language to another @@ -108,7 +89,7 @@ TaskRegistry.registerTask(TextTranslationTask); * @returns Promise resolving to the translated text output(s) */ export const textTranslation = (input: TextTranslationTaskInput, config?: JobQueueTaskConfig) => { - return new TextTranslationTask(input, config).run(); + return new TextTranslationTask({} as TextTranslationTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/TopicSegmenterTask.ts b/packages/ai/src/task/TopicSegmenterTask.ts new file mode 100644 index 00000000..415fc55e --- /dev/null +++ b/packages/ai/src/task/TopicSegmenterTask.ts @@ -0,0 +1,439 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { DataPortSchema, FromSchema } from "@workglow/util"; + +export const SegmentationMethod = { + HEURISTIC: "heuristic", + EMBEDDING_SIMILARITY: "embedding-similarity", + HYBRID: "hybrid", +} as const; + +export type SegmentationMethod = (typeof SegmentationMethod)[keyof typeof SegmentationMethod]; + +const inputSchema = { + type: "object", + properties: { + text: { + type: "string", + title: "Text", + description: "The text to segment into topics", + }, + method: { + type: "string", + enum: Object.values(SegmentationMethod), + title: "Segmentation Method", + description: "Method to use for topic segmentation", + default: SegmentationMethod.HEURISTIC, + }, + minSegmentSize: { + type: "number", + title: "Min Segment Size", + description: "Minimum segment size in characters", + minimum: 50, + default: 100, + }, + maxSegmentSize: { + type: "number", + title: "Max Segment Size", + description: "Maximum segment size in characters", + minimum: 100, + default: 2000, + }, + similarityThreshold: { + type: "number", + title: "Similarity Threshold", + description: "Threshold for embedding similarity (0-1, lower = more splits)", + minimum: 0, + maximum: 1, + default: 0.5, + }, + }, + required: ["text"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + segments: { + type: "array", + items: { + type: "object", + properties: { + text: { type: "string" }, + startOffset: { type: "number" }, + endOffset: { type: "number" }, + }, + required: ["text", "startOffset", "endOffset"], + additionalProperties: false, + }, + title: "Segments", + description: "Detected topic segments", + }, + count: { + type: "number", + title: "Count", + description: "Number of segments detected", + }, + }, + required: ["segments", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type TopicSegmenterTaskInput = FromSchema; +export type TopicSegmenterTaskOutput = FromSchema; + +/** + * Task for segmenting text into topic-based sections + * Uses hybrid approach: heuristics + optional embedding similarity + */ +export class TopicSegmenterTask extends Task< + TopicSegmenterTaskInput, + TopicSegmenterTaskOutput, + JobQueueTaskConfig +> { + public static type = "TopicSegmenterTask"; + public static category = "Document"; + public static title = "Topic Segmenter"; + public static description = "Segment text into topic-based sections using hybrid approach"; + public static cacheable = true; + private static readonly EMBEDDING_DIMENSIONS = 256; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: TopicSegmenterTaskInput, + context: IExecuteContext + ): Promise { + const { + text, + method = SegmentationMethod.HEURISTIC, + minSegmentSize = 100, + maxSegmentSize = 2000, + similarityThreshold = 0.5, + } = input; + + let segments: Array<{ text: string; startOffset: number; endOffset: number }>; + + switch (method) { + case SegmentationMethod.EMBEDDING_SIMILARITY: + segments = this.embeddingSegmentation( + text, + minSegmentSize, + maxSegmentSize, + similarityThreshold + ); + break; + case SegmentationMethod.HYBRID: + // Start with heuristic, optionally refine with embeddings + segments = this.heuristicSegmentation(text, minSegmentSize, maxSegmentSize); + // TODO: Add embedding refinement step + break; + case SegmentationMethod.HEURISTIC: + default: + segments = this.heuristicSegmentation(text, minSegmentSize, maxSegmentSize); + break; + } + + return { + segments, + count: segments.length, + }; + } + + /** + * Embedding-based segmentation using hashed token vectors and cosine similarity + */ + private embeddingSegmentation( + text: string, + minSegmentSize: number, + maxSegmentSize: number, + similarityThreshold: number + ): Array<{ text: string; startOffset: number; endOffset: number }> { + const paragraphs = this.splitIntoParagraphs(text); + if (paragraphs.length === 0) { + return []; + } + + const embeddings = paragraphs.map((p) => + this.embedParagraph(p.text, TopicSegmenterTask.EMBEDDING_DIMENSIONS) + ); + + const segments: Array<{ text: string; startOffset: number; endOffset: number }> = []; + let currentSegmentParagraphs: Array<{ text: string; offset: number }> = []; + let currentSegmentSize = 0; + + for (let i = 0; i < paragraphs.length; i++) { + const paragraph = paragraphs[i]; + const paragraphSize = paragraph.text.length; + const exceedsMax = + currentSegmentSize + paragraphSize > maxSegmentSize && currentSegmentSize >= minSegmentSize; + + let shouldSplit = false; + if (i > 0 && currentSegmentSize >= minSegmentSize) { + const prev = embeddings[i - 1]; + const curr = embeddings[i]; + const similarity = this.cosineSimilarityWithNorms( + prev.vector, + prev.norm, + curr.vector, + curr.norm + ); + shouldSplit = similarity < similarityThreshold; + } + + if ((exceedsMax || shouldSplit) && currentSegmentParagraphs.length > 0) { + segments.push(this.createSegment(currentSegmentParagraphs)); + currentSegmentParagraphs = []; + currentSegmentSize = 0; + } + + currentSegmentParagraphs.push(paragraph); + currentSegmentSize += paragraphSize; + } + + if (currentSegmentParagraphs.length > 0) { + segments.push(this.createSegment(currentSegmentParagraphs)); + } + + return this.mergeSmallSegments(segments, minSegmentSize); + } + + /** + * Heuristic segmentation based on paragraph breaks and transition markers + */ + private heuristicSegmentation( + text: string, + minSegmentSize: number, + maxSegmentSize: number + ): Array<{ text: string; startOffset: number; endOffset: number }> { + const segments: Array<{ text: string; startOffset: number; endOffset: number }> = []; + + // Split by double newlines (paragraph breaks) + const paragraphs = this.splitIntoParagraphs(text); + + let currentSegmentParagraphs: Array<{ text: string; offset: number }> = []; + let currentSegmentSize = 0; + + for (const paragraph of paragraphs) { + const paragraphSize = paragraph.text.length; + + // Check if adding this paragraph would exceed max size + if ( + currentSegmentSize + paragraphSize > maxSegmentSize && + currentSegmentSize >= minSegmentSize + ) { + // Flush current segment + if (currentSegmentParagraphs.length > 0) { + const segment = this.createSegment(currentSegmentParagraphs); + segments.push(segment); + currentSegmentParagraphs = []; + currentSegmentSize = 0; + } + } + + // Check for transition markers + const hasTransition = this.hasTransitionMarker(paragraph.text); + if ( + hasTransition && + currentSegmentSize >= minSegmentSize && + currentSegmentParagraphs.length > 0 + ) { + // Flush current segment before transition + const segment = this.createSegment(currentSegmentParagraphs); + segments.push(segment); + currentSegmentParagraphs = []; + currentSegmentSize = 0; + } + + currentSegmentParagraphs.push(paragraph); + currentSegmentSize += paragraphSize; + } + + // Flush remaining segment + if (currentSegmentParagraphs.length > 0) { + const segment = this.createSegment(currentSegmentParagraphs); + segments.push(segment); + } + + // Merge small segments + return this.mergeSmallSegments(segments, minSegmentSize); + } + + /** + * Create a hashed token embedding for fast similarity checks + */ + private embedParagraph(text: string, dimensions: number): { vector: Float32Array; norm: number } { + const vector = new Float32Array(dimensions); + const tokens = text.toLowerCase().match(/[a-z0-9]+/g); + if (!tokens) { + return { vector, norm: 0 }; + } + + for (const token of tokens) { + let hash = 2166136261; + for (let i = 0; i < token.length; i++) { + hash ^= token.charCodeAt(i); + hash = Math.imul(hash, 16777619); + } + const index = (hash >>> 0) % dimensions; + vector[index] += 1; + } + + let sumSquares = 0; + for (let i = 0; i < vector.length; i++) { + const value = vector[i]; + sumSquares += value * value; + } + + return { vector, norm: sumSquares > 0 ? Math.sqrt(sumSquares) : 0 }; + } + + private cosineSimilarityWithNorms( + a: Float32Array, + aNorm: number, + b: Float32Array, + bNorm: number + ): number { + if (aNorm === 0 || bNorm === 0) { + return 0; + } + + let dot = 0; + for (let i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + } + + return dot / (aNorm * bNorm); + } + + /** + * Split text into paragraphs with offsets + */ + private splitIntoParagraphs(text: string): Array<{ text: string; offset: number }> { + const paragraphs: Array<{ text: string; offset: number }> = []; + const splits = text.split(/\n\s*\n/); + + let currentOffset = 0; + for (const split of splits) { + const trimmed = split.trim(); + if (trimmed.length > 0) { + const offset = text.indexOf(trimmed, currentOffset); + paragraphs.push({ text: trimmed, offset }); + currentOffset = offset + trimmed.length; + } + } + + return paragraphs; + } + + /** + * Check if paragraph contains transition markers + */ + private hasTransitionMarker(text: string): boolean { + const transitionMarkers = [ + /^(however|therefore|thus|consequently|in conclusion|in summary|furthermore|moreover|additionally|meanwhile|nevertheless|on the other hand)/i, + /^(first|second|third|finally|lastly)/i, + /^\d+\./, // Numbered list + ]; + + return transitionMarkers.some((pattern) => pattern.test(text)); + } + + /** + * Create a segment from paragraphs + */ + private createSegment(paragraphs: Array<{ text: string; offset: number }>): { + text: string; + startOffset: number; + endOffset: number; + } { + const text = paragraphs.map((p) => p.text).join("\n\n"); + const startOffset = paragraphs[0].offset; + const endOffset = + paragraphs[paragraphs.length - 1].offset + paragraphs[paragraphs.length - 1].text.length; + + return { text, startOffset, endOffset }; + } + + /** + * Merge segments that are too small + */ + private mergeSmallSegments( + segments: Array<{ text: string; startOffset: number; endOffset: number }>, + minSegmentSize: number + ): Array<{ text: string; startOffset: number; endOffset: number }> { + if (segments.length <= 1) { + return segments; + } + + const merged: Array<{ text: string; startOffset: number; endOffset: number }> = []; + let i = 0; + + while (i < segments.length) { + const current = segments[i]; + + if (current.text.length < minSegmentSize && i + 1 < segments.length) { + // Merge with next + const next = segments[i + 1]; + const mergedSegment = { + text: current.text + "\n\n" + next.text, + startOffset: current.startOffset, + endOffset: next.endOffset, + }; + merged.push(mergedSegment); + i += 2; + } else if (current.text.length < minSegmentSize && merged.length > 0) { + // Merge with previous + const previous = merged[merged.length - 1]; + merged[merged.length - 1] = { + text: previous.text + "\n\n" + current.text, + startOffset: previous.startOffset, + endOffset: current.endOffset, + }; + i++; + } else { + merged.push(current); + i++; + } + } + + return merged; + } +} + +TaskRegistry.registerTask(TopicSegmenterTask); + +export const topicSegmenter = (input: TopicSegmenterTaskInput, config?: JobQueueTaskConfig) => { + return new TopicSegmenterTask({} as TopicSegmenterTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + topicSegmenter: CreateWorkflow< + TopicSegmenterTaskInput, + TopicSegmenterTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.topicSegmenter = CreateWorkflow(TopicSegmenterTask); diff --git a/packages/ai/src/task/UnloadModelTask.ts b/packages/ai/src/task/UnloadModelTask.ts index 8a027d7b..3398f267 100644 --- a/packages/ai/src/task/UnloadModelTask.ts +++ b/packages/ai/src/task/UnloadModelTask.ts @@ -7,9 +7,9 @@ import { CreateWorkflow, JobQueueTaskConfig, TaskRegistry, Workflow } from "@workglow/task-graph"; import { DataPortSchema, FromSchema } from "@workglow/util"; import { AiTask } from "./base/AiTask"; -import { DeReplicateFromSchema, TypeModel, TypeReplicateArray } from "./base/AiTaskSchemas"; +import { TypeModel } from "./base/AiTaskSchemas"; -const modelSchema = TypeReplicateArray(TypeModel("model")); +const modelSchema = TypeModel("model"); const UnloadModelInputSchema = { type: "object", @@ -31,8 +31,6 @@ const UnloadModelOutputSchema = { export type UnloadModelTaskRunInput = FromSchema; export type UnloadModelTaskRunOutput = FromSchema; -export type UnloadModelTaskExecuteInput = DeReplicateFromSchema; -export type UnloadModelTaskExecuteOutput = DeReplicateFromSchema; /** * Unload a model from memory and clear its cache. @@ -67,7 +65,7 @@ TaskRegistry.registerTask(UnloadModelTask); * @returns Promise resolving to the unloaded model(s) */ export const unloadModel = (input: UnloadModelTaskRunInput, config?: JobQueueTaskConfig) => { - return new UnloadModelTask(input, config).run(); + return new UnloadModelTask({} as UnloadModelTaskRunInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/ai/src/task/VectorQuantizeTask.ts b/packages/ai/src/task/VectorQuantizeTask.ts new file mode 100644 index 00000000..9ed102dc --- /dev/null +++ b/packages/ai/src/task/VectorQuantizeTask.ts @@ -0,0 +1,257 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CreateWorkflow, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { + DataPortSchema, + FromSchema, + normalizeNumberArray, + TensorType, + TypedArray, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; + +const inputSchema = { + type: "object", + properties: { + vector: { + anyOf: [ + TypedArraySchema({ + title: "Vector", + description: "The vector to quantize", + }), + { + type: "array", + items: TypedArraySchema({ + title: "Vector", + description: "Vector to quantize", + }), + }, + ], + title: "Input Vector(s)", + description: "Vector or array of vectors to quantize", + }, + targetType: { + type: "string", + enum: Object.values(TensorType), + title: "Target Type", + description: "Target quantization type", + default: TensorType.INT8, + }, + normalize: { + type: "boolean", + title: "Normalize", + description: "Normalize vector before quantization", + default: true, + }, + }, + required: ["vector", "targetType"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + vector: { + anyOf: [ + TypedArraySchema({ + title: "Quantized Vector", + description: "The quantized vector", + }), + { + type: "array", + items: TypedArraySchema({ + title: "Quantized Vector", + description: "Quantized vector", + }), + }, + ], + title: "Output Vector(s)", + description: "Quantized vector or array of vectors", + }, + originalType: { + type: "string", + enum: Object.values(TensorType), + title: "Original Type", + description: "Original vector type", + }, + targetType: { + type: "string", + enum: Object.values(TensorType), + title: "Target Type", + description: "Target quantization type", + }, + }, + required: ["vector", "originalType", "targetType"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type VectorQuantizeTaskInput = FromSchema; +export type VectorQuantizeTaskOutput = FromSchema; + +/** + * Task for quantizing vectors to reduce storage and improve performance. + * Supports various quantization types including binary, int8, uint8, int16, uint16. + */ +export class VectorQuantizeTask extends Task< + VectorQuantizeTaskInput, + VectorQuantizeTaskOutput, + JobQueueTaskConfig +> { + public static type = "VectorQuantizeTask"; + public static category = "Vector Processing"; + public static title = "Quantize Vector"; + public static description = "Quantize vectors to reduce storage and improve performance"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async executeReactive(input: VectorQuantizeTaskInput): Promise { + const { vector, targetType, normalize = true } = input; + const isArray = Array.isArray(vector); + const vectors = isArray ? vector : [vector]; + const originalType = this.getVectorType(vectors[0]); + + const quantized = vectors.map((v) => this.vectorQuantize(v, targetType, normalize)); + + return { + vector: isArray ? quantized : quantized[0], + originalType, + targetType, + }; + } + + private getVectorType(vector: TypedArray): TensorType { + if (vector instanceof Float16Array) return TensorType.FLOAT16; + if (vector instanceof Float32Array) return TensorType.FLOAT32; + if (vector instanceof Float64Array) return TensorType.FLOAT64; + if (vector instanceof Int8Array) return TensorType.INT8; + if (vector instanceof Uint8Array) return TensorType.UINT8; + if (vector instanceof Int16Array) return TensorType.INT16; + if (vector instanceof Uint16Array) return TensorType.UINT16; + throw new Error(`Unknown vector type: ${typeof vector}`); + } + + private vectorQuantize( + vector: TypedArray, + targetType: TensorType, + normalize: boolean + ): TypedArray { + let values = Array.from(vector) as number[]; + + // Normalize if requested + if (normalize) { + values = normalizeNumberArray(values, false); + } + + switch (targetType) { + case TensorType.FLOAT16: + return new Float16Array(values); + + case TensorType.FLOAT32: + return new Float32Array(values); + + case TensorType.FLOAT64: + return new Float64Array(values); + + case TensorType.INT8: + return this.quantizeToInt8(values); + + case TensorType.UINT8: + return this.quantizeToUint8(values); + + case TensorType.INT16: + return this.quantizeToInt16(values); + + case TensorType.UINT16: + return this.quantizeToUint16(values); + + default: + return new Float32Array(values); + } + } + + /** + * Find min and max values in a single pass for better performance + */ + private findMinMax(values: number[]): { min: number; max: number } { + if (values.length === 0) { + return { min: 0, max: 1 }; + } + + let min = values[0]; + let max = values[0]; + + for (let i = 1; i < values.length; i++) { + const val = values[i]; + if (val < min) min = val; + if (val > max) max = val; + } + + return { min, max }; + } + + private quantizeToInt8(values: number[]): Int8Array { + // Assume values are in [-1, 1] range after normalization + // Scale to [-127, 127] to avoid overflow at -128 + return new Int8Array(values.map((v) => Math.round(Math.max(-1, Math.min(1, v)) * 127))); + } + + private quantizeToUint8(values: number[]): Uint8Array { + // Find min/max for scaling in a single pass + const { min, max } = this.findMinMax(values); + const range = max - min || 1; + + // Scale to [0, 255] + return new Uint8Array(values.map((v) => Math.round(((v - min) / range) * 255))); + } + + private quantizeToInt16(values: number[]): Int16Array { + // Assume values are in [-1, 1] range after normalization + // Scale to [-32767, 32767] + return new Int16Array(values.map((v) => Math.round(Math.max(-1, Math.min(1, v)) * 32767))); + } + + private quantizeToUint16(values: number[]): Uint16Array { + // Find min/max for scaling in a single pass + const { min, max } = this.findMinMax(values); + const range = max - min || 1; + + // Scale to [0, 65535] + return new Uint16Array(values.map((v) => Math.round(((v - min) / range) * 65535))); + } +} + +TaskRegistry.registerTask(VectorQuantizeTask); + +export const vectorQuantize = (input: VectorQuantizeTaskInput, config?: JobQueueTaskConfig) => { + return new VectorQuantizeTask({} as VectorQuantizeTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + vectorQuantize: CreateWorkflow< + VectorQuantizeTaskInput, + VectorQuantizeTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.vectorQuantize = CreateWorkflow(VectorQuantizeTask); diff --git a/packages/ai/src/task/VectorSimilarityTask.ts b/packages/ai/src/task/VectorSimilarityTask.ts index 48898ba7..cece5beb 100644 --- a/packages/ai/src/task/VectorSimilarityTask.ts +++ b/packages/ai/src/task/VectorSimilarityTask.ts @@ -5,15 +5,21 @@ */ import { - ArrayTask, CreateWorkflow, + GraphAsTask, JobQueueTaskConfig, - TaskError, TaskRegistry, Workflow, } from "@workglow/task-graph"; -import { DataPortSchema, FromSchema } from "@workglow/util"; -import { TypedArray, TypedArraySchema, TypedArraySchemaOptions } from "./base/AiTaskSchemas"; +import { + cosineSimilarity, + DataPortSchema, + FromSchema, + hammingSimilarity, + jaccardSimilarity, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; export const SimilarityFn = { COSINE: "cosine", @@ -21,6 +27,12 @@ export const SimilarityFn = { HAMMING: "hamming", } as const; +const similarityFunctions = { + cosine: cosineSimilarity, + jaccard: jaccardSimilarity, + hamming: hammingSimilarity, +} as const; + export type SimilarityFn = (typeof SimilarityFn)[keyof typeof SimilarityFn]; const SimilarityInputSchema = { @@ -44,7 +56,7 @@ const SimilarityInputSchema = { minimum: 1, default: 10, }, - similarity: { + method: { type: "string", enum: Object.values(SimilarityFn), title: "Similarity 𝑓", @@ -52,7 +64,7 @@ const SimilarityInputSchema = { default: SimilarityFn.COSINE, }, }, - required: ["query", "input", "similarity"], + required: ["query", "input", "method"], additionalProperties: false, } as const satisfies DataPortSchema; @@ -88,7 +100,7 @@ export type VectorSimilarityTaskOutput = FromSchema< TypedArraySchemaOptions >; -export class VectorSimilarityTask extends ArrayTask< +export class VectorSimilarityTask extends GraphAsTask< VectorSimilarityTaskInput, VectorSimilarityTaskOutput, JobQueueTaskConfig @@ -107,15 +119,10 @@ export class VectorSimilarityTask extends ArrayTask< return SimilarityOutputSchema as DataPortSchema; } - // @ts-ignore (TODO: fix this) - async executeReactive( - { query, input, similarity, topK }: VectorSimilarityTaskInput, - oldOutput: VectorSimilarityTaskOutput - ) { + async executeReactive({ query, input, method, topK }: VectorSimilarityTaskInput) { let similarities = []; - const fns = { cosine }; - const fnName = similarity as keyof typeof fns; - const fn = fns[fnName]; + const fnName = method as keyof typeof similarityFunctions; + const fn = similarityFunctions[fnName]; for (const embedding of input) { similarities.push({ @@ -137,7 +144,7 @@ export class VectorSimilarityTask extends ArrayTask< TaskRegistry.registerTask(VectorSimilarityTask); export const similarity = (input: VectorSimilarityTaskInput, config?: JobQueueTaskConfig) => { - return new VectorSimilarityTask(input, config).run(); + return new VectorSimilarityTask({} as VectorSimilarityTaskInput, config).run(input); }; declare module "@workglow/task-graph" { @@ -151,41 +158,3 @@ declare module "@workglow/task-graph" { } Workflow.prototype.similarity = CreateWorkflow(VectorSimilarityTask); - -// =============================================================================== - -export function inner(arr1: TypedArray, arr2: TypedArray): number { - // @ts-ignore - return 1 - arr1.reduce((acc, val, i) => acc + val * arr2[i], 0); -} - -export function magnitude(arr: TypedArray) { - // @ts-ignore - return Math.sqrt(arr.reduce((acc, val) => acc + val * val, 0)); -} - -function cosine(arr1: TypedArray, arr2: TypedArray) { - const dotProduct = inner(arr1, arr2); - const magnitude1 = magnitude(arr1); - const magnitude2 = magnitude(arr2); - return 1 - dotProduct / (magnitude1 * magnitude2); -} - -export function normalize(vector: TypedArray): TypedArray { - const mag = magnitude(vector); - - if (mag === 0) { - throw new TaskError("Cannot normalize a zero vector."); - } - - const normalized = vector.map((val) => Number(val) / mag); - - if (vector instanceof Float64Array) { - return new Float64Array(normalized); - } - if (vector instanceof Float32Array) { - return new Float32Array(normalized); - } - // For integer arrays and bigint[], use Float32Array since normalization produces floats - return new Float32Array(normalized); -} diff --git a/packages/ai/src/task/VectorStoreSearchTask.ts b/packages/ai/src/task/VectorStoreSearchTask.ts new file mode 100644 index 00000000..8f629260 --- /dev/null +++ b/packages/ai/src/task/VectorStoreSearchTask.ts @@ -0,0 +1,173 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { IVectorRepository, TypeVectorRepository } from "@workglow/storage"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { + DataPortSchema, + FromSchema, + TypedArray, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; + +const inputSchema = { + type: "object", + properties: { + repository: TypeVectorRepository({ + title: "Vector Repository", + description: "The vector repository instance to search in", + }), + query: TypedArraySchema({ + title: "Query Vector", + description: "The query vector to search for similar vectors", + }), + topK: { + type: "number", + title: "Top K", + description: "Number of top results to return", + minimum: 1, + default: 10, + }, + filter: { + type: "object", + title: "Metadata Filter", + description: "Filter results by metadata fields", + }, + scoreThreshold: { + type: "number", + title: "Score Threshold", + description: "Minimum similarity score threshold (0-1)", + minimum: 0, + maximum: 1, + default: 0, + }, + }, + required: ["repository", "query"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + ids: { + type: "array", + items: { type: "string" }, + title: "IDs", + description: "IDs of matching vectors", + }, + vectors: { + type: "array", + items: TypedArraySchema({ + title: "Vector", + description: "Matching vector embedding", + }), + title: "Vectors", + description: "Matching vector embeddings", + }, + metadata: { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata of matching vector", + }, + title: "Metadata", + description: "Metadata of matching vectors", + }, + scores: { + type: "array", + items: { type: "number" }, + title: "Scores", + description: "Similarity scores for each result", + }, + count: { + type: "number", + title: "Count", + description: "Number of results returned", + }, + }, + required: ["ids", "vectors", "metadata", "scores", "count"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type VectorStoreSearchTaskInput = FromSchema; +export type VectorStoreSearchTaskOutput = FromSchema; + +/** + * Task for searching similar vectors in a vector repository. + * Returns top-K most similar vectors with their metadata and scores. + */ +export class VectorStoreSearchTask extends Task< + VectorStoreSearchTaskInput, + VectorStoreSearchTaskOutput, + JobQueueTaskConfig +> { + public static type = "VectorStoreSearchTask"; + public static category = "Vector Store"; + public static title = "Vector Store Search"; + public static description = "Search for similar vectors in a vector repository"; + public static cacheable = true; + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: VectorStoreSearchTaskInput, + context: IExecuteContext + ): Promise { + const { repository, query, topK = 10, filter, scoreThreshold = 0 } = input; + + const repo = repository as IVectorRepository; + + const results = await repo.similaritySearch(query, { + topK, + filter, + scoreThreshold, + }); + + return { + ids: results.map((r) => r.id), + vectors: results.map((r) => r.vector), + metadata: results.map((r) => r.metadata), + scores: results.map((r) => r.score), + count: results.length, + }; + } +} + +TaskRegistry.registerTask(VectorStoreSearchTask); + +export const vectorStoreSearch = ( + input: VectorStoreSearchTaskInput, + config?: JobQueueTaskConfig +) => { + return new VectorStoreSearchTask({} as VectorStoreSearchTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + vectorStoreSearch: CreateWorkflow< + VectorStoreSearchTaskInput, + VectorStoreSearchTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.vectorStoreSearch = CreateWorkflow(VectorStoreSearchTask); diff --git a/packages/ai/src/task/VectorStoreUpsertTask.ts b/packages/ai/src/task/VectorStoreUpsertTask.ts new file mode 100644 index 00000000..851a07a5 --- /dev/null +++ b/packages/ai/src/task/VectorStoreUpsertTask.ts @@ -0,0 +1,183 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { IVectorRepository, TypeVectorRepository } from "@workglow/storage"; +import { + CreateWorkflow, + IExecuteContext, + JobQueueTaskConfig, + Task, + TaskRegistry, + Workflow, +} from "@workglow/task-graph"; +import { + DataPortSchema, + FromSchema, + TypedArray, + TypedArraySchema, + TypedArraySchemaOptions, +} from "@workglow/util"; + +const inputSchema = { + type: "object", + properties: { + repository: TypeVectorRepository({ + title: "Vector Repository", + description: "The vector repository instance to store vectors in", + }), + ids: { + oneOf: [{ type: "string" }, { type: "array", items: { type: "string" } }], + title: "IDs", + description: "Unique identifier(s) for the vector(s)", + }, + vectors: { + oneOf: [ + TypedArraySchema({ + title: "Vector", + description: "The vector embedding", + }), + { + type: "array", + items: TypedArraySchema({ + title: "Vector", + description: "The vector embedding", + }), + }, + ], + title: "Vectors", + description: "Vector embedding(s) to store", + }, + metadata: { + oneOf: [ + { + type: "object", + title: "Metadata", + description: "Metadata associated with the vector", + }, + { + type: "array", + items: { + type: "object", + title: "Metadata", + description: "Metadata associated with the vector", + }, + }, + ], + title: "Metadata", + description: "Metadata associated with the vector(s)", + }, + }, + required: ["repository", "ids", "vectors", "metadata"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +const outputSchema = { + type: "object", + properties: { + count: { + type: "number", + title: "Count", + description: "Number of vectors upserted", + }, + ids: { + type: "array", + items: { type: "string" }, + title: "IDs", + description: "IDs of upserted vectors", + }, + }, + required: ["count", "ids"], + additionalProperties: false, +} as const satisfies DataPortSchema; + +export type VectorStoreUpsertTaskInput = FromSchema; +export type VectorStoreUpsertTaskOutput = FromSchema; + +/** + * Task for upserting (insert or update) vectors into a vector repository. + * Supports both single and bulk operations. + */ +export class VectorStoreUpsertTask extends Task< + VectorStoreUpsertTaskInput, + VectorStoreUpsertTaskOutput, + JobQueueTaskConfig +> { + public static type = "VectorStoreUpsertTask"; + public static category = "Vector Store"; + public static title = "Vector Store Upsert"; + public static description = "Store vector embeddings with metadata in a vector repository"; + public static cacheable = false; // Has side effects + + public static inputSchema(): DataPortSchema { + return inputSchema as DataPortSchema; + } + + public static outputSchema(): DataPortSchema { + return outputSchema as DataPortSchema; + } + + async execute( + input: VectorStoreUpsertTaskInput, + context: IExecuteContext + ): Promise { + const { repository, ids, vectors, metadata } = input; + + // Normalize inputs to arrays + const idArray = Array.isArray(ids) ? ids : [ids]; + const vectorArray = Array.isArray(vectors) ? vectors : [vectors]; + const metadataArray = Array.isArray(metadata) ? metadata : [metadata]; + + // Validate lengths match + if (idArray.length !== vectorArray.length || idArray.length !== metadataArray.length) { + throw new Error( + `Mismatched array lengths: ids(${idArray.length}), vectors(${vectorArray.length}), metadata(${metadataArray.length})` + ); + } + + const repo = repository as IVectorRepository; + + await context.updateProgress(1, "Upserting vectors"); + + // Bulk upsert if multiple items + if (idArray.length > 1) { + const entries = idArray.map((id, i) => ({ + id, + vector: vectorArray[i], + metadata: metadataArray[i], + })); + await repo.upsertBulk(entries); + } else if (idArray.length === 1) { + // Single upsert + await repo.upsert(idArray[0], vectorArray[0], metadataArray[0]); + } + + return { + count: idArray.length, + ids: idArray, + }; + } +} + +TaskRegistry.registerTask(VectorStoreUpsertTask); + +export const vectorStoreUpsert = ( + input: VectorStoreUpsertTaskInput, + config?: JobQueueTaskConfig +) => { + return new VectorStoreUpsertTask({} as VectorStoreUpsertTaskInput, config).run(input); +}; + +declare module "@workglow/task-graph" { + interface Workflow { + vectorStoreUpsert: CreateWorkflow< + VectorStoreUpsertTaskInput, + VectorStoreUpsertTaskOutput, + JobQueueTaskConfig + >; + } +} + +Workflow.prototype.vectorStoreUpsert = CreateWorkflow(VectorStoreUpsertTask); diff --git a/packages/ai/src/task/base/AiTask.ts b/packages/ai/src/task/base/AiTask.ts index 08e2c301..155ee4fd 100644 --- a/packages/ai/src/task/base/AiTask.ts +++ b/packages/ai/src/task/base/AiTask.ts @@ -16,11 +16,12 @@ import { TaskInput, type TaskOutput, } from "@workglow/task-graph"; -import { type JsonSchema } from "@workglow/util"; +import { type JsonSchema, type ServiceRegistry } from "@workglow/util"; import { AiJob, AiJobInput } from "../../job/AiJob"; -import { getGlobalModelRepository } from "../../model/ModelRegistry"; -import type { ModelConfig, ModelRecord } from "../../model/ModelSchema"; +import { MODEL_REPOSITORY } from "../../model/ModelRegistry"; +import type { ModelRepository } from "../../model/ModelRepository"; +import type { ModelConfig } from "../../model/ModelSchema"; function schemaFormat(schema: JsonSchema): string | undefined { return typeof schema === "object" && schema !== null && "format" in schema @@ -32,21 +33,19 @@ export interface AiSingleTaskInput extends TaskInput { model: string | ModelConfig; } -export interface AiArrayTaskInput extends TaskInput { - model: string | ModelConfig | (string | ModelConfig)[]; -} - /** * A base class for AI related tasks that run in a job queue. * Extends the JobQueueTask class to provide LLM-specific functionality. + * + * Model resolution is handled automatically by the TaskRunner before execution. + * By the time execute() is called, input.model is always a ModelConfig object. */ export class AiTask< - Input extends AiArrayTaskInput = AiArrayTaskInput, + Input extends AiSingleTaskInput = AiSingleTaskInput, Output extends TaskOutput = TaskOutput, Config extends JobQueueTaskConfig = JobQueueTaskConfig, > extends JobQueueTask { public static type: string = "AiTask"; - private modelCache?: { name: string; model: ModelRecord }; /** * Creates a new AiTask instance @@ -56,11 +55,9 @@ export class AiTask< const modelLabel = typeof input.model === "string" ? input.model - : Array.isArray(input.model) - ? undefined - : typeof input.model === "object" && input.model - ? input.model.model_id || input.model.title || input.model.provider - : undefined; + : typeof input.model === "object" && input.model + ? input.model.model_id || input.model.title || input.model.provider + : undefined; config.name ||= `${new.target.type || new.target.name}${ modelLabel ? " with model " + modelLabel : "" }`; @@ -74,58 +71,31 @@ export class AiTask< /** * Get the input to submit to the job queue. * Transforms the task input to AiJobInput format. - * @param input - The task input + * + * Note: By the time this is called, input.model has already been resolved + * to a ModelConfig by the TaskRunner's input resolution system. + * + * @param input - The task input (with resolved model) * @returns The AiJobInput to submit to the queue */ protected override async getJobInput(input: Input): Promise> { - if (Array.isArray(input.model)) { - console.error("AiTask: Model is an array", input); + // Model is guaranteed to be resolved by TaskRunner before this is called + const model = input.model as ModelConfig; + if (!model || typeof model !== "object") { throw new TaskConfigurationError( - "AiTask: Model is an array, only create job for single model tasks" + "AiTask: Model was not resolved to ModelConfig - this indicates a bug in the resolution system" ); } - const runtype = (this.constructor as any).runtype ?? (this.constructor as any).type; - const model = await this.getModelConfigForInput(input as AiSingleTaskInput); - // TODO: if the queue is not memory based, we need to convert to something that can structure clone to the queue - // const registeredQueue = await this.resolveQueue(input); - // const queueName = registeredQueue?.server.queueName; + const runtype = (this.constructor as any).runtype ?? (this.constructor as any).type; return { taskType: runtype, aiProvider: model.provider, - taskInput: { ...(input as any), model } as Input & { model: ModelConfig }, + taskInput: input as Input & { model: ModelConfig }, }; } - /** - * Resolves a model configuration for the given input. - * - * @remarks - * - If `input.model` is a string, it is resolved via the global model repository. - * - If `input.model` is already a config object, it is used directly. - */ - protected async getModelConfigForInput(input: AiSingleTaskInput): Promise { - const modelValue = input.model; - if (!modelValue) throw new TaskConfigurationError("AiTask: No model found"); - if (typeof modelValue === "string") { - const modelname = modelValue; - if (this.modelCache && this.modelCache.name === modelname) { - return this.modelCache.model; - } - const model = await getGlobalModelRepository().findByName(modelname); - if (!model) { - throw new TaskConfigurationError(`AiTask: No model ${modelname} found`); - } - this.modelCache = { name: modelname, model }; - return model; - } - if (typeof modelValue === "object") { - return modelValue; - } - throw new TaskConfigurationError("AiTask: Invalid model value"); - } - /** * Creates a new Job instance for direct execution (without a queue). * @param input - The task input @@ -149,42 +119,25 @@ export class AiTask< return job; } - protected async getModelForInput(input: AiSingleTaskInput): Promise { - const modelname = input.model; - if (!modelname) throw new TaskConfigurationError("AiTask: No model name found"); - if (typeof modelname !== "string") { - throw new TaskConfigurationError("AiTask: Model name is not a string"); - } - if (this.modelCache && this.modelCache.name === modelname) { - return this.modelCache.model; - } - const model = await getGlobalModelRepository().findByName(modelname); - if (!model) { - throw new TaskConfigurationError(`JobQueueTask: No model ${modelname} found`); - } - this.modelCache = { name: modelname, model }; - return model; - } - + /** + * Gets the default queue name based on the model's provider. + * After TaskRunner resolution, input.model is a ModelConfig. + */ protected override async getDefaultQueueName(input: Input): Promise { - if (typeof input.model === "string") { - const model = await this.getModelForInput(input as AiSingleTaskInput); - return model.provider; - } - if (typeof input.model === "object" && input.model !== null && !Array.isArray(input.model)) { - return (input.model as ModelConfig).provider; - } - return undefined; + const model = input.model as ModelConfig; + return model?.provider; } /** - * Validates that a model name really exists - * @param schema The schema to validate against - * @param item The item to validate - * @returns True if the item is valid, false otherwise + * Validates that model inputs are valid ModelConfig objects. + * + * Note: By the time this is called, string model IDs have already been + * resolved to ModelConfig objects by the TaskRunner's input resolution system. + * + * @param input The input to validate + * @returns True if the input is valid */ async validateInput(input: Input): Promise { - // TODO(str): this is very inefficient, we should cache the results, including intermediate results const inputSchema = this.inputSchema(); if (typeof inputSchema === "boolean") { if (inputSchema === false) { @@ -192,59 +145,41 @@ export class AiTask< } return true; } + + // Find properties with model:TaskName format - need task compatibility check const modelTaskProperties = Object.entries( (inputSchema.properties || {}) as Record ).filter(([key, schema]) => schemaFormat(schema)?.startsWith("model:")); - if (modelTaskProperties.length > 0) { - const taskModels = await getGlobalModelRepository().findModelsByTask(this.type); - for (const [key, propSchema] of modelTaskProperties) { - let requestedModels = Array.isArray(input[key]) ? input[key] : [input[key]]; - for (const model of requestedModels) { - if (typeof model === "string") { - const foundModel = taskModels?.find((m) => m.model_id === model); - if (!foundModel) { - throw new TaskConfigurationError( - `AiTask: Missing model for '${key}' named '${model}' for task '${this.type}'` - ); - } - } else if (typeof model === "object" && model !== null) { - // Inline configs are accepted without requiring repository access. - // If 'tasks' is provided, do a best-effort compatibility check. - const tasks = (model as ModelConfig).tasks; - if (Array.isArray(tasks) && tasks.length > 0 && !tasks.includes(this.type)) { - throw new TaskConfigurationError( - `AiTask: Inline model for '${key}' is not compatible with task '${this.type}'` - ); - } - } else { - throw new TaskConfigurationError(`AiTask: Invalid model for '${key}'`); - } + for (const [key] of modelTaskProperties) { + const model = input[key]; + if (typeof model === "object" && model !== null) { + // Check task compatibility if tasks array is specified + const tasks = (model as ModelConfig).tasks; + if (Array.isArray(tasks) && tasks.length > 0 && !tasks.includes(this.type)) { + throw new TaskConfigurationError( + `AiTask: Model for '${key}' is not compatible with task '${this.type}'` + ); } + } else if (model !== undefined && model !== null) { + // Should be a ModelConfig object after resolution + throw new TaskConfigurationError( + `AiTask: Invalid model for '${key}' - expected ModelConfig object` + ); } } + // Find properties with plain model format - just ensure they're objects const modelPlainProperties = Object.entries( (inputSchema.properties || {}) as Record ).filter(([key, schema]) => schemaFormat(schema) === "model"); - if (modelPlainProperties.length > 0) { - for (const [key, propSchema] of modelPlainProperties) { - let requestedModels = Array.isArray(input[key]) ? input[key] : [input[key]]; - for (const model of requestedModels) { - if (typeof model === "string") { - const foundModel = await getGlobalModelRepository().findByName(model); - if (!foundModel) { - throw new TaskConfigurationError( - `AiTask: Missing model for "${key}" named "${model}"` - ); - } - } else if (typeof model === "object" && model !== null) { - // Inline configs are accepted without requiring repository access. - } else { - throw new TaskConfigurationError(`AiTask: Invalid model for "${key}"`); - } - } + for (const [key] of modelPlainProperties) { + const model = input[key]; + if (model !== undefined && model !== null && typeof model !== "object") { + throw new TaskConfigurationError( + `AiTask: Invalid model for '${key}' - expected ModelConfig object` + ); } } @@ -253,7 +188,7 @@ export class AiTask< // dataflows can strip some models that are incompatible with the target task // if all of them are stripped, then the task will fail in validateInput - async narrowInput(input: Input): Promise { + async narrowInput(input: Input, registry: ServiceRegistry): Promise { // TODO(str): this is very inefficient, we should cache the results, including intermediate results const inputSchema = this.inputSchema(); if (typeof inputSchema === "boolean") { @@ -266,34 +201,25 @@ export class AiTask< (inputSchema.properties || {}) as Record ).filter(([key, schema]) => schemaFormat(schema)?.startsWith("model:")); if (modelTaskProperties.length > 0) { - const taskModels = await getGlobalModelRepository().findModelsByTask(this.type); + const modelRepo = registry.get(MODEL_REPOSITORY); + const taskModels = await modelRepo.findModelsByTask(this.type); for (const [key, propSchema] of modelTaskProperties) { - let requestedModels = Array.isArray(input[key]) ? input[key] : [input[key]]; - const requestedStrings = requestedModels.filter( - (m: unknown): m is string => typeof m === "string" - ); - const requestedInline = requestedModels.filter( - (m: unknown): m is ModelConfig => typeof m === "object" && m !== null - ); + const requestedModel = input[key]; - const usingStrings = requestedStrings.filter((model: string) => - taskModels?.find((m) => m.model_id === model) - ); - - const usingInline = requestedInline.filter((model: ModelConfig) => { + if (typeof requestedModel === "string") { + // Verify string model ID is compatible + const found = taskModels?.find((m) => m.model_id === requestedModel); + if (!found) { + (input as any)[key] = undefined; + } + } else if (typeof requestedModel === "object" && requestedModel !== null) { + // Verify inline config is compatible + const model = requestedModel as ModelConfig; const tasks = model.tasks; - // Filter out inline configs with explicit incompatible tasks arrays - // This matches the validation logic in validateInput if (Array.isArray(tasks) && tasks.length > 0 && !tasks.includes(this.type)) { - return false; + (input as any)[key] = undefined; } - return true; - }); - - const combined: (string | ModelConfig)[] = [...usingInline, ...usingStrings]; - - // we alter input to be the models that were found for this kind of input - (input as any)[key] = combined.length > 1 ? combined : combined[0]; + } } } return input; diff --git a/packages/ai/src/task/base/AiTaskSchemas.ts b/packages/ai/src/task/base/AiTaskSchemas.ts index 67d2fcaf..a77aa7ca 100644 --- a/packages/ai/src/task/base/AiTaskSchemas.ts +++ b/packages/ai/src/task/base/AiTaskSchemas.ts @@ -7,174 +7,12 @@ import { DataPortSchemaNonBoolean, FromSchema, - FromSchemaDefaultOptions, - FromSchemaOptions, JsonSchema, + type TypedArray, + TypedArraySchemaOptions, } from "@workglow/util"; import { ModelConfigSchema } from "../../model/ModelSchema"; -export type TypedArray = - | Float64Array - | Float32Array - | Int32Array - | Int16Array - | Int8Array - | Uint32Array - | Uint16Array - | Uint8Array - | Uint8ClampedArray; - -// Type-only value for use in deserialize patterns -const TypedArrayType = null as any as TypedArray; - -const TypedArraySchemaOptions = { - ...FromSchemaDefaultOptions, - deserialize: [ - // { - // pattern: { - // type: "number"; - // "format": "BigInt" | "Float64"; - // }; - // output: bigint; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Float64Array"; - // }; - // output: Float64Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Float32Array"; - // }; - // output: Float32Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Int32Array"; - // }; - // output: Int32Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Int16Array"; - // }; - // output: Int16Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Int8Array"; - // }; - // output: Int8Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Uint8Array"; - // }; - // output: Uint8Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Uint16Array"; - // }; - // output: Uint16Array; - // }, - // { - // pattern: { - // type: "number"; - // "format": "Uint32Array"; - // }; - // output: Uint32Array; - // }, - // { - // pattern: { type: "array"; items: { type: "number" }; "format": "Uint8ClampedArray" }; - // output: Uint8ClampedArray; - // }, - { - pattern: { format: "TypedArray" }, - output: TypedArrayType, - }, - ], -} as const satisfies FromSchemaOptions; - -export type TypedArraySchemaOptions = typeof TypedArraySchemaOptions; - -export const TypedArraySchema = (annotations: Record = {}) => - ({ - oneOf: [ - { - type: "array", - items: { type: "number", format: "Float64" }, - title: "Float64Array", - description: "A 64-bit floating point array", - format: "Float64Array", - }, - { - type: "array", - items: { type: "number", format: "Float32" }, - title: "Float32Array", - description: "A 32-bit floating point array", - format: "Float32Array", - }, - { - type: "array", - items: { type: "number", format: "Int32" }, - title: "Int32Array", - description: "A 32-bit integer array", - format: "Int32Array", - }, - { - type: "array", - items: { type: "number", format: "Int16" }, - title: "Int16Array", - description: "A 16-bit integer array", - format: "Int16Array", - }, - { - type: "array", - items: { type: "number", format: "Int8" }, - title: "Int8Array", - }, - { - type: "array", - items: { type: "number", format: "Uint8" }, - title: "Uint8Array", - description: "A 8-bit unsigned integer array", - format: "Uint8Array", - }, - { - type: "array", - items: { type: "number", format: "Uint16" }, - title: "Uint16Array", - description: "A 16-bit unsigned integer array", - format: "Uint16Array", - }, - { - type: "array", - items: { type: "number", format: "Uint32" }, - title: "Uint32Array", - description: "A 32-bit unsigned integer array", - format: "Uint32Array", - }, - { - type: "array", - items: { type: "number", format: "Uint8Clamped" }, - title: "Uint8ClampedArray", - description: "A 8-bit unsigned integer array with values clamped to 0-255", - format: "Uint8ClampedArray", - }, - ], - format: "TypedArray", - ...annotations, - }) as const satisfies JsonSchema; - export const TypeLanguage = (annotations: Record = {}) => ({ type: "string", @@ -253,7 +91,7 @@ export const TypeReplicateArray = ( "x-replicate": true, }) as const; -export type TypedArrayFromSchema = FromSchema< +export type VectorFromSchema = FromSchema< SCHEMA, TypedArraySchemaOptions >; @@ -264,7 +102,7 @@ export type TypedArrayFromSchema = FromSchema< * Used to extract the single-value type from schemas with x-replicate annotation. * Uses distributive conditional types to filter out arrays from unions. * Checks for both array types and types with numeric index signatures (FromSchema array output). - * Preserves TypedArray types like Float64Array which also have numeric indices. + * Preserves Vector types like Float64Array which also have numeric indices. */ type UnwrapArrayUnion = T extends readonly any[] ? T extends TypedArray @@ -283,8 +121,8 @@ type UnwrapArrayUnion = T extends readonly any[] */ export type DeReplicateFromSchema }> = { [K in keyof S["properties"]]: S["properties"][K] extends { "x-replicate": true } - ? UnwrapArrayUnion> - : TypedArrayFromSchema; + ? UnwrapArrayUnion> + : VectorFromSchema; }; export type ImageSource = ImageBitmap | OffscreenCanvas | VideoFrame; diff --git a/packages/ai/src/task/base/AiVisionTask.ts b/packages/ai/src/task/base/AiVisionTask.ts index 16fbb7b4..52b5f2e6 100644 --- a/packages/ai/src/task/base/AiVisionTask.ts +++ b/packages/ai/src/task/base/AiVisionTask.ts @@ -19,16 +19,12 @@ export interface AiVisionTaskSingleInput extends TaskInput { model: string | ModelConfig; } -export interface AiVisionArrayTaskInput extends TaskInput { - model: string | ModelConfig | (string | ModelConfig)[]; -} - /** * A base class for AI related tasks that run in a job queue. * Extends the JobQueueTask class to provide LLM-specific functionality. */ export class AiVisionTask< - Input extends AiVisionArrayTaskInput = AiVisionArrayTaskInput, + Input extends AiVisionTaskSingleInput = AiVisionTaskSingleInput, Output extends TaskOutput = TaskOutput, Config extends JobQueueTaskConfig = JobQueueTaskConfig, > extends AiTask { diff --git a/packages/ai/src/task/index.ts b/packages/ai/src/task/index.ts index 91dbf459..afdf155c 100644 --- a/packages/ai/src/task/index.ts +++ b/packages/ai/src/task/index.ts @@ -7,18 +7,28 @@ export * from "./BackgroundRemovalTask"; export * from "./base/AiTask"; export * from "./base/AiTaskSchemas"; -export * from "./DocumentSplitterTask"; +export * from "./ChunkToVectorTask"; +export * from "./ContextBuilderTask"; +export * from "./DocumentEnricherTask"; export * from "./DownloadModelTask"; export * from "./FaceDetectorTask"; export * from "./FaceLandmarkerTask"; export * from "./GestureRecognizerTask"; export * from "./HandLandmarkerTask"; +export * from "./HierarchicalChunkerTask"; +export * from "./HierarchyJoinTask"; +export * from "./HybridSearchTask"; export * from "./ImageClassificationTask"; export * from "./ImageEmbeddingTask"; export * from "./ImageSegmentationTask"; export * from "./ImageToTextTask"; export * from "./ObjectDetectionTask"; export * from "./PoseLandmarkerTask"; +export * from "./QueryExpanderTask"; +export * from "./RerankerTask"; +export * from "./RetrievalTask"; +export * from "./TextChunkerTask"; +export * from "./StructuralParserTask"; export * from "./TextClassificationTask"; export * from "./TextEmbeddingTask"; export * from "./TextFillMaskTask"; @@ -29,5 +39,9 @@ export * from "./TextQuestionAnswerTask"; export * from "./TextRewriterTask"; export * from "./TextSummaryTask"; export * from "./TextTranslationTask"; +export * from "./TopicSegmenterTask"; export * from "./UnloadModelTask"; +export * from "./VectorQuantizeTask"; export * from "./VectorSimilarityTask"; +export * from "./VectorStoreSearchTask"; +export * from "./VectorStoreUpsertTask"; diff --git a/packages/debug/src/console/ConsoleFormatters.ts b/packages/debug/src/console/ConsoleFormatters.ts index 1e55021d..86c53ba9 100644 --- a/packages/debug/src/console/ConsoleFormatters.ts +++ b/packages/debug/src/console/ConsoleFormatters.ts @@ -73,8 +73,7 @@ class WorkflowConsoleFormatter extends ConsoleFormatter { body(obj: unknown, config?: Config): JsonMLElementDef { const body = new JsonMLElement("div"); - const graph: TaskGraph = - obj instanceof TaskGraph ? obj : (obj as Workflow).graph; + const graph: TaskGraph = obj instanceof TaskGraph ? obj : (obj as Workflow).graph; const nodes = body.createStyledList(); const tasks = graph.getTasks(); if (tasks.length) { @@ -314,7 +313,7 @@ class TaskConsoleFormatter extends ConsoleFormatter { const body = new JsonMLElement("div").setStyle("padding-left: 10px;"); const inputs = body.createStyledList("Inputs:"); - const allInboundDataflows = ((config as { graph?: TaskGraph })?.graph)?.getSourceDataflows( + const allInboundDataflows = (config as { graph?: TaskGraph })?.graph?.getSourceDataflows( task.config.id ); @@ -382,7 +381,6 @@ class TaskConsoleFormatter extends ConsoleFormatter { const taskConfig = body.createStyledList("Config:"); for (const [key, value] of Object.entries(task.config)) { if (value === undefined) continue; - if (key == "provenance") continue; const li = taskConfig.createListItem("", "padding-left: 20px;"); li.inputText(`${key}: `); li.createValueObject(value); @@ -750,7 +748,10 @@ interface NodeWithConfig { function computeLayout( graph: DirectedAcyclicGraph, canvasWidth: number -): { readonly positions: { readonly [id: string]: { readonly x: number; readonly y: number } }; readonly requiredHeight: number } { +): { + readonly positions: { readonly [id: string]: { readonly x: number; readonly y: number } }; + readonly requiredHeight: number; +} { const positions: { [id: string]: { x: number; y: number } } = {}; const layers: Map = new Map(); const depths: { [id: string]: number } = {}; @@ -873,4 +874,3 @@ export function installDevToolsFormatters(): void { new DAGConsoleFormatter() ); } - diff --git a/packages/storage/README.md b/packages/storage/README.md index 0b670966..aaf0c726 100644 --- a/packages/storage/README.md +++ b/packages/storage/README.md @@ -28,6 +28,7 @@ Modular storage solutions for Workglow.AI platform with multiple backend impleme - [Node.js Environment](#nodejs-environment) - [Bun Environment](#bun-environment) - [Advanced Features](#advanced-features) + - [Repository Registry](#repository-registry) - [Event-Driven Architecture](#event-driven-architecture) - [Compound Primary Keys](#compound-primary-keys) - [Custom File Layout (KV on filesystem)](#custom-file-layout-kv-on-filesystem) @@ -521,6 +522,95 @@ const cloudData = new SupabaseTabularRepository(supabase, "items", ItemSchema, [ ## Advanced Features +### Repository Registry + +Repositories can be registered globally by ID, allowing tasks to reference them by name rather than passing direct instances. This is useful for configuring repositories once at application startup and referencing them throughout your task graphs. + +#### Registering Repositories + +```typescript +import { + registerTabularRepository, + getTabularRepository, + InMemoryTabularRepository, +} from "@workglow/storage"; + +// Define your schema +const userSchema = { + type: "object", + properties: { + id: { type: "string" }, + name: { type: "string" }, + email: { type: "string" }, + }, + required: ["id", "name", "email"], + additionalProperties: false, +} as const; + +// Create and register a repository +const userRepo = new InMemoryTabularRepository(userSchema, ["id"] as const); +registerTabularRepository("users", userRepo); + +// Later, retrieve the repository by ID +const repo = getTabularRepository("users"); +``` + +#### Using Repositories in Tasks + +When using repositories with tasks, you can pass either the repository ID or a direct instance. The TaskRunner automatically resolves string IDs using the registry. + +```typescript +import { TypeTabularRepository } from "@workglow/storage"; + +// In your task's input schema, use TypeTabularRepository +static inputSchema() { + return { + type: "object", + properties: { + dataSource: TypeTabularRepository({ + title: "User Repository", + description: "Repository containing user records", + }), + }, + required: ["dataSource"], + }; +} + +// Both approaches work: +await task.run({ dataSource: "users" }); // Resolved from registry +await task.run({ dataSource: userRepoInstance }); // Direct instance +``` + +#### Schema Helper Functions + +The package provides schema helper functions for defining repository inputs with proper format annotations: + +```typescript +import { + TypeTabularRepository, + TypeVectorRepository, + TypeDocumentRepository, +} from "@workglow/storage"; + +// Tabular repository (format: "repository:tabular") +const tabularSchema = TypeTabularRepository({ + title: "Data Source", + description: "Tabular data repository", +}); + +// Vector repository (format: "repository:vector") +const vectorSchema = TypeVectorRepository({ + title: "Embeddings Store", + description: "Vector embeddings repository", +}); + +// Document repository (format: "repository:document") +const docSchema = TypeDocumentRepository({ + title: "Document Store", + description: "Document storage repository", +}); +``` + ### Event-Driven Architecture All storage implementations support event emission for monitoring and reactive programming: diff --git a/packages/storage/src/browser.ts b/packages/storage/src/browser.ts index 4a0d756c..c0632a59 100644 --- a/packages/storage/src/browser.ts +++ b/packages/storage/src/browser.ts @@ -20,3 +20,5 @@ export * from "./limiter/IndexedDbRateLimiterStorage"; export * from "./limiter/SupabaseRateLimiterStorage"; export * from "./util/IndexedDbTable"; + +export * from "./vector/EdgeVecRepository"; diff --git a/packages/storage/src/common-server.ts b/packages/storage/src/common-server.ts index ab1bca3f..4de35422 100644 --- a/packages/storage/src/common-server.ts +++ b/packages/storage/src/common-server.ts @@ -25,6 +25,9 @@ export * from "./limiter/PostgresRateLimiterStorage"; export * from "./limiter/SqliteRateLimiterStorage"; export * from "./limiter/SupabaseRateLimiterStorage"; +export * from "./vector/PostgresVectorRepository"; +export * from "./vector/SqliteVectorRepository"; + // testing export * from "./kv/IndexedDbKvRepository"; export * from "./limiter/IndexedDbRateLimiterStorage"; diff --git a/packages/storage/src/common.ts b/packages/storage/src/common.ts index 21bdb392..977a89d1 100644 --- a/packages/storage/src/common.ts +++ b/packages/storage/src/common.ts @@ -8,6 +8,9 @@ export * from "./tabular/CachedTabularRepository"; export * from "./tabular/InMemoryTabularRepository"; export * from "./tabular/ITabularRepository"; export * from "./tabular/TabularRepository"; +export * from "./tabular/TabularRepositoryRegistry"; + +export * from "./schema/RepositorySchema"; export * from "./kv/IKvRepository"; export * from "./kv/InMemoryKvRepository"; @@ -22,3 +25,7 @@ export * from "./limiter/IRateLimiterStorage"; export * from "./util/HybridSubscriptionManager"; export * from "./util/PollingSubscriptionManager"; + +export * from "./vector/InMemoryVectorRepository"; +export * from "./vector/IVectorRepository"; +export * from "./vector/VectorRepositoryRegistry"; diff --git a/packages/storage/src/schema/RepositorySchema.ts b/packages/storage/src/schema/RepositorySchema.ts new file mode 100644 index 00000000..2b3c6105 --- /dev/null +++ b/packages/storage/src/schema/RepositorySchema.ts @@ -0,0 +1,91 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { JsonSchema } from "@workglow/util"; + +/** + * Semantic format types for repository schema annotations. + * These are used by the InputResolver to determine how to resolve string IDs. + */ +export type RepositorySemantic = "repository:tabular" | "repository:vector" | "repository:document"; + +/** + * Creates a JSON schema for a tabular repository input. + * The schema accepts either a string ID (resolved from registry) or a direct repository instance. + * + * @param options Additional schema options to merge + * @returns JSON schema for tabular repository input + * + * @example + * ```typescript + * const inputSchema = { + * type: "object", + * properties: { + * dataSource: TypeTabularRepository({ + * title: "User Database", + * description: "Repository containing user records", + * }), + * }, + * required: ["dataSource"], + * } as const; + * ``` + */ +export function TypeTabularRepository = {}>( + options: O = {} as O +) { + return { + title: "Tabular Repository", + description: "Repository ID or instance for tabular data storage", + ...options, + format: "repository:tabular" as const, + oneOf: [ + { type: "string" as const, title: "Repository ID" }, + { title: "Repository Instance", additionalProperties: true }, + ], + } as const satisfies JsonSchema; +} + +/** + * Creates a JSON schema for a vector repository input. + * The schema accepts either a string ID (resolved from registry) or a direct repository instance. + * + * @param options Additional schema options to merge + * @returns JSON schema for vector repository input + */ +export function TypeVectorRepository = {}>(options: O = {} as O) { + return { + title: "Vector Repository", + description: "Repository ID or instance for vector data storage", + ...options, + format: "repository:vector" as const, + anyOf: [ + { type: "string" as const, title: "Repository ID" }, + { title: "Repository Instance", additionalProperties: true }, + ], + } as const satisfies JsonSchema; +} + +/** + * Creates a JSON schema for a document repository input. + * The schema accepts either a string ID (resolved from registry) or a direct repository instance. + * + * @param options Additional schema options to merge + * @returns JSON schema for document repository input + */ +export function TypeDocumentRepository = {}>( + options: O = {} as O +) { + return { + title: "Document Repository", + description: "Repository ID or instance for document data storage", + ...options, + format: "repository:document" as const, + anyOf: [ + { type: "string" as const, title: "Repository ID" }, + { title: "Repository Instance", additionalProperties: true }, + ], + } as const satisfies JsonSchema; +} diff --git a/packages/storage/src/tabular/TabularRepositoryRegistry.ts b/packages/storage/src/tabular/TabularRepositoryRegistry.ts new file mode 100644 index 00000000..887101dd --- /dev/null +++ b/packages/storage/src/tabular/TabularRepositoryRegistry.ts @@ -0,0 +1,87 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + createServiceToken, + globalServiceRegistry, + registerInputResolver, + ServiceRegistry, +} from "@workglow/util"; +import type { ITabularRepository } from "./ITabularRepository"; + +/** + * Service token for the tabular repository registry + * Maps repository IDs to ITabularRepository instances + */ +export const TABULAR_REPOSITORIES = createServiceToken< + Map> +>("storage.tabular.repositories"); + +// Register default factory if not already registered +if (!globalServiceRegistry.has(TABULAR_REPOSITORIES)) { + globalServiceRegistry.register( + TABULAR_REPOSITORIES, + (): Map> => new Map(), + true + ); +} + +/** + * Gets the global tabular repository registry + * @returns Map of tabular repository ID to instance + */ +export function getGlobalTabularRepositories(): Map< + string, + ITabularRepository +> { + return globalServiceRegistry.get(TABULAR_REPOSITORIES); +} + +/** + * Registers a tabular repository globally by ID + * @param id The unique identifier for this repository + * @param repository The repository instance to register + */ +export function registerTabularRepository( + id: string, + repository: ITabularRepository +): void { + const repos = getGlobalTabularRepositories(); + repos.set(id, repository); +} + +/** + * Gets a tabular repository by ID from the global registry + * @param id The repository identifier + * @returns The repository instance or undefined if not found + */ +export function getTabularRepository( + id: string +): ITabularRepository | undefined { + return getGlobalTabularRepositories().get(id); +} + +/** + * Resolves a repository ID to an instance from the registry. + * Used by the input resolver system. + */ +function resolveRepositoryFromRegistry( + id: string, + format: string, + registry: ServiceRegistry +): ITabularRepository { + const repos = registry.has(TABULAR_REPOSITORIES) + ? registry.get(TABULAR_REPOSITORIES) + : getGlobalTabularRepositories(); + const repo = repos.get(id); + if (!repo) { + throw new Error(`Tabular repository "${id}" not found in registry`); + } + return repo; +} + +// Register the repository resolver for format: "repository:tabular" +registerInputResolver("repository:tabular", resolveRepositoryFromRegistry); diff --git a/packages/storage/src/vector/EdgeVecRepository.ts b/packages/storage/src/vector/EdgeVecRepository.ts new file mode 100644 index 00000000..82d2ff78 --- /dev/null +++ b/packages/storage/src/vector/EdgeVecRepository.ts @@ -0,0 +1,397 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { cosineSimilarity, EventEmitter, TypedArray } from "@workglow/util"; +import { + HybridSearchOptions, + IVectorRepository, + SearchResult, + VectorEntry, + VectorEventListeners, + VectorSearchOptions, +} from "./IVectorRepository"; + +/** + * Check if metadata matches filter + */ +function matchesFilter(metadata: Metadata, filter: Partial): boolean { + for (const [key, value] of Object.entries(filter)) { + if (metadata[key as keyof Metadata] !== value) { + return false; + } + } + return true; +} + +/** + * Simple full-text search scoring (keyword matching) + */ +function textRelevance(text: string, query: string): number { + const textLower = text.toLowerCase(); + const queryLower = query.toLowerCase(); + const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); + if (queryWords.length === 0) { + return 0; + } + let matches = 0; + for (const word of queryWords) { + if (textLower.includes(word)) { + matches++; + } + } + return matches / queryWords.length; +} + +/** + * EdgeVec vector repository implementation. + * Optimized for edge/browser deployment with minimal dependencies. + * Stores vectors in memory with optional IndexedDB persistence. + * Designed for privacy-sensitive on-device RAG applications. + * + * Features: + * - Lightweight in-memory storage + * - Optional IndexedDB persistence for browser + * - WebGPU/WASM acceleration support (when available) + * - Supports quantized vectors (Int8Array, Uint8Array, etc.) + * - No server dependency + * - Privacy-first design + * + * @template Metadata - Type for metadata associated with vectors + * @template Vector - Type of vector array (Float32Array, Int8Array, etc.) + */ +export class EdgeVecRepository< + Metadata = Record, + VectorChoice extends TypedArray = Float32Array, +> + extends EventEmitter> + implements IVectorRepository +{ + private vectors: Map> = new Map(); + private dbName?: string; + private db?: IDBDatabase; + private initialized = false; + private useWebGPU = false; + private gpuDevice?: any; + + /** + * Creates a new EdgeVec repository + * @param options - Configuration options + */ + constructor( + options: { + /** IndexedDB database name for persistence (browser only) */ + dbName?: string; + /** Enable WebGPU acceleration if available */ + enableWebGPU?: boolean; + } = {} + ) { + super(); + this.dbName = options.dbName; + this.useWebGPU = options.enableWebGPU ?? false; + } + + async setupDatabase(): Promise { + if (this.initialized) { + return; + } + + // Initialize WebGPU if requested and available + if (this.useWebGPU && typeof navigator !== "undefined" && "gpu" in navigator) { + try { + const adapter = await (navigator as any).gpu.requestAdapter(); + if (adapter) { + this.gpuDevice = await adapter.requestDevice(); + } + } catch (error) { + console.warn("WebGPU initialization failed, falling back to CPU:", error); + } + } + + // Initialize IndexedDB if dbName provided (browser only) + if (this.dbName && typeof indexedDB !== "undefined") { + await this.initIndexedDB(); + await this.loadFromIndexedDB(); + } + + this.initialized = true; + } + + private async initIndexedDB(): Promise { + return new Promise((resolve, reject) => { + const request = indexedDB.open(this.dbName!, 1); + + request.onerror = () => reject(request.error); + request.onsuccess = () => { + this.db = request.result; + resolve(); + }; + + request.onupgradeneeded = (event) => { + const db = (event.target as IDBOpenDBRequest).result; + if (!db.objectStoreNames.contains("vectors")) { + db.createObjectStore("vectors", { keyPath: "id" }); + } + }; + }); + } + + private async loadFromIndexedDB(): Promise { + if (!this.db) return; + + return new Promise((resolve, reject) => { + const transaction = this.db!.transaction(["vectors"], "readonly"); + const store = transaction.objectStore("vectors"); + const request = store.getAll(); + + request.onerror = () => reject(request.error); + request.onsuccess = () => { + const entries = request.result as Array<{ + id: string; + vector: number[]; + metadata: Metadata; + }>; + for (const entry of entries) { + this.vectors.set(entry.id, { + id: entry.id, + vector: this.copyVector(new Float32Array(entry.vector)) as VectorChoice, + metadata: entry.metadata, + }); + } + resolve(); + }; + }); + } + + private async saveToIndexedDB(entry: VectorEntry): Promise { + if (!this.db) return; + + return new Promise((resolve, reject) => { + const transaction = this.db!.transaction(["vectors"], "readwrite"); + const store = transaction.objectStore("vectors"); + const request = store.put({ + id: entry.id, + vector: Array.from(entry.vector), + metadata: entry.metadata, + }); + + request.onerror = () => reject(request.error); + request.onsuccess = () => resolve(); + }); + } + + private async deleteFromIndexedDB(id: string): Promise { + if (!this.db) return; + + return new Promise((resolve, reject) => { + const transaction = this.db!.transaction(["vectors"], "readwrite"); + const store = transaction.objectStore("vectors"); + const request = store.delete(id); + + request.onerror = () => reject(request.error); + request.onsuccess = () => resolve(); + }); + } + + async upsert(id: string, vector: VectorChoice, metadata: Metadata): Promise { + const entry: VectorEntry = { + id, + vector: this.copyVector(vector) as VectorChoice, + metadata: { ...metadata } as Metadata, + }; + this.vectors.set(id, entry); + + if (this.db) { + await this.saveToIndexedDB(entry); + } + + this.emit("upsert", entry); + } + + async upsertBulk(items: VectorEntry[]): Promise { + for (const item of items) { + const entry: VectorEntry = { + id: item.id, + vector: this.copyVector(item.vector) as VectorChoice, + metadata: { ...item.metadata } as Metadata, + }; + this.vectors.set(item.id, entry); + + if (this.db) { + await this.saveToIndexedDB(entry); + } + + this.emit("upsert", entry); + } + } + + /** + * Copy a vector to avoid external mutations + */ + private copyVector(vector: TypedArray): TypedArray { + if (vector instanceof Float32Array) return new Float32Array(vector); + if (vector instanceof Float64Array) return new Float64Array(vector); + if (vector instanceof Int8Array) return new Int8Array(vector); + if (vector instanceof Uint8Array) return new Uint8Array(vector); + if (vector instanceof Int16Array) return new Int16Array(vector); + if (vector instanceof Uint16Array) return new Uint16Array(vector); + return new Float32Array(vector); + } + + async similaritySearch( + query: VectorChoice, + options: VectorSearchOptions = {} + ): Promise[]> { + const { topK = 10, filter, scoreThreshold = 0 } = options; + const results: SearchResult[] = []; + + // Use WebGPU acceleration if available + if (this.gpuDevice && this.vectors.size > 100) { + // TODO: Implement WebGPU-accelerated similarity computation + // For now, fall back to CPU + } + + // CPU-based similarity computation + for (const entry of this.vectors.values()) { + if (filter && !matchesFilter(entry.metadata, filter)) { + continue; + } + + const score = cosineSimilarity(query, entry.vector); + + if (score >= scoreThreshold) { + results.push({ + id: entry.id, + vector: entry.vector, + metadata: entry.metadata, + score, + }); + } + } + + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + this.emit("search", query, topResults); + return topResults; + } + + async hybridSearch( + query: VectorChoice, + options: HybridSearchOptions + ): Promise[]> { + const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; + + if (!textQuery || textQuery.trim().length === 0) { + return this.similaritySearch(query, { topK, filter, scoreThreshold }); + } + + const results: SearchResult[] = []; + + for (const entry of this.vectors.values()) { + if (filter && !matchesFilter(entry.metadata, filter)) { + continue; + } + + const vectorScore = cosineSimilarity(query, entry.vector); + const metadataText = JSON.stringify(entry.metadata).toLowerCase(); + const textScore = textRelevance(metadataText, textQuery); + const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; + + if (combinedScore >= scoreThreshold) { + results.push({ + id: entry.id, + vector: entry.vector, + metadata: entry.metadata, + score: combinedScore, + }); + } + } + + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + this.emit("search", query, topResults); + return topResults; + } + + async get(id: string): Promise | undefined> { + const entry = this.vectors.get(id); + if (entry) { + return { + id: entry.id, + vector: this.copyVector(entry.vector) as VectorChoice, + metadata: { ...entry.metadata } as Metadata, + }; + } + return undefined; + } + + async delete(id: string): Promise { + if (this.vectors.has(id)) { + this.vectors.delete(id); + if (this.db) { + await this.deleteFromIndexedDB(id); + } + this.emit("delete", id); + } + } + + async deleteBulk(ids: string[]): Promise { + for (const id of ids) { + if (this.vectors.has(id)) { + this.vectors.delete(id); + if (this.db) { + await this.deleteFromIndexedDB(id); + } + this.emit("delete", id); + } + } + } + + async deleteByFilter(filter: Partial): Promise { + const idsToDelete: string[] = []; + for (const entry of this.vectors.values()) { + if (matchesFilter(entry.metadata, filter)) { + idsToDelete.push(entry.id); + } + } + await this.deleteBulk(idsToDelete); + } + + async size(): Promise { + return this.vectors.size; + } + + async clear(): Promise { + const ids = Array.from(this.vectors.keys()); + this.vectors.clear(); + + if (this.db) { + const transaction = this.db.transaction(["vectors"], "readwrite"); + const store = transaction.objectStore("vectors"); + store.clear(); + } + + for (const id of ids) { + this.emit("delete", id); + } + } + + destroy(): void { + this.vectors.clear(); + if (this.db) { + this.db.close(); + } + this.removeAllListeners(); + } + + /** + * Get WebGPU device if available + */ + getGPUDevice(): any | undefined { + return this.gpuDevice; + } +} diff --git a/packages/storage/src/vector/IVectorRepository.ts b/packages/storage/src/vector/IVectorRepository.ts new file mode 100644 index 00000000..aef2188a --- /dev/null +++ b/packages/storage/src/vector/IVectorRepository.ts @@ -0,0 +1,224 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { DataPortSchemaObject, EventParameters, TypedArray } from "@workglow/util"; + +/** + * Schema for vector storage in tabular format. + * In-memory implementations may store vector as TypedArray directly, + * while SQL implementations serialize to JSON string. + */ +export const VectorSchema = { + type: "object", + properties: { + id: { type: "string" }, + vector: { type: "string" }, // TypedArray in memory, JSON string in SQL + metadata: { type: "string" }, // JSON-serialized metadata + }, + required: ["id", "vector", "metadata"], + additionalProperties: false, +} as const satisfies DataPortSchemaObject; + +export type VectorEntity = { + readonly id: string; + readonly vector: string | TypedArray; + readonly metadata: string; +}; + +/** + * A vector entry with its associated metadata + */ +export interface VectorEntry< + Metadata = Record, + VectorChoice extends TypedArray = Float32Array, +> { + readonly id: string; + readonly vector: VectorChoice; + readonly metadata: Metadata; +} + +/** + * A search result with similarity score + */ +export interface SearchResult< + Metadata = Record, + VectorChoice extends TypedArray = Float32Array, +> { + readonly id: string; + readonly vector: VectorChoice; + readonly metadata: Metadata; + readonly score: number; +} + +/** + * Options for vector search operations + */ +export interface VectorSearchOptions< + Metadata = Record, + VectorChoice extends TypedArray = Float32Array, +> { + /** Maximum number of results to return */ + topK?: number; + /** Filter by metadata fields */ + filter?: Partial; + /** Minimum similarity score threshold */ + scoreThreshold?: number; +} + +/** + * Options for hybrid search (vector + full-text) + */ +export interface HybridSearchOptions< + Metadata = Record, + VectorChoice extends TypedArray = Float32Array, +> extends VectorSearchOptions { + /** Full-text query string */ + textQuery: string; + /** Weight for vector similarity (0-1), remainder goes to text relevance */ + vectorWeight?: number; +} + +/** + * Type definitions for vector repository events + */ +export type VectorEventListeners = { + upsert: (entry: VectorEntry) => void; + delete: (id: string) => void; + search: (query: VectorChoice, results: SearchResult[]) => void; +}; + +export type VectorEventName = keyof VectorEventListeners; +export type VectorEventListener< + Event extends VectorEventName, + Metadata, + VectorChoice extends TypedArray = Float32Array, +> = VectorEventListeners[Event]; + +export type VectorEventParameters< + Event extends VectorEventName, + Metadata, + VectorChoice extends TypedArray = Float32Array, +> = EventParameters, Event>; + +/** + * Interface defining the contract for vector storage repositories. + * While the interface doesn't formally extend ITabularRepository (due to signature differences), + * implementations typically extend tabular repository implementations for code reuse. + * Provides operations for storing, retrieving, and searching vector embeddings. + * Supports various vector types including quantized formats. + * + * @typeParam Metadata - Type for metadata associated with vectors + * @typeParam VectorChoice - Type of vector array (Float32Array, Int8Array, etc.) + */ +export interface IVectorRepository< + Metadata = Record, + VectorChoice extends TypedArray = Float32Array, +> { + /** + * Upsert a vector entry (insert or update) + * @param id - Unique identifier for the vector + * @param vector - The vector embedding (Float32Array, Int8Array, etc.) + * @param metadata - Associated metadata + */ + upsert(id: string, vector: VectorChoice, metadata: Metadata): Promise; + + /** + * Upsert multiple vector entries in bulk + * @param items - Array of vector entries to upsert + */ + upsertBulk(items: VectorEntry[]): Promise; + + /** + * Search for similar vectors using similarity scoring + * @param query - Query vector to compare against + * @param options - Search options (topK, filter, scoreThreshold) + * @returns Array of search results sorted by similarity (highest first) + */ + similaritySearch( + query: VectorChoice, + options?: VectorSearchOptions + ): Promise[]>; + + /** + * Hybrid search combining vector similarity with full-text search + * This is optional and may not be supported by all implementations + * @param query - Query vector to compare against + * @param options - Hybrid search options including text query + * @returns Array of search results sorted by combined relevance + */ + hybridSearch?( + query: VectorChoice, + options: HybridSearchOptions + ): Promise[]>; + + /** + * Get a vector entry by ID + * @param id - Unique identifier + * @returns The vector entry or undefined if not found + */ + get(id: string): Promise | undefined>; + + /** + * Delete a vector entry by ID + * @param id - Unique identifier + */ + delete(id: string): Promise; + + /** + * Delete multiple vector entries by IDs + * @param ids - Array of unique identifiers + */ + deleteBulk(ids: string[]): Promise; + + /** + * Delete vectors matching metadata filter + * @param filter - Partial metadata to match + */ + deleteByFilter(filter: Partial): Promise; + + /** + * Get the number of vectors stored + * @returns Total count of vectors + */ + size(): Promise; + + /** + * Clear all vectors from the repository + */ + clear(): Promise; + + /** + * Set up the repository (create tables, indexes, etc.) + * Must be called before using other methods + */ + setupDatabase(): Promise; + + /** + * Destroy the repository and free resources + */ + destroy(): void; + + // Event handling methods + on( + name: Event, + fn: VectorEventListener + ): void; + off( + name: Event, + fn: VectorEventListener + ): void; + emit( + name: Event, + ...args: VectorEventParameters + ): void; + once( + name: Event, + fn: VectorEventListener + ): void; + waitOn( + name: Event + ): Promise>; +} diff --git a/packages/storage/src/vector/InMemoryVectorRepository.ts b/packages/storage/src/vector/InMemoryVectorRepository.ts new file mode 100644 index 00000000..26bc1052 --- /dev/null +++ b/packages/storage/src/vector/InMemoryVectorRepository.ts @@ -0,0 +1,264 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { cosineSimilarity, EventEmitter, TypedArray } from "@workglow/util"; +import { InMemoryTabularRepository } from "../tabular/InMemoryTabularRepository"; +import { + HybridSearchOptions, + IVectorRepository, + SearchResult, + VectorEntry, + VectorEventListeners, + VectorSchema, + VectorSearchOptions, +} from "./IVectorRepository"; + +/** + * Check if metadata matches filter + */ +function matchesFilter(metadata: Metadata, filter: Partial): boolean { + for (const [key, value] of Object.entries(filter)) { + if (metadata[key as keyof Metadata] !== value) { + return false; + } + } + return true; +} + +/** + * Simple full-text search scoring (keyword matching) + */ +function textRelevance(text: string, query: string): number { + const textLower = text.toLowerCase(); + const queryLower = query.toLowerCase(); + const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); + if (queryWords.length === 0) { + return 0; + } + let matches = 0; + for (const word of queryWords) { + if (textLower.includes(word)) { + matches++; + } + } + return matches / queryWords.length; +} + +/** + * In-memory vector repository implementation. + * Uses InMemoryTabularRepository internally for storage. + * Suitable for testing and small-scale browser applications. + * Supports all vector types including quantized formats. + * + * @template Metadata - Type for metadata associated with vectors + * @template VectorChoice - Type of vector array (Float32Array, Int8Array, etc.) + */ +export class InMemoryVectorRepository< + Metadata = Record, + VectorChoice extends TypedArray = Float32Array, +> + extends EventEmitter> + implements IVectorRepository +{ + private tabularRepo: InMemoryTabularRepository; + + /** + * Creates a new in-memory vector repository + */ + constructor() { + super(); + this.tabularRepo = new InMemoryTabularRepository(VectorSchema, ["id"] as const, []); + } + + async setupDatabase(): Promise { + await this.tabularRepo.setupDatabase(); + } + + async upsert(id: string, vector: VectorChoice, metadata: Metadata): Promise { + const entity = { + id, + vector: vector as any, // Store TypedArray directly in memory + metadata: JSON.stringify(metadata), + }; + await this.tabularRepo.put(entity); + this.emit("upsert", { id, vector, metadata }); + } + + async upsertBulk(items: VectorEntry[]): Promise { + const entities = items.map((item) => ({ + id: item.id, + vector: item.vector as any, + metadata: JSON.stringify(item.metadata), + })); + await this.tabularRepo.putBulk(entities); + for (const item of items) { + this.emit("upsert", item); + } + } + + /** + * Copy a vector to avoid external mutations + */ + private copyVector(vector: TypedArray): VectorChoice { + if (vector instanceof Float32Array) return new Float32Array(vector) as VectorChoice; + if (vector instanceof Float64Array) return new Float64Array(vector) as VectorChoice; + if (vector instanceof Int8Array) return new Int8Array(vector) as VectorChoice; + if (vector instanceof Uint8Array) return new Uint8Array(vector) as VectorChoice; + if (vector instanceof Int16Array) return new Int16Array(vector) as VectorChoice; + if (vector instanceof Uint16Array) return new Uint16Array(vector) as VectorChoice; + return new Float32Array(vector) as VectorChoice; + } + + async similaritySearch( + query: VectorChoice, + options: VectorSearchOptions = {} + ): Promise[]> { + const { topK = 10, filter, scoreThreshold = 0 } = options; + const results: SearchResult[] = []; + + const allEntities = (await this.tabularRepo.getAll()) || []; + + for (const entity of allEntities) { + const vector = entity.vector as unknown as VectorChoice; + const metadata = JSON.parse(entity.metadata) as Metadata; + + // Apply filter if provided + if (filter && !matchesFilter(metadata, filter)) { + continue; + } + + // Calculate similarity + const score = cosineSimilarity(query, vector); + + // Apply threshold + if (score < scoreThreshold) { + continue; + } + + results.push({ + id: entity.id, + vector, + metadata, + score, + }); + } + + // Sort by score descending and take top K + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + this.emit("search", query, topResults); + return topResults; + } + + async hybridSearch( + query: VectorChoice, + options: HybridSearchOptions + ): Promise[]> { + const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; + + if (!textQuery || textQuery.trim().length === 0) { + // Fall back to regular vector search if no text query + return this.similaritySearch(query, { topK, filter, scoreThreshold }); + } + + const results: SearchResult[] = []; + const allEntities = (await this.tabularRepo.getAll()) || []; + + for (const entity of allEntities) { + const vector = entity.vector as unknown as VectorChoice; + const metadata = JSON.parse(entity.metadata) as Metadata; + + // Apply filter if provided + if (filter && !matchesFilter(metadata, filter)) { + continue; + } + + // Calculate vector similarity + const vectorScore = cosineSimilarity(query, vector); + + // Calculate text relevance (simple keyword matching) + const metadataText = entity.metadata.toLowerCase(); + const textScore = textRelevance(metadataText, textQuery); + + // Combine scores + const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; + + // Apply threshold + if (combinedScore < scoreThreshold) { + continue; + } + + results.push({ + id: entity.id, + vector, + metadata, + score: combinedScore, + }); + } + + // Sort by combined score descending and take top K + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + this.emit("search", query, topResults); + return topResults; + } + + async get(id: string): Promise | undefined> { + const entity = await this.tabularRepo.get({ id }); + if (entity) { + return { + id: entity.id, + vector: this.copyVector(entity.vector as unknown as TypedArray), + metadata: JSON.parse(entity.metadata) as Metadata, + }; + } + return undefined; + } + + async delete(id: string): Promise { + await this.tabularRepo.delete({ id }); + this.emit("delete", id); + } + + async deleteBulk(ids: string[]): Promise { + for (const id of ids) { + await this.delete(id); + } + } + + async deleteByFilter(filter: Partial): Promise { + const allEntities = (await this.tabularRepo.getAll()) || []; + const idsToDelete: string[] = []; + + for (const entity of allEntities) { + const metadata = JSON.parse(entity.metadata) as Metadata; + if (matchesFilter(metadata, filter)) { + idsToDelete.push(entity.id); + } + } + + await this.deleteBulk(idsToDelete); + } + + async size(): Promise { + return await this.tabularRepo.size(); + } + + async clear(): Promise { + const allEntities = (await this.tabularRepo.getAll()) || []; + await this.tabularRepo.deleteAll(); + for (const entity of allEntities) { + this.emit("delete", entity.id); + } + } + + destroy(): void { + this.tabularRepo.destroy(); + this.removeAllListeners(); + } +} diff --git a/packages/storage/src/vector/PostgresVectorRepository.ts b/packages/storage/src/vector/PostgresVectorRepository.ts new file mode 100644 index 00000000..46a790c7 --- /dev/null +++ b/packages/storage/src/vector/PostgresVectorRepository.ts @@ -0,0 +1,551 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { cosineSimilarity, EventEmitter, TypedArray } from "@workglow/util"; +import type { Pool } from "pg"; +import { PostgresTabularRepository } from "../tabular/PostgresTabularRepository"; +import { + HybridSearchOptions, + IVectorRepository, + SearchResult, + VectorEntry, + VectorEventListeners, + VectorSchema, + VectorSearchOptions, +} from "./IVectorRepository"; + +type VectorRow = { + id: string; + vector: string; + metadata: string; +}; + +/** + * PostgreSQL vector repository implementation using pgvector extension. + * Uses tabular repository underneath for consistency. + * Provides efficient vector similarity search with native database support. + * + * Requirements: + * - PostgreSQL database with pgvector extension installed + * - CREATE EXTENSION vector; + * + * @template Metadata - Type for metadata associated with vectors + * @template VectorChoice - Type of vector array (Float32Array, Int8Array, etc.) + */ +export class PostgresVectorRepository< + Metadata = Record, + VectorChoice extends TypedArray = Float32Array, +> + extends EventEmitter> + implements IVectorRepository +{ + private tabularRepo: PostgresTabularRepository; + private db: Pool; + private table: string; + private vectorDimension: number; + private initialized = false; + private useNativeVector = false; + + /** + * Creates a new PostgreSQL vector repository + * @param db - PostgreSQL connection pool + * @param table - The name of the table to use for storage (defaults to 'vectors') + * @param vectorDimension - Dimension of vectors (e.g., 384, 768, 1536) + */ + constructor(db: Pool, table: string = "vectors", vectorDimension: number = 384) { + super(); + this.db = db; + this.table = table; + this.vectorDimension = vectorDimension; + this.tabularRepo = new PostgresTabularRepository( + db, + table, + VectorSchema, + ["id"] as const, + [] // We'll create custom indexes + ); + } + + async setupDatabase(): Promise { + if (this.initialized) { + return; + } + + // Check if pgvector is available + try { + await this.db.query("CREATE EXTENSION IF NOT EXISTS vector"); + this.useNativeVector = true; + + // Create table with native vector column + await this.db.query(` + CREATE TABLE IF NOT EXISTS "${this.table}" ( + id TEXT PRIMARY KEY, + vector vector(${this.vectorDimension}), + metadata JSONB NOT NULL + ) + `); + + // Create HNSW index for fast similarity search + await this.db.query(` + CREATE INDEX IF NOT EXISTS "${this.table}_vector_idx" + ON "${this.table}" + USING hnsw (vector vector_cosine_ops) + `); + + // Create GIN index on metadata for filtering + await this.db.query(` + CREATE INDEX IF NOT EXISTS "${this.table}_metadata_idx" + ON "${this.table}" + USING gin (metadata) + `); + } catch (error) { + console.warn("pgvector not available, falling back to tabular storage:", error); + this.useNativeVector = false; + // Fall back to tabular repository + await this.tabularRepo.setupDatabase(); + } + + this.initialized = true; + } + + async upsert(id: string, vector: VectorChoice, metadata: Metadata): Promise { + const vectorArray = Array.from(vector); + const vectorJson = JSON.stringify(vectorArray); + const metadataJson = JSON.stringify(metadata); + + if (this.useNativeVector) { + const vectorStr = `[${vectorArray.join(",")}]`; + await this.db.query( + ` + INSERT INTO "${this.table}" (id, vector, metadata) + VALUES ($1, $2, $3) + ON CONFLICT (id) DO UPDATE + SET vector = EXCLUDED.vector, metadata = EXCLUDED.metadata + `, + [id, vectorStr, metadataJson] + ); + } else { + await this.tabularRepo.put({ + id, + vector: vectorJson, + metadata: metadataJson, + }); + } + + this.emit("upsert", { id, vector, metadata }); + } + + async upsertBulk(items: VectorEntry[]): Promise { + if (items.length === 0) return; + + if (this.useNativeVector) { + const values: string[] = []; + const params: any[] = []; + let paramIndex = 1; + + for (const item of items) { + const vectorArray = Array.from(item.vector); + const vectorStr = `[${vectorArray.join(",")}]`; + const metadataJson = JSON.stringify(item.metadata); + + values.push(`($${paramIndex}, $${paramIndex + 1}, $${paramIndex + 2})`); + params.push(item.id, vectorStr, metadataJson); + paramIndex += 3; + } + + await this.db.query( + ` + INSERT INTO "${this.table}" (id, vector, metadata) + VALUES ${values.join(", ")} + ON CONFLICT (id) DO UPDATE + SET vector = EXCLUDED.vector, metadata = EXCLUDED.metadata + `, + params + ); + } else { + const rows: VectorRow[] = items.map((item) => ({ + id: item.id, + vector: JSON.stringify(Array.from(item.vector)), + metadata: JSON.stringify(item.metadata), + })); + await this.tabularRepo.putBulk(rows); + } + + for (const item of items) { + this.emit("upsert", item); + } + } + + async similaritySearch( + query: VectorChoice, + options: VectorSearchOptions = {} + ): Promise[]> { + const { topK = 10, filter, scoreThreshold = 0 } = options; + + if (this.useNativeVector) { + const queryVector = `[${Array.from(query).join(",")}]`; + let sql = ` + SELECT + id, + metadata, + 1 - (vector <=> $1::vector) as score + FROM "${this.table}" + `; + + const params: any[] = [queryVector]; + let paramIndex = 2; + + if (filter && Object.keys(filter).length > 0) { + const conditions: string[] = []; + for (const [key, value] of Object.entries(filter)) { + conditions.push(`metadata->>'${key}' = $${paramIndex}`); + params.push(String(value)); + paramIndex++; + } + sql += ` WHERE ${conditions.join(" AND ")}`; + } + + if (scoreThreshold > 0) { + sql += filter ? " AND" : " WHERE"; + sql += ` (1 - (vector <=> $1::vector)) >= $${paramIndex}`; + params.push(scoreThreshold); + paramIndex++; + } + + sql += ` ORDER BY vector <=> $1::vector LIMIT $${paramIndex}`; + params.push(topK); + + const result = await this.db.query(sql, params); + + // Fetch vectors separately for each result + const results: SearchResult[] = []; + for (const row of result.rows) { + const vectorResult = await this.db.query( + `SELECT vector::text FROM "${this.table}" WHERE id = $1`, + [row.id] + ); + const vectorStr = vectorResult.rows[0]?.vector || "[]"; + const vectorArray = JSON.parse(vectorStr); + + results.push({ + id: row.id, + vector: this.deserializeVector(JSON.stringify(vectorArray)), + metadata: typeof row.metadata === "string" ? JSON.parse(row.metadata) : row.metadata, + score: parseFloat(row.score), + }); + } + + this.emit("search", query, results); + return results; + } else { + // Fall back to in-memory similarity calculation + return this.searchFallback(query, options); + } + } + + async hybridSearch( + query: VectorChoice, + options: HybridSearchOptions + ): Promise[]> { + const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; + + if (!textQuery || textQuery.trim().length === 0) { + return this.similaritySearch(query, { topK, filter, scoreThreshold }); + } + + if (this.useNativeVector) { + const queryVector = `[${Array.from(query).join(",")}]`; + const tsQuery = textQuery.split(/\s+/).join(" & "); + + let sql = ` + SELECT + id, + metadata, + ( + $2 * (1 - (vector <=> $1::vector)) + + $3 * ts_rank(to_tsvector('english', metadata::text), to_tsquery('english', $4)) + ) as score + FROM "${this.table}" + `; + + const params: any[] = [queryVector, vectorWeight, 1 - vectorWeight, tsQuery]; + let paramIndex = 5; + + if (filter && Object.keys(filter).length > 0) { + const conditions: string[] = []; + for (const [key, value] of Object.entries(filter)) { + conditions.push(`metadata->>'${key}' = $${paramIndex}`); + params.push(String(value)); + paramIndex++; + } + sql += ` WHERE ${conditions.join(" AND ")}`; + } + + if (scoreThreshold > 0) { + sql += filter ? " AND" : " WHERE"; + sql += ` ( + $2 * (1 - (vector <=> $1::vector)) + + $3 * ts_rank(to_tsvector('english', metadata::text), to_tsquery('english', $4)) + ) >= $${paramIndex}`; + params.push(scoreThreshold); + paramIndex++; + } + + sql += ` ORDER BY score DESC LIMIT $${paramIndex}`; + params.push(topK); + + const result = await this.db.query(sql, params); + + // Fetch vectors separately for each result + const results: SearchResult[] = []; + for (const row of result.rows) { + const vectorResult = await this.db.query( + `SELECT vector::text FROM "${this.table}" WHERE id = $1`, + [row.id] + ); + const vectorStr = vectorResult.rows[0]?.vector || "[]"; + const vectorArray = JSON.parse(vectorStr); + + results.push({ + id: row.id, + vector: this.deserializeVector(JSON.stringify(vectorArray)), + metadata: typeof row.metadata === "string" ? JSON.parse(row.metadata) : row.metadata, + score: parseFloat(row.score), + }); + } + + this.emit("search", query, results); + return results; + } else { + return this.hybridSearchFallback(query, options); + } + } + + async get(id: string): Promise | undefined> { + if (this.useNativeVector) { + const result = await this.db.query( + `SELECT id, vector::text as vector, metadata FROM "${this.table}" WHERE id = $1`, + [id] + ); + + if (result.rows.length > 0) { + const row = result.rows[0]; + const vectorArray = JSON.parse(row.vector); + return { + id: row.id, + vector: this.deserializeVector(JSON.stringify(vectorArray)), + metadata: typeof row.metadata === "string" ? JSON.parse(row.metadata) : row.metadata, + }; + } + return undefined; + } else { + const row = await this.tabularRepo.get({ id }); + if (row) { + return { + id: row.id, + vector: this.deserializeVector(row.vector), + metadata: JSON.parse(row.metadata) as Metadata, + }; + } + return undefined; + } + } + + async delete(id: string): Promise { + if (this.useNativeVector) { + await this.db.query(`DELETE FROM "${this.table}" WHERE id = $1`, [id]); + } else { + await this.tabularRepo.delete({ id }); + } + this.emit("delete", id); + } + + async deleteBulk(ids: string[]): Promise { + if (ids.length === 0) return; + + if (this.useNativeVector) { + await this.db.query(`DELETE FROM "${this.table}" WHERE id = ANY($1)`, [ids]); + } else { + for (const id of ids) { + await this.tabularRepo.delete({ id }); + } + } + + for (const id of ids) { + this.emit("delete", id); + } + } + + async deleteByFilter(filter: Partial): Promise { + if (Object.keys(filter).length === 0) return; + + if (this.useNativeVector) { + const conditions: string[] = []; + const params: any[] = []; + let paramIndex = 1; + + for (const [key, value] of Object.entries(filter)) { + conditions.push(`metadata->>'${key}' = $${paramIndex}`); + params.push(String(value)); + paramIndex++; + } + + await this.db.query(`DELETE FROM "${this.table}" WHERE ${conditions.join(" AND ")}`, params); + } else { + const allRows = (await this.tabularRepo.getAll()) || []; + const idsToDelete: string[] = []; + + for (const row of allRows) { + const metadata = JSON.parse(row.metadata) as Metadata; + if (this.matchesFilter(metadata, filter)) { + idsToDelete.push(row.id); + } + } + + await this.deleteBulk(idsToDelete); + } + } + + async size(): Promise { + if (this.useNativeVector) { + const result = await this.db.query(`SELECT COUNT(*) as count FROM "${this.table}"`); + return parseInt(result.rows[0].count); + } else { + return await this.tabularRepo.size(); + } + } + + async clear(): Promise { + if (this.useNativeVector) { + await this.db.query(`DELETE FROM "${this.table}"`); + } else { + await this.tabularRepo.deleteAll(); + } + } + + destroy(): void { + if (!this.useNativeVector) { + this.tabularRepo.destroy(); + } + this.removeAllListeners(); + } + + /** + * Fallback search using in-memory cosine similarity + */ + private async searchFallback( + query: VectorChoice, + options: VectorSearchOptions + ): Promise[]> { + const { topK = 10, filter, scoreThreshold = 0 } = options; + const allRows = (await this.tabularRepo.getAll()) || []; + const results: SearchResult[] = []; + + for (const row of allRows) { + const vector = this.deserializeVector(row.vector); + const metadata = JSON.parse(row.metadata) as Metadata; + + if (filter && !this.matchesFilter(metadata, filter)) { + continue; + } + + const score = cosineSimilarity(query, vector); + + if (score >= scoreThreshold) { + results.push({ id: row.id, vector, metadata, score }); + } + } + + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + this.emit("search", query, topResults); + return topResults; + } + + /** + * Fallback hybrid search + */ + private async hybridSearchFallback( + query: VectorChoice, + options: HybridSearchOptions + ): Promise[]> { + const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; + + const allRows = (await this.tabularRepo.getAll()) || []; + const results: SearchResult[] = []; + const queryLower = textQuery.toLowerCase(); + const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); + + for (const row of allRows) { + const vector = this.deserializeVector(row.vector); + const metadata = JSON.parse(row.metadata) as Metadata; + + if (filter && !this.matchesFilter(metadata, filter)) { + continue; + } + + const vectorScore = cosineSimilarity(query, vector); + const metadataText = row.metadata.toLowerCase(); + let textScore = 0; + if (queryWords.length > 0) { + let matches = 0; + for (const word of queryWords) { + if (metadataText.includes(word)) { + matches++; + } + } + textScore = matches / queryWords.length; + } + + const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; + + if (combinedScore >= scoreThreshold) { + results.push({ id: row.id, vector, metadata, score: combinedScore }); + } + } + + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + this.emit("search", query, topResults); + return topResults; + } + + private deserializeVector(vectorJson: string): VectorChoice { + const array = JSON.parse(vectorJson); + const hasFloats = array.some((v: number) => v % 1 !== 0); + const hasNegatives = array.some((v: number) => v < 0); + + if (hasFloats) { + return new Float32Array(array) as VectorChoice; + } else if (hasNegatives) { + const min = Math.min(...array); + const max = Math.max(...array); + if (min >= -128 && max <= 127) { + return new Int8Array(array) as VectorChoice; + } else { + return new Int16Array(array) as VectorChoice; + } + } else { + const max = Math.max(...array); + if (max <= 255) { + return new Uint8Array(array) as VectorChoice; + } else { + return new Uint16Array(array) as VectorChoice; + } + } + } + + private matchesFilter(metadata: Metadata, filter: Partial): boolean { + for (const [key, value] of Object.entries(filter)) { + if (metadata[key as keyof Metadata] !== value) { + return false; + } + } + return true; + } +} diff --git a/packages/storage/src/vector/README.md b/packages/storage/src/vector/README.md new file mode 100644 index 00000000..c69196e8 --- /dev/null +++ b/packages/storage/src/vector/README.md @@ -0,0 +1,448 @@ +# Vector Storage Module + +A flexible vector storage solution with multiple backend implementations for RAG (Retrieval-Augmented Generation) pipelines. Provides a consistent interface for vector CRUD operations with similarity search and hybrid search capabilities. + +## Features + +- **Multiple Storage Backends:** + - 🧠 `InMemoryVectorRepository` - Fast in-memory storage for testing and small datasets + - 📁 `SqliteVectorRepository` - Persistent SQLite storage for local applications + - 🐘 `PostgresVectorRepository` - PostgreSQL with pgvector extension for production + - 🔍 `SeekDbVectorRepository` - SeekDB/OceanBase with native hybrid search + - 📱 `EdgeVecRepository` - Edge/browser deployment with IndexedDB and WebGPU support + +- **Quantized Vector Support:** + - Float32Array (standard 32-bit floating point) + - Float64Array (64-bit high precision) + - Int8Array (8-bit signed - binary quantization) + - Uint8Array (8-bit unsigned - quantization) + - Int16Array (16-bit signed - quantization) + - Uint16Array (16-bit unsigned - quantization) + +- **Advanced Search Capabilities:** + - Vector similarity search (cosine similarity) + - Hybrid search (vector + full-text) + - Metadata filtering + - Top-K retrieval with score thresholds + +- **Production Ready:** + - Type-safe interfaces + - Event emitters for monitoring + - Bulk operations support + - Efficient indexing strategies + +## Installation + +```bash +bun install @workglow/storage +``` + +## Usage + +### In-Memory Repository (Testing/Browser) + +```typescript +import { InMemoryVectorRepository } from "@workglow/storage"; + +// Standard Float32 vectors +const repo = new InMemoryVectorRepository<{ text: string; source: string }>(); +await repo.setupDatabase(); + +// Upsert vectors +await repo.upsert( + "doc1", + new Float32Array([0.1, 0.2, 0.3, ...]), + { text: "Hello world", source: "example.txt" } +); + +// Search for similar vectors +const results = await repo.similaritySearch( + new Float32Array([0.15, 0.25, 0.35, ...]), + { topK: 5, scoreThreshold: 0.7 } +); +``` + +### Quantized Vectors (Reduced Storage) + +```typescript +import { InMemoryVectorRepository } from "@workglow/storage"; + +// Use Int8Array for 4x smaller storage (binary quantization) +const repo = new InMemoryVectorRepository< + { text: string }, + Int8Array +>(); +await repo.setupDatabase(); + +// Store quantized vectors +await repo.upsert( + "doc1", + new Int8Array([127, -128, 64, ...]), + { text: "Quantized embedding" } +); + +// Search with quantized query +const results = await repo.similaritySearch( + new Int8Array([100, -50, 75, ...]), + { topK: 5 } +); +``` + +### SQLite Repository (Local Persistence) + +```typescript +import { SqliteVectorRepository } from "@workglow/storage"; + +const repo = new SqliteVectorRepository<{ text: string }>( + "./vectors.db", // database path + "embeddings" // table name +); +await repo.setupDatabase(); + +// Bulk upsert +await repo.upsertBulk([ + { id: "1", vector: new Float32Array([...]), metadata: { text: "..." } }, + { id: "2", vector: new Float32Array([...]), metadata: { text: "..." } }, +]); +``` + +### PostgreSQL with pgvector + +```typescript +import { Pool } from "pg"; +import { PostgresVectorRepository } from "@workglow/storage"; + +const pool = new Pool({ connectionString: "postgresql://..." }); +const repo = new PostgresVectorRepository<{ text: string; category: string }>( + pool, + "vectors", + 384 // vector dimension +); +await repo.setupDatabase(); + +// Hybrid search (vector + full-text) +const results = await repo.hybridSearch(queryVector, { + textQuery: "machine learning", + topK: 10, + vectorWeight: 0.7, + filter: { category: "ai" }, +}); +``` + +### SeekDB (Hybrid Search Database) + +```typescript +import mysql from "mysql2/promise"; +import { SeekDbVectorRepository } from "@workglow/storage"; + +const pool = mysql.createPool({ host: "...", database: "..." }); +const repo = new SeekDbVectorRepository<{ text: string }>( + pool, + "vectors", + 768 // vector dimension +); +await repo.setupDatabase(); + +// Native hybrid search +const results = await repo.hybridSearch(queryVector, { + textQuery: "neural networks", + topK: 5, + vectorWeight: 0.6, +}); +``` + +### EdgeVec (Browser/Edge Deployment) + +```typescript +import { EdgeVecRepository } from "@workglow/storage"; + +const repo = new EdgeVecRepository<{ text: string }>({ + dbName: "my-vectors", // IndexedDB name + enableWebGPU: true, // Enable GPU acceleration +}); +await repo.setupDatabase(); + +// Works entirely in the browser +await repo.upsert("1", vector, { text: "..." }); +const results = await repo.similaritySearch(queryVector, { topK: 3 }); +``` + +## API Documentation + +### Core Methods + +All repositories implement the `IVectorRepository` interface: + +```typescript +interface IVectorRepository { + // Setup + setupDatabase(): Promise; + + // CRUD Operations + upsert(id: string, vector: Float32Array, metadata: Metadata): Promise; + upsertBulk(items: VectorEntry[]): Promise; + get(id: string): Promise | undefined>; + delete(id: string): Promise; + deleteBulk(ids: string[]): Promise; + deleteByFilter(filter: Partial): Promise; + + // Search + search( + query: Float32Array, + options?: VectorSearchOptions + ): Promise[]>; + hybridSearch?( + query: Float32Array, + options: HybridSearchOptions + ): Promise[]>; + + // Utility + size(): Promise; + clear(): Promise; + destroy(): void; + + // Events + on(event: "upsert" | "delete" | "search", callback: Function): void; +} +``` + +### Search Options + +```typescript +interface VectorSearchOptions { + topK?: number; // Number of results (default: 10) + filter?: Partial; // Filter by metadata + scoreThreshold?: number; // Minimum score (0-1) +} + +interface HybridSearchOptions extends VectorSearchOptions { + textQuery: string; // Full-text query + vectorWeight?: number; // Vector weight 0-1 (default: 0.7) +} +``` + +## Quantization Benefits + +Quantized vectors can significantly reduce storage and improve performance: + +| Vector Type | Bytes/Dim | Storage vs Float32 | Use Case | +| ------------ | --------- | ------------------ | ------------------------------------ | +| Float32Array | 4 | 100% (baseline) | Standard embeddings | +| Float64Array | 8 | 200% | High precision needed | +| Int16Array | 2 | 50% | Good precision/size tradeoff | +| Int8Array | 1 | 25% | Binary quantization, max compression | +| Uint8Array | 1 | 25% | Quantized embeddings [0-255] | + +**Example:** A 768-dimensional embedding: + +- Float32: 3,072 bytes +- Int8: 768 bytes (75% reduction!) + +## Performance Considerations + +### InMemory + +- **Best for:** Testing, small datasets (<10K vectors), browser apps +- **Pros:** Fastest, no dependencies, supports all vector types +- **Cons:** No persistence, memory limited + +### SQLite + +- **Best for:** Local apps, medium datasets (<100K vectors) +- **Pros:** Persistent, single file, no server +- **Cons:** No native vector indexing, slower for large datasets + +### PostgreSQL + pgvector + +- **Best for:** Production, large datasets (>100K vectors) +- **Pros:** HNSW indexing, efficient, scalable +- **Cons:** Requires PostgreSQL server and pgvector extension + +### SeekDB + +- **Best for:** Hybrid search workloads, production +- **Pros:** Native hybrid search, MySQL-compatible +- **Cons:** Requires SeekDB/OceanBase instance + +### EdgeVec + +- **Best for:** Privacy-sensitive apps, offline-first, edge computing +- **Pros:** No server, IndexedDB persistence, WebGPU acceleration +- **Cons:** Limited by browser storage, smaller datasets + +## Integration with RAG Tasks + +The vector repositories integrate seamlessly with RAG tasks: + +```typescript +import { InMemoryVectorRepository } from "@workglow/storage"; +import { Workflow } from "@workglow/task-graph"; + +const repo = new InMemoryVectorRepository(); +await repo.setupDatabase(); + +const workflow = new Workflow() + // Load and chunk document + .fileLoader({ path: "./doc.md" }) + .textChunker({ chunkSize: 512, chunkOverlap: 50 }) + + // Generate embeddings + .textEmbedding({ model: "Xenova/all-MiniLM-L6-v2" }) + + // Store in vector repository + .vectorStoreUpsert({ repository: repo }); + +await workflow.run(); + +// Later: Search +const searchWorkflow = new Workflow() + .textEmbedding({ text: "What is RAG?", model: "..." }) + .vectorStoreSearch({ repository: repo, topK: 5 }) + .contextBuilder({ format: "markdown" }) + .textQuestionAnswer({ question: "What is RAG?" }); + +const result = await searchWorkflow.run(); +``` + +## Hierarchical Document Integration + +For document-level storage and hierarchical context enrichment, use vector repositories alongside document repositories: + +```typescript +import { InMemoryVectorRepository, InMemoryDocumentRepository } from "@workglow/storage"; +import { Workflow } from "@workglow/task-graph"; + +const vectorRepo = new InMemoryVectorRepository(); +const docRepo = new InMemoryDocumentRepository(); +await vectorRepo.setupDatabase(); + +// Ingestion with hierarchical structure +await new Workflow() + .structuralParser({ + text: markdownContent, + title: "Documentation", + format: "markdown", + }) + .hierarchicalChunker({ + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }) + .textEmbedding({ model: "Xenova/all-MiniLM-L6-v2" }) + .chunkToVector() + .vectorStoreUpsert({ repository: vectorRepo }) + .run(); + +// Retrieval with parent context +const result = await new Workflow() + .textEmbedding({ text: query, model: "Xenova/all-MiniLM-L6-v2" }) + .vectorStoreSearch({ repository: vectorRepo, topK: 10 }) + .hierarchyJoin({ + documentRepository: docRepo, + includeParentSummaries: true, + includeEntities: true, + }) + .reranker({ query, topK: 5 }) + .contextBuilder({ format: "markdown" }) + .run(); +``` + +### Vector Metadata for Hierarchical Documents + +When using hierarchical chunking, base vector metadata (stored in vector database) includes: + +```typescript +metadata: { + docId: string, // Document identifier + chunkId: string, // Chunk identifier + leafNodeId: string, // Reference to document tree node + depth: number, // Hierarchy depth + text: string, // Chunk text content + nodePath: string[], // Node IDs from root to leaf + // From enrichment (optional): + summary?: string, // Summary of the chunk content + entities?: Entity[], // Named entities extracted from the chunk +} +``` + +After `HierarchyJoinTask`, enriched metadata includes additional fields: + +```typescript +enrichedMetadata: { + // ... all base metadata fields above ... + parentSummaries?: string[], // Summaries from ancestor nodes (looked up on-demand) + sectionTitles?: string[], // Titles of ancestor section nodes +} +``` + +Note: `parentSummaries` is not stored in the vector database. It is computed on-demand by `HierarchyJoinTask` using `docId` and `leafNodeId` to look up ancestors from the document repository. + +## Document Repository + +The `IDocumentRepository` interface provides storage for hierarchical document structures: + +```typescript +class DocumentRepository { + constructor( + tabularStorage: ITabularRepository, + vectorStorage: IVectorRepository + ); + + upsert(document: Document): Promise; + get(docId: string): Promise; + getNode(docId: string, nodeId: string): Promise; + getAncestors(docId: string, nodeId: string): Promise; + getChunks(docId: string): Promise; + findChunksByNodeId(docId: string, nodeId: string): Promise; + delete(docId: string): Promise; + list(): Promise; + search(query: TypedArray, options?: VectorSearchOptions): Promise; +} +``` + +### Document Repository + +The `DocumentRepository` class provides a unified interface for storing hierarchical documents and searching chunks. It uses composition of storage backends: + +| Component | Purpose | +|-----------|---------| +| `ITabularRepository` | Stores document structure and metadata | +| `IVectorRepository` | Enables similarity search on document chunks | + +**Example Usage:** + +```typescript +import { + DocumentRepository, + InMemoryTabularRepository, + InMemoryVectorRepository, +} from "@workglow/storage"; + +// Define schema for document storage +const DocumentStorageSchema = { + type: "object", + properties: { + docId: { type: "string" }, + data: { type: "string" }, + }, + required: ["docId", "data"], +} as const; + +// Initialize storage backends +const tabularStorage = new InMemoryTabularRepository(DocumentStorageSchema, ["docId"]); +await tabularStorage.setupDatabase(); + +const vectorStorage = new InMemoryVectorRepository(); +await vectorStorage.setupDatabase(); + +// Create document repository +const docRepo = new DocumentRepository(tabularStorage, vectorStorage); + +// Use the repository +await docRepo.upsert(document); +const results = await docRepo.search(queryVector, { topK: 5 }); +``` + +## License + +Apache 2.0 diff --git a/packages/storage/src/vector/SqliteVectorRepository.ts b/packages/storage/src/vector/SqliteVectorRepository.ts new file mode 100644 index 00000000..f321da86 --- /dev/null +++ b/packages/storage/src/vector/SqliteVectorRepository.ts @@ -0,0 +1,285 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Sqlite } from "@workglow/sqlite"; +import { cosineSimilarity, EventEmitter, TypedArray } from "@workglow/util"; +import { SqliteTabularRepository } from "../tabular/SqliteTabularRepository"; +import { + HybridSearchOptions, + IVectorRepository, + SearchResult, + VectorEntry, + VectorEventListeners, + VectorSchema, + VectorSearchOptions, +} from "./IVectorRepository"; + +type VectorRow = { + id: string; + vector: string; + metadata: string; +}; + +/** + * SQLite vector repository implementation using tabular storage underneath. + * Stores vectors as JSON-encoded arrays with metadata. + * + * @template Metadata - Type for metadata associated with vectors + * @template Vector - Type of vector array (Float32Array, Int8Array, etc.) + */ +export class SqliteVectorRepository< + Metadata = Record, + VectorChoice extends TypedArray = Float32Array, +> + extends EventEmitter> + implements IVectorRepository +{ + private tabularRepo: SqliteTabularRepository; + private initialized = false; + + /** + * Creates a new SQLite vector repository + * @param dbOrPath - Either a Database instance or a path to the SQLite database file + * @param table - The name of the table to use for storage (defaults to 'vectors') + */ + constructor(dbOrPath: string | Sqlite.Database, table: string = "vectors") { + super(); + this.tabularRepo = new SqliteTabularRepository( + dbOrPath, + table, + VectorSchema, + ["id"] as const, + [] // No additional indexes needed for now + ); + } + + async setupDatabase(): Promise { + if (this.initialized) { + return; + } + await this.tabularRepo.setupDatabase(); + this.initialized = true; + } + + async upsert(id: string, vector: VectorChoice, metadata: Metadata): Promise { + const row: VectorRow = { + id, + vector: JSON.stringify(Array.from(vector)), + metadata: JSON.stringify(metadata), + }; + + await this.tabularRepo.put(row); + this.emit("upsert", { id, vector, metadata }); + } + + async upsertBulk(items: VectorEntry[]): Promise { + const rows: VectorRow[] = items.map((item) => ({ + id: item.id, + vector: JSON.stringify(Array.from(item.vector)), + metadata: JSON.stringify(item.metadata), + })); + + await this.tabularRepo.putBulk(rows); + + for (const item of items) { + this.emit("upsert", item); + } + } + + async similaritySearch( + query: VectorChoice, + options: VectorSearchOptions = {} + ): Promise[]> { + const { topK = 10, filter, scoreThreshold = 0 } = options; + + // Get all vectors (or filtered subset) + const allRows = (await this.tabularRepo.getAll()) || []; + const results: SearchResult[] = []; + + for (const row of allRows) { + const vector = this.deserializeVector(row.vector); + const metadata = JSON.parse(row.metadata) as Metadata; + + // Apply metadata filter if provided + if (filter && !this.matchesFilter(metadata, filter)) { + continue; + } + + // Calculate similarity + const score = cosineSimilarity(query, vector); + + if (score >= scoreThreshold) { + results.push({ + id: row.id, + vector, + metadata, + score, + }); + } + } + + // Sort by score descending and take top K + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + this.emit("search", query, topResults); + return topResults; + } + + async hybridSearch( + query: VectorChoice, + options: HybridSearchOptions + ): Promise[]> { + const { topK = 10, filter, scoreThreshold = 0, textQuery, vectorWeight = 0.7 } = options; + + if (!textQuery || textQuery.trim().length === 0) { + return this.similaritySearch(query, { topK, filter, scoreThreshold }); + } + + const allRows = (await this.tabularRepo.getAll()) || []; + const results: SearchResult[] = []; + const queryLower = textQuery.toLowerCase(); + const queryWords = queryLower.split(/\s+/).filter((w) => w.length > 0); + + for (const row of allRows) { + const vector = this.deserializeVector(row.vector); + const metadata = JSON.parse(row.metadata) as Metadata; + + if (filter && !this.matchesFilter(metadata, filter)) { + continue; + } + + // Vector similarity + const vectorScore = cosineSimilarity(query, vector); + + // Text relevance + const metadataText = row.metadata.toLowerCase(); + let textScore = 0; + if (queryWords.length > 0) { + let matches = 0; + for (const word of queryWords) { + if (metadataText.includes(word)) { + matches++; + } + } + textScore = matches / queryWords.length; + } + + // Combined score + const combinedScore = vectorWeight * vectorScore + (1 - vectorWeight) * textScore; + + if (combinedScore >= scoreThreshold) { + results.push({ + id: row.id, + vector, + metadata, + score: combinedScore, + }); + } + } + + results.sort((a, b) => b.score - a.score); + const topResults = results.slice(0, topK); + + this.emit("search", query, topResults); + return topResults; + } + + async get(id: string): Promise | undefined> { + const row = await this.tabularRepo.get({ id }); + if (row) { + return { + id: row.id, + vector: this.deserializeVector(row.vector), + metadata: JSON.parse(row.metadata) as Metadata, + }; + } + return undefined; + } + + async delete(id: string): Promise { + await this.tabularRepo.delete({ id }); + this.emit("delete", id); + } + + async deleteBulk(ids: string[]): Promise { + for (const id of ids) { + await this.tabularRepo.delete({ id }); + this.emit("delete", id); + } + } + + async deleteByFilter(filter: Partial): Promise { + if (Object.keys(filter).length === 0) return; + + // Get all and filter in memory (SQLite doesn't have JSON query operators) + const allRows = (await this.tabularRepo.getAll()) || []; + const idsToDelete: string[] = []; + + for (const row of allRows) { + const metadata = JSON.parse(row.metadata) as Metadata; + if (this.matchesFilter(metadata, filter)) { + idsToDelete.push(row.id); + } + } + + await this.deleteBulk(idsToDelete); + } + + async size(): Promise { + return await this.tabularRepo.size(); + } + + async clear(): Promise { + await this.tabularRepo.deleteAll(); + } + + destroy(): void { + this.tabularRepo.destroy(); + this.removeAllListeners(); + } + + /** + * Deserialize vector from JSON string + */ + private deserializeVector(vectorJson: string): VectorChoice { + const array = JSON.parse(vectorJson); + // Try to infer the type from the values + const hasFloats = array.some((v: number) => v % 1 !== 0); + const hasNegatives = array.some((v: number) => v < 0); + + if (hasFloats) { + return new Float32Array(array) as VectorChoice; + } else if (hasNegatives) { + const min = Math.min(...array); + const max = Math.max(...array); + if (min >= -128 && max <= 127) { + return new Int8Array(array) as VectorChoice; + } else { + return new Int16Array(array) as VectorChoice; + } + } else { + const max = Math.max(...array); + if (max <= 255) { + return new Uint8Array(array) as VectorChoice; + } else { + return new Uint16Array(array) as VectorChoice; + } + } + } + + /** + * Check if metadata matches filter + */ + private matchesFilter(metadata: Metadata, filter: Partial): boolean { + for (const [key, value] of Object.entries(filter)) { + if (metadata[key as keyof Metadata] !== value) { + return false; + } + } + return true; + } +} diff --git a/packages/storage/src/vector/VectorRepositoryRegistry.ts b/packages/storage/src/vector/VectorRepositoryRegistry.ts new file mode 100644 index 00000000..cf836a8b --- /dev/null +++ b/packages/storage/src/vector/VectorRepositoryRegistry.ts @@ -0,0 +1,80 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + createServiceToken, + globalServiceRegistry, + registerInputResolver, + ServiceRegistry, +} from "@workglow/util"; +import type { IVectorRepository } from "./IVectorRepository"; + +/** + * Service token for the vector repository registry + * Maps repository IDs to IVectorRepository instances + */ +export const VECTOR_REPOSITORIES = createServiceToken>>( + "vector.repositories" +); + +// Register default factory if not already registered +if (!globalServiceRegistry.has(VECTOR_REPOSITORIES)) { + globalServiceRegistry.register( + VECTOR_REPOSITORIES, + (): Map> => new Map(), + true + ); +} + +/** + * Gets the global vector repository registry + * @returns Map of vector repository ID to instance + */ +export function getGlobalVectorRepositories(): Map> { + return globalServiceRegistry.get(VECTOR_REPOSITORIES); +} + +/** + * Registers a vector repository globally by ID + * @param id The unique identifier for this repository + * @param repository The repository instance to register + */ +export function registerVectorRepository(id: string, repository: IVectorRepository): void { + const repos = getGlobalVectorRepositories(); + repos.set(id, repository); +} + +/** + * Gets a vector repository by ID from the global registry + * @param id The repository identifier + * @returns The repository instance or undefined if not found + */ +export function getVectorRepository(id: string): IVectorRepository | undefined { + return getGlobalVectorRepositories().get(id); +} + +/** + * Resolves a repository ID to an IVectorRepository from the registry. + * Used by the input resolver system. + */ +async function resolveVectorRepositoryFromRegistry( + id: string, + format: string, + registry: ServiceRegistry +): Promise> { + const repos = registry.has(VECTOR_REPOSITORIES) + ? registry.get>>(VECTOR_REPOSITORIES) + : getGlobalVectorRepositories(); + + const repo = repos.get(id); + if (!repo) { + throw new Error(`Vector repository "${id}" not found in registry`); + } + return repo; +} + +// Register the repository resolver for format: "repository:vector" +registerInputResolver("repository:vector", resolveVectorRepositoryFromRegistry); diff --git a/packages/task-graph/README.md b/packages/task-graph/README.md index 203c9ce5..73824243 100644 --- a/packages/task-graph/README.md +++ b/packages/task-graph/README.md @@ -120,7 +120,7 @@ console.log(result); // { result: 60 } // 2.3 Create a helper function export const MultiplyBy2 = (input: { value: number }) => { - return new MultiplyBy2Task(input).run(); + return new MultiplyBy2Task().run(input); }; const first = await MultiplyBy2({ value: 15 }); const second = await MultiplyBy2({ value: first.result }); diff --git a/packages/task-graph/src/common.ts b/packages/task-graph/src/common.ts index 4184e91e..776cc7a5 100644 --- a/packages/task-graph/src/common.ts +++ b/packages/task-graph/src/common.ts @@ -8,6 +8,7 @@ export * from "./task/ArrayTask"; export * from "./task/ConditionalTask"; export * from "./task/GraphAsTask"; export * from "./task/GraphAsTaskRunner"; +export * from "./task/InputResolver"; export * from "./task/InputTask"; export * from "./task/ITask"; export * from "./task/JobQueueFactory"; diff --git a/packages/task-graph/src/task-graph/Dataflow.ts b/packages/task-graph/src/task-graph/Dataflow.ts index bac428a6..c8e58258 100644 --- a/packages/task-graph/src/task-graph/Dataflow.ts +++ b/packages/task-graph/src/task-graph/Dataflow.ts @@ -7,7 +7,7 @@ import { areSemanticallyCompatible, EventEmitter } from "@workglow/util"; import { TaskError } from "../task/TaskError"; import { DataflowJson } from "../task/TaskJSON"; -import { Provenance, TaskIdType, TaskOutput, TaskStatus } from "../task/TaskTypes"; +import { TaskIdType, TaskOutput, TaskStatus } from "../task/TaskTypes"; import { DataflowEventListener, DataflowEventListeners, @@ -48,7 +48,6 @@ export class Dataflow { ); } public value: any = undefined; - public provenance: Provenance = {}; public status: TaskStatus = TaskStatus.PENDING; public error: TaskError | undefined; @@ -56,7 +55,6 @@ export class Dataflow { this.status = TaskStatus.PENDING; this.error = undefined; this.value = undefined; - this.provenance = {}; this.emit("reset"); this.emit("status", this.status); } @@ -87,7 +85,7 @@ export class Dataflow { this.emit("status", this.status); } - setPortData(entireDataBlock: any, nodeProvenance: any) { + setPortData(entireDataBlock: any) { if (this.sourceTaskPortId === DATAFLOW_ALL_PORTS) { this.value = entireDataBlock; } else if (this.sourceTaskPortId === DATAFLOW_ERROR_PORT) { @@ -95,7 +93,6 @@ export class Dataflow { } else { this.value = entireDataBlock[this.sourceTaskPortId]; } - if (nodeProvenance) this.provenance = nodeProvenance; } getPortData(): TaskOutput { diff --git a/packages/task-graph/src/task-graph/ITaskGraph.ts b/packages/task-graph/src/task-graph/ITaskGraph.ts index e2f6dc84..ba115d2a 100644 --- a/packages/task-graph/src/task-graph/ITaskGraph.ts +++ b/packages/task-graph/src/task-graph/ITaskGraph.ts @@ -6,7 +6,7 @@ import { ITask } from "../task/ITask"; import { JsonTaskItem, TaskGraphJson } from "../task/TaskJSON"; -import { TaskIdType, TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes"; +import type { TaskIdType, TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes"; import { Dataflow, DataflowIdType } from "./Dataflow"; import type { TaskGraphRunConfig } from "./TaskGraph"; import type { TaskGraphEventListener, TaskGraphEvents } from "./TaskGraphEvents"; diff --git a/packages/task-graph/src/task-graph/README.md b/packages/task-graph/src/task-graph/README.md index e9673f4e..abdf6629 100644 --- a/packages/task-graph/src/task-graph/README.md +++ b/packages/task-graph/src/task-graph/README.md @@ -25,7 +25,6 @@ A robust TypeScript library for creating and managing task graphs with dependenc - Directed Acyclic Graph (DAG) structure for task dependencies - Data flow management between task inputs/outputs - Workflow builder API with fluent interface -- Provenance tracking - Caching of task results (same run on same input returns cached result) - Error handling and abortion support - Serial and parallel execution patterns @@ -87,7 +86,6 @@ const output = await workflow.run(); - Connects task outputs to inputs - Value propagation -- Provenance tracking ### TaskGraphRunner diff --git a/packages/task-graph/src/task-graph/TaskGraph.ts b/packages/task-graph/src/task-graph/TaskGraph.ts index 0f7e5e83..5a5cbe61 100644 --- a/packages/task-graph/src/task-graph/TaskGraph.ts +++ b/packages/task-graph/src/task-graph/TaskGraph.ts @@ -4,11 +4,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { DirectedAcyclicGraph, EventEmitter, uuid4 } from "@workglow/util"; +import { DirectedAcyclicGraph, EventEmitter, ServiceRegistry, uuid4 } from "@workglow/util"; import { TaskOutputRepository } from "../storage/TaskOutputRepository"; import type { ITask } from "../task/ITask"; import { JsonTaskItem, TaskGraphJson } from "../task/TaskJSON"; -import type { Provenance, TaskIdType, TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes"; +import type { TaskIdType, TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes"; import { ensureTask, type PipeFunction } from "./Conversions"; import { Dataflow, type DataflowIdType } from "./Dataflow"; import type { ITaskGraph } from "./ITaskGraph"; @@ -37,8 +37,8 @@ export interface TaskGraphRunConfig { outputCache?: TaskOutputRepository | boolean; /** Optional signal to abort the task graph */ parentSignal?: AbortSignal; - /** Optional provenance to use for this task graph */ - parentProvenance?: Provenance; + /** Optional service registry to use for this task graph (creates child from global if not provided) */ + registry?: ServiceRegistry; } class TaskGraphDAG extends DirectedAcyclicGraph< @@ -102,7 +102,6 @@ export class TaskGraph implements ITaskGraph { ): Promise> { return this.runner.runGraph(input, { outputCache: config?.outputCache || this.outputCache, - parentProvenance: config?.parentProvenance || {}, parentSignal: config?.parentSignal || undefined, }); } diff --git a/packages/task-graph/src/task-graph/TaskGraphRunner.ts b/packages/task-graph/src/task-graph/TaskGraphRunner.ts index fc2af972..bf944ae6 100644 --- a/packages/task-graph/src/task-graph/TaskGraphRunner.ts +++ b/packages/task-graph/src/task-graph/TaskGraphRunner.ts @@ -8,13 +8,14 @@ import { collectPropertyValues, ConvertAllToOptionalArray, globalServiceRegistry, + ServiceRegistry, uuid4, } from "@workglow/util"; import { TASK_OUTPUT_REPOSITORY, TaskOutputRepository } from "../storage/TaskOutputRepository"; import { ConditionalTask } from "../task/ConditionalTask"; import { ITask } from "../task/ITask"; import { TaskAbortedError, TaskConfigurationError, TaskError } from "../task/TaskError"; -import { Provenance, TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes"; +import { TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes"; import { DATAFLOW_ALL_PORTS } from "./Dataflow"; import { TaskGraph, TaskGraphRunConfig } from "./TaskGraph"; import { DependencyBasedScheduler, TopologicalScheduler } from "./TaskGraphScheduler"; @@ -50,7 +51,7 @@ export type GraphResult< /** * Class for running a task graph - * Manages the execution of tasks in a task graph, including provenance tracking and caching + * Manages the execution of tasks in a task graph, including caching */ export class TaskGraphRunner { /** @@ -59,11 +60,6 @@ export class TaskGraphRunner { protected running = false; protected reactiveRunning = false; - /** - * Map of provenance input for each task - */ - protected provenanceInput: Map; - /** * The task graph to run */ @@ -73,6 +69,10 @@ export class TaskGraphRunner { * Output cache repository */ protected outputCache?: TaskOutputRepository; + /** + * Service registry for this graph run + */ + protected registry: ServiceRegistry = globalServiceRegistry; /** * AbortController for cancelling graph execution */ @@ -99,7 +99,6 @@ export class TaskGraphRunner { protected reactiveScheduler = new TopologicalScheduler(graph) ) { this.graph = graph; - this.provenanceInput = new Map(); graph.outputCache = outputCache; this.handleProgress = this.handleProgress.bind(this); } @@ -136,10 +135,9 @@ export class TaskGraphRunner { // Only filter input for non-root tasks; root tasks get the full input const taskInput = isRootTask ? input : this.filterInputForTask(task, input); - const taskPromise = this.runTaskWithProvenance( + const taskPromise = this.runTask( task, - taskInput, - config?.parentProvenance || {} + taskInput ); this.inProgressTasks!.set(task.config.id, taskPromise); const taskResult = await taskPromise; @@ -247,6 +245,7 @@ export class TaskGraphRunner { await this.handleDisable(); } + /** * Filters graph-level input to only include properties that are not connected via dataflows for a given task * @param task The task to filter input for @@ -332,40 +331,26 @@ export class TaskGraphRunner { } } - /** - * Retrieves the provenance input for a task - * @param node The task to retrieve provenance input for - * @returns The provenance input for the task - */ - protected getInputProvenance(node: ITask): TaskInput { - const nodeProvenance: Provenance = {}; - this.graph.getSourceDataflows(node.config.id).forEach((dataflow) => { - Object.assign(nodeProvenance, dataflow.provenance); - }); - return nodeProvenance; - } /** * Pushes the output of a task to its target tasks * @param node The task that produced the output * @param results The output of the task - * @param nodeProvenance The provenance input for the task */ protected async pushOutputFromNodeToEdges( node: ITask, - results: TaskOutput, - nodeProvenance?: Provenance + results: TaskOutput ) { const dataflows = this.graph.getTargetDataflows(node.config.id); for (const dataflow of dataflows) { const compatibility = dataflow.semanticallyCompatible(this.graph, dataflow); // console.log("pushOutputFromNodeToEdges", dataflow.id, compatibility, Object.keys(results)); if (compatibility === "static") { - dataflow.setPortData(results, nodeProvenance); + dataflow.setPortData(results); } else if (compatibility === "runtime") { const task = this.graph.getTask(dataflow.targetTaskId)!; - const narrowed = await task.narrowInput({ ...results }); - dataflow.setPortData(narrowed, nodeProvenance); + const narrowed = await task.narrowInput({ ...results }, this.registry); + dataflow.setPortData(narrowed); } else { // don't push incompatible data } @@ -494,33 +479,25 @@ export class TaskGraphRunner { } /** - * Runs a task with provenance input + * Runs a task * @param task The task to run - * @param parentProvenance The provenance input for the task + * @param input The input for the task * @returns The output of the task */ - protected async runTaskWithProvenance( + protected async runTask( task: ITask, - input: TaskInput, - parentProvenance: Provenance + input: TaskInput ): Promise> { - // Update provenance for the current task - const nodeProvenance = { - ...parentProvenance, - ...this.getInputProvenance(task), - ...task.getProvenance(), - }; - this.provenanceInput.set(task.config.id, nodeProvenance); this.copyInputFromEdgesToNode(task); const results = await task.runner.run(input, { - nodeProvenance, outputCache: this.outputCache, updateProgress: async (task: ITask, progress: number, message?: string, ...args: any[]) => await this.handleProgress(task, progress, message, ...args), + registry: this.registry, }); - await this.pushOutputFromNodeToEdges(task, results, nodeProvenance); + await this.pushOutputFromNodeToEdges(task, results); return { id: task.config.id, @@ -572,10 +549,18 @@ export class TaskGraphRunner { * @param parentSignal Optional abort signal from parent */ protected async handleStart(config?: TaskGraphRunConfig): Promise { + // Setup registry - create child from global if not provided + if (config?.registry !== undefined) { + this.registry = config.registry; + } else { + // Create a child container that inherits from global but allows overrides + this.registry = new ServiceRegistry(globalServiceRegistry.container.createChildContainer()); + } + if (config?.outputCache !== undefined) { if (typeof config.outputCache === "boolean") { if (config.outputCache === true) { - this.outputCache = globalServiceRegistry.get(TASK_OUTPUT_REPOSITORY); + this.outputCache = this.registry.get(TASK_OUTPUT_REPOSITORY); } else { this.outputCache = undefined; } @@ -706,7 +691,7 @@ export class TaskGraphRunner { progress = Math.round(completed / total); } this.pushStatusFromNodeToEdges(this.graph, task); - await this.pushOutputFromNodeToEdges(task, task.runOutputData, task.getProvenance()); + await this.pushOutputFromNodeToEdges(task, task.runOutputData); this.graph.emit("graph_progress", progress, message, args); } } diff --git a/packages/task-graph/src/task-graph/Workflow.ts b/packages/task-graph/src/task-graph/Workflow.ts index 440e6d52..b75cb024 100644 --- a/packages/task-graph/src/task-graph/Workflow.ts +++ b/packages/task-graph/src/task-graph/Workflow.ts @@ -23,10 +23,10 @@ import { } from "./TaskGraphRunner"; // Type definitions for the workflow -export type CreateWorkflow = ( +export type CreateWorkflow = ( input?: Partial, config?: Partial -) => Workflow; +) => Workflow; // Event types export type WorkflowEventListeners = { @@ -57,9 +57,10 @@ let taskIdCounter = 0; * Class for building and managing a task graph * Provides methods for adding tasks, connecting outputs to inputs, and running the task graph */ -export class Workflow - implements IWorkflow -{ +export class Workflow< + Input extends DataPorts = DataPorts, + Output extends DataPorts = DataPorts, +> implements IWorkflow { /** * Creates a new Workflow * @@ -99,10 +100,10 @@ export class Workflow(taskClass: ITaskConstructor): CreateWorkflow { const helper = function ( - this: Workflow, + this: Workflow, input: Partial = {}, config: Partial = {} - ): Workflow { + ) { this._error = ""; const parent = getLastTask(this); @@ -150,7 +151,19 @@ export class Workflow boolean ): Map => { - // If either schema is true (accepts everything), skip auto-matching + if (typeof sourceSchema === "object") { + if ( + targetSchema === true || + (typeof targetSchema === "object" && targetSchema.additionalProperties === true) + ) { + for (const fromOutputPortId of Object.keys(sourceSchema.properties || {})) { + matches.set(fromOutputPortId, fromOutputPortId); + this.connect(parent.config.id, fromOutputPortId, task.config.id, fromOutputPortId); + } + return matches; + } + } + // If either schema is true or false, skip auto-matching // as we cannot determine the appropriate connections if (typeof sourceSchema === "boolean" || typeof targetSchema === "boolean") { return matches; @@ -221,7 +234,10 @@ export class Workflow; } /** @@ -296,7 +312,6 @@ export class Workflow(input, { parentSignal: this._abortController.signal, - parentProvenance: {}, outputCache: this._repository, }); const results = this.graph.mergeExecuteOutputsToRunOutput( @@ -320,6 +335,7 @@ export class Workflow(input, { - parentProvenance: this.nodeProvenance || {}, parentSignal: this.abortController?.signal, outputCache: this.outputCache, }); @@ -53,35 +52,6 @@ export class GraphAsTaskRunner< super.handleDisable(); } - // ======================================================================== - // Utility methods - // ======================================================================== - - private fixInput(input: Input): Input { - // inputs has turned each property into an array, so we need to flatten the input - // but only for properties marked with x-replicate in the schema - const inputSchema = this.task.inputSchema(); - if (typeof inputSchema === "boolean") { - return input; - } - - const flattenedInput = Object.entries(input).reduce((acc, [key, value]) => { - const inputDef = inputSchema.properties?.[key]; - const shouldFlatten = - Array.isArray(value) && - typeof inputDef === "object" && - inputDef !== null && - "x-replicate" in inputDef && - (inputDef as any)["x-replicate"] === true; - - if (shouldFlatten) { - return { ...acc, [key]: value[0] }; - } - return { ...acc, [key]: value }; - }, {}); - return flattenedInput as Input; - } - // ======================================================================== // TaskRunner method overrides and helpers // ======================================================================== @@ -97,7 +67,7 @@ export class GraphAsTaskRunner< this.task.compoundMerge ); } else { - const result = await super.executeTask(this.fixInput(input)); + const result = await super.executeTask(input); this.task.runOutputData = result ?? ({} as Output); } return this.task.runOutputData as Output; @@ -114,7 +84,7 @@ export class GraphAsTaskRunner< this.task.compoundMerge ); } else { - const reactiveResults = await super.executeTaskReactive(this.fixInput(input), output); + const reactiveResults = await super.executeTaskReactive(input, output); this.task.runOutputData = Object.assign({}, output, reactiveResults ?? {}) as Output; } return this.task.runOutputData as Output; diff --git a/packages/task-graph/src/task/ITask.ts b/packages/task-graph/src/task/ITask.ts index 2a528fd6..8f3b703c 100644 --- a/packages/task-graph/src/task/ITask.ts +++ b/packages/task-graph/src/task/ITask.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { DataPortSchema, EventEmitter } from "@workglow/util"; +import type { DataPortSchema, EventEmitter, ServiceRegistry } from "@workglow/util"; import { TaskOutputRepository } from "../storage/TaskOutputRepository"; import { ITaskGraph } from "../task-graph/ITaskGraph"; import { IWorkflow } from "../task-graph/IWorkflow"; @@ -19,16 +19,21 @@ import type { } from "./TaskEvents"; import type { JsonTaskItem, TaskGraphItemJson } from "./TaskJSON"; import { TaskRunner } from "./TaskRunner"; -import type { Provenance, TaskConfig, TaskInput, TaskOutput, TaskStatus } from "./TaskTypes"; +import type { + TaskConfig, + TaskInput, + TaskOutput, + TaskStatus, +} from "./TaskTypes"; /** * Context for task execution */ export interface IExecuteContext { signal: AbortSignal; - nodeProvenance: Provenance; updateProgress: (progress: number, message?: string, ...args: any[]) => Promise; own: (i: T) => T; + registry: ServiceRegistry; } export type IExecuteReactiveContext = Pick; @@ -37,7 +42,6 @@ export type IExecuteReactiveContext = Pick; * Configuration for running a task */ export interface IRunConfig { - nodeProvenance?: Provenance; outputCache?: TaskOutputRepository | boolean; updateProgress?: ( task: ITask, @@ -45,6 +49,7 @@ export interface IRunConfig { message?: string, ...args: any[] ) => Promise; + registry?: ServiceRegistry; } /** @@ -115,7 +120,7 @@ export interface ITaskIO { addInput(overrides: Record | undefined): boolean; validateInput(input: Record): Promise; get cacheable(): boolean; - narrowInput(input: Record): Promise>; + narrowInput(input: Record, registry: ServiceRegistry): Promise>; } export interface ITaskInternalGraph { @@ -142,7 +147,6 @@ export interface ITaskEvents { * Interface for task serialization */ export interface ITaskSerialization { - getProvenance(): Provenance; toJSON(): JsonTaskItem | TaskGraphItemJson; toDependencyJSON(): JsonTaskItem; id(): unknown; @@ -168,7 +172,9 @@ export interface ITask< Input extends TaskInput = TaskInput, Output extends TaskOutput = TaskOutput, Config extends TaskConfig = TaskConfig, -> extends ITaskState, +> + extends + ITaskState, ITaskIO, ITaskEvents, ITaskLifecycle, diff --git a/packages/task-graph/src/task/InputResolver.ts b/packages/task-graph/src/task/InputResolver.ts new file mode 100644 index 00000000..28db3938 --- /dev/null +++ b/packages/task-graph/src/task/InputResolver.ts @@ -0,0 +1,113 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { DataPortSchema, ServiceRegistry } from "@workglow/util"; +import { getInputResolvers } from "@workglow/util"; + +/** + * Configuration for the input resolver + */ +export interface InputResolverConfig { + readonly registry: ServiceRegistry; +} + +/** + * Extracts the format string from a schema, handling oneOf/anyOf wrappers. + */ +function getSchemaFormat(schema: unknown): string | undefined { + if (typeof schema !== "object" || schema === null) return undefined; + + const s = schema as Record; + + // Direct format + if (typeof s.format === "string") return s.format; + + // Check oneOf/anyOf for format + const variants = (s.oneOf ?? s.anyOf) as unknown[] | undefined; + if (Array.isArray(variants)) { + for (const variant of variants) { + if (typeof variant === "object" && variant !== null) { + const v = variant as Record; + if (typeof v.format === "string") return v.format; + } + } + } + + return undefined; +} + +/** + * Gets the format prefix from a format string. + * For "model:TextEmbedding" returns "model" + * For "repository:tabular" returns "repository" + */ +function getFormatPrefix(format: string): string { + const colonIndex = format.indexOf(":"); + return colonIndex >= 0 ? format.substring(0, colonIndex) : format; +} + +/** + * Resolves schema-annotated inputs by looking up string IDs from registries. + * String values with matching format annotations are resolved to their instances. + * Non-string values (objects/instances) are passed through unchanged. + * + * @param input The task input object + * @param schema The task's input schema + * @param config Configuration including the service registry + * @returns The input with resolved values + * + * @example + * ```typescript + * // In TaskRunner.run() + * const resolvedInput = await resolveSchemaInputs( + * this.task.runInputData, + * (this.task.constructor as typeof Task).inputSchema(), + * { registry: this.registry } + * ); + * ``` + */ +export async function resolveSchemaInputs>( + input: T, + schema: DataPortSchema, + config: InputResolverConfig +): Promise { + if (typeof schema === "boolean") return input; + + const properties = schema.properties; + if (!properties || typeof properties !== "object") return input; + + const resolvers = getInputResolvers(); + const resolved: Record = { ...input }; + + for (const [key, propSchema] of Object.entries(properties)) { + const value = resolved[key]; + + const format = getSchemaFormat(propSchema); + if (!format) continue; + + // Try full format first (e.g., "repository:vector"), then fall back to prefix (e.g., "repository") + let resolver = resolvers.get(format); + if (!resolver) { + const prefix = getFormatPrefix(format); + resolver = resolvers.get(prefix); + } + + if (!resolver) continue; + + // Handle string values + if (typeof value === "string") { + resolved[key] = await resolver(value, format, config.registry); + } + // Handle arrays of strings - pass the entire array to the resolver + // (resolvers like resolveModelFromRegistry handle arrays even though typed as string) + else if (Array.isArray(value) && value.every((item) => typeof item === "string")) { + resolved[key] = await resolver(value as unknown as string, format, config.registry); + } + // Skip if not a string or array of strings (already resolved or direct instance) + } + + return resolved as T; +} diff --git a/packages/task-graph/src/task/JobQueueTask.ts b/packages/task-graph/src/task/JobQueueTask.ts index 65c3b337..c648fa12 100644 --- a/packages/task-graph/src/task/JobQueueTask.ts +++ b/packages/task-graph/src/task/JobQueueTask.ts @@ -5,7 +5,7 @@ */ import { Job, JobConstructorParam } from "@workglow/job-queue"; -import { ArrayTask } from "./ArrayTask"; +import { GraphAsTask } from "./GraphAsTask"; import { IExecuteContext } from "./ITask"; import { getJobQueueFactory } from "./JobQueueFactory"; import { JobTaskFailedError, TaskConfigurationError } from "./TaskError"; @@ -47,7 +47,7 @@ export abstract class JobQueueTask< Input extends TaskInput = TaskInput, Output extends TaskOutput = TaskOutput, Config extends JobQueueTaskConfig = JobQueueTaskConfig, -> extends ArrayTask { +> extends GraphAsTask { static readonly type: string = "JobQueueTask"; static canRunDirectly = true; diff --git a/packages/task-graph/src/task/README.md b/packages/task-graph/src/task/README.md index d04b4267..a767d9dd 100644 --- a/packages/task-graph/src/task/README.md +++ b/packages/task-graph/src/task/README.md @@ -13,6 +13,7 @@ This module provides a flexible task processing system with support for various - [Event Handling](#event-handling) - [Input/Output Schemas](#inputoutput-schemas) - [Registry \& Queues](#registry--queues) +- [Input Resolution](#input-resolution) - [Error Handling](#error-handling) - [Testing](#testing) - [Installation](#installation) @@ -30,6 +31,9 @@ This module provides a flexible task processing system with support for various ### A Simple Task ```typescript +import { Task, type DataPortSchema } from "@workglow/task-graph"; +import { Type } from "@sinclair/typebox"; + interface MyTaskInput { input: number; } @@ -178,6 +182,15 @@ static outputSchema = () => { }), }) satisfies DataPortSchema; }; + +type MyInput = FromSchema; +type MyOutput = FromSchema; + +class MyTask extends Task { + static readonly type = "MyTask"; + static inputSchema = () => MyInputSchema; + static outputSchema = () => MyOutputSchema; +} ``` ### Using Zod @@ -201,13 +214,16 @@ const outputSchemaZod = z.object({ type MyInput = z.infer; type MyOutput = z.infer; -static inputSchema = () => { - return inputSchemaZod.toJSONSchema() as DataPortSchema; -}; +class MyTask extends Task { + static readonly type = "MyTask"; + static inputSchema = () => { + return inputSchemaZod.toJSONSchema() as DataPortSchema; + }; -static outputSchema = () => { - return outputSchemaZod.toJSONSchema() as DataPortSchema; -}; + static outputSchema = () => { + return outputSchemaZod.toJSONSchema() as DataPortSchema; + }; +} ``` ## Registry & Queues @@ -226,6 +242,75 @@ const queue = getTaskQueueRegistry().getQueue("processing"); queue.add(new MyJobTask()); ``` +## Input Resolution + +The TaskRunner automatically resolves schema-annotated string inputs to their corresponding instances before task execution. This allows tasks to accept either string identifiers (like `"my-model"` or `"my-repository"`) or direct object instances, providing flexibility in how tasks are configured. + +### How It Works + +When a task's input schema includes properties with `format` annotations (such as `"model"`, `"model:TaskName"`, or `"repository:tabular"`), the TaskRunner inspects each input property: + +- **String values** are looked up in the appropriate registry and resolved to instances +- **Object values** (already instances) pass through unchanged + +This resolution happens automatically before `validateInput()` is called, so by the time `execute()` runs, all annotated inputs are guaranteed to be resolved objects. + +### Example: Task with Repository Input + +```typescript +import { Task } from "@workglow/task-graph"; +import { TypeTabularRepository } from "@workglow/storage"; + +class DataProcessingTask extends Task<{ repository: ITabularRepository; query: string }> { + static readonly type = "DataProcessingTask"; + + static inputSchema() { + return { + type: "object", + properties: { + repository: TypeTabularRepository({ + title: "Data Source", + description: "Repository to query", + }), + query: { type: "string", title: "Query" }, + }, + required: ["repository", "query"], + }; + } + + async execute(input: DataProcessingTaskInput, context: IExecuteContext) { + // repository is guaranteed to be an ITabularRepository instance + const data = await input.repository.getAll(); + return { results: data }; + } +} + +// Usage with string ID (resolved automatically) +const task = new DataProcessingTask(); +await task.run({ repository: "my-registered-repo", query: "test" }); + +// Usage with direct instance (passed through) +await task.run({ repository: myRepositoryInstance, query: "test" }); +``` + +### Registering Custom Resolvers + +Extend the input resolution system by registering custom resolvers for new format prefixes: + +```typescript +import { registerInputResolver } from "@workglow/util"; + +// Register a resolver for "config:*" formats +registerInputResolver("config", async (id, format, registry) => { + const configRepo = registry.get(CONFIG_REPOSITORY); + const config = await configRepo.findById(id); + if (!config) { + throw new Error(`Configuration "${id}" not found`); + } + return config; +}); +``` + ## Error Handling ```typescript diff --git a/packages/task-graph/src/task/Task.ts b/packages/task-graph/src/task/Task.ts index 2bc7dd3d..5e407617 100644 --- a/packages/task-graph/src/task/Task.ts +++ b/packages/task-graph/src/task/Task.ts @@ -11,6 +11,7 @@ import { SchemaNode, uuid4, type DataPortSchema, + type ServiceRegistry, } from "@workglow/util"; import { DATAFLOW_ALL_PORTS } from "../task-graph/Dataflow"; import { TaskGraph } from "../task-graph/TaskGraph"; @@ -26,7 +27,6 @@ import type { JsonTaskItem, TaskGraphItemJson } from "./TaskJSON"; import { TaskRunner } from "./TaskRunner"; import { TaskStatus, - type Provenance, type TaskConfig, type TaskIdType, type TaskInput, @@ -307,11 +307,6 @@ export class Task< } protected _events: EventEmitter | undefined; - /** - * Provenance information for the task - */ - protected nodeProvenance: Provenance = {}; - /** * Creates a new task instance * @@ -380,11 +375,86 @@ export class Task< * Resets input data to defaults */ public resetInputData(): void { - // Use deep clone to avoid state leakage + this.runInputData = this.smartClone(this.defaults) as Record; + } + + /** + * Smart clone that deep-clones plain objects and arrays while preserving + * class instances (objects with non-Object prototype) by reference. + * Detects and throws an error on circular references. + * + * This is necessary because: + * - structuredClone cannot clone class instances (methods are lost) + * - JSON.parse/stringify loses methods and fails on circular references + * - Class instances like repositories should be passed by reference + * + * This breaks the idea of everything being json serializable, but it allows + * more efficient use cases. Do be careful with this though! Use sparingly. + * + * @param obj The object to clone + * @param visited Set of objects in the current cloning path (for circular reference detection) + * @returns A cloned object with class instances preserved by reference + */ + private smartClone(obj: any, visited: WeakSet = new WeakSet()): any { + if (obj === null || obj === undefined) { + return obj; + } + + // Primitives (string, number, boolean, symbol, bigint) are returned as-is + if (typeof obj !== "object") { + return obj; + } + + // Check for circular references + if (visited.has(obj)) { + throw new Error( + "Circular reference detected in input data. " + + "Cannot clone objects with circular references." + ); + } + + // Clone TypedArrays (Float32Array, Int8Array, etc.) to avoid shared-mutation + // between defaults and runInputData, while preserving DataView by reference. + if (ArrayBuffer.isView(obj)) { + // Preserve DataView instances by reference (constructor signature differs) + if (typeof DataView !== "undefined" && obj instanceof DataView) { + return obj; + } + // For TypedArrays, create a new instance with the same data + const typedArray = obj as any; + return new (typedArray.constructor as any)(typedArray); + } + + // Preserve class instances (objects with non-Object/non-Array prototype) + // This includes repository instances, custom classes, etc. + if (!Array.isArray(obj)) { + const proto = Object.getPrototypeOf(obj); + if (proto !== Object.prototype && proto !== null) { + return obj; // Pass by reference + } + } + + // Add object to visited set before recursing + visited.add(obj); + try { - this.runInputData = structuredClone(this.defaults) as Record; - } catch (err) { - this.runInputData = JSON.parse(JSON.stringify(this.defaults)) as Record; + // Deep clone arrays, preserving class instances within + if (Array.isArray(obj)) { + return obj.map((item) => this.smartClone(item, visited)); + } + + // Deep clone plain objects + const result: Record = {}; + for (const key in obj) { + if (Object.prototype.hasOwnProperty.call(obj, key)) { + result[key] = this.smartClone(obj[key], visited); + } + } + return result; + } finally { + // Remove from visited set after processing to allow the same object + // in different branches (non-circular references) + visited.delete(obj); } } @@ -429,7 +499,7 @@ export class Task< // If additionalProperties is true, also copy any additional input properties if (schema.additionalProperties === true) { for (const [inputId, value] of Object.entries(input)) { - if (value !== undefined && !(inputId in properties)) { + if (!(inputId in properties)) { this.runInputData[inputId] = value; } } @@ -506,7 +576,7 @@ export class Task< // If additionalProperties is true, also accept any additional input properties if (inputSchema.additionalProperties === true) { for (const [inputId, value] of Object.entries(overrides)) { - if (value !== undefined && !(inputId in properties)) { + if (!(inputId in properties)) { if (!deepEqual(this.runInputData[inputId], value)) { this.runInputData[inputId] = value; changed = true; @@ -521,9 +591,13 @@ export class Task< /** * Stub for narrowing input. Override in subclasses for custom logic. * @param input The input to narrow + * @param _registry Optional service registry for lookups * @returns The (possibly narrowed) input */ - public async narrowInput(input: Record): Promise> { + public async narrowInput( + input: Record, + _registry: ServiceRegistry + ): Promise> { return input; } @@ -664,13 +738,6 @@ export class Task< return this.config.id; } - /** - * Gets provenance information for the task - */ - public getProvenance(): Provenance { - return this.config.provenance ?? {}; - } - // ======================================================================== // Serialization methods // ======================================================================== @@ -684,6 +751,10 @@ export class Task< if (obj === null || obj === undefined) { return obj; } + // Preserve TypedArrays (Float32Array, Int8Array, etc.) + if (ArrayBuffer.isView(obj)) { + return obj; + } if (Array.isArray(obj)) { return obj.map((item) => this.stripSymbols(item)); } @@ -704,14 +775,12 @@ export class Task< * @returns The serialized task and subtasks */ public toJSON(): TaskGraphItemJson { - const provenance = this.getProvenance(); const extras = this.config.extras; let json: TaskGraphItemJson = this.stripSymbols({ id: this.config.id, type: this.type, ...(this.config.name ? { name: this.config.name } : {}), defaults: this.defaults, - ...(Object.keys(provenance).length ? { provenance } : {}), ...(extras && Object.keys(extras).length ? { extras } : {}), }); return json as TaskGraphItemJson; diff --git a/packages/task-graph/src/task/TaskEvents.ts b/packages/task-graph/src/task/TaskEvents.ts index 6eee2280..383389ec 100644 --- a/packages/task-graph/src/task/TaskEvents.ts +++ b/packages/task-graph/src/task/TaskEvents.ts @@ -5,8 +5,8 @@ */ import { EventParameters, type DataPortSchema } from "@workglow/util"; -import { TaskStatus } from "../common"; import { TaskAbortedError, TaskError } from "./TaskError"; +import { TaskStatus } from "./TaskTypes"; // ======================================================================== // Event Handling Types diff --git a/packages/task-graph/src/task/TaskJSON.test.ts b/packages/task-graph/src/task/TaskJSON.test.ts index a7e06874..4e5be6fe 100644 --- a/packages/task-graph/src/task/TaskJSON.test.ts +++ b/packages/task-graph/src/task/TaskJSON.test.ts @@ -130,7 +130,6 @@ describe("TaskJSON", () => { expect(json.type).toBe("TestTask"); expect(json.name).toBe("My Task"); expect(json.defaults).toEqual({ value: 42 }); - expect(json.provenance).toBeUndefined(); expect(json.extras).toBeUndefined(); }); @@ -141,27 +140,18 @@ describe("TaskJSON", () => { expect(json.defaults).toEqual({ value: 10, multiplier: 5 }); }); - test("should serialize task with provenance and extras", () => { + test("should serialize task with extras", () => { const task = new TestTask( { value: 100 }, { id: "task3", - provenance: { source: "test", version: "1.0" }, extras: { metadata: { key: "value" } }, } ); const json = task.toJSON(); - expect(json.provenance).toEqual({ source: "test", version: "1.0" }); expect(json.extras).toEqual({ metadata: { key: "value" } }); }); - - test("should not include empty provenance in JSON", () => { - const task = new TestTask({ value: 50 }, { id: "task4", provenance: {} }); - const json = task.toJSON(); - - expect(json.provenance).toBeUndefined(); - }); }); describe("TaskGraph.toJSON()", () => { @@ -234,18 +224,16 @@ describe("TaskJSON", () => { expect(task.defaults).toEqual({ value: 10, multiplier: 5 }); }); - test("should create a task with provenance and extras", () => { + test("should create a task with extras", () => { const json: TaskGraphItemJson = { id: "task3", type: "TestTask", defaults: { value: 100 }, - provenance: { source: "test", version: "1.0" }, extras: { metadata: { key: "value" } }, }; const task = createTaskFromGraphJSON(json); - expect(task.config.provenance).toEqual({ source: "test", version: "1.0" }); expect(task.config.extras).toEqual({ metadata: { key: "value" } }); }); @@ -386,14 +374,13 @@ describe("TaskJSON", () => { expect(restoredDataflows[0].targetTaskId).toBe(originalDataflows[0].targetTaskId); }); - test("should round-trip a task graph with defaults, provenance, and extras", () => { + test("should round-trip a task graph with defaults and extras", () => { const originalGraph = new TaskGraph(); const task1 = new TestTaskWithDefaults( { value: 10, multiplier: 3 }, { id: "task1", name: "Task with Defaults", - provenance: { source: "test", version: "1.0" }, extras: { metadata: { key: "value" } }, } ); @@ -404,7 +391,6 @@ describe("TaskJSON", () => { const restoredTask = restoredGraph.getTasks()[0]; expect(restoredTask.defaults).toEqual({ value: 10, multiplier: 3 }); - expect(restoredTask.config.provenance).toEqual({ source: "test", version: "1.0" }); expect(restoredTask.config.extras).toEqual({ metadata: { key: "value" } }); }); diff --git a/packages/task-graph/src/task/TaskJSON.ts b/packages/task-graph/src/task/TaskJSON.ts index 529e8b14..abd1ecb9 100644 --- a/packages/task-graph/src/task/TaskJSON.ts +++ b/packages/task-graph/src/task/TaskJSON.ts @@ -10,7 +10,7 @@ import { CompoundMergeStrategy } from "../task-graph/TaskGraphRunner"; import { TaskConfigurationError, TaskJSONError } from "../task/TaskError"; import { TaskRegistry } from "../task/TaskRegistry"; import { GraphAsTask } from "./GraphAsTask"; -import { DataPorts, Provenance, TaskConfig, TaskInput } from "./TaskTypes"; +import { DataPorts, TaskConfig, TaskInput } from "./TaskTypes"; // ======================================================================== // JSON Serialization Types @@ -53,9 +53,6 @@ export type JsonTaskItem = { /** Optional user data to use for this task, not used by the task framework except it will be exported as part of the task JSON*/ extras?: DataPorts; - /** Optional metadata about task origin */ - provenance?: Provenance; - /** Nested tasks for compound operations */ subtasks?: JsonTaskItem[]; }; /** @@ -67,7 +64,6 @@ export type TaskGraphItemJson = { type: string; name?: string; defaults?: TaskInput; - provenance?: Provenance; extras?: DataPorts; subgraph?: TaskGraphJson; merge?: CompoundMergeStrategy; @@ -88,10 +84,8 @@ export type DataflowJson = { const createSingleTaskFromJSON = (item: JsonTaskItem | TaskGraphItemJson) => { if (!item.id) throw new TaskJSONError("Task id required"); if (!item.type) throw new TaskJSONError("Task type required"); - if (item.defaults && (Array.isArray(item.defaults) || Array.isArray(item.provenance))) + if (item.defaults && Array.isArray(item.defaults)) throw new TaskJSONError("Task defaults must be an object"); - if (item.provenance && (Array.isArray(item.provenance) || typeof item.provenance !== "object")) - throw new TaskJSONError("Task provenance must be an object"); const taskClass = TaskRegistry.all.get(item.type); if (!taskClass) @@ -100,7 +94,6 @@ const createSingleTaskFromJSON = (item: JsonTaskItem | TaskGraphItemJson) => { const taskConfig: TaskConfig = { id: item.id, name: item.name, - provenance: item.provenance ?? {}, extras: item.extras, }; const task = new taskClass(item.defaults ?? {}, taskConfig); diff --git a/packages/task-graph/src/task/TaskRunner.ts b/packages/task-graph/src/task/TaskRunner.ts index b5a8bad8..bc41aa01 100644 --- a/packages/task-graph/src/task/TaskRunner.ts +++ b/packages/task-graph/src/task/TaskRunner.ts @@ -4,13 +4,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { globalServiceRegistry } from "@workglow/util"; +import { globalServiceRegistry, ServiceRegistry } from "@workglow/util"; import { TASK_OUTPUT_REPOSITORY, TaskOutputRepository } from "../storage/TaskOutputRepository"; import { ensureTask, type Taskish } from "../task-graph/Conversions"; +import { resolveSchemaInputs } from "./InputResolver"; import { IRunConfig, ITask } from "./ITask"; import { ITaskRunner } from "./ITaskRunner"; +import { Task } from "./Task"; import { TaskAbortedError, TaskError, TaskFailedError, TaskInvalidInputError } from "./TaskError"; -import { Provenance, TaskConfig, TaskInput, TaskOutput, TaskStatus } from "./TaskTypes"; +import { TaskConfig, TaskInput, TaskOutput, TaskStatus } from "./TaskTypes"; /** * Responsible for running tasks @@ -27,10 +29,6 @@ export class TaskRunner< protected running = false; protected reactiveRunning = false; - /** - * Provenance information for the task - */ - protected nodeProvenance: Provenance = {}; /** * The task to run @@ -47,6 +45,11 @@ export class TaskRunner< */ protected outputCache?: TaskOutputRepository; + /** + * The service registry for the task + */ + protected registry: ServiceRegistry = globalServiceRegistry; + /** * Constructor for TaskRunner * @param task The task to run @@ -72,6 +75,15 @@ export class TaskRunner< try { this.task.setInput(overrides); + + // Resolve schema-annotated inputs (models, repositories) before validation + const schema = (this.task.constructor as typeof Task).inputSchema(); + this.task.runInputData = (await resolveSchemaInputs( + this.task.runInputData as Record, + schema, + { registry: this.registry } + )) as Input; + const isValid = await this.task.validateInput(this.task.runInputData); if (!isValid) { throw new TaskInvalidInputError("Invalid input data"); @@ -118,6 +130,14 @@ export class TaskRunner< } this.task.setInput(overrides); + // Resolve schema-annotated inputs (models, repositories) before validation + const schema = (this.task.constructor as typeof Task).inputSchema(); + this.task.runInputData = (await resolveSchemaInputs( + this.task.runInputData as Record, + schema, + { registry: this.registry } + )) as Input; + await this.handleStartReactive(); try { @@ -165,11 +185,11 @@ export class TaskRunner< * Protected method to execute a task by delegating back to the task itself. */ protected async executeTask(input: Input): Promise { - const result = await this.task.execute(input, { + const result = await this.task.execute(input, { signal: this.abortController!.signal, updateProgress: this.handleProgress.bind(this), - nodeProvenance: this.nodeProvenance, own: this.own, + registry: this.registry, }); return await this.executeTaskReactive(input, result || ({} as Output)); } @@ -192,7 +212,6 @@ export class TaskRunner< protected async handleStart(config: IRunConfig = {}): Promise { if (this.task.status === TaskStatus.PROCESSING) return; - this.nodeProvenance = {}; this.running = true; this.task.startedAt = new Date(); @@ -204,8 +223,6 @@ export class TaskRunner< this.handleAbort(); }); - this.nodeProvenance = config.nodeProvenance ?? {}; - const cache = this.task.config.outputCache ?? config.outputCache; if (cache === true) { let instance = globalServiceRegistry.get(TASK_OUTPUT_REPOSITORY); @@ -220,6 +237,10 @@ export class TaskRunner< this.updateProgress = config.updateProgress; } + if (config.registry) { + this.registry = config.registry; + } + this.task.emit("start"); this.task.emit("status", this.task.status); } @@ -260,7 +281,6 @@ export class TaskRunner< this.task.progress = 100; this.task.status = TaskStatus.COMPLETED; this.abortController = undefined; - this.nodeProvenance = {}; this.task.emit("complete"); this.task.emit("status", this.task.status); @@ -276,7 +296,6 @@ export class TaskRunner< this.task.progress = 100; this.task.completedAt = new Date(); this.abortController = undefined; - this.nodeProvenance = {}; this.task.emit("disabled"); this.task.emit("status", this.task.status); } @@ -303,7 +322,6 @@ export class TaskRunner< this.task.error = err instanceof TaskError ? err : new TaskFailedError(err?.message || "Task failed"); this.abortController = undefined; - this.nodeProvenance = {}; this.task.emit("error", this.task.error); this.task.emit("status", this.task.status); } diff --git a/packages/task-graph/src/task/TaskTypes.ts b/packages/task-graph/src/task/TaskTypes.ts index 3d844459..6ac1a678 100644 --- a/packages/task-graph/src/task/TaskTypes.ts +++ b/packages/task-graph/src/task/TaskTypes.ts @@ -61,8 +61,6 @@ export type CompoundTaskOutput = [key: string]: unknown | unknown[] | undefined; }; -/** Type for task provenance metadata */ -export type Provenance = DataPorts; /** Type for task type names */ export type TaskTypeName = string; @@ -81,8 +79,6 @@ export interface IConfig { /** Optional display name for the task */ name?: string; - /** Optional metadata about task origin */ - provenance?: Provenance; /** Optional ID of the runner to use for this task */ runnerId?: string; diff --git a/packages/tasks/src/task/DebugLogTask.ts b/packages/tasks/src/task/DebugLogTask.ts index 803c7e9e..7eba96fa 100644 --- a/packages/tasks/src/task/DebugLogTask.ts +++ b/packages/tasks/src/task/DebugLogTask.ts @@ -89,8 +89,8 @@ export class DebugLogTask< TaskRegistry.registerTask(DebugLogTask); export const debugLog = (input: DebugLogTaskInput, config: TaskConfig = {}) => { - const task = new DebugLogTask(input, config); - return task.run(); + const task = new DebugLogTask({} as DebugLogTaskInput, config); + return task.run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/tasks/src/task/DelayTask.ts b/packages/tasks/src/task/DelayTask.ts index b1ffedb6..213b5259 100644 --- a/packages/tasks/src/task/DelayTask.ts +++ b/packages/tasks/src/task/DelayTask.ts @@ -88,8 +88,8 @@ TaskRegistry.registerTask(DelayTask); * @param {delay} - The delay in milliseconds */ export const delay = (input: DelayTaskInput, config: TaskConfig = {}) => { - const task = new DelayTask(input, config); - return task.run(); + const task = new DelayTask({} as DelayTaskInput, config); + return task.run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/tasks/src/task/FetchUrlTask.ts b/packages/tasks/src/task/FetchUrlTask.ts index 64208b13..5d7b48be 100644 --- a/packages/tasks/src/task/FetchUrlTask.ts +++ b/packages/tasks/src/task/FetchUrlTask.ts @@ -421,7 +421,7 @@ export const fetchUrl = async ( input: FetchUrlTaskInput, config: FetchUrlTaskConfig = {} ): Promise => { - const result = await new FetchUrlTask(input, config).run(); + const result = await new FetchUrlTask({} as FetchUrlTaskInput, config).run(input); return result as FetchUrlTaskOutput; }; diff --git a/packages/tasks/src/task/FileLoaderTask.server.ts b/packages/tasks/src/task/FileLoaderTask.server.ts index 761b813e..f0eb7a67 100644 --- a/packages/tasks/src/task/FileLoaderTask.server.ts +++ b/packages/tasks/src/task/FileLoaderTask.server.ts @@ -216,7 +216,7 @@ export class FileLoaderTask extends BaseFileLoaderTask { TaskRegistry.registerTask(FileLoaderTask); export const fileLoader = (input: FileLoaderTaskInput, config?: JobQueueTaskConfig) => { - return new FileLoaderTask(input, config).run(); + return new FileLoaderTask({} as FileLoaderTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/tasks/src/task/FileLoaderTask.ts b/packages/tasks/src/task/FileLoaderTask.ts index 9ecf0030..663a2088 100644 --- a/packages/tasks/src/task/FileLoaderTask.ts +++ b/packages/tasks/src/task/FileLoaderTask.ts @@ -408,7 +408,7 @@ export class FileLoaderTask extends Task< TaskRegistry.registerTask(FileLoaderTask); export const fileLoader = (input: FileLoaderTaskInput, config?: JobQueueTaskConfig) => { - return new FileLoaderTask(input, config).run(); + return new FileLoaderTask({} as FileLoaderTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/tasks/src/task/JavaScriptTask.ts b/packages/tasks/src/task/JavaScriptTask.ts index 12320213..0cd87903 100644 --- a/packages/tasks/src/task/JavaScriptTask.ts +++ b/packages/tasks/src/task/JavaScriptTask.ts @@ -74,7 +74,7 @@ export class JavaScriptTask extends Task { - return new JavaScriptTask(input, config).run(); + return new JavaScriptTask({} as JavaScriptTaskInput, config).run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/tasks/src/task/JsonTask.ts b/packages/tasks/src/task/JsonTask.ts index b62c7cf2..92e53d85 100644 --- a/packages/tasks/src/task/JsonTask.ts +++ b/packages/tasks/src/task/JsonTask.ts @@ -103,7 +103,7 @@ TaskRegistry.registerTask(JsonTask); * Convenience function to create and run a JsonTask */ export const json = (input: JsonTaskInput, config: TaskConfig = {}) => { - return new JsonTask(input, config).run(); + return new JsonTask({} as JsonTaskInput, config).run(input); }; // Add Json task workflow to Workflow interface diff --git a/packages/tasks/src/task/MergeTask.ts b/packages/tasks/src/task/MergeTask.ts index 60d72a1b..6cd3c307 100644 --- a/packages/tasks/src/task/MergeTask.ts +++ b/packages/tasks/src/task/MergeTask.ts @@ -85,8 +85,8 @@ export class MergeTask< TaskRegistry.registerTask(MergeTask); export const merge = (input: MergeTaskInput, config: TaskConfig = {}) => { - const task = new MergeTask(input, config); - return task.run(); + const task = new MergeTask({} as MergeTaskInput, config); + return task.run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/tasks/src/task/SplitTask.ts b/packages/tasks/src/task/SplitTask.ts index ada6290b..3f15504a 100644 --- a/packages/tasks/src/task/SplitTask.ts +++ b/packages/tasks/src/task/SplitTask.ts @@ -88,8 +88,8 @@ export class SplitTask< TaskRegistry.registerTask(SplitTask); export const split = (input: SplitTaskInput, config: TaskConfig = {}) => { - const task = new SplitTask(input, config); - return task.run(); + const task = new SplitTask({} as SplitTaskInput, config); + return task.run(input); }; declare module "@workglow/task-graph" { diff --git a/packages/test/src/samples/ONNXModelSamples.ts b/packages/test/src/samples/ONNXModelSamples.ts index 29489782..6aecb2e5 100644 --- a/packages/test/src/samples/ONNXModelSamples.ts +++ b/packages/test/src/samples/ONNXModelSamples.ts @@ -68,6 +68,18 @@ export async function registerHuggingfaceLocalModels(): Promise { }, metadata: {}, }, + { + model_id: "onnx:onnx-community/NeuroBERT-NER-ONNX:q8", + title: "NeuroBERT NER", + description: "onnx-community/NeuroBERT-NER-ONNX", + tasks: ["TextNamedEntityRecognitionTask"], + provider: HF_TRANSFORMERS_ONNX, + provider_config: { + pipeline: "token-classification", + model_path: "onnx-community/NeuroBERT-NER-ONNX", + }, + metadata: {}, + }, { model_id: "onnx:Xenova/distilbert-base-uncased-distilled-squad:q8", title: "distilbert-base-uncased-distilled-squad", diff --git a/packages/test/src/test/hierarchical/ChunkToVector.test.ts b/packages/test/src/test/hierarchical/ChunkToVector.test.ts new file mode 100644 index 00000000..cabaa9da --- /dev/null +++ b/packages/test/src/test/hierarchical/ChunkToVector.test.ts @@ -0,0 +1,133 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + ChunkNode, + ChunkToVectorTaskOutput, + HierarchicalChunkerTaskOutput, + NodeIdGenerator, + StructuralParser, +} from "@workglow/ai"; +import { Workflow } from "@workglow/task-graph"; +import { describe, expect, it } from "vitest"; + +describe("ChunkToVectorTask", () => { + it("should transform chunks and vectors to vector store format", async () => { + const markdown = "# Test\n\nContent."; + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Test"); + + // Generate chunks using workflow + const chunkResult = (await new Workflow() + .hierarchicalChunker({ + docId, + documentTree: root, + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }) + .run()) as HierarchicalChunkerTaskOutput; + + // Mock vectors (would normally come from TextEmbeddingTask) + const mockVectors = chunkResult.chunks.map(() => + new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5]) + ); + + // Transform to vector store format using workflow + const result = (await new Workflow() + .chunkToVector({ + chunks: chunkResult.chunks as ChunkNode[], + vectors: mockVectors, + }) + .run()) as ChunkToVectorTaskOutput; + + // Verify output format + expect(result.ids).toBeDefined(); + expect(result.vectors).toBeDefined(); + expect(result.metadata).toBeDefined(); + expect(result.texts).toBeDefined(); + + expect(result.ids.length).toBe(chunkResult.count); + expect(result.vectors.length).toBe(chunkResult.count); + expect(result.metadata.length).toBe(chunkResult.count); + expect(result.texts.length).toBe(chunkResult.count); + + // Check metadata structure + for (let i = 0; i < result.metadata.length; i++) { + const meta = result.metadata[i]; + expect(meta.docId).toBe(docId); + expect(meta.chunkId).toBeDefined(); + expect(meta.leafNodeId).toBeDefined(); + expect(meta.depth).toBeDefined(); + expect(meta.text).toBeDefined(); + expect(meta.nodePath).toBeDefined(); + } + + // Verify IDs match chunks + for (let i = 0; i < result.ids.length; i++) { + expect(result.ids[i]).toBe(chunkResult.chunks[i].chunkId); + } + }); + + it("should throw error on length mismatch", async () => { + const chunks = [ + { + chunkId: "chunk_1", + docId: "doc_1", + text: "Test", + nodePath: ["node_1"], + depth: 1, + }, + { + chunkId: "chunk_2", + docId: "doc_1", + text: "Test 2", + nodePath: ["node_1"], + depth: 1, + }, + ]; + + const vectors = [new Float32Array([1, 2, 3])]; // Only 1 vector for 2 chunks + + // Using workflow + await expect( + new Workflow() + .chunkToVector({ chunks, vectors }) + .run() + ).rejects.toThrow("Mismatch"); + }); + + it("should include enrichment in metadata if present", async () => { + const chunks = [ + { + chunkId: "chunk_1", + docId: "doc_1", + text: "Test", + nodePath: ["node_1"], + depth: 1, + enrichment: { + summary: "Test summary", + entities: [{ text: "Entity", type: "TEST", score: 0.9 }], + }, + }, + ]; + + const vectors = [new Float32Array([1, 2, 3])]; + + const result = (await new Workflow() + .chunkToVector({ chunks, vectors }) + .run()) as ChunkToVectorTaskOutput; + + const metadata = result.metadata as Array<{ + summary?: string; + entities?: Array<{ text: string; type: string; score: number }>; + [key: string]: unknown; + }>; + expect(metadata[0].summary).toBe("Test summary"); + expect(metadata[0].entities).toBeDefined(); + expect(metadata[0].entities!.length).toBe(1); + }); +}); diff --git a/packages/test/src/test/hierarchical/Document.test.ts b/packages/test/src/test/hierarchical/Document.test.ts new file mode 100644 index 00000000..454cc148 --- /dev/null +++ b/packages/test/src/test/hierarchical/Document.test.ts @@ -0,0 +1,52 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { ChunkNode, DocumentNode } from "@workglow/ai"; +import { Document, NodeKind } from "@workglow/ai"; +import { describe, expect, test } from "vitest"; + +describe("Document", () => { + const createTestDocumentNode = (): DocumentNode => ({ + nodeId: "root", + kind: NodeKind.DOCUMENT, + range: { startOffset: 0, endOffset: 100 }, + text: "Test document stuff", + title: "Test document", + children: [], + }); + + const createTestChunks = (): ChunkNode[] => [ + { + chunkId: "chunk1", + docId: "doc1", + text: "Test chunk", + nodePath: ["root"], + depth: 1, + }, + ]; + + test("setChunks and getChunks", () => { + const doc = new Document("doc1", createTestDocumentNode(), { title: "Test" }); + + doc.setChunks(createTestChunks()); + + const chunks = doc.getChunks(); + expect(chunks).toBeDefined(); + expect(chunks.length).toBe(1); + expect(chunks[0].text).toBe("Test chunk"); + }); + + test("findChunksByNodeId", () => { + const doc = new Document("doc1", createTestDocumentNode(), { title: "Test" }); + + doc.setChunks(createTestChunks()); + + const chunks = doc.findChunksByNodeId("root"); + expect(chunks).toBeDefined(); + expect(chunks.length).toBe(1); + expect(chunks[0].text).toBe("Test chunk"); + }); +}); diff --git a/packages/test/src/test/hierarchical/DocumentRepository.test.ts b/packages/test/src/test/hierarchical/DocumentRepository.test.ts new file mode 100644 index 00000000..507f4377 --- /dev/null +++ b/packages/test/src/test/hierarchical/DocumentRepository.test.ts @@ -0,0 +1,209 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + Document, + DocumentRepository, + DocumentStorageSchema, + NodeIdGenerator, + NodeKind, + StructuralParser, +} from "@workglow/ai"; +import { InMemoryTabularRepository, InMemoryVectorRepository } from "@workglow/storage"; +import { beforeEach, describe, expect, it } from "vitest"; + +describe("DocumentRepository", () => { + let repo: DocumentRepository; + + beforeEach(async () => { + const tabularStorage = new InMemoryTabularRepository(DocumentStorageSchema, ["docId"]); + await tabularStorage.setupDatabase(); + + const vectorStorage = new InMemoryVectorRepository(); + await vectorStorage.setupDatabase(); + + repo = new DocumentRepository(tabularStorage, vectorStorage); + }); + + it("should store and retrieve documents", async () => { + const markdown = "# Test\n\nContent."; + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Test"); + + const doc = new Document(docId, root, { title: "Test Document" }); + + await repo.upsert(doc); + const retrieved = await repo.get(docId); + + expect(retrieved).toBeDefined(); + expect(retrieved?.docId).toBe(docId); + expect(retrieved?.metadata.title).toBe("Test Document"); + }); + + it("should retrieve nodes by ID", async () => { + const markdown = "# Section\n\nParagraph."; + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Test"); + + const doc = new Document(docId, root, { title: "Test" }); + await repo.upsert(doc); + + // Get a child node + const firstChild = root.children[0]; + const retrieved = await repo.getNode(docId, firstChild.nodeId); + + expect(retrieved).toBeDefined(); + expect(retrieved?.nodeId).toBe(firstChild.nodeId); + }); + + it("should get ancestors of a node", async () => { + const markdown = `# Section 1 + +## Subsection 1.1 + +Paragraph.`; + + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Test"); + + const doc = new Document(docId, root, { title: "Test" }); + await repo.upsert(doc); + + // Find a deeply nested node + const section = root.children.find((c) => c.kind === NodeKind.SECTION); + expect(section).toBeDefined(); + + const subsection = (section as any).children.find((c: any) => c.kind === NodeKind.SECTION); + expect(subsection).toBeDefined(); + + const ancestors = await repo.getAncestors(docId, subsection.nodeId); + + // Should include root, section, and subsection + expect(ancestors.length).toBeGreaterThanOrEqual(3); + expect(ancestors[0].nodeId).toBe(root.nodeId); + expect(ancestors[1].nodeId).toBe(section!.nodeId); + expect(ancestors[2].nodeId).toBe(subsection.nodeId); + }); + + it("should handle chunks", async () => { + const markdown = "# Test\n\nContent."; + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Test"); + + const doc = new Document(docId, root, { title: "Test" }); + + // Add chunks + const chunks = [ + { + chunkId: "chunk_1", + docId, + text: "Test chunk", + nodePath: [root.nodeId], + depth: 1, + }, + ]; + + doc.setChunks(chunks); + + await repo.upsert(doc); + + // Retrieve chunks + const retrievedChunks = await repo.getChunks(docId); + expect(retrievedChunks).toBeDefined(); + expect(retrievedChunks.length).toBe(1); + }); + + it("should list all documents", async () => { + const markdown1 = "# Doc 1"; + const markdown2 = "# Doc 2"; + + const id1 = await NodeIdGenerator.generateDocId("test1", markdown1); + const id2 = await NodeIdGenerator.generateDocId("test2", markdown2); + + const root1 = await StructuralParser.parseMarkdown(id1, markdown1, "Doc 1"); + const root2 = await StructuralParser.parseMarkdown(id2, markdown2, "Doc 2"); + + const doc1 = new Document(id1, root1, { title: "Doc 1" }); + const doc2 = new Document(id2, root2, { title: "Doc 2" }); + + await repo.upsert(doc1); + await repo.upsert(doc2); + + const list = await repo.list(); + expect(list.length).toBe(2); + expect(list).toContain(id1); + expect(list).toContain(id2); + }); + + it("should delete documents", async () => { + const markdown = "# Test"; + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Test"); + + const doc = new Document(docId, root, { title: "Test" }); + await repo.upsert(doc); + + expect(await repo.get(docId)).toBeDefined(); + + await repo.delete(docId); + + expect(await repo.get(docId)).toBeUndefined(); + }); +}); + +describe("Document", () => { + it("should manage chunks", async () => { + const markdown = "# Test"; + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Test"); + + const doc = new Document(docId, root, { title: "Test" }); + + const chunks = [ + { + chunkId: "chunk_1", + docId, + text: "Chunk 1", + nodePath: [root.nodeId], + depth: 1, + }, + ]; + doc.setChunks(chunks); + + const retrievedChunks = doc.getChunks(); + expect(retrievedChunks.length).toBe(1); + expect(retrievedChunks[0].text).toBe("Chunk 1"); + }); + + it("should serialize and deserialize", async () => { + const markdown = "# Test"; + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Test"); + + const doc = new Document(docId, root, { title: "Test" }); + + const chunks = [ + { + chunkId: "chunk_1", + docId, + text: "Chunk", + nodePath: [root.nodeId], + depth: 1, + }, + ]; + doc.setChunks(chunks); + + // Serialize + const json = doc.toJSON(); + + // Deserialize + const restored = Document.fromJSON(JSON.stringify(json)); + + expect(restored.docId).toBe(doc.docId); + expect(restored.metadata.title).toBe(doc.metadata.title); + expect(restored.getChunks().length).toBe(1); + }); +}); diff --git a/packages/test/src/test/hierarchical/EndToEnd.test.ts b/packages/test/src/test/hierarchical/EndToEnd.test.ts new file mode 100644 index 00000000..39872801 --- /dev/null +++ b/packages/test/src/test/hierarchical/EndToEnd.test.ts @@ -0,0 +1,138 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + Document, + DocumentRepository, + DocumentStorageSchema, + hierarchicalChunker, + NodeIdGenerator, + StructuralParser, +} from "@workglow/ai"; +import { InMemoryTabularRepository, InMemoryVectorRepository } from "@workglow/storage"; +import { describe, expect, it } from "vitest"; + +describe("End-to-end hierarchical RAG", () => { + it("should demonstrate chainable design (chunks → text array)", async () => { + // Sample markdown document + const markdown = `# Machine Learning + +Machine learning is AI. + +## Supervised Learning + +Uses labeled data. + +## Unsupervised Learning + +Finds patterns in data.`; + + // Parse into hierarchical tree + const docId = await NodeIdGenerator.generateDocId("ml-guide", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "ML Guide"); + + // CHAINABLE DESIGN TEST - Use workflow to verify chaining + const chunkResult = await hierarchicalChunker({ + docId, + documentTree: root, + maxTokens: 256, + overlap: 25, + strategy: "hierarchical", + }); + + // Verify outputs are ready for next task in chain + expect(chunkResult.chunks).toBeDefined(); + expect(chunkResult.text).toBeDefined(); + expect(chunkResult.count).toBe(chunkResult.text.length); + expect(chunkResult.count).toBe(chunkResult.chunks.length); + + // The text array can be directly consumed by TextEmbeddingTask + expect(Array.isArray(chunkResult.text)).toBe(true); + expect(chunkResult.text.every((t) => typeof t === "string")).toBe(true); + }); + + it("should manage document chunks", async () => { + const markdown = "# Test Document\n\nThis is test content."; + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Test"); + + const doc = new Document(docId, root, { title: "Test" }); + + const chunks = [ + { + chunkId: "chunk_1", + docId, + text: "Test chunk 1", + nodePath: [root.nodeId], + depth: 1, + }, + ]; + + doc.setChunks(chunks); + + // Verify chunks are stored + const retrievedChunks = doc.getChunks(); + expect(retrievedChunks.length).toBe(1); + expect(retrievedChunks[0].text).toBe("Test chunk 1"); + }); + + it("should demonstrate document repository integration", async () => { + const tabularStorage = new InMemoryTabularRepository(DocumentStorageSchema, ["docId"]); + await tabularStorage.setupDatabase(); + + const vectorStorage = new InMemoryVectorRepository(); + await vectorStorage.setupDatabase(); + + const docRepo = new DocumentRepository(tabularStorage, vectorStorage); + + // Create document with enriched hierarchy + const markdown = `# Guide + +## Section 1 + +Content about topic A. + +## Section 2 + +Content about topic B.`; + + const docId = await NodeIdGenerator.generateDocId("guide", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Guide"); + + const doc = new Document(docId, root, { title: "Guide" }); + + // Enrich (in real workflow this would use DocumentEnricherTask) + // For test, manually add enrichment + const enrichedRoot = { + ...root, + enrichment: { + summary: "A guide covering two sections", + }, + }; + + const enrichedDoc = new Document(docId, enrichedRoot as any, doc.metadata); + await docRepo.upsert(enrichedDoc); + + // Generate chunks using workflow (without embedding to avoid model requirement) + const chunkResult = await hierarchicalChunker({ + docId, + documentTree: enrichedRoot, + maxTokens: 256, + overlap: 25, + strategy: "hierarchical", + }); + expect(chunkResult.count).toBeGreaterThan(0); + + // Add chunks to document + enrichedDoc.setChunks(chunkResult.chunks); + await docRepo.upsert(enrichedDoc); + + // Verify chunks were stored + const retrieved = await docRepo.getChunks(docId); + expect(retrieved).toBeDefined(); + expect(retrieved.length).toBe(chunkResult.count); + }); +}); diff --git a/packages/test/src/test/hierarchical/FullChain.test.ts b/packages/test/src/test/hierarchical/FullChain.test.ts new file mode 100644 index 00000000..3441ab3e --- /dev/null +++ b/packages/test/src/test/hierarchical/FullChain.test.ts @@ -0,0 +1,146 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type ChunkNode, HierarchicalChunkerTaskOutput, NodeIdGenerator } from "@workglow/ai"; +import { InMemoryVectorRepository } from "@workglow/storage"; +import { Workflow } from "@workglow/task-graph"; +import { describe, expect, it } from "vitest"; + +describe("Complete chainable workflow", () => { + it("should chain from parsing to storage without loops", async () => { + const vectorRepo = new InMemoryVectorRepository<{ + docId: string; + chunkId: string; + leafNodeId: string; + depth: number; + text: string; + }>(); + await vectorRepo.setupDatabase(); + + const markdown = `# Test Document + +## Section 1 + +This is the first section with some content. + +## Section 2 + +This is the second section with more content.`; + + // Parse → Enrich → Chunk + const result = await new Workflow() + .structuralParser({ + text: markdown, + title: "Test Doc", + format: "markdown", + sourceUri: "test.md", + }) + .documentEnricher({ + generateSummaries: true, + extractEntities: true, + }) + .hierarchicalChunker({ + maxTokens: 256, + overlap: 25, + strategy: "hierarchical", + }) + .run(); + + // Verify the chain worked - final output from hierarchicalChunker + expect(result.docId).toBeDefined(); + expect(result.docId).toMatch(/^doc_[0-9a-f]{16}$/); + expect(result.chunks).toBeDefined(); + expect(result.text).toBeDefined(); + expect(result.count).toBeGreaterThan(0); + + // Verify output structure matches expectations + expect(result.chunks.length).toBe(result.count); + expect(result.text.length).toBe(result.count); + }); + + it("should demonstrate data flow through chain", async () => { + const markdown = "# Title\n\nParagraph content."; + + const result = await new Workflow() + .structuralParser({ + text: markdown, + title: "Test", + format: "markdown", + }) + .hierarchicalChunker({ + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }) + .run(); + + // Verify data flows correctly (final output from hierarchicalChunker) + expect(result.docId).toBeDefined(); + expect(result.chunks).toBeDefined(); + expect(result.text).toBeDefined(); + + // docId should flow through the chain to all chunks + // PropertyArrayGraphResult makes chunks potentially an array of arrays + const chunks = ( + Array.isArray(result.chunks) && result.chunks.length > 0 + ? Array.isArray(result.chunks[0]) + ? result.chunks.flat() + : result.chunks + : [] + ) as ChunkNode[]; + for (const chunk of chunks) { + expect(chunk.docId).toBe(result.docId); + } + }); + + it("should generate consistent docId across chains", async () => { + const markdown = "# Test\n\nContent."; + + // Run twice with same content + const result1 = await new Workflow() + .structuralParser({ + text: markdown, + title: "Test", + sourceUri: "test.md", + }) + .run(); + + const result2 = await new Workflow() + .structuralParser({ + text: markdown, + title: "Test", + sourceUri: "test.md", + }) + .run(); + + // Should generate same docId (deterministic) + expect(result1.docId).toBe(result2.docId); + }); + + it("should allow docId override for variant creation", async () => { + const markdown = "# Test\n\nContent."; + const customId = await NodeIdGenerator.generateDocId("custom", markdown); + + const result = (await new Workflow() + .structuralParser({ + text: markdown, + title: "Test", + docId: customId, // Override with custom ID + }) + .hierarchicalChunker({ + maxTokens: 512, + }) + .run()) as HierarchicalChunkerTaskOutput; + + // Should use the provided ID + expect(result.docId).toBe(customId); + + // All chunks should reference it + for (const chunk of result.chunks) { + expect(chunk.docId).toBe(customId); + } + }); +}); diff --git a/packages/test/src/test/hierarchical/HierarchicalChunker.test.ts b/packages/test/src/test/hierarchical/HierarchicalChunker.test.ts new file mode 100644 index 00000000..d14aae47 --- /dev/null +++ b/packages/test/src/test/hierarchical/HierarchicalChunker.test.ts @@ -0,0 +1,197 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + estimateTokens, + hierarchicalChunker, + HierarchicalChunkerTask, + NodeIdGenerator, + StructuralParser, +} from "@workglow/ai"; +import { Workflow } from "@workglow/task-graph"; +import { describe, expect, it } from "vitest"; + +describe("HierarchicalChunkerTask", () => { + it("should chunk a simple document hierarchically", async () => { + const markdown = `# Section 1 + +This is a paragraph that should fit in one chunk. + +# Section 2 + +This is another paragraph.`; + + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Test"); + + const result = await hierarchicalChunker({ + docId, + documentTree: root, + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }); + + expect(result.chunks).toBeDefined(); + expect(result.text).toBeDefined(); + expect(result.count).toBeGreaterThan(0); + expect(result.chunks.length).toBe(result.count); + expect(result.text.length).toBe(result.count); + + // Each chunk should have required fields + for (const chunk of result.chunks) { + expect(chunk.chunkId).toBeDefined(); + expect(chunk.docId).toBe(docId); + expect(chunk.text).toBeDefined(); + expect(chunk.nodePath).toBeDefined(); + expect(chunk.nodePath.length).toBeGreaterThan(0); + expect(chunk.depth).toBeGreaterThanOrEqual(0); + } + }); + + it("should respect token budgets", async () => { + // Create a long text that requires splitting + const longText = "Lorem ipsum dolor sit amet. ".repeat(100); + const markdown = `# Section\n\n${longText}`; + + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Long"); + + const maxTokens = 100; + const result = await hierarchicalChunker({ + docId, + documentTree: root, + maxTokens, + overlap: 10, + strategy: "hierarchical", + }); + + // Should create multiple chunks + expect(result.count).toBeGreaterThan(1); + + // Each chunk should respect token budget + for (const chunk of result.chunks) { + const tokens = estimateTokens(chunk.text); + expect(tokens).toBeLessThanOrEqual(maxTokens); + } + }); + + it("should create overlapping chunks", async () => { + const text = "Word ".repeat(200); + const markdown = `# Section\n\n${text}`; + + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Overlap"); + + const maxTokens = 50; + const overlap = 10; + const result = await hierarchicalChunker({ + docId, + documentTree: root, + maxTokens, + overlap, + strategy: "hierarchical", + }); + + // Should have multiple chunks + expect(result.count).toBeGreaterThan(1); + + // Check for overlap in text content + if (result.chunks.length > 1) { + const chunk0 = result.chunks[0].text; + const chunk1 = result.chunks[1].text; + + // Extract end of first chunk + const chunk0End = chunk0.substring(Math.max(0, chunk0.length - 50)); + // Check if beginning of second chunk overlaps + const hasOverlap = chunk1.includes(chunk0End.substring(0, 20)); + + expect(hasOverlap).toBe(true); + } + }); + + + it("should handle flat strategy", async () => { + const markdown = `# Section 1 + +Paragraph 1. + +# Section 2 + +Paragraph 2.`; + + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Flat"); + + const result = await new Workflow() + .hierarchicalChunker({ + docId, + documentTree: root, + maxTokens: 512, + overlap: 50, + strategy: "flat", + }) + .run(); + + // Flat strategy should still produce chunks + expect(result.count).toBeGreaterThan(0); + }); + + it("should maintain node paths in chunks", async () => { + const markdown = `# Section 1 + +## Subsection 1.1 + +Paragraph content.`; + + const docId = await NodeIdGenerator.generateDocId("test", markdown); + const root = await StructuralParser.parseMarkdown(docId, markdown, "Paths"); + + const result = await hierarchicalChunker({ + docId, + documentTree: root, + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }); + + // Check that chunks have node paths + for (const chunk of result.chunks) { + expect(chunk.nodePath).toBeDefined(); + expect(Array.isArray(chunk.nodePath)).toBe(true); + expect(chunk.nodePath.length).toBeGreaterThan(0); + + // First element should be root node ID + expect(chunk.nodePath[0]).toBe(root.nodeId); + } + }); +}); + +describe("Token estimation", () => { + it("should estimate tokens approximately", () => { + const text = "This is a test string"; + const tokens = estimateTokens(text); + + // Rough approximation: 1 token ~= 4 characters + const expected = Math.ceil(text.length / 4); + expect(tokens).toBe(expected); + }); + + it("should handle empty strings", () => { + const tokens = estimateTokens(""); + expect(tokens).toBe(0); + }); + + it("should increase token count with text length", () => { + const shortText = "Hello"; + const longText = "Hello world this is a much longer text"; + + const shortTokens = estimateTokens(shortText); + const longTokens = estimateTokens(longText); + + expect(longTokens).toBeGreaterThan(shortTokens); + }); +}); diff --git a/packages/test/src/test/hierarchical/StructuralParser.test.ts b/packages/test/src/test/hierarchical/StructuralParser.test.ts new file mode 100644 index 00000000..62793616 --- /dev/null +++ b/packages/test/src/test/hierarchical/StructuralParser.test.ts @@ -0,0 +1,204 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { NodeIdGenerator, NodeKind, StructuralParser } from "@workglow/ai"; +import { describe, expect, it } from "vitest"; + +describe("StructuralParser", () => { + describe("Markdown parsing", () => { + it("should parse markdown with headers into hierarchical tree", async () => { + const markdown = `# Main Title + +This is the intro. + +## Section 1 + +Content for section 1. + +## Section 2 + +Content for section 2. + +### Subsection 2.1 + +Nested content.`; + + const docId = "doc_test123"; + const root = await StructuralParser.parseMarkdown(docId, markdown, "Test Document"); + + expect(root.kind).toBe(NodeKind.DOCUMENT); + expect(root.children.length).toBeGreaterThan(0); + + // Find sections - parser should create sections for headers + const sections = root.children.filter((child) => child.kind === NodeKind.SECTION); + expect(sections.length).toBeGreaterThan(0); + + // Should have some children (sections or paragraphs) + expect(root.children.length).toBeGreaterThanOrEqual(1); + }); + + it("should preserve source offsets", async () => { + const markdown = `# Title + +Paragraph one. + +Paragraph two.`; + + const docId = "doc_test456"; + const root = await StructuralParser.parseMarkdown(docId, markdown, "Test"); + + expect(root.range.startOffset).toBe(0); + expect(root.range.endOffset).toBe(markdown.length); + + // Check children have valid offsets + for (const child of root.children) { + expect(child.range.startOffset).toBeGreaterThanOrEqual(0); + expect(child.range.endOffset).toBeLessThanOrEqual(markdown.length); + expect(child.range.endOffset).toBeGreaterThan(child.range.startOffset); + } + }); + + it("should handle nested sections correctly", async () => { + const markdown = `# Level 1 + +Content. + +## Level 2 + +More content. + +### Level 3 + +Deep content.`; + + const docId = "doc_test789"; + const root = await StructuralParser.parseMarkdown(docId, markdown, "Nested Test"); + + // Find first section (Level 1) + const level1 = root.children.find( + (c) => c.kind === NodeKind.SECTION && (c as any).level === 1 + ); + expect(level1).toBeDefined(); + + // It should have children including level 2 + const level2 = (level1 as any).children.find( + (c: any) => c.kind === NodeKind.SECTION && c.level === 2 + ); + expect(level2).toBeDefined(); + + // Level 2 should have level 3 + const level3 = (level2 as any).children.find( + (c: any) => c.kind === NodeKind.SECTION && c.level === 3 + ); + expect(level3).toBeDefined(); + }); + }); + + describe("Plain text parsing", () => { + it("should parse plain text into paragraphs", async () => { + const text = `First paragraph here. + +Second paragraph here. + +Third paragraph here.`; + + const docId = "doc_plain123"; + const root = await StructuralParser.parsePlainText(docId, text, "Plain Text"); + + expect(root.kind).toBe(NodeKind.DOCUMENT); + expect(root.children.length).toBe(3); + + for (const child of root.children) { + expect(child.kind).toBe(NodeKind.PARAGRAPH); + } + }); + + it("should handle single paragraph", async () => { + const text = "Just one paragraph."; + + const docId = "doc_plain456"; + const root = await StructuralParser.parsePlainText(docId, text, "Single"); + + expect(root.children.length).toBe(1); + expect(root.children[0].kind).toBe(NodeKind.PARAGRAPH); + expect(root.children[0].text).toBe(text); + }); + }); + + describe("Auto-detect", () => { + it("should auto-detect markdown", async () => { + const markdown = "# Header\n\nParagraph."; + const docId = "doc_auto123"; + + const root = await StructuralParser.parse(docId, markdown, "Auto"); + + // Should have detected markdown and created sections + const hasSection = root.children.some((c) => c.kind === NodeKind.SECTION); + expect(hasSection).toBe(true); + }); + + it("should default to plain text when no markdown markers", async () => { + const text = "Just plain text here."; + const docId = "doc_auto456"; + + const root = await StructuralParser.parse(docId, text, "Plain"); + + // Should be plain paragraph + expect(root.children[0].kind).toBe(NodeKind.PARAGRAPH); + }); + }); + + describe("NodeIdGenerator", () => { + it("should generate consistent docIds", async () => { + const id1 = await NodeIdGenerator.generateDocId("source1", "content"); + const id2 = await NodeIdGenerator.generateDocId("source1", "content"); + + expect(id1).toBe(id2); + expect(id1).toMatch(/^doc_[0-9a-f]{16}$/); + }); + + it("should generate different IDs for different content", () => { + const id1 = NodeIdGenerator.generateDocId("source", "content1"); + const id2 = NodeIdGenerator.generateDocId("source", "content2"); + + expect(id1).not.toBe(id2); + }); + + it("should generate consistent structural node IDs", async () => { + const docId = "doc_test"; + const range = { startOffset: 0, endOffset: 100 }; + + const id1 = await NodeIdGenerator.generateStructuralNodeId(docId, NodeKind.SECTION, range); + const id2 = await NodeIdGenerator.generateStructuralNodeId(docId, NodeKind.SECTION, range); + + expect(id1).toBe(id2); + expect(id1).toMatch(/^node_[0-9a-f]{16}$/); + }); + + it("should generate consistent child node IDs", async () => { + const parentId = "node_parent"; + const ordinal = 2; + + const id1 = await NodeIdGenerator.generateChildNodeId(parentId, ordinal); + const id2 = await NodeIdGenerator.generateChildNodeId(parentId, ordinal); + + expect(id1).toBe(id2); + expect(id1).toMatch(/^node_[0-9a-f]{16}$/); + }); + + it("should generate consistent chunk IDs", async () => { + const docId = "doc_test"; + const leafNodeId = "node_leaf"; + const ordinal = 0; + + const id1 = await NodeIdGenerator.generateChunkId(docId, leafNodeId, ordinal); + const id2 = await NodeIdGenerator.generateChunkId(docId, leafNodeId, ordinal); + + expect(id1).toBe(id2); + expect(id1).toMatch(/^chunk_[0-9a-f]{16}$/); + }); + }); +}); diff --git a/packages/test/src/test/rag/RagWorkflow.test.ts b/packages/test/src/test/rag/RagWorkflow.test.ts new file mode 100644 index 00000000..5a84ce13 --- /dev/null +++ b/packages/test/src/test/rag/RagWorkflow.test.ts @@ -0,0 +1,366 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * RAG (Retrieval Augmented Generation) Workflow End-to-End Test + * + * This test demonstrates a complete RAG pipeline using the Workflow API + * in a way that's compatible with visual node editors. + * + * Node Editor Mapping: + * ==================== + * Each workflow step below represents a node in a visual editor with + * dataflow connections between them: + * + * 1. Document Ingestion Pipeline (per file): + * FileLoader → StructuralParser → DocumentEnricher → HierarchicalChunker + * → [Array Processing] → TextEmbedding (multiple) → ChunkToVector → VectorStoreUpsert + * + * Note: The array processing step (embedding multiple chunks) would use: + * - A "ForEach" or "Map" control node in the visual editor + * - Or an ArrayTask wrapper that replicates TextEmbedding nodes + * - Or a batch TextEmbedding node that accepts arrays + * + * 2. Semantic Search Pipeline: + * Query (input) → RetrievalTask → Results (output) + * + * 3. Question Answering Pipeline: + * Question → RetrievalTask → ContextBuilder → TextQuestionAnswerTask → Answer + * + * Models Used: + * - Xenova/all-MiniLM-L6-v2 (Text Embedding - 384D) + * - onnx-community/NeuroBERT-NER-ONNX (Named Entity Recognition) + * - Xenova/distilbert-base-uncased-distilled-squad (Question Answering) + */ + +import { + DocumentRepository, + DocumentStorageSchema, + getGlobalModelRepository, + HierarchicalChunkerTaskOutput, + InMemoryModelRepository, + RetrievalTaskOutput, + setGlobalModelRepository, + TextEmbeddingTaskOutput, + TextQuestionAnswerTaskOutput, + VectorStoreUpsertTaskOutput, +} from "@workglow/ai"; +import { + HF_TRANSFORMERS_ONNX, + HfTransformersOnnxModelRecord, + register_HFT_InlineJobFns, +} from "@workglow/ai-provider"; +import { + InMemoryTabularRepository, + InMemoryVectorRepository, + registerVectorRepository, +} from "@workglow/storage"; +import { getTaskQueueRegistry, setTaskQueueRegistry, Workflow } from "@workglow/task-graph"; +import { FileLoaderTask } from "@workglow/tasks"; +import { readdirSync } from "fs"; +import { join } from "path"; +import { afterAll, beforeAll, describe, expect, it } from "vitest"; + +describe("RAG Workflow End-to-End", () => { + let vectorRepo: InMemoryVectorRepository; + let docRepo: DocumentRepository; + const vectorRepoName = "rag-test-vector-repo"; + const embeddingModel = "onnx:Xenova/all-MiniLM-L6-v2:q8"; + const nerModel = "onnx:onnx-community/NeuroBERT-NER-ONNX:q8"; + const qaModel = "onnx:Xenova/distilbert-base-uncased-distilled-squad:q8"; + + beforeAll(async () => { + // Setup task queue and model repository + setTaskQueueRegistry(null); + setGlobalModelRepository(new InMemoryModelRepository()); + await register_HFT_InlineJobFns(); + + // Register ONNX models + const models: HfTransformersOnnxModelRecord[] = [ + { + model_id: embeddingModel, + title: "All MiniLM L6 V2 384D", + description: "Xenova/all-MiniLM-L6-v2", + tasks: ["TextEmbeddingTask"], + provider: HF_TRANSFORMERS_ONNX, + provider_config: { + pipeline: "feature-extraction", + model_path: "Xenova/all-MiniLM-L6-v2", + native_dimensions: 384, + }, + metadata: {}, + }, + { + model_id: nerModel, + title: "NeuroBERT NER", + description: "onnx-community/NeuroBERT-NER-ONNX", + tasks: ["TextNamedEntityRecognitionTask"], + provider: HF_TRANSFORMERS_ONNX, + provider_config: { + pipeline: "token-classification", + model_path: "onnx-community/NeuroBERT-NER-ONNX", + }, + metadata: {}, + }, + { + model_id: qaModel, + title: "distilbert-base-uncased-distilled-squad", + description: "Xenova/distilbert-base-uncased-distilled-squad quantized to 8bit", + tasks: ["TextQuestionAnswerTask"], + provider: HF_TRANSFORMERS_ONNX, + provider_config: { + pipeline: "question-answering", + model_path: "Xenova/distilbert-base-uncased-distilled-squad", + }, + metadata: {}, + }, + ]; + + for (const model of models) { + await getGlobalModelRepository().addModel(model); + } + + // Setup repositories + vectorRepo = new InMemoryVectorRepository(); + await vectorRepo.setupDatabase(); + + // Register vector repository for use in workflows + registerVectorRepository(vectorRepoName, vectorRepo); + + const tabularRepo = new InMemoryTabularRepository(DocumentStorageSchema, ["docId"]); + await tabularRepo.setupDatabase(); + + docRepo = new DocumentRepository(tabularRepo, vectorRepo); + }); + + afterAll(async () => { + getTaskQueueRegistry().stopQueues().clearQueues(); + setTaskQueueRegistry(null); + }); + + it("should ingest markdown documents with NER enrichment", async () => { + // Find markdown files in docs folder + const docsPath = join(process.cwd(), "docs", "background"); + const files = readdirSync(docsPath).filter((f) => f.endsWith(".md")); + + console.log(`Found ${files.length} markdown files to process`); + + // Process first 2 files for testing (to keep test fast) + const filesToProcess = files.slice(0, 8); + + let totalVectors = 0; + + // NODE EDITOR MAPPING: + // In a visual node editor, this loop would be replaced by either: + // - Multiple FileLoader nodes (one per file), or + // - A "ForEach File" control flow node that iterates the pipeline + for (const file of filesToProcess) { + const filePath = join(docsPath, file); + console.log(`Processing: ${file}`); + + // Step 1: Load file + // NODE: FileLoaderTask + const fileLoader = new FileLoaderTask({ url: `file://${filePath}`, format: "markdown" }); + const fileContent = await fileLoader.run(); + expect(fileContent.text).toBeDefined(); + + // Step 2-4: Parse, enrich, and chunk + // NODES: StructuralParserTask → DocumentEnricherTask → HierarchicalChunkerTask + // DATAFLOWS: text → documentTree → documentTree → chunks[] + const ingestionWorkflow = new Workflow(); + ingestionWorkflow + .structuralParser({ + text: fileContent.text!, + title: fileContent.metadata.title, + format: "markdown", + sourceUri: filePath, + }) + .documentEnricher({ + generateSummaries: false, + extractEntities: true, + nerModel, + }) + .hierarchicalChunker({ + maxTokens: 512, + overlap: 50, + strategy: "hierarchical", + }); + + const chunkResult = (await ingestionWorkflow.run()) as HierarchicalChunkerTaskOutput; + console.log(` → Generated ${chunkResult.chunks.length} chunks`); + + // Step 5: Generate embeddings for array of chunks + // NODE EDITOR: This array processing would use one of: + // - ArrayTask wrapper around TextEmbeddingTask (processes each item) + // - ForEach control node that replicates TextEmbedding node per chunk + // - Batch TextEmbedding node that accepts text[] array + const embeddingWorkflows = chunkResult.text.map((text) => { + const embeddingWf = new Workflow(); + embeddingWf.textEmbedding({ + text, + model: embeddingModel, + }); + return embeddingWf.run(); + }); + + const embeddingResults = await Promise.all(embeddingWorkflows); + const vectors = embeddingResults.map((r) => (r as TextEmbeddingTaskOutput).vector); + + // Step 6-7: Transform and store vectors + // NODES: ChunkToVectorTask → VectorStoreUpsertTask + // DATAFLOWS: chunks[] + vectors[] → ids[] + vectors[] + metadata[] → count + const storeWorkflow = new Workflow(); + storeWorkflow + .chunkToVector({ + docId: chunkResult.docId, + chunks: chunkResult.chunks, + vectors, + }) + .vectorStoreUpsert({ + repository: vectorRepoName, + }); + + const result = (await storeWorkflow.run()) as VectorStoreUpsertTaskOutput; + + console.log(` → Stored ${result.count} vectors`); + totalVectors += result.count; + } + + // Verify vectors were stored + expect(totalVectors).toBeGreaterThan(0); + console.log(`Total vectors in repository: ${totalVectors}`); + }, 360000); // 3 minute timeout for model downloads + + it("should search for relevant content", async () => { + const query = "What is retrieval augmented generation?"; + + console.log(`\nSearching for: "${query}"`); + + // Create search workflow + const searchWorkflow = new Workflow(); + + searchWorkflow.retrieval({ + repository: vectorRepoName, + query, + model: embeddingModel, + topK: 5, + scoreThreshold: 0.3, + }); + + const searchResult = (await searchWorkflow.run()) as RetrievalTaskOutput; + + // Verify search results + expect(searchResult.chunks).toBeDefined(); + expect(Array.isArray(searchResult.chunks)).toBe(true); + expect(searchResult.chunks.length).toBeGreaterThan(0); + expect(searchResult.chunks.length).toBeLessThanOrEqual(5); + expect(searchResult.scores).toBeDefined(); + expect(searchResult.scores!.length).toBe(searchResult.chunks.length); + + console.log(`Found ${searchResult.chunks.length} relevant chunks:`); + for (let i = 0; i < searchResult.chunks.length; i++) { + const chunk = searchResult.chunks[i]; + const score = searchResult.scores![i]; + console.log(` ${i + 1}. Score: ${score.toFixed(3)} - ${chunk.substring(0, 80)}...`); + } + + // Verify scores are in descending order + for (let i = 1; i < searchResult.scores!.length; i++) { + expect(searchResult.scores![i]).toBeLessThanOrEqual(searchResult.scores![i - 1]); + } + }, 60000); // 1 minute timeout + + it("should answer questions using retrieved context", async () => { + const question = "What is RAG?"; + + console.log(`\nAnswering question: "${question}"`); + + // Step 1: Retrieve relevant context + const retrievalWorkflow = new Workflow(); + + retrievalWorkflow.retrieval({ + repository: vectorRepoName, + query: question, + model: embeddingModel, + topK: 3, + scoreThreshold: 0.2, // Lower threshold to find results + }); + + const retrievalResult = (await retrievalWorkflow.run()) as RetrievalTaskOutput; + + expect(retrievalResult.chunks).toBeDefined(); + + if (retrievalResult.chunks.length === 0) { + console.log("No relevant chunks found, skipping QA"); + return; // Skip QA if no relevant context found + } + + console.log(`Retrieved ${retrievalResult.chunks.length} context chunks`); + + // Step 2: Build context from retrieved chunks + const context = retrievalResult.chunks.join("\n\n"); + + console.log(`Context length: ${context.length} characters`); + + // Step 3: Answer question using context + const qaWorkflow = new Workflow(); + + qaWorkflow.textQuestionAnswer({ + context, + question, + model: qaModel, + }); + + const answer = (await qaWorkflow.run()) as TextQuestionAnswerTaskOutput; + + // Verify answer + expect(answer.text).toBeDefined(); + expect(typeof answer.text).toBe("string"); + expect(answer.text.length).toBeGreaterThan(0); + + console.log(`\nAnswer: ${answer.text}`); + }, 60000); // 1 minute timeout + + it("should handle complex multi-step RAG pipeline", async () => { + const question = "How does vector search work?"; + + console.log(`\nComplex RAG pipeline for: "${question}"`); + + // Step 1: Retrieve context + const retrievalWorkflow = new Workflow(); + retrievalWorkflow.retrieval({ + repository: vectorRepoName, + query: question, + model: embeddingModel, + topK: 3, + scoreThreshold: 0.2, + }); + + const retrievalResult = (await retrievalWorkflow.run()) as RetrievalTaskOutput; + + if (retrievalResult.chunks.length === 0) { + console.log("No chunks found, skipping QA step"); + return; + } + + // Step 2: Answer question with retrieved context + const context = retrievalResult.chunks.join("\n\n"); + const qaWorkflow = new Workflow(); + qaWorkflow.textQuestionAnswer({ + context, + question, + model: qaModel, + }); + + const result = (await qaWorkflow.run()) as TextQuestionAnswerTaskOutput; + + expect(result.text).toBeDefined(); + expect(typeof result.text).toBe("string"); + expect(result.text.length).toBeGreaterThan(0); + + console.log(`Answer: ${result.text}`); + }, 60000); // 1 minute timeout +}); diff --git a/packages/test/src/test/task-graph/InputResolver.test.ts b/packages/test/src/test/task-graph/InputResolver.test.ts new file mode 100644 index 00000000..d98b463c --- /dev/null +++ b/packages/test/src/test/task-graph/InputResolver.test.ts @@ -0,0 +1,265 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + getGlobalTabularRepositories, + InMemoryTabularRepository, + ITabularRepository, + registerTabularRepository, + TypeTabularRepository, +} from "@workglow/storage"; +import { IExecuteContext, resolveSchemaInputs, Task, TaskRegistry } from "@workglow/task-graph"; +import { + getInputResolvers, + globalServiceRegistry, + registerInputResolver, + type DataPortSchema, +} from "@workglow/util"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; + +describe("InputResolver", () => { + // Test schema for tabular repository + const testEntitySchema = { + type: "object", + properties: { + id: { type: "string" }, + name: { type: "string" }, + }, + required: ["id", "name"], + additionalProperties: false, + } as const; + + let testRepo: InMemoryTabularRepository; + + beforeEach(async () => { + // Create and register a test repository + testRepo = new InMemoryTabularRepository(testEntitySchema, ["id"] as const); + await testRepo.setupDatabase(); + registerTabularRepository("test-repo", testRepo); + }); + + afterEach(() => { + // Clean up the registry + getGlobalTabularRepositories().delete("test-repo"); + testRepo.destroy(); + }); + + describe("resolveSchemaInputs", () => { + test("should pass through non-string values unchanged", async () => { + const schema: DataPortSchema = { + type: "object", + properties: { + repository: TypeTabularRepository(), + }, + }; + + const input = { repository: testRepo }; + const resolved = await resolveSchemaInputs(input, schema, { + registry: globalServiceRegistry, + }); + + expect(resolved.repository).toBe(testRepo); + }); + + test("should resolve string repository ID to instance", async () => { + const schema: DataPortSchema = { + type: "object", + properties: { + repository: TypeTabularRepository(), + }, + }; + + const input = { repository: "test-repo" }; + const resolved = await resolveSchemaInputs(input, schema, { + registry: globalServiceRegistry, + }); + + expect(resolved.repository).toBe(testRepo); + }); + + test("should throw error for unknown repository ID", async () => { + const schema: DataPortSchema = { + type: "object", + properties: { + repository: TypeTabularRepository(), + }, + }; + + const input = { repository: "non-existent-repo" }; + + await expect( + resolveSchemaInputs(input, schema, { registry: globalServiceRegistry }) + ).rejects.toThrow('Tabular repository "non-existent-repo" not found'); + }); + + test("should not resolve properties without format annotation", async () => { + const schema: DataPortSchema = { + type: "object", + properties: { + name: { type: "string" }, + }, + }; + + const input = { name: "test-name" }; + const resolved = await resolveSchemaInputs(input, schema, { + registry: globalServiceRegistry, + }); + + expect(resolved.name).toBe("test-name"); + }); + + test("should handle boolean schema", async () => { + const input = { foo: "bar" }; + const resolved = await resolveSchemaInputs(input, true as DataPortSchema, { + registry: globalServiceRegistry, + }); + + expect(resolved).toEqual(input); + }); + + test("should handle schema without properties", async () => { + // @ts-expect-error - schema is not a DataPortSchemaObject + const schema: DataPortSchema = { type: "object" }; + const input = { foo: "bar" }; + const resolved = await resolveSchemaInputs(input, schema, { + registry: globalServiceRegistry, + }); + + expect(resolved).toEqual(input); + }); + }); + + describe("registerInputResolver", () => { + test("should register custom resolver", async () => { + // Register a custom resolver for a test format + registerInputResolver("custom", (id, format, registry) => { + return { resolved: true, id, format }; + }); + + const schema: DataPortSchema = { + type: "object", + properties: { + data: { type: "string", format: "custom:test" }, + }, + }; + + const input = { data: "my-id" }; + const resolved = await resolveSchemaInputs(input, schema, { + registry: globalServiceRegistry, + }); + + expect(resolved.data).toEqual({ resolved: true, id: "my-id", format: "custom:test" }); + + // Clean up + getInputResolvers().delete("custom"); + }); + + test("should support async resolvers", async () => { + registerInputResolver("async", async (id, format, registry) => { + await new Promise((resolve) => setTimeout(resolve, 10)); + return { asyncResolved: true, id }; + }); + + const schema: DataPortSchema = { + type: "object", + properties: { + data: { type: "string", format: "async" }, + }, + }; + + const input = { data: "async-id" }; + const resolved = await resolveSchemaInputs(input, schema, { + registry: globalServiceRegistry, + }); + + expect(resolved.data).toEqual({ asyncResolved: true, id: "async-id" }); + + // Clean up + getInputResolvers().delete("async"); + }); + }); + + describe("Integration with Task", () => { + // Define a test task that uses a repository + class RepositoryConsumerTask extends Task< + { repository: any; query: string }, + { results: any[] } + > { + public static type = "RepositoryConsumerTask"; + + public static inputSchema(): DataPortSchema { + return { + type: "object", + properties: { + repository: TypeTabularRepository({ + title: "Data Repository", + description: "Repository to query", + }), + query: { type: "string", title: "Query" }, + }, + required: ["repository", "query"], + additionalProperties: false, + }; + } + + public static outputSchema(): DataPortSchema { + return { + type: "object", + properties: { + results: { type: "array", items: { type: "object" } }, + }, + required: ["results"], + additionalProperties: false, + }; + } + + async execute( + input: { repository: ITabularRepository; query: string }, + _context: IExecuteContext + ): Promise<{ results: any[] }> { + const { repository } = input; + // In a real task, we'd search the repository + const results = await repository.getAll(); + return { results: results ?? [] }; + } + } + + beforeEach(() => { + TaskRegistry.registerTask(RepositoryConsumerTask); + }); + + afterEach(() => { + TaskRegistry.all.delete(RepositoryConsumerTask.type); + }); + + test("should resolve repository when running task with string ID", async () => { + // Add some test data + await testRepo.put({ id: "1", name: "Test Item" }); + + const task = new RepositoryConsumerTask(); + const result = await task.run({ + repository: "test-repo", + query: "test", + }); + + expect(result.results).toHaveLength(1); + expect(result.results[0]).toEqual({ id: "1", name: "Test Item" }); + }); + + test("should work with direct repository instance", async () => { + await testRepo.put({ id: "2", name: "Direct Item" }); + + const task = new RepositoryConsumerTask(); + const result = await task.run({ + repository: testRepo, + query: "test", + }); + + expect(result.results).toHaveLength(1); + expect(result.results[0]).toEqual({ id: "2", name: "Direct Item" }); + }); + }); +}); diff --git a/packages/test/src/test/task-graph/TaskGraphFormatSemantic.test.ts b/packages/test/src/test/task-graph/TaskGraphFormatSemantic.test.ts index cfd17fcb..554c9cf0 100644 --- a/packages/test/src/test/task-graph/TaskGraphFormatSemantic.test.ts +++ b/packages/test/src/test/task-graph/TaskGraphFormatSemantic.test.ts @@ -5,8 +5,78 @@ */ import { Dataflow, Task, TaskGraph, type TaskInput } from "@workglow/task-graph"; -import type { DataPortSchema } from "@workglow/util"; +import type { DataPortSchema, ServiceRegistry } from "@workglow/util"; import { beforeEach, describe, expect, it } from "vitest"; +import { + MODEL_REPOSITORY, + InMemoryModelRepository, + type ModelRecord, + type ModelRepository, +} from "@workglow/ai"; + +/** + * Test model fixtures for embedding models + */ +const EMBEDDING_MODELS: ModelRecord[] = [ + { + model_id: "text-embedding-ada-002", + tasks: ["EmbeddingTask"], + provider: "openai", + title: "OpenAI Ada Embedding", + description: "OpenAI text embedding model", + provider_config: {}, + metadata: {}, + }, + { + model_id: "all-MiniLM-L6-v2", + tasks: ["EmbeddingTask"], + provider: "local", + title: "MiniLM Embedding", + description: "Local embedding model", + provider_config: {}, + metadata: {}, + }, +]; + +/** + * Test model fixtures for text generation models + */ +const TEXT_GEN_MODELS: ModelRecord[] = [ + { + model_id: "gpt-4", + tasks: ["TextGenerationTask"], + provider: "openai", + title: "GPT-4", + description: "OpenAI GPT-4 text generation model", + provider_config: {}, + metadata: {}, + }, + { + model_id: "claude-3", + tasks: ["TextGenerationTask"], + provider: "anthropic", + title: "Claude 3", + description: "Anthropic Claude 3 model", + provider_config: {}, + metadata: {}, + }, +]; + +/** + * Helper function to create a test-local service registry with a model repository + * @param models - Array of model records to populate the repository with + * @returns Promise resolving to a configured ServiceRegistry + */ +async function createTestRegistry(models: ModelRecord[]): Promise { + const { ServiceRegistry } = await import("@workglow/util"); + const registry = new ServiceRegistry(); + const modelRepo = new InMemoryModelRepository(); + for (const model of models) { + await modelRepo.addModel(model); + } + registry.registerInstance(MODEL_REPOSITORY, modelRepo); + return registry; +} /** * Test task with generic model output (format: "model") @@ -459,16 +529,17 @@ describe("TaskGraph with format annotations", () => { } as const satisfies DataPortSchema; } - // Simulate runtime narrowing of models - async narrowInput(input: { - model: string | string[]; - }): Promise<{ model: string | string[] }> { - // In real implementation, this would check ModelRepository for compatible models - // For testing, we simulate filtering - const validEmbeddingModels = ["text-embedding-ada-002", "all-MiniLM-L6-v2"]; + // Runtime narrowing using ModelRepository from the registry + async narrowInput( + input: { model: string | string[] }, + registry: ServiceRegistry + ): Promise<{ model: string | string[] }> { + const modelRepo = registry.get(MODEL_REPOSITORY); + const validModels = await modelRepo.findModelsByTask(this.type); + const validIds = new Set(validModels?.map((m) => m.model_id) ?? []); const models = Array.isArray(input.model) ? input.model : [input.model]; - const narrowedModels = models.filter((m) => validEmbeddingModels.includes(m)); + const narrowedModels = models.filter((m) => validIds.has(m)); return { model: narrowedModels.length === 1 ? narrowedModels[0] : narrowedModels, @@ -482,12 +553,15 @@ describe("TaskGraph with format annotations", () => { const task = new NarrowableModelConsumerTask({}, { id: "consumer" }); + // Create test registry with embedding and text generation models + const registry = await createTestRegistry([...EMBEDDING_MODELS, ...TEXT_GEN_MODELS]); + // Test narrowing with array of models (some compatible, some not) const inputWithMixed = { model: ["text-embedding-ada-002", "gpt-4", "all-MiniLM-L6-v2", "claude-3"], }; - const narrowedResult = await task.narrowInput(inputWithMixed); + const narrowedResult = await task.narrowInput(inputWithMixed, registry); // Should only keep the embedding models expect(narrowedResult.model).toEqual(["text-embedding-ada-002", "all-MiniLM-L6-v2"]); @@ -522,12 +596,16 @@ describe("TaskGraph with format annotations", () => { } as const satisfies DataPortSchema; } - async narrowInput(input: { - model: string | string[]; - }): Promise<{ model: string | string[] }> { - const validModels = ["text-embedding-ada-002"]; + async narrowInput( + input: { model: string | string[] }, + registry: ServiceRegistry + ): Promise<{ model: string | string[] }> { + const modelRepo = registry.get(MODEL_REPOSITORY); + const validModels = await modelRepo.findModelsByTask(this.type); + const validIds = new Set(validModels?.map((m) => m.model_id) ?? []); + const models = Array.isArray(input.model) ? input.model : [input.model]; - const narrowed = models.filter((m) => validModels.includes(m)); + const narrowed = models.filter((m) => validIds.has(m)); return { model: narrowed.length === 1 ? narrowed[0] : narrowed }; } @@ -538,12 +616,15 @@ describe("TaskGraph with format annotations", () => { const task = new NarrowableModelTask({}, { id: "task" }); + // Create test registry with only embedding models + const registry = await createTestRegistry(EMBEDDING_MODELS); + // Test with single valid model - const result1 = await task.narrowInput({ model: "text-embedding-ada-002" }); + const result1 = await task.narrowInput({ model: "text-embedding-ada-002" }, registry); expect(result1.model).toBe("text-embedding-ada-002"); // Test with single invalid model (gets filtered out) - const result2 = await task.narrowInput({ model: "gpt-4" }); + const result2 = await task.narrowInput({ model: "gpt-4" }, registry); expect(result2.model).toEqual([]); }); diff --git a/packages/test/src/test/task/ContextBuilderTask.test.ts b/packages/test/src/test/task/ContextBuilderTask.test.ts new file mode 100644 index 00000000..b7282c31 --- /dev/null +++ b/packages/test/src/test/task/ContextBuilderTask.test.ts @@ -0,0 +1,247 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ContextFormat, contextBuilder } from "@workglow/ai"; +import { describe, expect, test } from "vitest"; + +describe("ContextBuilderTask", () => { + const testChunks = [ + "First chunk of text about artificial intelligence.", + "Second chunk discussing machine learning algorithms.", + "Third chunk covering neural networks and deep learning.", + ]; + + const testMetadata = [ + { source: "doc1.txt", page: 1 }, + { source: "doc2.txt", page: 2 }, + { source: "doc3.txt", page: 3 }, + ]; + + const testScores = [0.95, 0.87, 0.82]; + + test("should format chunks with SIMPLE format", async () => { + const result = await contextBuilder({ + chunks: testChunks, + }); + + expect(result.context).toBeDefined(); + expect(result.chunksUsed).toBe(3); + expect(result.totalLength).toBeGreaterThan(0); + expect(result.context).toContain(testChunks[0]); + expect(result.context).toContain(testChunks[1]); + expect(result.context).toContain(testChunks[2]); + }); + + test("should format chunks with NUMBERED format", async () => { + const result = await contextBuilder({ + chunks: testChunks, + format: ContextFormat.NUMBERED, + }); + + expect(result.context).toContain("[1]"); + expect(result.context).toContain("[2]"); + expect(result.context).toContain("[3]"); + expect(result.context).toContain(testChunks[0]); + }); + + test("should format chunks with XML format", async () => { + const result = await contextBuilder({ + chunks: testChunks, + format: ContextFormat.XML, + }); + + expect(result.context).toContain(""); + expect(result.context).toContain('id="1"'); + expect(result.context).toContain(testChunks[0]); + }); + + test("should format chunks with MARKDOWN format", async () => { + const result = await contextBuilder({ + chunks: testChunks, + format: ContextFormat.MARKDOWN, + }); + + expect(result.context).toContain("### Chunk"); + expect(result.context).toContain("### Chunk 1"); + expect(result.context).toContain("### Chunk 2"); + expect(result.context).toContain(testChunks[0]); + }); + + test("should format chunks with JSON format", async () => { + const result = await contextBuilder({ + chunks: testChunks, + format: ContextFormat.JSON, + }); + + // Should contain JSON objects + expect(result.context).toContain('"index"'); + expect(result.context).toContain('"content"'); + expect(result.context).toContain(testChunks[0]); + }); + + test("should include metadata when includeMetadata is true", async () => { + const result = await contextBuilder({ + chunks: testChunks, + metadata: testMetadata, + includeMetadata: true, + format: ContextFormat.NUMBERED, + }); + + expect(result.context).toContain("doc1.txt"); + expect(result.context).toContain("page"); + }); + + test("should include scores when provided and includeMetadata is true", async () => { + const result = await contextBuilder({ + chunks: testChunks, + metadata: testMetadata, + scores: testScores, + includeMetadata: true, + format: ContextFormat.NUMBERED, + }); + + // NUMBERED format includes scores in the formatNumbered method when includeMetadata is true + // The formatNumbered method uses formatMetadataInline which includes scores + expect(result.context).toContain("score="); + expect(result.context).toContain("0.95"); + }); + + test("should respect maxLength constraint", async () => { + const result = await contextBuilder({ + chunks: testChunks, + maxLength: 100, + }); + + expect(result.totalLength).toBeLessThanOrEqual(100); + expect(result.chunksUsed).toBeLessThanOrEqual(testChunks.length); + }); + + test("should use custom separator", async () => { + const separator = "---"; + const result = await contextBuilder({ + chunks: testChunks, + separator: separator, + }); + + // Should contain separator between chunks + const separatorCount = (result.context.match(new RegExp(separator, "g")) || []).length; + expect(separatorCount).toBeGreaterThan(0); + }); + + test("should handle empty chunks array", async () => { + const result = await contextBuilder({ + chunks: [], + }); + + expect(result.context).toBe(""); + expect(result.chunksUsed).toBe(0); + expect(result.totalLength).toBe(0); + }); + + test("should handle single chunk", async () => { + const singleChunk = ["Only one chunk"]; + const result = await contextBuilder({ + chunks: singleChunk, + }); + + expect(result.context).toBe(singleChunk[0]); + expect(result.chunksUsed).toBe(1); + expect(result.totalLength).toBe(singleChunk[0].length); + }); + + test("should handle chunks with mismatched metadata length", async () => { + const result = await contextBuilder({ + chunks: testChunks, + metadata: [testMetadata[0]], // Only one metadata entry + includeMetadata: true, + }); + + // Should handle gracefully, only include metadata where available + expect(result.chunksUsed).toBe(3); + expect(result.context).toBeDefined(); + }); + + test("should handle chunks with mismatched scores length", async () => { + const result = await contextBuilder({ + chunks: testChunks, + scores: [testScores[0]], // Only one score + includeMetadata: true, + }); + + expect(result.chunksUsed).toBe(3); + expect(result.context).toBeDefined(); + }); + + test("should truncate first chunk if maxLength is very small", async () => { + const result = await contextBuilder({ + chunks: testChunks, + maxLength: 50, + }); + + expect(result.totalLength).toBeLessThanOrEqual(50); + expect(result.context.length).toBeLessThanOrEqual(50); + if (result.chunksUsed > 0) { + expect(result.context).toContain("..."); + } + }); + + test("should use default separator when not specified", async () => { + const result = await contextBuilder({ + chunks: testChunks, + }); + + // Default separator is "\n\n" + expect(result.context).toContain("\n\n"); + }); + + test("should escape XML special characters in XML format", async () => { + const chunksWithSpecialChars = ['Text with & "quotes"']; + const result = await contextBuilder({ + chunks: chunksWithSpecialChars, + format: ContextFormat.XML, + }); + + // Should escape XML characters + expect(result.context).not.toContain(""); + expect(result.context).toContain("<tag>"); + expect(result.context).toContain("&"); + expect(result.context).toContain(""quotes""); + }); + + test("should format metadata correctly in different formats", async () => { + // Test MARKDOWN format with metadata + const markdownResult = await contextBuilder({ + chunks: testChunks, + metadata: testMetadata, + includeMetadata: true, + format: ContextFormat.MARKDOWN, + }); + + expect(markdownResult.context).toContain("**Metadata:**"); + expect(markdownResult.context).toContain("- source:"); + + // Test JSON format with metadata + const jsonResult = await contextBuilder({ + chunks: testChunks, + metadata: testMetadata, + includeMetadata: true, + format: ContextFormat.JSON, + }); + + expect(jsonResult.context).toContain('"metadata"'); + }); + + test("should handle very long chunks", async () => { + const longChunk = "A".repeat(10000); + const result = await contextBuilder({ + chunks: [longChunk], + maxLength: 5000, + }); + + expect(result.totalLength).toBeLessThanOrEqual(5000); + }); +}); diff --git a/packages/test/src/test/task/HybridSearchTask.test.ts b/packages/test/src/test/task/HybridSearchTask.test.ts new file mode 100644 index 00000000..88b43756 --- /dev/null +++ b/packages/test/src/test/task/HybridSearchTask.test.ts @@ -0,0 +1,269 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { hybridSearch } from "@workglow/ai"; +import { InMemoryVectorRepository, registerVectorRepository } from "@workglow/storage"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; + +describe("HybridSearchTask", () => { + let repo: InMemoryVectorRepository<{ text: string; category?: string }>; + + beforeEach(async () => { + repo = new InMemoryVectorRepository<{ text: string; category?: string }>(); + await repo.setupDatabase(); + + // Populate repository with test data + const vectors = [ + new Float32Array([1.0, 0.0, 0.0]), // Similar vector, contains "machine" + new Float32Array([0.8, 0.2, 0.0]), // Somewhat similar, contains "learning" + new Float32Array([0.0, 1.0, 0.0]), // Different vector, contains "cooking" + new Float32Array([0.0, 0.0, 1.0]), // Different vector, contains "travel" + new Float32Array([0.9, 0.1, 0.0]), // Very similar, contains "artificial" + ]; + + const metadata = [ + { text: "Document about machine learning", category: "tech" }, + { text: "Document about deep learning algorithms", category: "tech" }, + { text: "Document about cooking recipes", category: "food" }, + { text: "Document about travel destinations", category: "travel" }, + { text: "Document about artificial intelligence", category: "tech" }, + ]; + + for (let i = 0; i < vectors.length; i++) { + await repo.upsert(`doc${i + 1}`, vectors[i], metadata[i]); + } + }); + + afterEach(() => { + repo.destroy(); + }); + + test("should perform hybrid search with vector and text query", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine learning"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 3, + }); + + expect(result.count).toBeGreaterThan(0); + expect(result.chunks).toHaveLength(result.count); + expect(result.ids).toHaveLength(result.count); + expect(result.metadata).toHaveLength(result.count); + expect(result.scores).toHaveLength(result.count); + + // Scores should be in descending order + for (let i = 1; i < result.scores.length; i++) { + expect(result.scores[i - 1]).toBeGreaterThanOrEqual(result.scores[i]); + } + }); + + test("should combine vector and text scores", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 5, + }); + + // Results should be ranked by combined score + expect(result.scores.length).toBeGreaterThan(0); + result.scores.forEach((score) => { + expect(score).toBeGreaterThanOrEqual(0); + expect(score).toBeLessThanOrEqual(1); + }); + }); + + test("should respect vectorWeight parameter", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "learning"; + + // Test with high vector weight + const resultHighVector = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 5, + vectorWeight: 0.9, + }); + + // Test with low vector weight (high text weight) + const resultHighText = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 5, + vectorWeight: 0.1, + }); + + // Results might differ based on weight + expect(resultHighVector.count).toBeGreaterThan(0); + expect(resultHighText.count).toBeGreaterThan(0); + }); + + test("should return vectors when returnVectors is true", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 3, + returnVectors: true, + }); + + expect(result.vectors).toBeDefined(); + expect(result.vectors).toHaveLength(result.count); + expect(result.vectors![0]).toBeInstanceOf(Float32Array); + }); + + test("should not return vectors when returnVectors is false", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 3, + returnVectors: false, + }); + + expect(result.vectors).toBeUndefined(); + }); + + test("should apply metadata filter", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "learning"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 10, + filter: { category: "tech" }, + }); + + // All results should have category "tech" + result.metadata.forEach((meta) => { + expect(meta).toHaveProperty("category", "tech"); + }); + }); + + test("should apply score threshold", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 10, + scoreThreshold: 0.5, + }); + + // All scores should be >= threshold + result.scores.forEach((score) => { + expect(score).toBeGreaterThanOrEqual(0.5); + }); + }); + + test("should respect topK limit", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "document"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 2, + }); + + expect(result.count).toBeLessThanOrEqual(2); + expect(result.chunks).toHaveLength(result.count); + }); + + test("should handle default parameters", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "learning"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + }); + + // Default topK is 10, vectorWeight is 0.7 + expect(result.count).toBeGreaterThan(0); + expect(result.count).toBeLessThanOrEqual(10); + }); + + test("should extract chunks from metadata", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 5, + }); + + // Chunks should match metadata text + result.chunks.forEach((chunk, idx) => { + expect(chunk).toBe(result.metadata[idx].text); + }); + }); + + test("should work with quantized query vectors", async () => { + const queryVector = new Int8Array([127, 0, 0]); + const queryText = "machine"; + + const result = await hybridSearch({ + repository: repo, + queryVector: queryVector, + queryText: queryText, + topK: 3, + }); + + expect(result.count).toBeGreaterThan(0); + expect(result.chunks).toHaveLength(result.count); + }); + + test("should resolve repository from string ID", async () => { + // Register repository by ID + registerVectorRepository("test-hybrid-repo", repo); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + const queryText = "machine learning"; + + // Pass repository as string ID instead of instance + const result = await hybridSearch({ + repository: "test-hybrid-repo" as any, + queryVector: queryVector, + queryText: queryText, + topK: 3, + }); + + expect(result.count).toBeGreaterThan(0); + expect(result.chunks).toHaveLength(result.count); + expect(result.ids).toHaveLength(result.count); + expect(result.metadata).toHaveLength(result.count); + expect(result.scores).toHaveLength(result.count); + + // Scores should be in descending order + for (let i = 1; i < result.scores.length; i++) { + expect(result.scores[i - 1]).toBeGreaterThanOrEqual(result.scores[i]); + } + }); +}); diff --git a/packages/test/src/test/task/RetrievalTask.test.ts b/packages/test/src/test/task/RetrievalTask.test.ts new file mode 100644 index 00000000..45c79c90 --- /dev/null +++ b/packages/test/src/test/task/RetrievalTask.test.ts @@ -0,0 +1,290 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { retrieval } from "@workglow/ai"; +import { InMemoryVectorRepository, registerVectorRepository } from "@workglow/storage"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; + +describe("RetrievalTask", () => { + let repo: InMemoryVectorRepository<{ + text?: string; + content?: string; + chunk?: string; + category?: string; + title?: string; + author?: string; + }>; + + beforeEach(async () => { + repo = new InMemoryVectorRepository<{ + text?: string; + content?: string; + chunk?: string; + category?: string; + title?: string; + author?: string; + }>(); + await repo.setupDatabase(); + + // Populate repository with test data + const vectors = [ + new Float32Array([1.0, 0.0, 0.0]), + new Float32Array([0.8, 0.2, 0.0]), + new Float32Array([0.0, 1.0, 0.0]), + new Float32Array([0.0, 0.0, 1.0]), + new Float32Array([0.9, 0.1, 0.0]), + ]; + + const metadata = [ + { text: "First chunk about AI" }, + { text: "Second chunk about machine learning" }, + { content: "Third chunk about cooking" }, + { chunk: "Fourth chunk about travel" }, + { text: "Fifth chunk about artificial intelligence" }, + ]; + + for (let i = 0; i < vectors.length; i++) { + await repo.upsert(`doc${i + 1}`, vectors[i], metadata[i]); + } + }); + + afterEach(() => { + repo.destroy(); + }); + + test("should retrieve chunks with query vector", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 3, + }); + + expect(result.count).toBe(3); + expect(result.chunks).toHaveLength(3); + expect(result.ids).toHaveLength(3); + expect(result.metadata).toHaveLength(3); + expect(result.scores).toHaveLength(3); + + // Chunks should be extracted from metadata + expect(result.chunks[0]).toBeTruthy(); + expect(typeof result.chunks[0]).toBe("string"); + }); + + test("should extract text from metadata.text field", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 5, + }); + + // Find chunks that have text field + const textChunks = result.chunks.filter((chunk, idx) => { + const meta = result.metadata[idx]; + return meta.text !== undefined; + }); + + expect(textChunks.length).toBeGreaterThan(0); + textChunks.forEach((chunk, idx) => { + const originalIdx = result.chunks.indexOf(chunk); + expect(chunk).toBe(result.metadata[originalIdx].text); + }); + }); + + test("should extract text from metadata.content field as fallback", async () => { + const queryVector = new Float32Array([0.0, 1.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 5, + }); + + // Find the chunk with content field + const contentChunkIdx = result.metadata.findIndex((meta) => meta.content !== undefined); + if (contentChunkIdx >= 0) { + expect(result.chunks[contentChunkIdx]).toBe(result.metadata[contentChunkIdx].content); + } + }); + + test("should extract text from metadata.chunk field as fallback", async () => { + const queryVector = new Float32Array([0.0, 0.0, 1.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 5, + }); + + // Find the chunk with chunk field + const chunkIdx = result.metadata.findIndex((meta) => meta.chunk !== undefined); + if (chunkIdx >= 0) { + expect(result.chunks[chunkIdx]).toBe(result.metadata[chunkIdx].chunk); + } + }); + + test("should return vectors when returnVectors is true", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 3, + returnVectors: true, + }); + + expect(result.vectors).toBeDefined(); + expect(result.vectors).toHaveLength(3); + expect(result.vectors![0]).toBeInstanceOf(Float32Array); + }); + + test("should not return vectors when returnVectors is false", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 3, + returnVectors: false, + }); + + expect(result.vectors).toBeUndefined(); + }); + + test("should respect topK parameter", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 2, + }); + + expect(result.count).toBe(2); + expect(result.chunks).toHaveLength(2); + }); + + test("should apply metadata filter", async () => { + // Add a document with specific metadata for filtering + await repo.upsert("filtered_doc", new Float32Array([1.0, 0.0, 0.0]), { + text: "Filtered document", + category: "test", + }); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 10, + filter: { category: "test" }, + }); + + expect(result.count).toBe(1); + expect(result.ids[0]).toBe("filtered_doc"); + }); + + test("should apply score threshold", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 10, + scoreThreshold: 0.9, + }); + + result.scores.forEach((score) => { + expect(score).toBeGreaterThanOrEqual(0.9); + }); + }); + + test("should use queryEmbedding when provided", async () => { + const queryEmbedding = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryEmbedding, + topK: 3, + }); + + expect(result.count).toBe(3); + expect(result.chunks).toHaveLength(3); + }); + + test("should throw error when query is string without model", async () => { + await expect( + // @ts-expect-error - query is string but no model is provided + retrieval({ + repository: repo, + query: "test query string", + topK: 3, + }) + ).rejects.toThrow("model"); + }); + + test("should handle default topK value", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + }); + + // Default topK is 5 + expect(result.count).toBe(5); + expect(result.count).toBeLessThanOrEqual(5); + }); + + test("should JSON.stringify metadata when no text/content/chunk fields", async () => { + // Add document with only non-standard metadata + await repo.upsert("json_doc", new Float32Array([1.0, 0.0, 0.0]), { + title: "Title only", + author: "Author name", + }); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const result = await retrieval({ + repository: repo, + query: queryVector, + topK: 10, + }); + + // Find the JSON stringified chunk + const jsonChunk = result.chunks.find((chunk) => chunk.includes("title")); + expect(jsonChunk).toBeDefined(); + expect(jsonChunk).toContain("Title only"); + expect(jsonChunk).toContain("Author name"); + }); + + test("should resolve repository from string ID", async () => { + // Register repository by ID + registerVectorRepository("test-retrieval-repo", repo); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + // Pass repository as string ID instead of instance + const result = await retrieval({ + repository: "test-retrieval-repo" as any, + query: queryVector, + topK: 3, + }); + + expect(result.count).toBe(3); + expect(result.chunks).toHaveLength(3); + expect(result.ids).toHaveLength(3); + expect(result.metadata).toHaveLength(3); + expect(result.scores).toHaveLength(3); + + // Chunks should be extracted from metadata + expect(result.chunks[0]).toBeTruthy(); + expect(typeof result.chunks[0]).toBe("string"); + }); +}); diff --git a/packages/test/src/test/task/Task.smartClone.test.ts b/packages/test/src/test/task/Task.smartClone.test.ts new file mode 100644 index 00000000..ac926ce7 --- /dev/null +++ b/packages/test/src/test/task/Task.smartClone.test.ts @@ -0,0 +1,203 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { IExecuteContext } from "@workglow/task-graph"; +import { Task } from "@workglow/task-graph"; +import type { DataPortSchema } from "@workglow/util"; +import { beforeEach, describe, expect, test } from "vitest"; + +// Test task class to access private smartClone method +class TestSmartCloneTask extends Task<{ data: any }, { result: any }> { + static readonly type = "TestSmartCloneTask"; + static readonly category = "Test"; + static readonly title = "Test Smart Clone Task"; + static readonly description = "A task for testing smartClone"; + declare runInputData: { data: any }; + declare runOutputData: { result: any }; + + static inputSchema(): DataPortSchema { + return { + type: "object", + properties: { + data: {}, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + static outputSchema(): DataPortSchema { + return { + type: "object", + properties: { + result: {}, + }, + additionalProperties: false, + } as const satisfies DataPortSchema; + } + + async execute(input: { data: any }, context: IExecuteContext): Promise<{ result: any }> { + return { result: input.data }; + } + + // Expose smartClone for testing + public testSmartClone(obj: any): any { + return (this as any).smartClone(obj); + } +} + +describe("Task.smartClone circular reference detection", () => { + let task: TestSmartCloneTask; + + beforeEach(() => { + task = new TestSmartCloneTask({ data: {} }, { id: "test-task" }); + }); + + test("should handle simple objects without circular references", () => { + const obj = { a: 1, b: { c: 2 } }; + const cloned = task.testSmartClone(obj); + + expect(cloned).toEqual(obj); + expect(cloned).not.toBe(obj); + expect(cloned.b).not.toBe(obj.b); + }); + + test("should handle arrays without circular references", () => { + const arr = [1, 2, [3, 4]]; + const cloned = task.testSmartClone(arr); + + expect(cloned).toEqual(arr); + expect(cloned).not.toBe(arr); + expect(cloned[2]).not.toBe(arr[2]); + }); + + test("should throw error on object with circular self-reference", () => { + const obj: any = { a: 1 }; + obj.self = obj; + + expect(() => task.testSmartClone(obj)).toThrow("Circular reference detected in input data"); + }); + + test("should throw error on nested circular reference", () => { + const obj: any = { a: 1, b: { c: 2 } }; + obj.b.parent = obj; + + expect(() => task.testSmartClone(obj)).toThrow("Circular reference detected in input data"); + }); + + test("should throw error on array with circular reference", () => { + const arr: any = [1, 2, 3]; + arr.push(arr); + + expect(() => task.testSmartClone(arr)).toThrow("Circular reference detected in input data"); + }); + + test("should throw error on complex circular reference chain", () => { + const obj1: any = { name: "obj1" }; + const obj2: any = { name: "obj2", ref: obj1 }; + const obj3: any = { name: "obj3", ref: obj2 }; + obj1.ref = obj3; // Create circular chain + + expect(() => task.testSmartClone(obj1)).toThrow("Circular reference detected in input data"); + }); + + test("should handle same object referenced multiple times (not circular)", () => { + const shared = { value: 42 }; + const obj = { a: shared, b: shared }; + + // This should work - same object referenced multiple times is not circular + // Each reference gets cloned independently + const cloned = task.testSmartClone(obj); + + expect(cloned).toEqual(obj); + expect(cloned.a).toEqual(shared); + expect(cloned.b).toEqual(shared); + // The cloned references should be different objects (deep clone) + expect(cloned.a).not.toBe(shared); + expect(cloned.b).not.toBe(shared); + expect(cloned.a).not.toBe(cloned.b); + }); + + test("should preserve class instances by reference (no circular check needed)", () => { + class CustomClass { + constructor(public value: number) {} + } + + const instance = new CustomClass(42); + const obj = { data: instance }; + + const cloned = task.testSmartClone(obj); + + expect(cloned.data).toBe(instance); // Should be same reference + expect(cloned.data.value).toBe(42); + }); + + test("should clone TypedArrays to avoid shared mutation", () => { + const typedArray = new Float32Array([1.0, 2.0, 3.0]); + const obj = { data: typedArray }; + + const cloned = task.testSmartClone(obj); + + expect(cloned.data).not.toBe(typedArray); // Should be a new instance + expect(cloned.data).toEqual(typedArray); // But with the same values + expect(cloned.data).toBeInstanceOf(Float32Array); + }); + + test("should handle null and undefined", () => { + expect(task.testSmartClone(null)).toBe(null); + expect(task.testSmartClone(undefined)).toBe(undefined); + expect(task.testSmartClone({ a: null, b: undefined })).toEqual({ a: null, b: undefined }); + }); + + test("should handle primitives", () => { + expect(task.testSmartClone(42)).toBe(42); + expect(task.testSmartClone("hello")).toBe("hello"); + expect(task.testSmartClone(true)).toBe(true); + expect(task.testSmartClone(false)).toBe(false); + }); + + test("should clone nested structures without circular references", () => { + const obj = { + level1: { + level2: { + level3: { + value: "deep", + }, + }, + array: [1, 2, { nested: true }], + }, + }; + + const cloned = task.testSmartClone(obj); + + expect(cloned).toEqual(obj); + expect(cloned).not.toBe(obj); + expect(cloned.level1).not.toBe(obj.level1); + expect(cloned.level1.level2).not.toBe(obj.level1.level2); + expect(cloned.level1.array).not.toBe(obj.level1.array); + expect(cloned.level1.array[2]).not.toBe(obj.level1.array[2]); + }); + + test("should handle mixed object and array structures", () => { + const obj = { + users: [ + { id: 1, name: "Alice" }, + { id: 2, name: "Bob" }, + ], + settings: { + theme: "dark", + features: ["feature1", "feature2"], + }, + }; + + const cloned = task.testSmartClone(obj); + + expect(cloned).toEqual(obj); + expect(cloned.users).not.toBe(obj.users); + expect(cloned.users[0]).not.toBe(obj.users[0]); + expect(cloned.settings).not.toBe(obj.settings); + expect(cloned.settings.features).not.toBe(obj.settings.features); + }); +}); diff --git a/packages/test/src/test/task/TestTasks.ts b/packages/test/src/test/task/TestTasks.ts index a803799a..d780d921 100644 --- a/packages/test/src/test/task/TestTasks.ts +++ b/packages/test/src/test/task/TestTasks.ts @@ -97,7 +97,7 @@ export class TestIOTask extends Task { /** * Implementation of full run mode - returns complete results */ - async execute(): Promise { + async execute(_input: TestIOTaskInput, _context: IExecuteContext): Promise { return { all: true, key: "full", reactiveOnly: false }; } } @@ -680,8 +680,11 @@ export class StringTask extends Task<{ input: string }, { output: string }, Task /** * Returns the input string as output */ - async execute() { - return { output: this.runInputData.input }; + async executeReactive( + input: { input: string }, + _output: { output: string } + ): Promise<{ output: string }> { + return { output: input.input }; } } @@ -719,8 +722,8 @@ export class NumberToStringTask extends Task<{ input: number }, { output: string /** * Returns the input string as output */ - async execute() { - return { output: String(this.runInputData.input) }; + async execute(input: { input: number }, _context: IExecuteContext): Promise<{ output: string }> { + return { output: String(input.input) }; } } @@ -759,8 +762,8 @@ export class NumberTask extends Task<{ input: number }, { output: number }, Task /** * Returns the input number as output */ - async execute() { - return { output: this.runInputData.input }; + async execute(input: { input: number }, _context: IExecuteContext): Promise<{ output: number }> { + return { output: input.input }; } } diff --git a/packages/test/src/test/task/TextChunkerTask.test.ts b/packages/test/src/test/task/TextChunkerTask.test.ts new file mode 100644 index 00000000..425677b1 --- /dev/null +++ b/packages/test/src/test/task/TextChunkerTask.test.ts @@ -0,0 +1,226 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ChunkingStrategy, textChunker } from "@workglow/ai"; +import { describe, expect, test } from "vitest"; + +describe("TextChunkerTask", () => { + const testText = + "This is the first sentence. This is the second sentence! This is the third sentence? " + + "This is the fourth sentence. This is the fifth sentence."; + + test("should chunk text with FIXED strategy", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 50, + chunkOverlap: 10, + strategy: ChunkingStrategy.FIXED, + }); + + expect(result.chunks).toBeDefined(); + expect(result.chunks.length).toBeGreaterThan(0); + expect(result.metadata).toHaveLength(result.chunks.length); + + // Verify metadata structure + result.metadata.forEach((meta, idx) => { + expect(meta).toHaveProperty("index"); + expect(meta).toHaveProperty("startChar"); + expect(meta).toHaveProperty("endChar"); + expect(meta).toHaveProperty("length"); + expect(meta.index).toBe(idx); + }); + }); + + test("should chunk with SENTENCE strategy", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 80, + chunkOverlap: 20, + strategy: ChunkingStrategy.SENTENCE, + }); + + expect(result.chunks.length).toBeGreaterThan(0); + expect(result.metadata).toHaveLength(result.chunks.length); + + // Chunks should respect sentence boundaries + result.chunks.forEach((chunk) => { + expect(chunk.length).toBeGreaterThan(0); + }); + }); + + test("should chunk with PARAGRAPH strategy", async () => { + const paragraphText = + "First paragraph with multiple sentences. It has more content.\n\n" + + "Second paragraph with different content. It also has sentences.\n\n" + + "Third paragraph is here. With more text."; + + const result = await textChunker({ + text: paragraphText, + chunkSize: 100, + chunkOverlap: 20, + strategy: ChunkingStrategy.PARAGRAPH, + }); + + expect(result.chunks.length).toBeGreaterThan(0); + expect(result.metadata).toHaveLength(result.chunks.length); + }); + + test("should handle default parameters", async () => { + const result = await textChunker({ + text: testText, + }); + + // Default: chunkSize=512, chunkOverlap=50, strategy=FIXED + expect(result.chunks).toBeDefined(); + expect(result.chunks.length).toBeGreaterThan(0); + expect(result.metadata).toHaveLength(result.chunks.length); + }); + + test("should handle chunkOverlap correctly", async () => { + const shortText = "A".repeat(100); // 100 characters + const result = await textChunker({ + text: shortText, + chunkSize: 30, + chunkOverlap: 10, + strategy: ChunkingStrategy.FIXED, + }); + + // With chunkSize=30 and overlap=10, we move forward by 20 each time + // Should have multiple chunks + expect(result.chunks.length).toBeGreaterThan(1); + + // Verify overlap by checking that chunks share content + if (result.chunks.length > 1) { + const firstChunkEnd = result.chunks[0].slice(-10); + const secondChunkStart = result.chunks[1].slice(0, 10); + // There should be some overlap + expect(firstChunkEnd).toBe(secondChunkStart); + } + }); + + test("should handle zero overlap", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 50, + chunkOverlap: 0, + strategy: ChunkingStrategy.FIXED, + }); + + expect(result.chunks.length).toBeGreaterThan(0); + // With zero overlap, chunks should be adjacent + result.metadata.forEach((meta, idx) => { + if (idx > 0) { + const prevMeta = result.metadata[idx - 1]; + expect(meta.startChar).toBe(prevMeta.endChar); + } + }); + }); + + test("should handle text shorter than chunkSize", async () => { + const shortText = "Short text"; + const result = await textChunker({ + text: shortText, + chunkSize: 100, + chunkOverlap: 10, + }); + + expect(result.chunks.length).toBe(1); + expect(result.chunks[0]).toBe(shortText); + expect(result.metadata[0].length).toBe(shortText.length); + }); + + test("should handle empty text", async () => { + const result = await textChunker({ + text: "", + chunkSize: 50, + }); + + // Empty text should produce empty chunks or handle gracefully + expect(result.chunks).toBeDefined(); + expect(result.metadata).toBeDefined(); + }); + + test("should include all text in chunks (no loss)", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 50, + chunkOverlap: 10, + strategy: ChunkingStrategy.FIXED, + }); + + // Reconstruct text from chunks (accounting for overlap) + const totalChars = result.chunks.reduce((sum, chunk) => sum + chunk.length, 0); + // With overlap, total should be >= original length + expect(totalChars).toBeGreaterThanOrEqual(testText.length); + }); + + test("should handle SEMANTIC strategy (currently same as sentence)", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 80, + chunkOverlap: 20, + strategy: ChunkingStrategy.SEMANTIC, + }); + + expect(result.chunks.length).toBeGreaterThan(0); + expect(result.metadata).toHaveLength(result.chunks.length); + }); + + test("should preserve chunk order", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 50, + chunkOverlap: 10, + }); + + // Metadata indices should be sequential + result.metadata.forEach((meta, idx) => { + expect(meta.index).toBe(idx); + }); + + // Start positions should be in order + for (let i = 1; i < result.metadata.length; i++) { + expect(result.metadata[i].startChar).toBeGreaterThanOrEqual( + result.metadata[i - 1].startChar! + ); + } + }); + + test("should handle very large chunkSize", async () => { + const result = await textChunker({ + text: testText, + chunkSize: 10000, + chunkOverlap: 0, + }); + + // Should produce single chunk + expect(result.chunks.length).toBe(1); + expect(result.chunks[0]).toBe(testText); + }); + + test("should handle overlap equal to chunkSize (edge case)", async () => { + // This should be handled to prevent infinite loops + const result = await textChunker({ + text: testText, + chunkSize: 50, + chunkOverlap: 50, + }); + + expect(result.chunks.length).toBeGreaterThan(0); + expect(result.metadata.length).toBe(result.chunks.length); + }); + + test("should handle overlap greater than chunkSize (edge case)", async () => { + // Should handle gracefully + const result = await textChunker({ + text: testText, + chunkSize: 30, + chunkOverlap: 50, + }); + + expect(result.chunks.length).toBeGreaterThan(0); + }); +}); diff --git a/packages/test/src/test/task/VectorQuantizeTask.test.ts b/packages/test/src/test/task/VectorQuantizeTask.test.ts new file mode 100644 index 00000000..f88dab04 --- /dev/null +++ b/packages/test/src/test/task/VectorQuantizeTask.test.ts @@ -0,0 +1,228 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { vectorQuantize } from "@workglow/ai"; +import { TensorType } from "@workglow/util"; +import { describe, expect, test } from "vitest"; + +describe("VectorQuantizeTask", () => { + const testVector = new Float32Array([0.5, -0.5, 0.8, -0.3, 0.0, 1.0, -1.0]); + + test("should quantize to INT8", async () => { + const result = await vectorQuantize({ + vector: testVector, + targetType: TensorType.INT8, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Int8Array); + expect(result.originalType).toBe(TensorType.FLOAT32); + expect(result.targetType).toBe(TensorType.INT8); + + const quantized = result.vector as Int8Array; + expect(quantized.length).toBe(testVector.length); + // Values should be scaled to [-127, 127] + expect(quantized[0]).toBe(64); // 0.5 * 127 ≈ 64 + expect(quantized[1]).toBe(-63); // -0.5 * 127 ≈ -63 (rounded) + }); + + test("should quantize to UINT8", async () => { + const result = await vectorQuantize({ + vector: testVector, + targetType: TensorType.UINT8, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Uint8Array); + expect(result.targetType).toBe(TensorType.UINT8); + + const quantized = result.vector as Uint8Array; + expect(quantized.length).toBe(testVector.length); + // Values should be scaled to [0, 255] + expect(quantized.every((v) => v >= 0 && v <= 255)).toBe(true); + }); + + test("should quantize to INT16", async () => { + const result = await vectorQuantize({ + vector: testVector, + targetType: TensorType.INT16, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Int16Array); + expect(result.targetType).toBe(TensorType.INT16); + + const quantized = result.vector as Int16Array; + expect(quantized.length).toBe(testVector.length); + // Values should be scaled to [-32767, 32767] + expect(quantized[0]).toBeCloseTo(16384, -2); // 0.5 * 32767 + }); + + test("should quantize to UINT16", async () => { + const result = await vectorQuantize({ + vector: testVector, + targetType: TensorType.UINT16, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Uint16Array); + expect(result.targetType).toBe(TensorType.UINT16); + + const quantized = result.vector as Uint16Array; + expect(quantized.length).toBe(testVector.length); + // Values should be scaled to [0, 65535] + expect(quantized.every((v) => v >= 0 && v <= 65535)).toBe(true); + }); + + test("should quantize to FLOAT16", async () => { + const result = await vectorQuantize({ + vector: testVector, + targetType: TensorType.FLOAT16, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Float16Array); + expect(result.targetType).toBe(TensorType.FLOAT16); + + const quantized = result.vector as Float16Array; + expect(quantized.length).toBe(testVector.length); + }); + + test("should quantize to FLOAT64", async () => { + const result = await vectorQuantize({ + vector: testVector, + targetType: TensorType.FLOAT64, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Float64Array); + expect(result.targetType).toBe(TensorType.FLOAT64); + + const quantized = result.vector as Float64Array; + expect(quantized.length).toBe(testVector.length); + }); + + test("should handle normalization", async () => { + const unnormalizedVector = new Float32Array([1, 2, 3, 4, 5]); + + const result = await vectorQuantize({ + vector: unnormalizedVector, + targetType: TensorType.INT8, + normalize: true, + }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Int8Array); + + // With normalization, values should be normalized before quantization + const quantized = result.vector as Int8Array; + expect(quantized.length).toBe(unnormalizedVector.length); + }); + + test("should handle array of vectors", async () => { + const vectors = [ + new Float32Array([0.5, -0.5, 0.8]), + new Float32Array([0.1, 0.2, 0.3]), + new Float32Array([-0.4, -0.5, -0.6]), + ]; + + const result = await vectorQuantize({ + vector: vectors, + targetType: TensorType.INT8, + normalize: false, + }); + + expect(result).toBeDefined(); + expect(Array.isArray(result.vector)).toBe(true); + + const quantizedVectors = result.vector as Int8Array[]; + expect(quantizedVectors.length).toBe(3); + quantizedVectors.forEach((v, idx) => { + expect(v).toBeInstanceOf(Int8Array); + expect(v.length).toBe(vectors[idx].length); + }); + }); + + test("should preserve dimensions when quantizing", async () => { + const largeVector = new Float32Array(384).map(() => Math.random() * 2 - 1); + + const result = await vectorQuantize({ + vector: largeVector, + targetType: TensorType.INT8, + normalize: true, + }); + + expect(result).toBeDefined(); + const quantized = result.vector as Int8Array; + expect(quantized.length).toBe(largeVector.length); + }); + + test("should handle edge cases in INT8 quantization", async () => { + const edgeVector = new Float32Array([1.0, -1.0, 1.5, -1.5, 0.0]); + + const result = await vectorQuantize({ + vector: edgeVector, + targetType: TensorType.INT8, + normalize: false, + }); + + const quantized = result.vector as Int8Array; + // Values clamped to [-1, 1] before scaling + expect(quantized[0]).toBe(127); // 1.0 * 127 + expect(quantized[1]).toBe(-127); // -1.0 * 127 + expect(quantized[2]).toBe(127); // 1.5 clamped to 1.0 + expect(quantized[3]).toBe(-127); // -1.5 clamped to -1.0 + expect(quantized[4]).toBe(0); // 0.0 + }); + + test("should detect original vector type", async () => { + const int8Vector = new Int8Array([10, 20, 30, 40]); + + const result = await vectorQuantize({ + vector: int8Vector, + targetType: TensorType.FLOAT32, + normalize: false, + }); + + expect(result.originalType).toBe(TensorType.INT8); + expect(result.targetType).toBe(TensorType.FLOAT32); + expect(result.vector).toBeInstanceOf(Float32Array); + }); + + test("should handle different typed arrays as input", async () => { + const testCases = [ + { input: new Float16Array([0.5, -0.5]), expected: TensorType.FLOAT16 }, + { input: new Float32Array([0.5, -0.5]), expected: TensorType.FLOAT32 }, + { input: new Float64Array([0.5, -0.5]), expected: TensorType.FLOAT64 }, + { input: new Int8Array([10, -10]), expected: TensorType.INT8 }, + { input: new Uint8Array([10, 20]), expected: TensorType.UINT8 }, + { input: new Int16Array([100, -100]), expected: TensorType.INT16 }, + { input: new Uint16Array([100, 200]), expected: TensorType.UINT16 }, + ]; + + for (const testCase of testCases) { + const result = await vectorQuantize({ + vector: testCase.input, + targetType: TensorType.FLOAT32, + normalize: false, + }); + expect(result.originalType).toBe(testCase.expected); + } + }); + + test("should use default normalize value of true", async () => { + const result = await vectorQuantize({ vector: testVector, targetType: TensorType.INT8 }); + + expect(result).toBeDefined(); + expect(result.vector).toBeInstanceOf(Int8Array); + }); +}); diff --git a/packages/test/src/test/task/VectorStoreSearchTask.test.ts b/packages/test/src/test/task/VectorStoreSearchTask.test.ts new file mode 100644 index 00000000..c722a966 --- /dev/null +++ b/packages/test/src/test/task/VectorStoreSearchTask.test.ts @@ -0,0 +1,245 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { VectorStoreSearchTask } from "@workglow/ai"; +import { InMemoryVectorRepository, registerVectorRepository } from "@workglow/storage"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; + +describe("VectorStoreSearchTask", () => { + let repo: InMemoryVectorRepository<{ text: string; category?: string }>; + + beforeEach(async () => { + repo = new InMemoryVectorRepository<{ text: string; category?: string }>(); + await repo.setupDatabase(); + + // Populate repository with test data + const vectors = [ + new Float32Array([1.0, 0.0, 0.0]), // doc1 - similar to query + new Float32Array([0.8, 0.2, 0.0]), // doc2 - somewhat similar + new Float32Array([0.0, 1.0, 0.0]), // doc3 - different + new Float32Array([0.0, 0.0, 1.0]), // doc4 - different + new Float32Array([0.9, 0.1, 0.0]), // doc5 - very similar + ]; + + const metadata = [ + { text: "Document about AI", category: "tech" }, + { text: "Document about machine learning", category: "tech" }, + { text: "Document about cooking", category: "food" }, + { text: "Document about travel", category: "travel" }, + { text: "Document about artificial intelligence", category: "tech" }, + ]; + + for (let i = 0; i < vectors.length; i++) { + await repo.upsert(`doc${i + 1}`, vectors[i], metadata[i]); + } + }); + + afterEach(() => { + repo.destroy(); + }); + + test("should search and return top K results", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new VectorStoreSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 3, + }); + + expect(result.count).toBe(3); + expect(result.ids).toHaveLength(3); + expect(result.vectors).toHaveLength(3); + expect(result.metadata).toHaveLength(3); + expect(result.scores).toHaveLength(3); + + // Scores should be in descending order + for (let i = 1; i < result.scores.length; i++) { + expect(result.scores[i - 1]).toBeGreaterThanOrEqual(result.scores[i]); + } + + // Most similar should be doc1 (exact match) + expect(result.ids[0]).toBe("doc1"); + }); + + test("should respect topK limit", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new VectorStoreSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 2, + }); + + expect(result.count).toBe(2); + expect(result.ids).toHaveLength(2); + }); + + test("should filter by metadata", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new VectorStoreSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 10, + filter: { category: "tech" }, + }); + + expect(result.count).toBeGreaterThan(0); + // All results should have category "tech" + result.metadata.forEach((meta) => { + expect(meta).toHaveProperty("category", "tech"); + }); + }); + + test("should apply score threshold", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new VectorStoreSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 10, + scoreThreshold: 0.9, + }); + + // All scores should be >= 0.9 + result.scores.forEach((score) => { + expect(score).toBeGreaterThanOrEqual(0.9); + }); + }); + + test("should return empty results when no matches", async () => { + const queryVector = new Float32Array([0.0, 0.0, 1.0]); + + const task = new VectorStoreSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 10, + filter: { category: "nonexistent" }, + }); + + expect(result.count).toBe(0); + expect(result.ids).toHaveLength(0); + expect(result.vectors).toHaveLength(0); + expect(result.metadata).toHaveLength(0); + expect(result.scores).toHaveLength(0); + }); + + test("should handle default topK value", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new VectorStoreSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + }); + + // Default topK is 10, but we only have 5 documents + expect(result.count).toBe(5); + expect(result.count).toBeLessThanOrEqual(10); + }); + + test("should work with quantized query vectors (Int8Array)", async () => { + const queryVector = new Int8Array([127, 0, 0]); + + const task = new VectorStoreSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 3, + }); + + expect(result.count).toBeGreaterThan(0); + expect(result.ids).toHaveLength(result.count); + expect(result.scores).toHaveLength(result.count); + }); + + test("should return results sorted by similarity score", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new VectorStoreSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 5, + }); + + // Verify descending order + for (let i = 1; i < result.scores.length; i++) { + expect(result.scores[i - 1]).toBeGreaterThanOrEqual(result.scores[i]); + } + }); + + test("should handle empty repository", async () => { + const emptyRepo = new InMemoryVectorRepository<{ text: string }>(); + await emptyRepo.setupDatabase(); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new VectorStoreSearchTask(); + const result = await task.run({ + repository: emptyRepo, + query: queryVector, + topK: 10, + }); + + expect(result.count).toBe(0); + expect(result.ids).toHaveLength(0); + expect(result.scores).toHaveLength(0); + + emptyRepo.destroy(); + }); + + test("should combine filter and score threshold", async () => { + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new VectorStoreSearchTask(); + const result = await task.run({ + repository: repo, + query: queryVector, + topK: 10, + filter: { category: "tech" }, + scoreThreshold: 0.7, + }); + + // All results should pass both filter and threshold + result.metadata.forEach((meta) => { + expect(meta).toHaveProperty("category", "tech"); + }); + result.scores.forEach((score) => { + expect(score).toBeGreaterThanOrEqual(0.7); + }); + }); + + test("should resolve repository from string ID", async () => { + // Register repository by ID + registerVectorRepository("test-vector-repo", repo); + + const queryVector = new Float32Array([1.0, 0.0, 0.0]); + + const task = new VectorStoreSearchTask(); + // Pass repository as string ID instead of instance + const result = await task.run({ + repository: "test-vector-repo" as any, + query: queryVector, + topK: 3, + }); + + expect(result.count).toBe(3); + expect(result.ids).toHaveLength(3); + expect(result.vectors).toHaveLength(3); + expect(result.metadata).toHaveLength(3); + expect(result.scores).toHaveLength(3); + + // Most similar should be doc1 (exact match) + expect(result.ids[0]).toBe("doc1"); + }); +}); diff --git a/packages/test/src/test/task/VectorStoreUpsertTask.test.ts b/packages/test/src/test/task/VectorStoreUpsertTask.test.ts new file mode 100644 index 00000000..db4ba675 --- /dev/null +++ b/packages/test/src/test/task/VectorStoreUpsertTask.test.ts @@ -0,0 +1,233 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { VectorStoreUpsertTask } from "@workglow/ai"; +import { InMemoryVectorRepository, registerVectorRepository } from "@workglow/storage"; +import { afterEach, beforeEach, describe, expect, test } from "vitest"; + +describe("VectorStoreUpsertTask", () => { + let repo: InMemoryVectorRepository<{ text: string; source?: string }>; + + beforeEach(async () => { + repo = new InMemoryVectorRepository<{ text: string; source?: string }>(); + await repo.setupDatabase(); + }); + + afterEach(() => { + repo.destroy(); + }); + + test("should upsert a single vector", async () => { + const vector = new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5]); + const metadata = { text: "Test document", source: "test.txt" }; + + const task = new VectorStoreUpsertTask(); + const result = await task.run({ + repository: repo, + ids: "doc1", + vectors: [vector], + metadata: metadata, + }); + + expect(result.count).toBe(1); + expect(result.ids).toEqual(["doc1"]); + + // Verify vector was stored + const retrieved = await repo.get("doc1"); + expect(retrieved).toBeDefined(); + expect(retrieved?.id).toBe("doc1"); + expect(retrieved?.metadata).toEqual(metadata); + expect(retrieved?.vector).toEqual(vector); + }); + + test("should upsert multiple vectors in bulk", async () => { + const vectors = [ + new Float32Array([0.1, 0.2, 0.3]), + new Float32Array([0.4, 0.5, 0.6]), + new Float32Array([0.7, 0.8, 0.9]), + ]; + const metadata = [ + { text: "Document 1", source: "doc1.txt" }, + { text: "Document 2", source: "doc2.txt" }, + { text: "Document 3", source: "doc3.txt" }, + ]; + + const task = new VectorStoreUpsertTask(); + const result = await task.run({ + repository: repo, + ids: ["doc1", "doc2", "doc3"], + vectors: vectors, + metadata: metadata, + }); + + expect(result.count).toBe(3); + expect(result.ids).toEqual(["doc1", "doc2", "doc3"]); + + // Verify all vectors were stored + for (let i = 0; i < 3; i++) { + const retrieved = await repo.get(`doc${i + 1}`); + expect(retrieved).toBeDefined(); + expect(retrieved?.metadata).toEqual(metadata[i]); + expect(retrieved?.vector).toEqual(vectors[i]); + } + }); + + test("should handle array of single item (normalized to bulk)", async () => { + const vector = new Float32Array([0.1, 0.2, 0.3]); + const metadata = { text: "Single item as array" }; + + const task = new VectorStoreUpsertTask(); + const result = await task.run({ + repository: repo, + ids: ["doc1"], + vectors: [vector], + metadata: [metadata], + }); + + expect(result.count).toBe(1); + expect(result.ids).toEqual(["doc1"]); + + const retrieved = await repo.get("doc1"); + expect(retrieved).toBeDefined(); + expect(retrieved?.metadata).toEqual(metadata); + }); + + test("should update existing vector when upserting with same ID", async () => { + const vector1 = new Float32Array([0.1, 0.2, 0.3]); + const vector2 = new Float32Array([0.9, 0.8, 0.7]); + const metadata1 = { text: "Original document" }; + const metadata2 = { text: "Updated document", source: "updated.txt" }; + + // First upsert + const task1 = new VectorStoreUpsertTask(); + await task1.run({ + repository: repo, + ids: "doc1", + vectors: [vector1], + metadata: metadata1, + }); + + // Update with same ID + const task2 = new VectorStoreUpsertTask(); + await task2.run({ + repository: repo, + ids: "doc1", + vectors: [vector2], + metadata: metadata2, + }); + + const retrieved = await repo.get("doc1"); + expect(retrieved).toBeDefined(); + expect(retrieved?.vector).toEqual(vector2); + expect(retrieved?.metadata).toEqual(metadata2); + }); + + test("should throw error on mismatched array lengths", async () => { + const vectors = [new Float32Array([0.1, 0.2]), new Float32Array([0.3, 0.4])]; + const metadata = [{ text: "Only one metadata" }]; + + const task = new VectorStoreUpsertTask(); + await expect( + task.run({ + repository: repo, + ids: ["doc1", "doc2"], + vectors: vectors, + metadata: metadata, + }) + ).rejects.toThrow("Mismatched array lengths"); + }); + + test("should handle quantized vectors (Int8Array)", async () => { + const vector = new Int8Array([127, -128, 64, -64, 0]); + const metadata = { text: "Quantized vector" }; + + const task = new VectorStoreUpsertTask(); + const result = await task.run({ + repository: repo, + ids: "doc1", + vectors: [vector], + metadata: metadata, + }); + + expect(result.count).toBe(1); + + const retrieved = await repo.get("doc1"); + expect(retrieved).toBeDefined(); + expect(retrieved?.vector).toBeInstanceOf(Int8Array); + expect(retrieved?.vector).toEqual(vector); + }); + + test("should handle metadata without optional fields", async () => { + const vector = new Float32Array([0.1, 0.2, 0.3]); + const metadata = { text: "Simple metadata" }; + + const task = new VectorStoreUpsertTask(); + const result = await task.run({ + repository: repo, + ids: "doc1", + vectors: [vector], + metadata: metadata, + }); + + expect(result.count).toBe(1); + + const retrieved = await repo.get("doc1"); + expect(retrieved?.metadata).toEqual(metadata); + }); + + test("should handle large batch upsert", async () => { + const count = 100; + const vectors = Array.from( + { length: count }, + (_, i) => new Float32Array([i * 0.01, i * 0.02, i * 0.03]) + ); + const metadata = Array.from({ length: count }, (_, i) => ({ + text: `Document ${i + 1}`, + })); + const ids = Array.from({ length: count }, (_, i) => `doc${i + 1}`); + + const task = new VectorStoreUpsertTask(); + const result = await task.run({ + repository: repo, + ids: ids, + vectors: vectors, + metadata: metadata, + }); + + expect(result.count).toBe(count); + expect(result.ids).toHaveLength(count); + + const size = await repo.size(); + expect(size).toBe(count); + }); + + test("should resolve repository from string ID", async () => { + // Register repository by ID + registerVectorRepository("test-upsert-repo", repo); + + const vector = new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5]); + const metadata = { text: "Test document", source: "test.txt" }; + + const task = new VectorStoreUpsertTask(); + // Pass repository as string ID instead of instance + const result = await task.run({ + repository: "test-upsert-repo" as any, + ids: "doc1", + vectors: [vector], + metadata: metadata, + }); + + expect(result.count).toBe(1); + expect(result.ids).toEqual(["doc1"]); + + // Verify vector was stored + const retrieved = await repo.get("doc1"); + expect(retrieved).toBeDefined(); + expect(retrieved?.id).toBe("doc1"); + expect(retrieved?.metadata).toEqual(metadata); + expect(retrieved?.vector).toEqual(vector); + }); +}); diff --git a/packages/test/src/test/util/Document.test.ts b/packages/test/src/test/util/Document.test.ts new file mode 100644 index 00000000..9e4e1041 --- /dev/null +++ b/packages/test/src/test/util/Document.test.ts @@ -0,0 +1,52 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { ChunkNode, DocumentNode } from "@workglow/ai"; +import { Document, NodeKind } from "@workglow/ai"; +import { describe, expect, test } from "vitest"; + +describe("Document", () => { + const createTestDocumentNode = (): DocumentNode => ({ + nodeId: "root", + kind: NodeKind.DOCUMENT, + range: { startOffset: 0, endOffset: 100 }, + text: "Test document", + title: "Test document", + children: [], + }); + + const createTestChunks = (): ChunkNode[] => [ + { + chunkId: "chunk1", + docId: "doc1", + text: "Test chunk", + nodePath: ["root"], + depth: 1, + }, + ]; + + test("setChunks and getChunks", () => { + const doc = new Document("doc1", createTestDocumentNode(), { title: "Test" }); + + doc.setChunks(createTestChunks()); + + const chunks = doc.getChunks(); + expect(chunks).toBeDefined(); + expect(chunks.length).toBe(1); + expect(chunks[0].text).toBe("Test chunk"); + }); + + test("findChunksByNodeId", () => { + const doc = new Document("doc1", createTestDocumentNode(), { title: "Test" }); + + doc.setChunks(createTestChunks()); + + const chunks = doc.findChunksByNodeId("root"); + expect(chunks).toBeDefined(); + expect(chunks.length).toBe(1); + expect(chunks[0].text).toBe("Test chunk"); + }); +}); diff --git a/packages/test/src/test/util/VectorSimilarityUtils.test.ts b/packages/test/src/test/util/VectorSimilarityUtils.test.ts new file mode 100644 index 00000000..01408040 --- /dev/null +++ b/packages/test/src/test/util/VectorSimilarityUtils.test.ts @@ -0,0 +1,390 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + cosineSimilarity, + jaccardSimilarity, + hammingDistance, + hammingSimilarity, +} from "@workglow/util"; +import { describe, expect, test } from "vitest"; + +describe("VectorSimilarityUtils", () => { + describe("cosineSimilarity", () => { + test("should calculate cosine similarity for identical vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([1, 2, 3, 4]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should calculate cosine similarity for orthogonal vectors", () => { + const a = new Float32Array([1, 0, 0]); + const b = new Float32Array([0, 1, 0]); + expect(cosineSimilarity(a, b)).toBeCloseTo(0.0, 5); + }); + + test("should calculate cosine similarity for opposite vectors", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([-1, -2, -3]); + expect(cosineSimilarity(a, b)).toBeCloseTo(-1.0, 5); + }); + + test("should handle zero vectors", () => { + const a = new Float32Array([0, 0, 0]); + const b = new Float32Array([1, 2, 3]); + expect(cosineSimilarity(a, b)).toBe(0); + }); + + test("should handle both zero vectors", () => { + const a = new Float32Array([0, 0, 0]); + const b = new Float32Array([0, 0, 0]); + expect(cosineSimilarity(a, b)).toBe(0); + }); + + test("should work with Int8Array", () => { + const a = new Int8Array([10, 20, 30]); + const b = new Int8Array([10, 20, 30]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Uint8Array", () => { + const a = new Uint8Array([10, 20, 30]); + const b = new Uint8Array([10, 20, 30]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Int16Array", () => { + const a = new Int16Array([100, 200, 300]); + const b = new Int16Array([100, 200, 300]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Uint16Array", () => { + const a = new Uint16Array([100, 200, 300]); + const b = new Uint16Array([100, 200, 300]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Float64Array", () => { + const a = new Float64Array([1.5, 2.5, 3.5]); + const b = new Float64Array([1.5, 2.5, 3.5]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should calculate cosine similarity for partially similar vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([2, 3, 4, 5]); + const result = cosineSimilarity(a, b); + expect(result).toBeGreaterThan(0.9); + expect(result).toBeLessThan(1.0); + }); + + test("should throw error for mismatched vector lengths", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([1, 2]); + expect(() => cosineSimilarity(a, b)).toThrow("Vectors must have the same length"); + }); + + test("should handle negative values correctly", () => { + const a = new Float32Array([-1, -2, -3]); + const b = new Float32Array([-1, -2, -3]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should handle mixed positive and negative values", () => { + const a = new Float32Array([1, -2, 3, -4]); + const b = new Float32Array([1, -2, 3, -4]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should handle large vectors", () => { + const size = 1000; + const a = new Float32Array(size).fill(1); + const b = new Float32Array(size).fill(1); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + }); + + describe("jaccardSimilarity", () => { + test("should calculate Jaccard similarity for identical vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([1, 2, 3, 4]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should calculate Jaccard similarity for completely different vectors", () => { + const a = new Float32Array([5, 5, 5]); + const b = new Float32Array([1, 1, 1]); + const result = jaccardSimilarity(a, b); + expect(result).toBeGreaterThan(0); + expect(result).toBeLessThan(1); + }); + + test("should handle zero vectors", () => { + const a = new Float32Array([0, 0, 0]); + const b = new Float32Array([1, 2, 3]); + expect(jaccardSimilarity(a, b)).toBe(0); + }); + + test("should handle both zero vectors", () => { + const a = new Float32Array([0, 0, 0]); + const b = new Float32Array([0, 0, 0]); + expect(jaccardSimilarity(a, b)).toBe(0); + }); + + test("should work with Int8Array", () => { + const a = new Int8Array([10, 20, 30]); + const b = new Int8Array([10, 20, 30]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Uint8Array", () => { + const a = new Uint8Array([10, 20, 30]); + const b = new Uint8Array([10, 20, 30]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Int16Array", () => { + const a = new Int16Array([100, 200, 300]); + const b = new Int16Array([100, 200, 300]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Uint16Array", () => { + const a = new Uint16Array([100, 200, 300]); + const b = new Uint16Array([100, 200, 300]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should work with Float64Array", () => { + const a = new Float64Array([1.5, 2.5, 3.5]); + const b = new Float64Array([1.5, 2.5, 3.5]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should calculate correct similarity for partially overlapping vectors", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([2, 3, 4]); + const result = jaccardSimilarity(a, b); + expect(result).toBeGreaterThan(0); + expect(result).toBeLessThan(1); + }); + + test("should throw error for mismatched vector lengths", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([1, 2]); + expect(() => jaccardSimilarity(a, b)).toThrow("Vectors must have the same length"); + }); + + test("should handle all positive values", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([1, 2, 3]); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should handle negative values by using min/max", () => { + const a = new Float32Array([-1, -2, -3]); + const b = new Float32Array([-2, -3, -4]); + const result = jaccardSimilarity(a, b); + expect(result).toBeGreaterThan(0); + expect(result).toBeLessThan(1); + }); + }); + + describe("hammingDistance", () => { + test("should calculate Hamming distance for identical vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([1, 2, 3, 4]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should calculate Hamming distance for completely different vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([5, 6, 7, 8]); + expect(hammingDistance(a, b)).toBe(1.0); + }); + + test("should calculate Hamming distance for partially different vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([1, 2, 5, 6]); + expect(hammingDistance(a, b)).toBe(0.5); + }); + + test("should handle zero vectors", () => { + const a = new Float32Array([0, 0, 0]); + const b = new Float32Array([0, 0, 0]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should work with Int8Array", () => { + const a = new Int8Array([10, 20, 30]); + const b = new Int8Array([10, 20, 30]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should work with Uint8Array", () => { + const a = new Uint8Array([10, 20, 30]); + const b = new Uint8Array([10, 20, 40]); + expect(hammingDistance(a, b)).toBeCloseTo(1 / 3, 5); + }); + + test("should work with Int16Array", () => { + const a = new Int16Array([100, 200, 300]); + const b = new Int16Array([100, 200, 300]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should work with Uint16Array", () => { + const a = new Uint16Array([100, 200, 300]); + const b = new Uint16Array([100, 200, 300]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should work with Float64Array", () => { + const a = new Float64Array([1.5, 2.5, 3.5]); + const b = new Float64Array([1.5, 2.5, 3.5]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should throw error for mismatched vector lengths", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([1, 2]); + expect(() => hammingDistance(a, b)).toThrow("Vectors must have the same length"); + }); + + test("should handle negative values", () => { + const a = new Float32Array([-1, -2, -3]); + const b = new Float32Array([-1, -2, -3]); + expect(hammingDistance(a, b)).toBe(0); + }); + + test("should distinguish between close but not equal values", () => { + const a = new Float32Array([1.0, 2.0, 3.0]); + const b = new Float32Array([1.0001, 2.0, 3.0]); + expect(hammingDistance(a, b)).toBeCloseTo(1 / 3, 5); + }); + + test("should normalize distance by vector length", () => { + const a = new Float32Array([1, 2, 3, 4, 5, 6, 7, 8]); + const b = new Float32Array([1, 2, 3, 4, 9, 10, 11, 12]); + expect(hammingDistance(a, b)).toBe(0.5); + }); + }); + + describe("hammingSimilarity", () => { + test("should calculate Hamming similarity for identical vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([1, 2, 3, 4]); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + + test("should calculate Hamming similarity for completely different vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([5, 6, 7, 8]); + expect(hammingSimilarity(a, b)).toBe(0); + }); + + test("should calculate Hamming similarity for partially different vectors", () => { + const a = new Float32Array([1, 2, 3, 4]); + const b = new Float32Array([1, 2, 5, 6]); + expect(hammingSimilarity(a, b)).toBe(0.5); + }); + + test("should be inverse of Hamming distance", () => { + const a = new Float32Array([1, 2, 3, 4, 5]); + const b = new Float32Array([1, 6, 3, 8, 5]); + const distance = hammingDistance(a, b); + const similarity = hammingSimilarity(a, b); + expect(similarity).toBeCloseTo(1 - distance, 5); + }); + + test("should work with Int8Array", () => { + const a = new Int8Array([10, 20, 30]); + const b = new Int8Array([10, 20, 30]); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + + test("should work with Uint8Array", () => { + const a = new Uint8Array([10, 20, 30]); + const b = new Uint8Array([10, 20, 40]); + expect(hammingSimilarity(a, b)).toBeCloseTo(2 / 3, 5); + }); + + test("should work with Int16Array", () => { + const a = new Int16Array([100, 200, 300]); + const b = new Int16Array([100, 200, 300]); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + + test("should work with Uint16Array", () => { + const a = new Uint16Array([100, 200, 300]); + const b = new Uint16Array([100, 200, 300]); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + + test("should work with Float64Array", () => { + const a = new Float64Array([1.5, 2.5, 3.5]); + const b = new Float64Array([1.5, 2.5, 3.5]); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + + test("should throw error for mismatched vector lengths", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([1, 2]); + expect(() => hammingSimilarity(a, b)).toThrow("Vectors must have the same length"); + }); + + test("should handle zero vectors", () => { + const a = new Float32Array([0, 0, 0]); + const b = new Float32Array([0, 0, 0]); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + }); + + describe("Edge cases and cross-function consistency", () => { + test("should handle single element vectors", () => { + const a = new Float32Array([5]); + const b = new Float32Array([5]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + expect(hammingDistance(a, b)).toBe(0); + expect(hammingSimilarity(a, b)).toBe(1.0); + }); + + test("should handle empty vectors", () => { + const a = new Float32Array([]); + const b = new Float32Array([]); + // For empty vectors, the functions should handle them gracefully + expect(hammingDistance(a, b)).toBeNaN(); // 0/0 + expect(hammingSimilarity(a, b)).toBeNaN(); + }); + + test("should handle very small values", () => { + const a = new Float32Array([0.0001, 0.0002, 0.0003]); + const b = new Float32Array([0.0001, 0.0002, 0.0003]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("should handle very large values", () => { + const a = new Float32Array([10000, 20000, 30000]); + const b = new Float32Array([10000, 20000, 30000]); + expect(cosineSimilarity(a, b)).toBeCloseTo(1.0, 5); + expect(jaccardSimilarity(a, b)).toBeCloseTo(1.0, 5); + }); + + test("all functions should throw same error for length mismatch", () => { + const a = new Float32Array([1, 2, 3]); + const b = new Float32Array([1, 2]); + const errorMessage = "Vectors must have the same length"; + + expect(() => cosineSimilarity(a, b)).toThrow(errorMessage); + expect(() => jaccardSimilarity(a, b)).toThrow(errorMessage); + expect(() => hammingDistance(a, b)).toThrow(errorMessage); + expect(() => hammingSimilarity(a, b)).toThrow(errorMessage); + }); + }); +}); diff --git a/packages/test/src/test/util/VectorUtils.test.ts b/packages/test/src/test/util/VectorUtils.test.ts new file mode 100644 index 00000000..135ef740 --- /dev/null +++ b/packages/test/src/test/util/VectorUtils.test.ts @@ -0,0 +1,382 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { inner, magnitude, normalize, normalizeNumberArray } from "@workglow/util"; +import { describe, expect, test } from "vitest"; + +describe("VectorUtils", () => { + describe("magnitude", () => { + test("should calculate magnitude for Float32Array", () => { + const vector = new Float32Array([3, 4]); + const result = magnitude(vector); + expect(result).toBe(5); + }); + + test("should calculate magnitude for Float64Array", () => { + const vector = new Float64Array([1, 2, 2]); + const result = magnitude(vector); + expect(result).toBe(3); + }); + + test("should calculate magnitude for Int8Array", () => { + const vector = new Int8Array([6, 8]); + const result = magnitude(vector); + expect(result).toBe(10); + }); + + test("should calculate magnitude for Uint8Array", () => { + const vector = new Uint8Array([5, 12]); + const result = magnitude(vector); + expect(result).toBe(13); + }); + + test("should calculate magnitude for Int16Array", () => { + const vector = new Int16Array([3, 4]); + const result = magnitude(vector); + expect(result).toBe(5); + }); + + test("should calculate magnitude for Uint16Array", () => { + const vector = new Uint16Array([8, 15]); + const result = magnitude(vector); + expect(result).toBe(17); + }); + + test("should calculate magnitude for Float16Array", () => { + const vector = new Float16Array([3, 4]); + const result = magnitude(vector); + expect(result).toBeCloseTo(5, 1); + }); + + test("should calculate magnitude for number array", () => { + const vector = [3, 4]; + const result = magnitude(vector); + expect(result).toBe(5); + }); + + test("should return 0 for zero vector", () => { + const vector = new Float32Array([0, 0, 0]); + const result = magnitude(vector); + expect(result).toBe(0); + }); + + test("should handle single element vector", () => { + const vector = new Float32Array([5]); + const result = magnitude(vector); + expect(result).toBe(5); + }); + + test("should handle negative values", () => { + const vector = new Float32Array([-3, -4]); + const result = magnitude(vector); + expect(result).toBe(5); + }); + + test("should handle mixed positive and negative values", () => { + const vector = new Float32Array([3, -4]); + const result = magnitude(vector); + expect(result).toBe(5); + }); + + test("should handle large vectors", () => { + const vector = new Float32Array(1000).fill(1); + const result = magnitude(vector); + expect(result).toBeCloseTo(Math.sqrt(1000), 5); + }); + }); + + describe("inner", () => { + test("should calculate dot product for Float32Array", () => { + const arr1 = new Float32Array([1, 2, 3]); + const arr2 = new Float32Array([4, 5, 6]); + const result = inner(arr1, arr2); + expect(result).toBe(32); // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32 + }); + + test("should calculate dot product for Float64Array", () => { + const arr1 = new Float64Array([2, 3]); + const arr2 = new Float64Array([4, 5]); + const result = inner(arr1, arr2); + expect(result).toBe(23); // 2*4 + 3*5 = 8 + 15 = 23 + }); + + test("should calculate dot product for Int8Array", () => { + const arr1 = new Int8Array([1, 2, 3]); + const arr2 = new Int8Array([4, 5, 6]); + const result = inner(arr1, arr2); + expect(result).toBe(32); + }); + + test("should calculate dot product for Uint8Array", () => { + const arr1 = new Uint8Array([1, 2, 3]); + const arr2 = new Uint8Array([4, 5, 6]); + const result = inner(arr1, arr2); + expect(result).toBe(32); + }); + + test("should calculate dot product for Int16Array", () => { + const arr1 = new Int16Array([10, 20]); + const arr2 = new Int16Array([5, 3]); + const result = inner(arr1, arr2); + expect(result).toBe(110); // 10*5 + 20*3 = 50 + 60 = 110 + }); + + test("should calculate dot product for Uint16Array", () => { + const arr1 = new Uint16Array([10, 20]); + const arr2 = new Uint16Array([5, 3]); + const result = inner(arr1, arr2); + expect(result).toBe(110); + }); + + test("should calculate dot product for Float16Array", () => { + const arr1 = new Float16Array([1, 2, 3]); + const arr2 = new Float16Array([4, 5, 6]); + const result = inner(arr1, arr2); + expect(result).toBeCloseTo(32, 0); + }); + + test("should return 0 for zero vectors", () => { + const arr1 = new Float32Array([0, 0, 0]); + const arr2 = new Float32Array([1, 2, 3]); + const result = inner(arr1, arr2); + expect(result).toBe(0); + }); + + test("should handle orthogonal vectors", () => { + const arr1 = new Float32Array([1, 0, 0]); + const arr2 = new Float32Array([0, 1, 0]); + const result = inner(arr1, arr2); + expect(result).toBe(0); + }); + + test("should handle negative values", () => { + const arr1 = new Float32Array([-1, -2, -3]); + const arr2 = new Float32Array([4, 5, 6]); + const result = inner(arr1, arr2); + expect(result).toBe(-32); // -1*4 + -2*5 + -3*6 = -4 - 10 - 18 = -32 + }); + + test("should handle single element vectors", () => { + const arr1 = new Float32Array([5]); + const arr2 = new Float32Array([3]); + const result = inner(arr1, arr2); + expect(result).toBe(15); + }); + + test("should handle large vectors", () => { + const size = 1000; + const arr1 = new Float32Array(size).fill(1); + const arr2 = new Float32Array(size).fill(2); + const result = inner(arr1, arr2); + expect(result).toBe(2000); + }); + }); + + describe("normalize", () => { + test("should normalize Float32Array to unit length", () => { + const vector = new Float32Array([3, 4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float32Array); + expect(result.length).toBe(2); + expect(result[0]).toBeCloseTo(0.6, 5); + expect(result[1]).toBeCloseTo(0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should normalize Float64Array to unit length", () => { + const vector = new Float64Array([3, 4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float64Array); + expect(result[0]).toBeCloseTo(0.6, 5); + expect(result[1]).toBeCloseTo(0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should normalize Int8Array to unit length", () => { + const vector = new Int8Array([3, 4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Int8Array); + expect(result.length).toBe(2); + // Int8Array will truncate the decimal values (0.6, 0.8 -> 0, 0) + // So magnitude will be 0, which is expected behavior for integer arrays + expect(magnitude(result)).toBe(0); + }); + + test("should normalize Uint8Array to unit length", () => { + const vector = new Uint8Array([3, 4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Uint8Array); + expect(result.length).toBe(2); + // Uint8Array will truncate the decimal values (0.6, 0.8 -> 0, 0) + // So magnitude will be 0, which is expected behavior for integer arrays + expect(magnitude(result)).toBe(0); + }); + + test("should normalize Int16Array to unit length", () => { + const vector = new Int16Array([3, 4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Int16Array); + expect(result.length).toBe(2); + // Int16Array will truncate the decimal values (0.6, 0.8 -> 0, 0) + // So magnitude will be 0, which is expected behavior for integer arrays + expect(magnitude(result)).toBe(0); + }); + + test("should normalize Uint16Array to unit length", () => { + const vector = new Uint16Array([3, 4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Uint16Array); + expect(result.length).toBe(2); + // Uint16Array will truncate the decimal values (0.6, 0.8 -> 0, 0) + // So magnitude will be 0, which is expected behavior for integer arrays + expect(magnitude(result)).toBe(0); + }); + + test("should normalize Float16Array and convert to Float32Array", () => { + const vector = new Float16Array([3, 4]); + const result = normalize(vector, true, true); + // For Float16Array, the function should return Float32Array + expect(result).toBeInstanceOf(Float32Array); + expect(result[0]).toBeCloseTo(0.6, 1); + expect(result[1]).toBeCloseTo(0.8, 1); + }); + + test("should throw error for zero vector by default", () => { + const vector = new Float32Array([0, 0, 0]); + expect(() => normalize(vector)).toThrow("Cannot normalize a zero vector."); + }); + + test("should return original zero vector when throwOnZero is false", () => { + const vector = new Float32Array([0, 0, 0]); + const result = normalize(vector, false); + expect(result).toBe(vector); + expect(result[0]).toBe(0); + expect(result[1]).toBe(0); + expect(result[2]).toBe(0); + }); + + test("should handle negative values", () => { + const vector = new Float32Array([-3, -4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float32Array); + expect(result[0]).toBeCloseTo(-0.6, 5); + expect(result[1]).toBeCloseTo(-0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle mixed positive and negative values", () => { + const vector = new Float32Array([3, -4]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float32Array); + expect(result[0]).toBeCloseTo(0.6, 5); + expect(result[1]).toBeCloseTo(-0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle single element vector", () => { + const vector = new Float32Array([5]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float32Array); + expect(result[0]).toBe(1); + }); + + test("should handle already normalized vector", () => { + const vector = new Float32Array([0.6, 0.8]); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float32Array); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle large vectors", () => { + const vector = new Float32Array(1000).fill(1); + const result = normalize(vector); + expect(result).toBeInstanceOf(Float32Array); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should preserve type for other integer arrays", () => { + // Test Int32Array which is not explicitly handled + const vector = new Int32Array([3, 4]); + // @ts-ignore - Int32Array is not explicitly handled by normalize + const result = normalize(vector); + // Should fall through to Float32Array for unhandled types + expect(result).toBeInstanceOf(Float32Array); + }); + }); + + describe("normalizeNumberArray", () => { + test("should normalize number array to unit length", () => { + const values = [3, 4]; + const result = normalizeNumberArray(values); + expect(Array.isArray(result)).toBe(true); + expect(result.length).toBe(2); + expect(result[0]).toBeCloseTo(0.6, 5); + expect(result[1]).toBeCloseTo(0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should return original array for zero vector by default", () => { + const values = [0, 0, 0]; + const result = normalizeNumberArray(values); + expect(result).toBe(values); + expect(result[0]).toBe(0); + expect(result[1]).toBe(0); + expect(result[2]).toBe(0); + }); + + test("should throw error for zero vector when throwOnZero is true", () => { + const values = [0, 0, 0]; + expect(() => normalizeNumberArray(values, true)).toThrow("Cannot normalize a zero vector."); + }); + + test("should handle negative values", () => { + const values = [-3, -4]; + const result = normalizeNumberArray(values); + expect(result[0]).toBeCloseTo(-0.6, 5); + expect(result[1]).toBeCloseTo(-0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle mixed positive and negative values", () => { + const values = [3, -4]; + const result = normalizeNumberArray(values); + expect(result[0]).toBeCloseTo(0.6, 5); + expect(result[1]).toBeCloseTo(-0.8, 5); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle single element array", () => { + const values = [5]; + const result = normalizeNumberArray(values); + expect(result[0]).toBe(1); + }); + + test("should handle already normalized array", () => { + const values = [0.6, 0.8]; + const result = normalizeNumberArray(values); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle large arrays", () => { + const values = new Array(1000).fill(1); + const result = normalizeNumberArray(values); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should handle decimal values", () => { + const values = [0.1, 0.2, 0.3]; + const result = normalizeNumberArray(values); + expect(magnitude(result)).toBeCloseTo(1, 5); + }); + + test("should not mutate original array", () => { + const values = [3, 4]; + const original = [...values]; + normalizeNumberArray(values); + expect(values).toEqual(original); + }); + }); +}); diff --git a/packages/util/README.md b/packages/util/README.md index d0ce9501..7acc1bc8 100644 --- a/packages/util/README.md +++ b/packages/util/README.md @@ -108,6 +108,37 @@ container.register("UserService", UserService); const userService = container.resolve("UserService"); ``` +### Input Resolver Registry + +The input resolver registry enables automatic resolution of string identifiers to object instances based on JSON Schema format annotations. This is used by the TaskRunner to resolve inputs like model names or repository IDs before task execution. + +```typescript +import { + registerInputResolver, + getInputResolvers, + INPUT_RESOLVERS, +} from "@workglow/util"; + +// Register a custom resolver for a format prefix +registerInputResolver("myformat", async (id, format, registry) => { + // id: the string value to resolve (e.g., "my-item-id") + // format: the full format string (e.g., "myformat:subtype") + // registry: ServiceRegistry for accessing other services + + const myRepo = registry.get(MY_REPOSITORY_TOKEN); + const item = await myRepo.findById(id); + if (!item) { + throw new Error(`Item "${id}" not found`); + } + return item; +}); + +// Get all registered resolvers +const resolvers = getInputResolvers(); +``` + +When a task input schema includes a property with `format: "myformat:subtype"`, and the input value is a string, the resolver is called automatically to convert it to the resolved instance. + ### Event System ```typescript @@ -260,6 +291,7 @@ type User = z.infer; - Decorator-based injection - Singleton and transient lifetimes - Circular dependency detection +- Input resolver registry for schema-based resolution ### Event System (`/events`) diff --git a/packages/util/package.json b/packages/util/package.json index 07d311be..bb841966 100644 --- a/packages/util/package.json +++ b/packages/util/package.json @@ -36,7 +36,7 @@ "access": "public" }, "dependencies": { - "json-schema-library": "^10.5.1", + "@sroussey/json-schema-library": "^10.5.3", "@sroussey/json-schema-to-ts": "3.1.3" } } \ No newline at end of file diff --git a/packages/util/src/common.ts b/packages/util/src/common.ts index fb650c38..e686798a 100644 --- a/packages/util/src/common.ts +++ b/packages/util/src/common.ts @@ -16,5 +16,9 @@ export * from "./utilities/BaseError"; export * from "./utilities/Misc"; export * from "./utilities/objectOfArraysAsArrayOfObjects"; export * from "./utilities/TypeUtilities"; +export * from "./vector/Tensor"; +export * from "./vector/TypedArray"; +export * from "./vector/VectorSimilarityUtils"; +export * from "./vector/VectorUtils"; export * from "./worker/WorkerManager"; export * from "./worker/WorkerServer"; diff --git a/packages/util/src/di/InputResolverRegistry.ts b/packages/util/src/di/InputResolverRegistry.ts new file mode 100644 index 00000000..064fb9f9 --- /dev/null +++ b/packages/util/src/di/InputResolverRegistry.ts @@ -0,0 +1,83 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { createServiceToken, globalServiceRegistry } from "./ServiceRegistry"; +import type { ServiceRegistry } from "./ServiceRegistry"; + +/** + * A resolver function that converts a string ID to an instance. + * Returns undefined if the resolver cannot handle this format. + * Throws an error if the ID is not found. + * + * @param id The string ID to resolve + * @param format The full format string (e.g., "model:TextEmbedding", "repository:tabular") + * @param registry The service registry to use for lookups + */ +export type InputResolverFn = ( + id: string, + format: string, + registry: ServiceRegistry +) => unknown | Promise; + +/** + * Service token for the input resolver registry. + * Maps format prefixes to resolver functions. + */ +export const INPUT_RESOLVERS = createServiceToken>( + "task.input.resolvers" +); + +// Register default factory if not already registered +if (!globalServiceRegistry.has(INPUT_RESOLVERS)) { + globalServiceRegistry.register( + INPUT_RESOLVERS, + (): Map => new Map(), + true + ); +} + +/** + * Gets the global input resolver registry + * @returns Map of format prefix to resolver function + */ +export function getInputResolvers(): Map { + return globalServiceRegistry.get(INPUT_RESOLVERS); +} + +/** + * Registers an input resolver for a format prefix. + * The resolver will be called for any format that starts with this prefix. + * + * @param formatPrefix The format prefix to match (e.g., "model", "repository") + * @param resolver The resolver function + * + * @example + * ```typescript + * // Register model resolver + * registerInputResolver("model", async (id, format, registry) => { + * const modelRepo = registry.get(MODEL_REPOSITORY); + * const model = await modelRepo.findByName(id); + * if (!model) throw new Error(`Model "${id}" not found`); + * return model; + * }); + * + * // Register repository resolver + * registerInputResolver("repository", (id, format, registry) => { + * const repoType = format.split(":")[1]; // "tabular", "vector", etc. + * if (repoType === "tabular") { + * const repos = registry.get(TABULAR_REPOSITORIES); + * const repo = repos.get(id); + * if (!repo) throw new Error(`Repository "${id}" not found`); + * return repo; + * } + * throw new Error(`Unknown repository type: ${repoType}`); + * }); + * ``` + */ +export function registerInputResolver(formatPrefix: string, resolver: InputResolverFn): void { + const resolvers = getInputResolvers(); + resolvers.set(formatPrefix, resolver); +} diff --git a/packages/util/src/di/ServiceRegistry.ts b/packages/util/src/di/ServiceRegistry.ts index 66aa94b6..eeafa954 100644 --- a/packages/util/src/di/ServiceRegistry.ts +++ b/packages/util/src/di/ServiceRegistry.ts @@ -27,7 +27,7 @@ export function createServiceToken(id: string): ServiceToken { * Service registry for managing and accessing services */ export class ServiceRegistry { - private container: Container; + public container: Container; /** * Create a new service registry diff --git a/packages/util/src/di/index.ts b/packages/util/src/di/index.ts index 4163b35f..a221c727 100644 --- a/packages/util/src/di/index.ts +++ b/packages/util/src/di/index.ts @@ -5,4 +5,5 @@ */ export * from "./Container"; +export * from "./InputResolverRegistry"; export * from "./ServiceRegistry"; diff --git a/packages/util/src/json-schema/SchemaValidation.ts b/packages/util/src/json-schema/SchemaValidation.ts index 4a6b2d5f..7e1585cc 100644 --- a/packages/util/src/json-schema/SchemaValidation.ts +++ b/packages/util/src/json-schema/SchemaValidation.ts @@ -4,5 +4,5 @@ * SPDX-License-Identifier: Apache-2.0 */ -export { compileSchema } from "json-schema-library"; -export type { SchemaNode } from "json-schema-library"; +export { compileSchema } from "@sroussey/json-schema-library"; +export type { SchemaNode } from "@sroussey/json-schema-library"; diff --git a/packages/util/src/vector/Tensor.ts b/packages/util/src/vector/Tensor.ts new file mode 100644 index 00000000..34897fa7 --- /dev/null +++ b/packages/util/src/vector/Tensor.ts @@ -0,0 +1,62 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { FromSchema } from "../json-schema/FromSchema"; +import { JsonSchema } from "../json-schema/JsonSchema"; +import { TypedArraySchema, TypedArraySchemaOptions } from "./TypedArray"; + +export const TensorType = { + FLOAT16: "float16", + FLOAT32: "float32", + FLOAT64: "float64", + INT8: "int8", + UINT8: "uint8", + INT16: "int16", + UINT16: "uint16", +} as const; + +export type TensorType = (typeof TensorType)[keyof typeof TensorType]; + +/** + * Tensor schema for representing tensors as arrays of numbers + * @param annotations - Additional annotations for the schema + * @returns The tensor schema + */ +export const TensorSchema = (annotations: Record = {}) => + ({ + type: "object", + properties: { + type: { + type: "string", + enum: Object.values(TensorType), + title: "Type", + description: "The type of the tensor", + }, + data: TypedArraySchema({ + title: "Data", + description: "The data of the tensor", + }), + shape: { + type: "array", + items: { type: "number" }, + title: "Shape", + description: "The shape of the tensor (dimensions)", + minItems: 1, + default: [1], + }, + normalized: { + type: "boolean", + title: "Normalized", + description: "Whether the tensor data is normalized", + default: false, + }, + }, + required: ["data"], + additionalProperties: false, + ...annotations, + }) as const satisfies JsonSchema; + +export type Tensor = FromSchema, TypedArraySchemaOptions>; diff --git a/packages/util/src/vector/TypedArray.ts b/packages/util/src/vector/TypedArray.ts new file mode 100644 index 00000000..61bc59b1 --- /dev/null +++ b/packages/util/src/vector/TypedArray.ts @@ -0,0 +1,101 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import { FromSchemaDefaultOptions, FromSchemaOptions } from "../json-schema/FromSchema"; +import { JsonSchema } from "../json-schema/JsonSchema"; + +/** + * Supported typed array types + * - Float16Array: 16-bit floating point (medium precision) + * - Float32Array: Standard 32-bit floating point (most common) + * - Float64Array: 64-bit floating point (high precision) + * - Int8Array: 8-bit signed integer (binary quantization) + * - Uint8Array: 8-bit unsigned integer (quantization) + * - Int16Array: 16-bit signed integer (quantization) + * - Uint16Array: 16-bit unsigned integer (quantization) + */ +export type TypedArray = + | Float32Array + | Float16Array + | Float64Array + | Int8Array + | Uint8Array + | Int16Array + | Uint16Array; + +// Type-only value for use in deserialize patterns +const TypedArrayType = null as any as TypedArray; + +const TypedArraySchemaOptions = { + ...FromSchemaDefaultOptions, + deserialize: [ + // { + // pattern: { + // type: "number"; + // "format": "Float64Array"; + // }; + // output: Float64Array; + // }, + // { + // pattern: { + // type: "number"; + // "format": "Float32Array"; + // }; + // output: Float32Array; + // }, + // { + // pattern: { + // type: "number"; + // "format": "Float16Array"; + // }; + // output: Float16Array; + // }, + // { + // pattern: { + // type: "number"; + // "format": "Int16Array"; + // }; + // output: Int16Array; + // }, + // { + // pattern: { + // type: "number"; + // "format": "Int8Array"; + // }; + // output: Int8Array; + // }, + // { + // pattern: { + // type: "number"; + // "format": "Uint8Array"; + // }; + // output: Uint8Array; + // }, + // { + // pattern: { + // type: "number"; + // "format": "Uint16Array"; + // }; + // output: Uint16Array; + // }, + { + pattern: { format: "TypedArray" }, + output: TypedArrayType, + }, + ], +} as const satisfies FromSchemaOptions; + +export type TypedArraySchemaOptions = typeof TypedArraySchemaOptions; + +export const TypedArraySchema = (annotations: Record = {}) => + ({ + type: "array", + items: { type: "number" }, + format: "TypedArray", + title: "Typed Array", + description: "A typed array (Float32Array, Int8Array, etc.) or regular number array", + ...annotations, + }) as const satisfies JsonSchema; diff --git a/packages/util/src/vector/VectorSimilarityUtils.ts b/packages/util/src/vector/VectorSimilarityUtils.ts new file mode 100644 index 00000000..74151950 --- /dev/null +++ b/packages/util/src/vector/VectorSimilarityUtils.ts @@ -0,0 +1,92 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { TypedArray } from "./TypedArray"; + +/** + * Calculates cosine similarity between two vectors + * Returns a value between -1 and 1, where 1 means identical direction + */ +export function cosineSimilarity(a: TypedArray, b: TypedArray): number { + if (a.length !== b.length) { + throw new Error("Vectors must have the same length"); + } + let dotProduct = 0; + let normA = 0; + let normB = 0; + for (let i = 0; i < a.length; i++) { + dotProduct += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + const denominator = Math.sqrt(normA) * Math.sqrt(normB); + if (denominator === 0) { + return 0; + } + return dotProduct / denominator; +} + +/** + * Calculates Jaccard similarity between two vectors + * Uses the formula: sum(min(a[i], b[i])) / sum(max(a[i], b[i])) + * Returns a value between 0 and 1 + * For negative values, normalizes by finding the global min and shifting to non-negative range + */ +export function jaccardSimilarity(a: TypedArray, b: TypedArray): number { + if (a.length !== b.length) { + throw new Error("Vectors must have the same length"); + } + + // Find global min across both vectors to handle negative values + let globalMin = a[0]; + for (let i = 0; i < a.length; i++) { + globalMin = Math.min(globalMin, a[i], b[i]); + } + + // Shift values to non-negative range if needed + const shift = globalMin < 0 ? -globalMin : 0; + + let minSum = 0; + let maxSum = 0; + + for (let i = 0; i < a.length; i++) { + const shiftedA = a[i] + shift; + const shiftedB = b[i] + shift; + minSum += Math.min(shiftedA, shiftedB); + maxSum += Math.max(shiftedA, shiftedB); + } + + return maxSum === 0 ? 0 : minSum / maxSum; +} + +/** + * Calculates Hamming distance between two vectors (normalized) + * Counts the number of positions where vectors differ + * Returns a value between 0 and 1 (0 = identical, 1 = completely different) + */ +export function hammingDistance(a: TypedArray, b: TypedArray): number { + if (a.length !== b.length) { + throw new Error("Vectors must have the same length"); + } + + let differences = 0; + + for (let i = 0; i < a.length; i++) { + if (a[i] !== b[i]) { + differences++; + } + } + + return differences / a.length; +} + +/** + * Calculates Hamming similarity (inverse of distance) + * Returns a value between 0 and 1 (1 = identical, 0 = completely different) + */ +export function hammingSimilarity(a: TypedArray, b: TypedArray): number { + return 1 - hammingDistance(a, b); +} diff --git a/packages/util/src/vector/VectorUtils.ts b/packages/util/src/vector/VectorUtils.ts new file mode 100644 index 00000000..e7044415 --- /dev/null +++ b/packages/util/src/vector/VectorUtils.ts @@ -0,0 +1,95 @@ +/** + * @license + * Copyright 2025 Steven Roussey + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { TypedArray } from "./TypedArray"; + +/** + * Calculates the magnitude (L2 norm) of a vector + */ +export function magnitude(arr: TypedArray | number[]): number { + // @ts-ignore - Vector reduce works but TS doesn't recognize it + return Math.sqrt(arr.reduce((acc, val) => acc + val * val, 0)); +} + +/** + * Calculates the inner (dot) product of two vectors + */ +export function inner(arr1: TypedArray, arr2: TypedArray): number { + if (arr1.length !== arr2.length) { + throw new Error("Vectors must have the same length to compute inner product."); + } + // @ts-ignore - Vector reduce works but TS doesn't recognize it + return arr1.reduce((acc, val, i) => acc + val * arr2[i], 0); +} + +/** + * Normalizes a vector to unit length (L2 normalization) + * + * @param vector - The vector to normalize + * @param throwOnZero - If true, throws an error for zero vectors. If false, returns the original vector. + * @returns Normalized vector with the same type as input + */ +export function normalize(vector: TypedArray, throwOnZero = true, float32 = false): TypedArray { + const mag = magnitude(vector); + + if (mag === 0) { + if (throwOnZero) { + throw new Error("Cannot normalize a zero vector."); + } + return vector; + } + + const normalized = Array.from(vector).map((val) => Number(val) / mag); + + if (float32) { + return new Float32Array(normalized); + } + + // Preserve the original Vector type + if (vector instanceof Float64Array) { + return new Float64Array(normalized); + } + if (vector instanceof Float16Array) { + return new Float16Array(normalized); + } + if (vector instanceof Float32Array) { + return new Float32Array(normalized); + } + if (vector instanceof Int8Array) { + return new Int8Array(normalized); + } + if (vector instanceof Uint8Array) { + return new Uint8Array(normalized); + } + if (vector instanceof Int16Array) { + return new Int16Array(normalized); + } + if (vector instanceof Uint16Array) { + return new Uint16Array(normalized); + } + // For other integer arrays, use Float32Array since normalization produces floats + return new Float32Array(normalized); +} + +/** + * Normalizes an array of numbers to unit length (L2 normalization) + * + * @param values - The array of numbers to normalize + * @param throwOnZero - If true, throws an error for zero vectors. If false, returns the original array. + * @returns Normalized array of numbers + */ +export function normalizeNumberArray(values: number[], throwOnZero = false): number[] { + const norm = magnitude(values); + + if (norm === 0) { + if (throwOnZero) { + throw new Error("Cannot normalize a zero vector."); + } + return values; + } + + return values.map((v) => v / norm); +}